lucasvl.nl/lib/python3.11/site-packages/werkzeug/middleware/lint.py
2024-03-04 21:31:20 -06:00

421 lines
14 KiB
Python

"""
WSGI Protocol Linter
====================
This module provides a middleware that performs sanity checks on the
behavior of the WSGI server and application. It checks that the
:pep:`3333` WSGI spec is properly implemented. It also warns on some
common HTTP errors such as non-empty responses for 304 status codes.
.. autoclass:: LintMiddleware
:copyright: 2007 Pallets
:license: BSD-3-Clause
"""
from __future__ import annotations
import typing as t
from types import TracebackType
from urllib.parse import urlparse
from warnings import warn
from ..datastructures import Headers
from ..http import is_entity_header
from ..wsgi import FileWrapper
if t.TYPE_CHECKING:
from _typeshed.wsgi import StartResponse
from _typeshed.wsgi import WSGIApplication
from _typeshed.wsgi import WSGIEnvironment
class WSGIWarning(Warning):
"""Warning class for WSGI warnings."""
class HTTPWarning(Warning):
"""Warning class for HTTP warnings."""
def check_type(context: str, obj: object, need: t.Type = str) -> None:
if type(obj) is not need:
warn(
f"{context!r} requires {need.__name__!r}, got {type(obj).__name__!r}.",
WSGIWarning,
stacklevel=3,
)
class InputStream:
def __init__(self, stream: t.IO[bytes]) -> None:
self._stream = stream
def read(self, *args: t.Any) -> bytes:
if len(args) == 0:
warn(
"WSGI does not guarantee an EOF marker on the input stream, thus making"
" calls to 'wsgi.input.read()' unsafe. Conforming servers may never"
" return from this call.",
WSGIWarning,
stacklevel=2,
)
elif len(args) != 1:
warn(
"Too many parameters passed to 'wsgi.input.read()'.",
WSGIWarning,
stacklevel=2,
)
return self._stream.read(*args)
def readline(self, *args: t.Any) -> bytes:
if len(args) == 0:
warn(
"Calls to 'wsgi.input.readline()' without arguments are unsafe. Use"
" 'wsgi.input.read()' instead.",
WSGIWarning,
stacklevel=2,
)
elif len(args) == 1:
warn(
"'wsgi.input.readline()' was called with a size hint. WSGI does not"
" support this, although it's available on all major servers.",
WSGIWarning,
stacklevel=2,
)
else:
raise TypeError("Too many arguments passed to 'wsgi.input.readline()'.")
return self._stream.readline(*args)
def __iter__(self) -> t.Iterator[bytes]:
try:
return iter(self._stream)
except TypeError:
warn("'wsgi.input' is not iterable.", WSGIWarning, stacklevel=2)
return iter(())
def close(self) -> None:
warn("The application closed the input stream!", WSGIWarning, stacklevel=2)
self._stream.close()
class ErrorStream:
def __init__(self, stream: t.IO[str]) -> None:
self._stream = stream
def write(self, s: str) -> None:
check_type("wsgi.error.write()", s, str)
self._stream.write(s)
def flush(self) -> None:
self._stream.flush()
def writelines(self, seq: t.Iterable[str]) -> None:
for line in seq:
self.write(line)
def close(self) -> None:
warn("The application closed the error stream!", WSGIWarning, stacklevel=2)
self._stream.close()
class GuardedWrite:
def __init__(self, write: t.Callable[[bytes], object], chunks: list[int]) -> None:
self._write = write
self._chunks = chunks
def __call__(self, s: bytes) -> None:
check_type("write()", s, bytes)
self._write(s)
self._chunks.append(len(s))
class GuardedIterator:
def __init__(
self,
iterator: t.Iterable[bytes],
headers_set: tuple[int, Headers],
chunks: list[int],
) -> None:
self._iterator = iterator
self._next = iter(iterator).__next__
self.closed = False
self.headers_set = headers_set
self.chunks = chunks
def __iter__(self) -> GuardedIterator:
return self
def __next__(self) -> bytes:
if self.closed:
warn("Iterated over closed 'app_iter'.", WSGIWarning, stacklevel=2)
rv = self._next()
if not self.headers_set:
warn(
"The application returned before it started the response.",
WSGIWarning,
stacklevel=2,
)
check_type("application iterator items", rv, bytes)
self.chunks.append(len(rv))
return rv
def close(self) -> None:
self.closed = True
if hasattr(self._iterator, "close"):
self._iterator.close()
if self.headers_set:
status_code, headers = self.headers_set
bytes_sent = sum(self.chunks)
content_length = headers.get("content-length", type=int)
if status_code == 304:
for key, _value in headers:
key = key.lower()
if key not in ("expires", "content-location") and is_entity_header(
key
):
warn(
f"Entity header {key!r} found in 304 response.", HTTPWarning
)
if bytes_sent:
warn("304 responses must not have a body.", HTTPWarning)
elif 100 <= status_code < 200 or status_code == 204:
if content_length != 0:
warn(
f"{status_code} responses must have an empty content length.",
HTTPWarning,
)
if bytes_sent:
warn(f"{status_code} responses must not have a body.", HTTPWarning)
elif content_length is not None and content_length != bytes_sent:
warn(
"Content-Length and the number of bytes sent to the"
" client do not match.",
WSGIWarning,
)
def __del__(self) -> None:
if not self.closed:
try:
warn(
"Iterator was garbage collected before it was closed.", WSGIWarning
)
except Exception:
pass
class LintMiddleware:
"""Warns about common errors in the WSGI and HTTP behavior of the
server and wrapped application. Some of the issues it checks are:
- invalid status codes
- non-bytes sent to the WSGI server
- strings returned from the WSGI application
- non-empty conditional responses
- unquoted etags
- relative URLs in the Location header
- unsafe calls to wsgi.input
- unclosed iterators
Error information is emitted using the :mod:`warnings` module.
:param app: The WSGI application to wrap.
.. code-block:: python
from werkzeug.middleware.lint import LintMiddleware
app = LintMiddleware(app)
"""
def __init__(self, app: WSGIApplication) -> None:
self.app = app
def check_environ(self, environ: WSGIEnvironment) -> None:
if type(environ) is not dict:
warn(
"WSGI environment is not a standard Python dict.",
WSGIWarning,
stacklevel=4,
)
for key in (
"REQUEST_METHOD",
"SERVER_NAME",
"SERVER_PORT",
"wsgi.version",
"wsgi.input",
"wsgi.errors",
"wsgi.multithread",
"wsgi.multiprocess",
"wsgi.run_once",
):
if key not in environ:
warn(
f"Required environment key {key!r} not found",
WSGIWarning,
stacklevel=3,
)
if environ["wsgi.version"] != (1, 0):
warn("Environ is not a WSGI 1.0 environ.", WSGIWarning, stacklevel=3)
script_name = environ.get("SCRIPT_NAME", "")
path_info = environ.get("PATH_INFO", "")
if script_name and script_name[0] != "/":
warn(
f"'SCRIPT_NAME' does not start with a slash: {script_name!r}",
WSGIWarning,
stacklevel=3,
)
if path_info and path_info[0] != "/":
warn(
f"'PATH_INFO' does not start with a slash: {path_info!r}",
WSGIWarning,
stacklevel=3,
)
def check_start_response(
self,
status: str,
headers: list[tuple[str, str]],
exc_info: None | (tuple[type[BaseException], BaseException, TracebackType]),
) -> tuple[int, Headers]:
check_type("status", status, str)
status_code_str = status.split(None, 1)[0]
if len(status_code_str) != 3 or not status_code_str.isdecimal():
warn("Status code must be three digits.", WSGIWarning, stacklevel=3)
if len(status) < 4 or status[3] != " ":
warn(
f"Invalid value for status {status!r}. Valid status strings are three"
" digits, a space and a status explanation.",
WSGIWarning,
stacklevel=3,
)
status_code = int(status_code_str)
if status_code < 100:
warn("Status code < 100 detected.", WSGIWarning, stacklevel=3)
if type(headers) is not list:
warn("Header list is not a list.", WSGIWarning, stacklevel=3)
for item in headers:
if type(item) is not tuple or len(item) != 2:
warn("Header items must be 2-item tuples.", WSGIWarning, stacklevel=3)
name, value = item
if type(name) is not str or type(value) is not str:
warn(
"Header keys and values must be strings.", WSGIWarning, stacklevel=3
)
if name.lower() == "status":
warn(
"The status header is not supported due to"
" conflicts with the CGI spec.",
WSGIWarning,
stacklevel=3,
)
if exc_info is not None and not isinstance(exc_info, tuple):
warn("Invalid value for exc_info.", WSGIWarning, stacklevel=3)
headers = Headers(headers)
self.check_headers(headers)
return status_code, headers
def check_headers(self, headers: Headers) -> None:
etag = headers.get("etag")
if etag is not None:
if etag.startswith(("W/", "w/")):
if etag.startswith("w/"):
warn(
"Weak etag indicator should be upper case.",
HTTPWarning,
stacklevel=4,
)
etag = etag[2:]
if not (etag[:1] == etag[-1:] == '"'):
warn("Unquoted etag emitted.", HTTPWarning, stacklevel=4)
location = headers.get("location")
if location is not None:
if not urlparse(location).netloc:
warn(
"Absolute URLs required for location header.",
HTTPWarning,
stacklevel=4,
)
def check_iterator(self, app_iter: t.Iterable[bytes]) -> None:
if isinstance(app_iter, str):
warn(
"The application returned a string. The response will send one"
" character at a time to the client, which will kill performance."
" Return a list or iterable instead.",
WSGIWarning,
stacklevel=3,
)
def __call__(self, *args: t.Any, **kwargs: t.Any) -> t.Iterable[bytes]:
if len(args) != 2:
warn("A WSGI app takes two arguments.", WSGIWarning, stacklevel=2)
if kwargs:
warn(
"A WSGI app does not take keyword arguments.", WSGIWarning, stacklevel=2
)
environ: WSGIEnvironment = args[0]
start_response: StartResponse = args[1]
self.check_environ(environ)
environ["wsgi.input"] = InputStream(environ["wsgi.input"])
environ["wsgi.errors"] = ErrorStream(environ["wsgi.errors"])
# Hook our own file wrapper in so that applications will always
# iterate to the end and we can check the content length.
environ["wsgi.file_wrapper"] = FileWrapper
headers_set: list[t.Any] = []
chunks: list[int] = []
def checking_start_response(
*args: t.Any, **kwargs: t.Any
) -> t.Callable[[bytes], None]:
if len(args) not in {2, 3}:
warn(
f"Invalid number of arguments: {len(args)}, expected 2 or 3.",
WSGIWarning,
stacklevel=2,
)
if kwargs:
warn("'start_response' does not take keyword arguments.", WSGIWarning)
status: str = args[0]
headers: list[tuple[str, str]] = args[1]
exc_info: None | (
tuple[type[BaseException], BaseException, TracebackType]
) = (args[2] if len(args) == 3 else None)
headers_set[:] = self.check_start_response(status, headers, exc_info)
return GuardedWrite(start_response(status, headers, exc_info), chunks)
app_iter = self.app(environ, t.cast("StartResponse", checking_start_response))
self.check_iterator(app_iter)
return GuardedIterator(
app_iter, t.cast(t.Tuple[int, Headers], headers_set), chunks
)