Update webservice with cherrypy

Fix playback issues that was causing Kodi to hang up
This commit is contained in:
angelblue05 2019-01-30 06:43:14 -06:00
parent b2bc90cb06
commit 158a736360
164 changed files with 42855 additions and 174 deletions

View file

@ -0,0 +1,6 @@
"""High-performance, pure-Python HTTP server used by CherryPy."""
from __future__ import absolute_import, division, print_function
__metaclass__ = type
__version__ = '6.4.0'

View file

@ -0,0 +1,6 @@
"""Stub for accessing the Cheroot CLI tool."""
from .cli import main
if __name__ == '__main__':
main()

View file

@ -0,0 +1,66 @@
"""Compatibility code for using Cheroot with various versions of Python."""
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import re
import six
if six.PY3:
def ntob(n, encoding='ISO-8859-1'):
"""Return the native string as bytes in the given encoding."""
assert_native(n)
# In Python 3, the native string type is unicode
return n.encode(encoding)
def ntou(n, encoding='ISO-8859-1'):
"""Return the native string as unicode with the given encoding."""
assert_native(n)
# In Python 3, the native string type is unicode
return n
def bton(b, encoding='ISO-8859-1'):
"""Return the byte string as native string in the given encoding."""
return b.decode(encoding)
else:
# Python 2
def ntob(n, encoding='ISO-8859-1'):
"""Return the native string as bytes in the given encoding."""
assert_native(n)
# In Python 2, the native string type is bytes. Assume it's already
# in the given encoding, which for ISO-8859-1 is almost always what
# was intended.
return n
def ntou(n, encoding='ISO-8859-1'):
"""Return the native string as unicode with the given encoding."""
assert_native(n)
# In Python 2, the native string type is bytes.
# First, check for the special encoding 'escape'. The test suite uses
# this to signal that it wants to pass a string with embedded \uXXXX
# escapes, but without having to prefix it with u'' for Python 2,
# but no prefix for Python 3.
if encoding == 'escape':
return six.u(
re.sub(r'\\u([0-9a-zA-Z]{4})',
lambda m: six.unichr(int(m.group(1), 16)),
n.decode('ISO-8859-1')))
# Assume it's already in the given encoding, which for ISO-8859-1
# is almost always what was intended.
return n.decode(encoding)
def bton(b, encoding='ISO-8859-1'):
"""Return the byte string as native string in the given encoding."""
return b
def assert_native(n):
"""Check whether the input is of nativ ``str`` type.
Raises:
TypeError: in case of failed check
"""
if not isinstance(n, str):
raise TypeError('n must be a native str (got %s)' % type(n).__name__)

233
libraries/cheroot/cli.py Normal file
View file

@ -0,0 +1,233 @@
"""Command line tool for starting a Cheroot WSGI/HTTP server instance.
Basic usage::
# Start a server on 127.0.0.1:8000 with the default settings
# for the WSGI app myapp/wsgi.py:application()
cheroot myapp.wsgi
# Start a server on 0.0.0.0:9000 with 8 threads
# for the WSGI app myapp/wsgi.py:main_app()
cheroot myapp.wsgi:main_app --bind 0.0.0.0:9000 --threads 8
# Start a server for the cheroot.server.Gateway subclass
# myapp/gateway.py:HTTPGateway
cheroot myapp.gateway:HTTPGateway
# Start a server on the UNIX socket /var/spool/myapp.sock
cheroot myapp.wsgi --bind /var/spool/myapp.sock
# Start a server on the abstract UNIX socket CherootServer
cheroot myapp.wsgi --bind @CherootServer
"""
import argparse
from importlib import import_module
import os
import sys
import contextlib
import six
from . import server
from . import wsgi
__metaclass__ = type
class BindLocation:
"""A class for storing the bind location for a Cheroot instance."""
class TCPSocket(BindLocation):
"""TCPSocket."""
def __init__(self, address, port):
"""Initialize.
Args:
address (str): Host name or IP address
port (int): TCP port number
"""
self.bind_addr = address, port
class UnixSocket(BindLocation):
"""UnixSocket."""
def __init__(self, path):
"""Initialize."""
self.bind_addr = path
class AbstractSocket(BindLocation):
"""AbstractSocket."""
def __init__(self, addr):
"""Initialize."""
self.bind_addr = '\0{}'.format(self.abstract_socket)
class Application:
"""Application."""
@classmethod
def resolve(cls, full_path):
"""Read WSGI app/Gateway path string and import application module."""
mod_path, _, app_path = full_path.partition(':')
app = getattr(import_module(mod_path), app_path or 'application')
with contextlib.suppress(TypeError):
if issubclass(app, server.Gateway):
return GatewayYo(app)
return cls(app)
def __init__(self, wsgi_app):
"""Initialize."""
if not callable(wsgi_app):
raise TypeError(
'Application must be a callable object or '
'cheroot.server.Gateway subclass'
)
self.wsgi_app = wsgi_app
def server_args(self, parsed_args):
"""Return keyword args for Server class."""
args = {
arg: value
for arg, value in vars(parsed_args).items()
if not arg.startswith('_') and value is not None
}
args.update(vars(self))
return args
def server(self, parsed_args):
"""Server."""
return wsgi.Server(**self.server_args(parsed_args))
class GatewayYo:
"""Gateway."""
def __init__(self, gateway):
"""Init."""
self.gateway = gateway
def server(self, parsed_args):
"""Server."""
server_args = vars(self)
server_args['bind_addr'] = parsed_args['bind_addr']
if parsed_args.max is not None:
server_args['maxthreads'] = parsed_args.max
if parsed_args.numthreads is not None:
server_args['minthreads'] = parsed_args.numthreads
return server.HTTPServer(**server_args)
def parse_wsgi_bind_location(bind_addr_string):
"""Convert bind address string to a BindLocation."""
# try and match for an IP/hostname and port
match = six.moves.urllib.parse.urlparse('//{}'.format(bind_addr_string))
try:
addr = match.hostname
port = match.port
if addr is not None or port is not None:
return TCPSocket(addr, port)
except ValueError:
pass
# else, assume a UNIX socket path
# if the string begins with an @ symbol, use an abstract socket
if bind_addr_string.startswith('@'):
return AbstractSocket(bind_addr_string[1:])
return UnixSocket(path=bind_addr_string)
def parse_wsgi_bind_addr(bind_addr_string):
"""Convert bind address string to bind address parameter."""
return parse_wsgi_bind_location(bind_addr_string).bind_addr
_arg_spec = {
'_wsgi_app': dict(
metavar='APP_MODULE',
type=Application.resolve,
help='WSGI application callable or cheroot.server.Gateway subclass',
),
'--bind': dict(
metavar='ADDRESS',
dest='bind_addr',
type=parse_wsgi_bind_addr,
default='[::1]:8000',
help='Network interface to listen on (default: [::1]:8000)',
),
'--chdir': dict(
metavar='PATH',
type=os.chdir,
help='Set the working directory',
),
'--server-name': dict(
dest='server_name',
type=str,
help='Web server name to be advertised via Server HTTP header',
),
'--threads': dict(
metavar='INT',
dest='numthreads',
type=int,
help='Minimum number of worker threads',
),
'--max-threads': dict(
metavar='INT',
dest='max',
type=int,
help='Maximum number of worker threads',
),
'--timeout': dict(
metavar='INT',
dest='timeout',
type=int,
help='Timeout in seconds for accepted connections',
),
'--shutdown-timeout': dict(
metavar='INT',
dest='shutdown_timeout',
type=int,
help='Time in seconds to wait for worker threads to cleanly exit',
),
'--request-queue-size': dict(
metavar='INT',
dest='request_queue_size',
type=int,
help='Maximum number of queued connections',
),
'--accepted-queue-size': dict(
metavar='INT',
dest='accepted_queue_size',
type=int,
help='Maximum number of active requests in queue',
),
'--accepted-queue-timeout': dict(
metavar='INT',
dest='accepted_queue_timeout',
type=int,
help='Timeout in seconds for putting requests into queue',
),
}
def main():
"""Create a new Cheroot instance with arguments from the command line."""
parser = argparse.ArgumentParser(
description='Start an instance of the Cheroot WSGI/HTTP server.')
for arg, spec in _arg_spec.items():
parser.add_argument(arg, **spec)
raw_args = parser.parse_args()
# ensure cwd in sys.path
'' in sys.path or sys.path.insert(0, '')
# create a server based on the arguments provided
raw_args._wsgi_app.server(raw_args).safe_start()

View file

@ -0,0 +1,58 @@
"""Collection of exceptions raised and/or processed by Cheroot."""
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import errno
import sys
class MaxSizeExceeded(Exception):
"""Exception raised when a client sends more data then acceptable within limit.
Depends on ``request.body.maxbytes`` config option if used within CherryPy
"""
class NoSSLError(Exception):
"""Exception raised when a client speaks HTTP to an HTTPS socket."""
class FatalSSLAlert(Exception):
"""Exception raised when the SSL implementation signals a fatal alert."""
def plat_specific_errors(*errnames):
"""Return error numbers for all errors in errnames on this platform.
The 'errno' module contains different global constants depending on
the specific platform (OS). This function will return the list of
numeric values for a given list of potential names.
"""
errno_names = dir(errno)
nums = [getattr(errno, k) for k in errnames if k in errno_names]
# de-dupe the list
return list(dict.fromkeys(nums).keys())
socket_error_eintr = plat_specific_errors('EINTR', 'WSAEINTR')
socket_errors_to_ignore = plat_specific_errors(
'EPIPE',
'EBADF', 'WSAEBADF',
'ENOTSOCK', 'WSAENOTSOCK',
'ETIMEDOUT', 'WSAETIMEDOUT',
'ECONNREFUSED', 'WSAECONNREFUSED',
'ECONNRESET', 'WSAECONNRESET',
'ECONNABORTED', 'WSAECONNABORTED',
'ENETRESET', 'WSAENETRESET',
'EHOSTDOWN', 'EHOSTUNREACH',
)
socket_errors_to_ignore.append('timed out')
socket_errors_to_ignore.append('The read operation timed out')
socket_errors_nonblocking = plat_specific_errors(
'EAGAIN', 'EWOULDBLOCK', 'WSAEWOULDBLOCK')
if sys.platform == 'darwin':
socket_errors_to_ignore.extend(plat_specific_errors('EPROTOTYPE'))
socket_errors_nonblocking.extend(plat_specific_errors('EPROTOTYPE'))

View file

@ -0,0 +1,387 @@
"""Socket file object."""
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import socket
try:
# prefer slower Python-based io module
import _pyio as io
except ImportError:
# Python 2.6
import io
import six
from . import errors
class BufferedWriter(io.BufferedWriter):
"""Faux file object attached to a socket object."""
def write(self, b):
"""Write bytes to buffer."""
self._checkClosed()
if isinstance(b, str):
raise TypeError("can't write str to binary stream")
with self._write_lock:
self._write_buf.extend(b)
self._flush_unlocked()
return len(b)
def _flush_unlocked(self):
self._checkClosed('flush of closed file')
while self._write_buf:
try:
# ssl sockets only except 'bytes', not bytearrays
# so perhaps we should conditionally wrap this for perf?
n = self.raw.write(bytes(self._write_buf))
except io.BlockingIOError as e:
n = e.characters_written
del self._write_buf[:n]
def MakeFile_PY3(sock, mode='r', bufsize=io.DEFAULT_BUFFER_SIZE):
"""File object attached to a socket object."""
if 'r' in mode:
return io.BufferedReader(socket.SocketIO(sock, mode), bufsize)
else:
return BufferedWriter(socket.SocketIO(sock, mode), bufsize)
class MakeFile_PY2(getattr(socket, '_fileobject', object)):
"""Faux file object attached to a socket object."""
def __init__(self, *args, **kwargs):
"""Initialize faux file object."""
self.bytes_read = 0
self.bytes_written = 0
socket._fileobject.__init__(self, *args, **kwargs)
def write(self, data):
"""Sendall for non-blocking sockets."""
while data:
try:
bytes_sent = self.send(data)
data = data[bytes_sent:]
except socket.error as e:
if e.args[0] not in errors.socket_errors_nonblocking:
raise
def send(self, data):
"""Send some part of message to the socket."""
bytes_sent = self._sock.send(data)
self.bytes_written += bytes_sent
return bytes_sent
def flush(self):
"""Write all data from buffer to socket and reset write buffer."""
if self._wbuf:
buffer = ''.join(self._wbuf)
self._wbuf = []
self.write(buffer)
def recv(self, size):
"""Receive message of a size from the socket."""
while True:
try:
data = self._sock.recv(size)
self.bytes_read += len(data)
return data
except socket.error as e:
what = (
e.args[0] not in errors.socket_errors_nonblocking
and e.args[0] not in errors.socket_error_eintr
)
if what:
raise
class FauxSocket:
"""Faux socket with the minimal interface required by pypy."""
def _reuse(self):
pass
_fileobject_uses_str_type = six.PY2 and isinstance(
socket._fileobject(FauxSocket())._rbuf, six.string_types)
# FauxSocket is no longer needed
del FauxSocket
if not _fileobject_uses_str_type:
def read(self, size=-1):
"""Read data from the socket to buffer."""
# Use max, disallow tiny reads in a loop as they are very
# inefficient.
# We never leave read() with any leftover data from a new recv()
# call in our internal buffer.
rbufsize = max(self._rbufsize, self.default_bufsize)
# Our use of StringIO rather than lists of string objects returned
# by recv() minimizes memory usage and fragmentation that occurs
# when rbufsize is large compared to the typical return value of
# recv().
buf = self._rbuf
buf.seek(0, 2) # seek end
if size < 0:
# Read until EOF
# reset _rbuf. we consume it via buf.
self._rbuf = io.BytesIO()
while True:
data = self.recv(rbufsize)
if not data:
break
buf.write(data)
return buf.getvalue()
else:
# Read until size bytes or EOF seen, whichever comes first
buf_len = buf.tell()
if buf_len >= size:
# Already have size bytes in our buffer? Extract and
# return.
buf.seek(0)
rv = buf.read(size)
self._rbuf = io.BytesIO()
self._rbuf.write(buf.read())
return rv
# reset _rbuf. we consume it via buf.
self._rbuf = io.BytesIO()
while True:
left = size - buf_len
# recv() will malloc the amount of memory given as its
# parameter even though it often returns much less data
# than that. The returned data string is short lived
# as we copy it into a StringIO and free it. This avoids
# fragmentation issues on many platforms.
data = self.recv(left)
if not data:
break
n = len(data)
if n == size and not buf_len:
# Shortcut. Avoid buffer data copies when:
# - We have no data in our buffer.
# AND
# - Our call to recv returned exactly the
# number of bytes we were asked to read.
return data
if n == left:
buf.write(data)
del data # explicit free
break
assert n <= left, 'recv(%d) returned %d bytes' % (left, n)
buf.write(data)
buf_len += n
del data # explicit free
# assert buf_len == buf.tell()
return buf.getvalue()
def readline(self, size=-1):
"""Read line from the socket to buffer."""
buf = self._rbuf
buf.seek(0, 2) # seek end
if buf.tell() > 0:
# check if we already have it in our buffer
buf.seek(0)
bline = buf.readline(size)
if bline.endswith('\n') or len(bline) == size:
self._rbuf = io.BytesIO()
self._rbuf.write(buf.read())
return bline
del bline
if size < 0:
# Read until \n or EOF, whichever comes first
if self._rbufsize <= 1:
# Speed up unbuffered case
buf.seek(0)
buffers = [buf.read()]
# reset _rbuf. we consume it via buf.
self._rbuf = io.BytesIO()
data = None
recv = self.recv
while data != '\n':
data = recv(1)
if not data:
break
buffers.append(data)
return ''.join(buffers)
buf.seek(0, 2) # seek end
# reset _rbuf. we consume it via buf.
self._rbuf = io.BytesIO()
while True:
data = self.recv(self._rbufsize)
if not data:
break
nl = data.find('\n')
if nl >= 0:
nl += 1
buf.write(data[:nl])
self._rbuf.write(data[nl:])
del data
break
buf.write(data)
return buf.getvalue()
else:
# Read until size bytes or \n or EOF seen, whichever comes
# first
buf.seek(0, 2) # seek end
buf_len = buf.tell()
if buf_len >= size:
buf.seek(0)
rv = buf.read(size)
self._rbuf = io.BytesIO()
self._rbuf.write(buf.read())
return rv
# reset _rbuf. we consume it via buf.
self._rbuf = io.BytesIO()
while True:
data = self.recv(self._rbufsize)
if not data:
break
left = size - buf_len
# did we just receive a newline?
nl = data.find('\n', 0, left)
if nl >= 0:
nl += 1
# save the excess data to _rbuf
self._rbuf.write(data[nl:])
if buf_len:
buf.write(data[:nl])
break
else:
# Shortcut. Avoid data copy through buf when
# returning a substring of our first recv().
return data[:nl]
n = len(data)
if n == size and not buf_len:
# Shortcut. Avoid data copy through buf when
# returning exactly all of our first recv().
return data
if n >= left:
buf.write(data[:left])
self._rbuf.write(data[left:])
break
buf.write(data)
buf_len += n
# assert buf_len == buf.tell()
return buf.getvalue()
else:
def read(self, size=-1):
"""Read data from the socket to buffer."""
if size < 0:
# Read until EOF
buffers = [self._rbuf]
self._rbuf = ''
if self._rbufsize <= 1:
recv_size = self.default_bufsize
else:
recv_size = self._rbufsize
while True:
data = self.recv(recv_size)
if not data:
break
buffers.append(data)
return ''.join(buffers)
else:
# Read until size bytes or EOF seen, whichever comes first
data = self._rbuf
buf_len = len(data)
if buf_len >= size:
self._rbuf = data[size:]
return data[:size]
buffers = []
if data:
buffers.append(data)
self._rbuf = ''
while True:
left = size - buf_len
recv_size = max(self._rbufsize, left)
data = self.recv(recv_size)
if not data:
break
buffers.append(data)
n = len(data)
if n >= left:
self._rbuf = data[left:]
buffers[-1] = data[:left]
break
buf_len += n
return ''.join(buffers)
def readline(self, size=-1):
"""Read line from the socket to buffer."""
data = self._rbuf
if size < 0:
# Read until \n or EOF, whichever comes first
if self._rbufsize <= 1:
# Speed up unbuffered case
assert data == ''
buffers = []
while data != '\n':
data = self.recv(1)
if not data:
break
buffers.append(data)
return ''.join(buffers)
nl = data.find('\n')
if nl >= 0:
nl += 1
self._rbuf = data[nl:]
return data[:nl]
buffers = []
if data:
buffers.append(data)
self._rbuf = ''
while True:
data = self.recv(self._rbufsize)
if not data:
break
buffers.append(data)
nl = data.find('\n')
if nl >= 0:
nl += 1
self._rbuf = data[nl:]
buffers[-1] = data[:nl]
break
return ''.join(buffers)
else:
# Read until size bytes or \n or EOF seen, whichever comes
# first
nl = data.find('\n', 0, size)
if nl >= 0:
nl += 1
self._rbuf = data[nl:]
return data[:nl]
buf_len = len(data)
if buf_len >= size:
self._rbuf = data[size:]
return data[:size]
buffers = []
if data:
buffers.append(data)
self._rbuf = ''
while True:
data = self.recv(self._rbufsize)
if not data:
break
buffers.append(data)
left = size - buf_len
nl = data.find('\n', 0, left)
if nl >= 0:
nl += 1
self._rbuf = data[nl:]
buffers[-1] = data[:nl]
break
n = len(data)
if n >= left:
self._rbuf = data[left:]
buffers[-1] = data[:left]
break
buf_len += n
return ''.join(buffers)
MakeFile = MakeFile_PY2 if six.PY2 else MakeFile_PY3

2001
libraries/cheroot/server.py Normal file

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,51 @@
"""Implementation of the SSL adapter base interface."""
from __future__ import absolute_import, division, print_function
__metaclass__ = type
from abc import ABCMeta, abstractmethod
from six import add_metaclass
@add_metaclass(ABCMeta)
class Adapter:
"""Base class for SSL driver library adapters.
Required methods:
* ``wrap(sock) -> (wrapped socket, ssl environ dict)``
* ``makefile(sock, mode='r', bufsize=DEFAULT_BUFFER_SIZE) ->
socket file object``
"""
@abstractmethod
def __init__(
self, certificate, private_key, certificate_chain=None,
ciphers=None):
"""Set up certificates, private key ciphers and reset context."""
self.certificate = certificate
self.private_key = private_key
self.certificate_chain = certificate_chain
self.ciphers = ciphers
self.context = None
@abstractmethod
def bind(self, sock):
"""Wrap and return the given socket."""
return sock
@abstractmethod
def wrap(self, sock):
"""Wrap and return the given socket, plus WSGI environ entries."""
raise NotImplementedError
@abstractmethod
def get_environ(self):
"""Return WSGI environ entries to be merged into each request."""
raise NotImplementedError
@abstractmethod
def makefile(self, sock, mode='r', bufsize=-1):
"""Return socket file object."""
raise NotImplementedError

View file

@ -0,0 +1,162 @@
"""
A library for integrating Python's builtin ``ssl`` library with Cheroot.
The ssl module must be importable for SSL functionality.
To use this module, set ``HTTPServer.ssl_adapter`` to an instance of
``BuiltinSSLAdapter``.
"""
from __future__ import absolute_import, division, print_function
__metaclass__ = type
try:
import ssl
except ImportError:
ssl = None
try:
from _pyio import DEFAULT_BUFFER_SIZE
except ImportError:
try:
from io import DEFAULT_BUFFER_SIZE
except ImportError:
DEFAULT_BUFFER_SIZE = -1
import six
from . import Adapter
from .. import errors
from ..makefile import MakeFile
if six.PY3:
generic_socket_error = OSError
else:
import socket
generic_socket_error = socket.error
del socket
def _assert_ssl_exc_contains(exc, *msgs):
"""Check whether SSL exception contains either of messages provided."""
if len(msgs) < 1:
raise TypeError(
'_assert_ssl_exc_contains() requires '
'at least one message to be passed.'
)
err_msg_lower = exc.args[1].lower()
return any(m.lower() in err_msg_lower for m in msgs)
class BuiltinSSLAdapter(Adapter):
"""A wrapper for integrating Python's builtin ssl module with Cheroot."""
certificate = None
"""The filename of the server SSL certificate."""
private_key = None
"""The filename of the server's private key file."""
certificate_chain = None
"""The filename of the certificate chain file."""
context = None
"""The ssl.SSLContext that will be used to wrap sockets."""
ciphers = None
"""The ciphers list of SSL."""
def __init__(
self, certificate, private_key, certificate_chain=None,
ciphers=None):
"""Set up context in addition to base class properties if available."""
if ssl is None:
raise ImportError('You must install the ssl module to use HTTPS.')
super(BuiltinSSLAdapter, self).__init__(
certificate, private_key, certificate_chain, ciphers)
self.context = ssl.create_default_context(
purpose=ssl.Purpose.CLIENT_AUTH,
cafile=certificate_chain
)
self.context.load_cert_chain(certificate, private_key)
if self.ciphers is not None:
self.context.set_ciphers(ciphers)
def bind(self, sock):
"""Wrap and return the given socket."""
return super(BuiltinSSLAdapter, self).bind(sock)
def wrap(self, sock):
"""Wrap and return the given socket, plus WSGI environ entries."""
EMPTY_RESULT = None, {}
try:
s = self.context.wrap_socket(
sock, do_handshake_on_connect=True, server_side=True,
)
except ssl.SSLError as ex:
if ex.errno == ssl.SSL_ERROR_EOF:
# This is almost certainly due to the cherrypy engine
# 'pinging' the socket to assert it's connectable;
# the 'ping' isn't SSL.
return EMPTY_RESULT
elif ex.errno == ssl.SSL_ERROR_SSL:
if _assert_ssl_exc_contains(ex, 'http request'):
# The client is speaking HTTP to an HTTPS server.
raise errors.NoSSLError
# Check if it's one of the known errors
# Errors that are caught by PyOpenSSL, but thrown by
# built-in ssl
_block_errors = (
'unknown protocol', 'unknown ca', 'unknown_ca',
'unknown error',
'https proxy request', 'inappropriate fallback',
'wrong version number',
'no shared cipher', 'certificate unknown',
'ccs received early',
)
if _assert_ssl_exc_contains(ex, *_block_errors):
# Accepted error, let's pass
return EMPTY_RESULT
elif _assert_ssl_exc_contains(ex, 'handshake operation timed out'):
# This error is thrown by builtin SSL after a timeout
# when client is speaking HTTP to an HTTPS server.
# The connection can safely be dropped.
return EMPTY_RESULT
raise
except generic_socket_error as exc:
"""It is unclear why exactly this happens.
It's reproducible only under Python 2 with openssl>1.0 and stdlib
``ssl`` wrapper, and only with CherryPy.
So it looks like some healthcheck tries to connect to this socket
during startup (from the same process).
Ref: https://github.com/cherrypy/cherrypy/issues/1618
"""
if six.PY2 and exc.args == (0, 'Error'):
return EMPTY_RESULT
raise
return s, self.get_environ(s)
# TODO: fill this out more with mod ssl env
def get_environ(self, sock):
"""Create WSGI environ entries to be merged into each request."""
cipher = sock.cipher()
ssl_environ = {
'wsgi.url_scheme': 'https',
'HTTPS': 'on',
'SSL_PROTOCOL': cipher[1],
'SSL_CIPHER': cipher[0]
# SSL_VERSION_INTERFACE string The mod_ssl program version
# SSL_VERSION_LIBRARY string The OpenSSL program version
}
return ssl_environ
def makefile(self, sock, mode='r', bufsize=DEFAULT_BUFFER_SIZE):
"""Return socket file object."""
return MakeFile(sock, mode, bufsize)

View file

@ -0,0 +1,267 @@
"""
A library for integrating pyOpenSSL with Cheroot.
The OpenSSL module must be importable for SSL functionality.
You can obtain it from `here <https://launchpad.net/pyopenssl>`_.
To use this module, set HTTPServer.ssl_adapter to an instance of
ssl.Adapter. There are two ways to use SSL:
Method One
----------
* ``ssl_adapter.context``: an instance of SSL.Context.
If this is not None, it is assumed to be an SSL.Context instance,
and will be passed to SSL.Connection on bind(). The developer is
responsible for forming a valid Context object. This approach is
to be preferred for more flexibility, e.g. if the cert and key are
streams instead of files, or need decryption, or SSL.SSLv3_METHOD
is desired instead of the default SSL.SSLv23_METHOD, etc. Consult
the pyOpenSSL documentation for complete options.
Method Two (shortcut)
---------------------
* ``ssl_adapter.certificate``: the filename of the server SSL certificate.
* ``ssl_adapter.private_key``: the filename of the server's private key file.
Both are None by default. If ssl_adapter.context is None, but .private_key
and .certificate are both given and valid, they will be read, and the
context will be automatically created from them.
"""
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import socket
import threading
import time
try:
from OpenSSL import SSL
from OpenSSL import crypto
except ImportError:
SSL = None
from . import Adapter
from .. import errors, server as cheroot_server
from ..makefile import MakeFile
class SSL_fileobject(MakeFile):
"""SSL file object attached to a socket object."""
ssl_timeout = 3
ssl_retry = .01
def _safe_call(self, is_reader, call, *args, **kwargs):
"""Wrap the given call with SSL error-trapping.
is_reader: if False EOF errors will be raised. If True, EOF errors
will return "" (to emulate normal sockets).
"""
start = time.time()
while True:
try:
return call(*args, **kwargs)
except SSL.WantReadError:
# Sleep and try again. This is dangerous, because it means
# the rest of the stack has no way of differentiating
# between a "new handshake" error and "client dropped".
# Note this isn't an endless loop: there's a timeout below.
time.sleep(self.ssl_retry)
except SSL.WantWriteError:
time.sleep(self.ssl_retry)
except SSL.SysCallError as e:
if is_reader and e.args == (-1, 'Unexpected EOF'):
return ''
errnum = e.args[0]
if is_reader and errnum in errors.socket_errors_to_ignore:
return ''
raise socket.error(errnum)
except SSL.Error as e:
if is_reader and e.args == (-1, 'Unexpected EOF'):
return ''
thirdarg = None
try:
thirdarg = e.args[0][0][2]
except IndexError:
pass
if thirdarg == 'http request':
# The client is talking HTTP to an HTTPS server.
raise errors.NoSSLError()
raise errors.FatalSSLAlert(*e.args)
if time.time() - start > self.ssl_timeout:
raise socket.timeout('timed out')
def recv(self, size):
"""Receive message of a size from the socket."""
return self._safe_call(True, super(SSL_fileobject, self).recv, size)
def sendall(self, *args, **kwargs):
"""Send whole message to the socket."""
return self._safe_call(False, super(SSL_fileobject, self).sendall,
*args, **kwargs)
def send(self, *args, **kwargs):
"""Send some part of message to the socket."""
return self._safe_call(False, super(SSL_fileobject, self).send,
*args, **kwargs)
class SSLConnection:
"""A thread-safe wrapper for an SSL.Connection.
``*args``: the arguments to create the wrapped ``SSL.Connection(*args)``.
"""
def __init__(self, *args):
"""Initialize SSLConnection instance."""
self._ssl_conn = SSL.Connection(*args)
self._lock = threading.RLock()
for f in ('get_context', 'pending', 'send', 'write', 'recv', 'read',
'renegotiate', 'bind', 'listen', 'connect', 'accept',
'setblocking', 'fileno', 'close', 'get_cipher_list',
'getpeername', 'getsockname', 'getsockopt', 'setsockopt',
'makefile', 'get_app_data', 'set_app_data', 'state_string',
'sock_shutdown', 'get_peer_certificate', 'want_read',
'want_write', 'set_connect_state', 'set_accept_state',
'connect_ex', 'sendall', 'settimeout', 'gettimeout'):
exec("""def %s(self, *args):
self._lock.acquire()
try:
return self._ssl_conn.%s(*args)
finally:
self._lock.release()
""" % (f, f))
def shutdown(self, *args):
"""Shutdown the SSL connection.
Ignore all incoming args since pyOpenSSL.socket.shutdown takes no args.
"""
self._lock.acquire()
try:
return self._ssl_conn.shutdown()
finally:
self._lock.release()
class pyOpenSSLAdapter(Adapter):
"""A wrapper for integrating pyOpenSSL with Cheroot."""
certificate = None
"""The filename of the server SSL certificate."""
private_key = None
"""The filename of the server's private key file."""
certificate_chain = None
"""Optional. The filename of CA's intermediate certificate bundle.
This is needed for cheaper "chained root" SSL certificates, and should be
left as None if not required."""
context = None
"""An instance of SSL.Context."""
ciphers = None
"""The ciphers list of SSL."""
def __init__(
self, certificate, private_key, certificate_chain=None,
ciphers=None):
"""Initialize OpenSSL Adapter instance."""
if SSL is None:
raise ImportError('You must install pyOpenSSL to use HTTPS.')
super(pyOpenSSLAdapter, self).__init__(
certificate, private_key, certificate_chain, ciphers)
self._environ = None
def bind(self, sock):
"""Wrap and return the given socket."""
if self.context is None:
self.context = self.get_context()
conn = SSLConnection(self.context, sock)
self._environ = self.get_environ()
return conn
def wrap(self, sock):
"""Wrap and return the given socket, plus WSGI environ entries."""
return sock, self._environ.copy()
def get_context(self):
"""Return an SSL.Context from self attributes."""
# See https://code.activestate.com/recipes/442473/
c = SSL.Context(SSL.SSLv23_METHOD)
c.use_privatekey_file(self.private_key)
if self.certificate_chain:
c.load_verify_locations(self.certificate_chain)
c.use_certificate_file(self.certificate)
return c
def get_environ(self):
"""Return WSGI environ entries to be merged into each request."""
ssl_environ = {
'HTTPS': 'on',
# pyOpenSSL doesn't provide access to any of these AFAICT
# 'SSL_PROTOCOL': 'SSLv2',
# SSL_CIPHER string The cipher specification name
# SSL_VERSION_INTERFACE string The mod_ssl program version
# SSL_VERSION_LIBRARY string The OpenSSL program version
}
if self.certificate:
# Server certificate attributes
cert = open(self.certificate, 'rb').read()
cert = crypto.load_certificate(crypto.FILETYPE_PEM, cert)
ssl_environ.update({
'SSL_SERVER_M_VERSION': cert.get_version(),
'SSL_SERVER_M_SERIAL': cert.get_serial_number(),
# 'SSL_SERVER_V_START':
# Validity of server's certificate (start time),
# 'SSL_SERVER_V_END':
# Validity of server's certificate (end time),
})
for prefix, dn in [('I', cert.get_issuer()),
('S', cert.get_subject())]:
# X509Name objects don't seem to have a way to get the
# complete DN string. Use str() and slice it instead,
# because str(dn) == "<X509Name object '/C=US/ST=...'>"
dnstr = str(dn)[18:-2]
wsgikey = 'SSL_SERVER_%s_DN' % prefix
ssl_environ[wsgikey] = dnstr
# The DN should be of the form: /k1=v1/k2=v2, but we must allow
# for any value to contain slashes itself (in a URL).
while dnstr:
pos = dnstr.rfind('=')
dnstr, value = dnstr[:pos], dnstr[pos + 1:]
pos = dnstr.rfind('/')
dnstr, key = dnstr[:pos], dnstr[pos + 1:]
if key and value:
wsgikey = 'SSL_SERVER_%s_DN_%s' % (prefix, key)
ssl_environ[wsgikey] = value
return ssl_environ
def makefile(self, sock, mode='r', bufsize=-1):
"""Return socket file object."""
if SSL and isinstance(sock, SSL.ConnectionType):
timeout = sock.gettimeout()
f = SSL_fileobject(sock, mode, bufsize)
f.ssl_timeout = timeout
return f
else:
return cheroot_server.CP_fileobject(sock, mode, bufsize)

View file

@ -0,0 +1 @@
"""Cheroot test suite."""

View file

@ -0,0 +1,27 @@
"""Pytest configuration module.
Contains fixtures, which are tightly bound to the Cheroot framework
itself, useless for end-users' app testing.
"""
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import pytest
from ..testing import ( # noqa: F401
native_server, wsgi_server,
)
from ..testing import get_server_client
@pytest.fixture # noqa: F811
def wsgi_server_client(wsgi_server):
"""Create a test client out of given WSGI server."""
return get_server_client(wsgi_server)
@pytest.fixture # noqa: F811
def native_server_client(native_server):
"""Create a test client out of given HTTP server."""
return get_server_client(native_server)

View file

@ -0,0 +1,169 @@
"""A library of helper functions for the Cheroot test suite."""
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import datetime
import logging
import os
import sys
import time
import threading
import types
from six.moves import http_client
import six
import cheroot.server
import cheroot.wsgi
from cheroot.test import webtest
log = logging.getLogger(__name__)
thisdir = os.path.abspath(os.path.dirname(__file__))
serverpem = os.path.join(os.getcwd(), thisdir, 'test.pem')
config = {
'bind_addr': ('127.0.0.1', 54583),
'server': 'wsgi',
'wsgi_app': None,
}
class CherootWebCase(webtest.WebCase):
"""Helper class for a web app test suite."""
script_name = ''
scheme = 'http'
available_servers = {
'wsgi': cheroot.wsgi.Server,
'native': cheroot.server.HTTPServer,
}
@classmethod
def setup_class(cls):
"""Create and run one HTTP server per class."""
conf = config.copy()
conf.update(getattr(cls, 'config', {}))
s_class = conf.pop('server', 'wsgi')
server_factory = cls.available_servers.get(s_class)
if server_factory is None:
raise RuntimeError('Unknown server in config: %s' % conf['server'])
cls.httpserver = server_factory(**conf)
cls.HOST, cls.PORT = cls.httpserver.bind_addr
if cls.httpserver.ssl_adapter is None:
ssl = ''
cls.scheme = 'http'
else:
ssl = ' (ssl)'
cls.HTTP_CONN = http_client.HTTPSConnection
cls.scheme = 'https'
v = sys.version.split()[0]
log.info('Python version used to run this test script: %s' % v)
log.info('Cheroot version: %s' % cheroot.__version__)
log.info('HTTP server version: %s%s' % (cls.httpserver.protocol, ssl))
log.info('PID: %s' % os.getpid())
if hasattr(cls, 'setup_server'):
# Clear the wsgi server so that
# it can be updated with the new root
cls.setup_server()
cls.start()
@classmethod
def teardown_class(cls):
"""Cleanup HTTP server."""
if hasattr(cls, 'setup_server'):
cls.stop()
@classmethod
def start(cls):
"""Load and start the HTTP server."""
threading.Thread(target=cls.httpserver.safe_start).start()
while not cls.httpserver.ready:
time.sleep(0.1)
@classmethod
def stop(cls):
"""Terminate HTTP server."""
cls.httpserver.stop()
td = getattr(cls, 'teardown', None)
if td:
td()
date_tolerance = 2
def assertEqualDates(self, dt1, dt2, seconds=None):
"""Assert abs(dt1 - dt2) is within Y seconds."""
if seconds is None:
seconds = self.date_tolerance
if dt1 > dt2:
diff = dt1 - dt2
else:
diff = dt2 - dt1
if not diff < datetime.timedelta(seconds=seconds):
raise AssertionError('%r and %r are not within %r seconds.' %
(dt1, dt2, seconds))
class Request:
"""HTTP request container."""
def __init__(self, environ):
"""Initialize HTTP request."""
self.environ = environ
class Response:
"""HTTP response container."""
def __init__(self):
"""Initialize HTTP response."""
self.status = '200 OK'
self.headers = {'Content-Type': 'text/html'}
self.body = None
def output(self):
"""Generate iterable response body object."""
if self.body is None:
return []
elif isinstance(self.body, six.text_type):
return [self.body.encode('iso-8859-1')]
elif isinstance(self.body, six.binary_type):
return [self.body]
else:
return [x.encode('iso-8859-1') for x in self.body]
class Controller:
"""WSGI app for tests."""
def __call__(self, environ, start_response):
"""WSGI request handler."""
req, resp = Request(environ), Response()
try:
# Python 3 supports unicode attribute names
# Python 2 encodes them
handler = self.handlers[environ['PATH_INFO']]
except KeyError:
resp.status = '404 Not Found'
else:
output = handler(req, resp)
if (output is not None and
not any(resp.status.startswith(status_code)
for status_code in ('204', '304'))):
resp.body = output
try:
resp.headers.setdefault('Content-Length', str(len(output)))
except TypeError:
if not isinstance(output, types.GeneratorType):
raise
start_response(resp.status, resp.headers.items())
return resp.output()

View file

@ -0,0 +1,38 @@
-----BEGIN RSA PRIVATE KEY-----
MIICXAIBAAKBgQDBKo554mzIMY+AByUNpaUOP9bJnQ7ZLQe9XgHwoLJR4VzpyZZZ
R9L4WtImEew05FY3Izerfm3MN3+MC0tJ6yQU9sOiU3vBW6RrLIMlfKsnRwBRZ0Kn
da+O6xldVSosu8Ev3z9VZ94iC/ZgKzrH7Mjj/U8/MQO7RBS/LAqee8bFNQIDAQAB
AoGAWOCF0ZrWxn3XMucWq2LNwPKqlvVGwbIwX3cDmX22zmnM4Fy6arXbYh4XlyCj
9+ofqRrxIFz5k/7tFriTmZ0xag5+Jdx+Kwg0/twiP7XCNKipFogwe1Hznw8OFAoT
enKBdj2+/n2o0Bvo/tDB59m9L/538d46JGQUmJlzMyqYikECQQDyoq+8CtMNvE18
8VgHcR/KtApxWAjj4HpaHYL637ATjThetUZkW92mgDgowyplthusxdNqhHWyv7E8
tWNdYErZAkEAy85ShTR0M5aWmrE7o0r0SpWInAkNBH9aXQRRARFYsdBtNfRu6I0i
0lvU9wiu3eF57FMEC86yViZ5UBnQfTu7vQJAVesj/Zt7pwaCDfdMa740OsxMUlyR
MVhhGx4OLpYdPJ8qUecxGQKq13XZ7R1HGyNEY4bd2X80Smq08UFuATfC6QJAH8UB
yBHtKz2GLIcELOg6PIYizW/7v3+6rlVF60yw7sb2vzpjL40QqIn4IKoR2DSVtOkb
8FtAIX3N21aq0VrGYQJBAIPiaEc2AZ8Bq2GC4F3wOz/BxJ/izvnkiotR12QK4fh5
yjZMhTjWCas5zwHR5PDjlD88AWGDMsZ1PicD4348xJQ=
-----END RSA PRIVATE KEY-----
-----BEGIN CERTIFICATE-----
MIIDxTCCAy6gAwIBAgIJAI18BD7eQxlGMA0GCSqGSIb3DQEBBAUAMIGeMQswCQYD
VQQGEwJVUzETMBEGA1UECBMKQ2FsaWZvcm5pYTESMBAGA1UEBxMJU2FuIERpZWdv
MRkwFwYDVQQKExBDaGVycnlQeSBQcm9qZWN0MREwDwYDVQQLEwhkZXYtdGVzdDEW
MBQGA1UEAxMNQ2hlcnJ5UHkgVGVhbTEgMB4GCSqGSIb3DQEJARYRcmVtaUBjaGVy
cnlweS5vcmcwHhcNMDYwOTA5MTkyMDIwWhcNMzQwMTI0MTkyMDIwWjCBnjELMAkG
A1UEBhMCVVMxEzARBgNVBAgTCkNhbGlmb3JuaWExEjAQBgNVBAcTCVNhbiBEaWVn
bzEZMBcGA1UEChMQQ2hlcnJ5UHkgUHJvamVjdDERMA8GA1UECxMIZGV2LXRlc3Qx
FjAUBgNVBAMTDUNoZXJyeVB5IFRlYW0xIDAeBgkqhkiG9w0BCQEWEXJlbWlAY2hl
cnJ5cHkub3JnMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDBKo554mzIMY+A
ByUNpaUOP9bJnQ7ZLQe9XgHwoLJR4VzpyZZZR9L4WtImEew05FY3Izerfm3MN3+M
C0tJ6yQU9sOiU3vBW6RrLIMlfKsnRwBRZ0Knda+O6xldVSosu8Ev3z9VZ94iC/Zg
KzrH7Mjj/U8/MQO7RBS/LAqee8bFNQIDAQABo4IBBzCCAQMwHQYDVR0OBBYEFDIQ
2feb71tVZCWpU0qJ/Tw+wdtoMIHTBgNVHSMEgcswgciAFDIQ2feb71tVZCWpU0qJ
/Tw+wdtooYGkpIGhMIGeMQswCQYDVQQGEwJVUzETMBEGA1UECBMKQ2FsaWZvcm5p
YTESMBAGA1UEBxMJU2FuIERpZWdvMRkwFwYDVQQKExBDaGVycnlQeSBQcm9qZWN0
MREwDwYDVQQLEwhkZXYtdGVzdDEWMBQGA1UEAxMNQ2hlcnJ5UHkgVGVhbTEgMB4G
CSqGSIb3DQEJARYRcmVtaUBjaGVycnlweS5vcmeCCQCNfAQ+3kMZRjAMBgNVHRME
BTADAQH/MA0GCSqGSIb3DQEBBAUAA4GBAL7AAQz7IePV48ZTAFHKr88ntPALsL5S
8vHCZPNMevNkLTj3DYUw2BcnENxMjm1kou2F2BkvheBPNZKIhc6z4hAml3ed1xa2
D7w6e6OTcstdK/+KrPDDHeOP1dhMWNs2JE1bNlfF1LiXzYKSXpe88eCKjCXsCT/T
NluCaWQys3MS
-----END CERTIFICATE-----

View file

@ -0,0 +1,49 @@
# -*- coding: utf-8 -*-
"""Test suite for cross-python compatibility helpers."""
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import pytest
import six
from cheroot._compat import ntob, ntou, bton
@pytest.mark.parametrize(
'func,inp,out',
[
(ntob, 'bar', b'bar'),
(ntou, 'bar', u'bar'),
(bton, b'bar', 'bar'),
],
)
def test_compat_functions_positive(func, inp, out):
"""Check that compat functions work with correct input."""
assert func(inp, encoding='utf-8') == out
@pytest.mark.parametrize(
'func',
[
ntob,
ntou,
],
)
def test_compat_functions_negative_nonnative(func):
"""Check that compat functions fail loudly for incorrect input."""
non_native_test_str = b'bar' if six.PY3 else u'bar'
with pytest.raises(TypeError):
func(non_native_test_str, encoding='utf-8')
@pytest.mark.skip(reason='This test does not work now')
@pytest.mark.skipif(
six.PY3,
reason='This code path only appears in Python 2 version.',
)
def test_ntou_escape():
"""Check that ntou supports escape-encoding under Python 2."""
expected = u''
actual = ntou('hi'.encode('ISO-8859-1'), encoding='escape')
assert actual == expected

View file

@ -0,0 +1,897 @@
"""Tests for TCP connection handling, including proper and timely close."""
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import socket
import time
from six.moves import range, http_client, urllib
import six
import pytest
from cheroot.test import helper, webtest
timeout = 1
pov = 'pPeErRsSiIsStTeEnNcCeE oOfF vViIsSiIoOnN'
class Controller(helper.Controller):
"""Controller for serving WSGI apps."""
def hello(req, resp):
"""Render Hello world."""
return 'Hello, world!'
def pov(req, resp):
"""Render pov value."""
return pov
def stream(req, resp):
"""Render streaming response."""
if 'set_cl' in req.environ['QUERY_STRING']:
resp.headers['Content-Length'] = str(10)
def content():
for x in range(10):
yield str(x)
return content()
def upload(req, resp):
"""Process file upload and render thank."""
if not req.environ['REQUEST_METHOD'] == 'POST':
raise AssertionError("'POST' != request.method %r" %
req.environ['REQUEST_METHOD'])
return "thanks for '%s'" % req.environ['wsgi.input'].read()
def custom_204(req, resp):
"""Render response with status 204."""
resp.status = '204'
return 'Code = 204'
def custom_304(req, resp):
"""Render response with status 304."""
resp.status = '304'
return 'Code = 304'
def err_before_read(req, resp):
"""Render response with status 500."""
resp.status = '500 Internal Server Error'
return 'ok'
def one_megabyte_of_a(req, resp):
"""Render 1MB response."""
return ['a' * 1024] * 1024
def wrong_cl_buffered(req, resp):
"""Render buffered response with invalid length value."""
resp.headers['Content-Length'] = '5'
return 'I have too many bytes'
def wrong_cl_unbuffered(req, resp):
"""Render unbuffered response with invalid length value."""
resp.headers['Content-Length'] = '5'
return ['I too', ' have too many bytes']
def _munge(string):
"""Encode PATH_INFO correctly depending on Python version.
WSGI 1.0 is a mess around unicode. Create endpoints
that match the PATH_INFO that it produces.
"""
if six.PY3:
return string.encode('utf-8').decode('latin-1')
return string
handlers = {
'/hello': hello,
'/pov': pov,
'/page1': pov,
'/page2': pov,
'/page3': pov,
'/stream': stream,
'/upload': upload,
'/custom/204': custom_204,
'/custom/304': custom_304,
'/err_before_read': err_before_read,
'/one_megabyte_of_a': one_megabyte_of_a,
'/wrong_cl_buffered': wrong_cl_buffered,
'/wrong_cl_unbuffered': wrong_cl_unbuffered,
}
@pytest.fixture
def testing_server(wsgi_server_client):
"""Attach a WSGI app to the given server and pre-configure it."""
app = Controller()
def _timeout(req, resp):
return str(wsgi_server.timeout)
app.handlers['/timeout'] = _timeout
wsgi_server = wsgi_server_client.server_instance
wsgi_server.wsgi_app = app
wsgi_server.max_request_body_size = 1001
wsgi_server.timeout = timeout
wsgi_server.server_client = wsgi_server_client
return wsgi_server
@pytest.fixture
def test_client(testing_server):
"""Get and return a test client out of the given server."""
return testing_server.server_client
def header_exists(header_name, headers):
"""Check that a header is present."""
return header_name.lower() in (k.lower() for (k, _) in headers)
def header_has_value(header_name, header_value, headers):
"""Check that a header with a given value is present."""
return header_name.lower() in (k.lower() for (k, v) in headers
if v == header_value)
def test_HTTP11_persistent_connections(test_client):
"""Test persistent HTTP/1.1 connections."""
# Initialize a persistent HTTP connection
http_connection = test_client.get_connection()
http_connection.auto_open = False
http_connection.connect()
# Make the first request and assert there's no "Connection: close".
status_line, actual_headers, actual_resp_body = test_client.get(
'/pov', http_conn=http_connection
)
actual_status = int(status_line[:3])
assert actual_status == 200
assert status_line[4:] == 'OK'
assert actual_resp_body == pov.encode()
assert not header_exists('Connection', actual_headers)
# Make another request on the same connection.
status_line, actual_headers, actual_resp_body = test_client.get(
'/page1', http_conn=http_connection
)
actual_status = int(status_line[:3])
assert actual_status == 200
assert status_line[4:] == 'OK'
assert actual_resp_body == pov.encode()
assert not header_exists('Connection', actual_headers)
# Test client-side close.
status_line, actual_headers, actual_resp_body = test_client.get(
'/page2', http_conn=http_connection,
headers=[('Connection', 'close')]
)
actual_status = int(status_line[:3])
assert actual_status == 200
assert status_line[4:] == 'OK'
assert actual_resp_body == pov.encode()
assert header_has_value('Connection', 'close', actual_headers)
# Make another request on the same connection, which should error.
with pytest.raises(http_client.NotConnected):
test_client.get('/pov', http_conn=http_connection)
@pytest.mark.parametrize(
'set_cl',
(
False, # Without Content-Length
True, # With Content-Length
)
)
def test_streaming_11(test_client, set_cl):
"""Test serving of streaming responses with HTTP/1.1 protocol."""
# Initialize a persistent HTTP connection
http_connection = test_client.get_connection()
http_connection.auto_open = False
http_connection.connect()
# Make the first request and assert there's no "Connection: close".
status_line, actual_headers, actual_resp_body = test_client.get(
'/pov', http_conn=http_connection
)
actual_status = int(status_line[:3])
assert actual_status == 200
assert status_line[4:] == 'OK'
assert actual_resp_body == pov.encode()
assert not header_exists('Connection', actual_headers)
# Make another, streamed request on the same connection.
if set_cl:
# When a Content-Length is provided, the content should stream
# without closing the connection.
status_line, actual_headers, actual_resp_body = test_client.get(
'/stream?set_cl=Yes', http_conn=http_connection
)
assert header_exists('Content-Length', actual_headers)
assert not header_has_value('Connection', 'close', actual_headers)
assert not header_exists('Transfer-Encoding', actual_headers)
assert actual_status == 200
assert status_line[4:] == 'OK'
assert actual_resp_body == b'0123456789'
else:
# When no Content-Length response header is provided,
# streamed output will either close the connection, or use
# chunked encoding, to determine transfer-length.
status_line, actual_headers, actual_resp_body = test_client.get(
'/stream', http_conn=http_connection
)
assert not header_exists('Content-Length', actual_headers)
assert actual_status == 200
assert status_line[4:] == 'OK'
assert actual_resp_body == b'0123456789'
chunked_response = False
for k, v in actual_headers:
if k.lower() == 'transfer-encoding':
if str(v) == 'chunked':
chunked_response = True
if chunked_response:
assert not header_has_value('Connection', 'close', actual_headers)
else:
assert header_has_value('Connection', 'close', actual_headers)
# Make another request on the same connection, which should
# error.
with pytest.raises(http_client.NotConnected):
test_client.get('/pov', http_conn=http_connection)
# Try HEAD.
# See https://www.bitbucket.org/cherrypy/cherrypy/issue/864.
# TODO: figure out how can this be possible on an closed connection
# (chunked_response case)
status_line, actual_headers, actual_resp_body = test_client.head(
'/stream', http_conn=http_connection
)
assert actual_status == 200
assert status_line[4:] == 'OK'
assert actual_resp_body == b''
assert not header_exists('Transfer-Encoding', actual_headers)
@pytest.mark.parametrize(
'set_cl',
(
False, # Without Content-Length
True, # With Content-Length
)
)
def test_streaming_10(test_client, set_cl):
"""Test serving of streaming responses with HTTP/1.0 protocol."""
original_server_protocol = test_client.server_instance.protocol
test_client.server_instance.protocol = 'HTTP/1.0'
# Initialize a persistent HTTP connection
http_connection = test_client.get_connection()
http_connection.auto_open = False
http_connection.connect()
# Make the first request and assert Keep-Alive.
status_line, actual_headers, actual_resp_body = test_client.get(
'/pov', http_conn=http_connection,
headers=[('Connection', 'Keep-Alive')],
protocol='HTTP/1.0',
)
actual_status = int(status_line[:3])
assert actual_status == 200
assert status_line[4:] == 'OK'
assert actual_resp_body == pov.encode()
assert header_has_value('Connection', 'Keep-Alive', actual_headers)
# Make another, streamed request on the same connection.
if set_cl:
# When a Content-Length is provided, the content should
# stream without closing the connection.
status_line, actual_headers, actual_resp_body = test_client.get(
'/stream?set_cl=Yes', http_conn=http_connection,
headers=[('Connection', 'Keep-Alive')],
protocol='HTTP/1.0',
)
actual_status = int(status_line[:3])
assert actual_status == 200
assert status_line[4:] == 'OK'
assert actual_resp_body == b'0123456789'
assert header_exists('Content-Length', actual_headers)
assert header_has_value('Connection', 'Keep-Alive', actual_headers)
assert not header_exists('Transfer-Encoding', actual_headers)
else:
# When a Content-Length is not provided,
# the server should close the connection.
status_line, actual_headers, actual_resp_body = test_client.get(
'/stream', http_conn=http_connection,
headers=[('Connection', 'Keep-Alive')],
protocol='HTTP/1.0',
)
actual_status = int(status_line[:3])
assert actual_status == 200
assert status_line[4:] == 'OK'
assert actual_resp_body == b'0123456789'
assert not header_exists('Content-Length', actual_headers)
assert not header_has_value('Connection', 'Keep-Alive', actual_headers)
assert not header_exists('Transfer-Encoding', actual_headers)
# Make another request on the same connection, which should error.
with pytest.raises(http_client.NotConnected):
test_client.get(
'/pov', http_conn=http_connection,
protocol='HTTP/1.0',
)
test_client.server_instance.protocol = original_server_protocol
@pytest.mark.parametrize(
'http_server_protocol',
(
'HTTP/1.0',
'HTTP/1.1',
)
)
def test_keepalive(test_client, http_server_protocol):
"""Test Keep-Alive enabled connections."""
original_server_protocol = test_client.server_instance.protocol
test_client.server_instance.protocol = http_server_protocol
http_client_protocol = 'HTTP/1.0'
# Initialize a persistent HTTP connection
http_connection = test_client.get_connection()
http_connection.auto_open = False
http_connection.connect()
# Test a normal HTTP/1.0 request.
status_line, actual_headers, actual_resp_body = test_client.get(
'/page2',
protocol=http_client_protocol,
)
actual_status = int(status_line[:3])
assert actual_status == 200
assert status_line[4:] == 'OK'
assert actual_resp_body == pov.encode()
assert not header_exists('Connection', actual_headers)
# Test a keep-alive HTTP/1.0 request.
status_line, actual_headers, actual_resp_body = test_client.get(
'/page3', headers=[('Connection', 'Keep-Alive')],
http_conn=http_connection, protocol=http_client_protocol,
)
actual_status = int(status_line[:3])
assert actual_status == 200
assert status_line[4:] == 'OK'
assert actual_resp_body == pov.encode()
assert header_has_value('Connection', 'Keep-Alive', actual_headers)
# Remove the keep-alive header again.
status_line, actual_headers, actual_resp_body = test_client.get(
'/page3', http_conn=http_connection,
protocol=http_client_protocol,
)
actual_status = int(status_line[:3])
assert actual_status == 200
assert status_line[4:] == 'OK'
assert actual_resp_body == pov.encode()
assert not header_exists('Connection', actual_headers)
test_client.server_instance.protocol = original_server_protocol
@pytest.mark.parametrize(
'timeout_before_headers',
(
True,
False,
)
)
def test_HTTP11_Timeout(test_client, timeout_before_headers):
"""Check timeout without sending any data.
The server will close the conn with a 408.
"""
conn = test_client.get_connection()
conn.auto_open = False
conn.connect()
if not timeout_before_headers:
# Connect but send half the headers only.
conn.send(b'GET /hello HTTP/1.1')
conn.send(('Host: %s' % conn.host).encode('ascii'))
# else: Connect but send nothing.
# Wait for our socket timeout
time.sleep(timeout * 2)
# The request should have returned 408 already.
response = conn.response_class(conn.sock, method='GET')
response.begin()
assert response.status == 408
conn.close()
def test_HTTP11_Timeout_after_request(test_client):
"""Check timeout after at least one request has succeeded.
The server should close the connection without 408.
"""
fail_msg = "Writing to timed out socket didn't fail as it should have: %s"
# Make an initial request
conn = test_client.get_connection()
conn.putrequest('GET', '/timeout?t=%s' % timeout, skip_host=True)
conn.putheader('Host', conn.host)
conn.endheaders()
response = conn.response_class(conn.sock, method='GET')
response.begin()
assert response.status == 200
actual_body = response.read()
expected_body = str(timeout).encode()
assert actual_body == expected_body
# Make a second request on the same socket
conn._output(b'GET /hello HTTP/1.1')
conn._output(('Host: %s' % conn.host).encode('ascii'))
conn._send_output()
response = conn.response_class(conn.sock, method='GET')
response.begin()
assert response.status == 200
actual_body = response.read()
expected_body = b'Hello, world!'
assert actual_body == expected_body
# Wait for our socket timeout
time.sleep(timeout * 2)
# Make another request on the same socket, which should error
conn._output(b'GET /hello HTTP/1.1')
conn._output(('Host: %s' % conn.host).encode('ascii'))
conn._send_output()
response = conn.response_class(conn.sock, method='GET')
try:
response.begin()
except (socket.error, http_client.BadStatusLine):
pass
except Exception as ex:
pytest.fail(fail_msg % ex)
else:
if response.status != 408:
pytest.fail(fail_msg % response.read())
conn.close()
# Make another request on a new socket, which should work
conn = test_client.get_connection()
conn.putrequest('GET', '/pov', skip_host=True)
conn.putheader('Host', conn.host)
conn.endheaders()
response = conn.response_class(conn.sock, method='GET')
response.begin()
assert response.status == 200
actual_body = response.read()
expected_body = pov.encode()
assert actual_body == expected_body
# Make another request on the same socket,
# but timeout on the headers
conn.send(b'GET /hello HTTP/1.1')
# Wait for our socket timeout
time.sleep(timeout * 2)
response = conn.response_class(conn.sock, method='GET')
try:
response.begin()
except (socket.error, http_client.BadStatusLine):
pass
except Exception as ex:
pytest.fail(fail_msg % ex)
else:
if response.status != 408:
pytest.fail(fail_msg % response.read())
conn.close()
# Retry the request on a new connection, which should work
conn = test_client.get_connection()
conn.putrequest('GET', '/pov', skip_host=True)
conn.putheader('Host', conn.host)
conn.endheaders()
response = conn.response_class(conn.sock, method='GET')
response.begin()
assert response.status == 200
actual_body = response.read()
expected_body = pov.encode()
assert actual_body == expected_body
conn.close()
def test_HTTP11_pipelining(test_client):
"""Test HTTP/1.1 pipelining.
httplib doesn't support this directly.
"""
conn = test_client.get_connection()
# Put request 1
conn.putrequest('GET', '/hello', skip_host=True)
conn.putheader('Host', conn.host)
conn.endheaders()
for trial in range(5):
# Put next request
conn._output(
('GET /hello?%s HTTP/1.1' % trial).encode('iso-8859-1')
)
conn._output(('Host: %s' % conn.host).encode('ascii'))
conn._send_output()
# Retrieve previous response
response = conn.response_class(conn.sock, method='GET')
# there is a bug in python3 regarding the buffering of
# ``conn.sock``. Until that bug get's fixed we will
# monkey patch the ``reponse`` instance.
# https://bugs.python.org/issue23377
if six.PY3:
response.fp = conn.sock.makefile('rb', 0)
response.begin()
body = response.read(13)
assert response.status == 200
assert body == b'Hello, world!'
# Retrieve final response
response = conn.response_class(conn.sock, method='GET')
response.begin()
body = response.read()
assert response.status == 200
assert body == b'Hello, world!'
conn.close()
def test_100_Continue(test_client):
"""Test 100-continue header processing."""
conn = test_client.get_connection()
# Try a page without an Expect request header first.
# Note that httplib's response.begin automatically ignores
# 100 Continue responses, so we must manually check for it.
conn.putrequest('POST', '/upload', skip_host=True)
conn.putheader('Host', conn.host)
conn.putheader('Content-Type', 'text/plain')
conn.putheader('Content-Length', '4')
conn.endheaders()
conn.send(b"d'oh")
response = conn.response_class(conn.sock, method='POST')
version, status, reason = response._read_status()
assert status != 100
conn.close()
# Now try a page with an Expect header...
conn.connect()
conn.putrequest('POST', '/upload', skip_host=True)
conn.putheader('Host', conn.host)
conn.putheader('Content-Type', 'text/plain')
conn.putheader('Content-Length', '17')
conn.putheader('Expect', '100-continue')
conn.endheaders()
response = conn.response_class(conn.sock, method='POST')
# ...assert and then skip the 100 response
version, status, reason = response._read_status()
assert status == 100
while True:
line = response.fp.readline().strip()
if line:
pytest.fail(
'100 Continue should not output any headers. Got %r' %
line)
else:
break
# ...send the body
body = b'I am a small file'
conn.send(body)
# ...get the final response
response.begin()
status_line, actual_headers, actual_resp_body = webtest.shb(response)
actual_status = int(status_line[:3])
assert actual_status == 200
expected_resp_body = ("thanks for '%s'" % body).encode()
assert actual_resp_body == expected_resp_body
conn.close()
@pytest.mark.parametrize(
'max_request_body_size',
(
0,
1001,
)
)
def test_readall_or_close(test_client, max_request_body_size):
"""Test a max_request_body_size of 0 (the default) and 1001."""
old_max = test_client.server_instance.max_request_body_size
test_client.server_instance.max_request_body_size = max_request_body_size
conn = test_client.get_connection()
# Get a POST page with an error
conn.putrequest('POST', '/err_before_read', skip_host=True)
conn.putheader('Host', conn.host)
conn.putheader('Content-Type', 'text/plain')
conn.putheader('Content-Length', '1000')
conn.putheader('Expect', '100-continue')
conn.endheaders()
response = conn.response_class(conn.sock, method='POST')
# ...assert and then skip the 100 response
version, status, reason = response._read_status()
assert status == 100
skip = True
while skip:
skip = response.fp.readline().strip()
# ...send the body
conn.send(b'x' * 1000)
# ...get the final response
response.begin()
status_line, actual_headers, actual_resp_body = webtest.shb(response)
actual_status = int(status_line[:3])
assert actual_status == 500
# Now try a working page with an Expect header...
conn._output(b'POST /upload HTTP/1.1')
conn._output(('Host: %s' % conn.host).encode('ascii'))
conn._output(b'Content-Type: text/plain')
conn._output(b'Content-Length: 17')
conn._output(b'Expect: 100-continue')
conn._send_output()
response = conn.response_class(conn.sock, method='POST')
# ...assert and then skip the 100 response
version, status, reason = response._read_status()
assert status == 100
skip = True
while skip:
skip = response.fp.readline().strip()
# ...send the body
body = b'I am a small file'
conn.send(body)
# ...get the final response
response.begin()
status_line, actual_headers, actual_resp_body = webtest.shb(response)
actual_status = int(status_line[:3])
assert actual_status == 200
expected_resp_body = ("thanks for '%s'" % body).encode()
assert actual_resp_body == expected_resp_body
conn.close()
test_client.server_instance.max_request_body_size = old_max
def test_No_Message_Body(test_client):
"""Test HTTP queries with an empty response body."""
# Initialize a persistent HTTP connection
http_connection = test_client.get_connection()
http_connection.auto_open = False
http_connection.connect()
# Make the first request and assert there's no "Connection: close".
status_line, actual_headers, actual_resp_body = test_client.get(
'/pov', http_conn=http_connection
)
actual_status = int(status_line[:3])
assert actual_status == 200
assert status_line[4:] == 'OK'
assert actual_resp_body == pov.encode()
assert not header_exists('Connection', actual_headers)
# Make a 204 request on the same connection.
status_line, actual_headers, actual_resp_body = test_client.get(
'/custom/204', http_conn=http_connection
)
actual_status = int(status_line[:3])
assert actual_status == 204
assert not header_exists('Content-Length', actual_headers)
assert actual_resp_body == b''
assert not header_exists('Connection', actual_headers)
# Make a 304 request on the same connection.
status_line, actual_headers, actual_resp_body = test_client.get(
'/custom/304', http_conn=http_connection
)
actual_status = int(status_line[:3])
assert actual_status == 304
assert not header_exists('Content-Length', actual_headers)
assert actual_resp_body == b''
assert not header_exists('Connection', actual_headers)
@pytest.mark.xfail(
reason='Server does not correctly read trailers/ending of the previous '
'HTTP request, thus the second request fails as the server tries '
r"to parse b'Content-Type: application/json\r\n' as a "
'Request-Line. This results in HTTP status code 400, instead of 413'
'Ref: https://github.com/cherrypy/cheroot/issues/69'
)
def test_Chunked_Encoding(test_client):
"""Test HTTP uploads with chunked transfer-encoding."""
# Initialize a persistent HTTP connection
conn = test_client.get_connection()
# Try a normal chunked request (with extensions)
body = (
b'8;key=value\r\nxx\r\nxxxx\r\n5\r\nyyyyy\r\n0\r\n'
b'Content-Type: application/json\r\n'
b'\r\n'
)
conn.putrequest('POST', '/upload', skip_host=True)
conn.putheader('Host', conn.host)
conn.putheader('Transfer-Encoding', 'chunked')
conn.putheader('Trailer', 'Content-Type')
# Note that this is somewhat malformed:
# we shouldn't be sending Content-Length.
# RFC 2616 says the server should ignore it.
conn.putheader('Content-Length', '3')
conn.endheaders()
conn.send(body)
response = conn.getresponse()
status_line, actual_headers, actual_resp_body = webtest.shb(response)
actual_status = int(status_line[:3])
assert actual_status == 200
assert status_line[4:] == 'OK'
expected_resp_body = ("thanks for '%s'" % b'xx\r\nxxxxyyyyy').encode()
assert actual_resp_body == expected_resp_body
# Try a chunked request that exceeds server.max_request_body_size.
# Note that the delimiters and trailer are included.
body = b'3e3\r\n' + (b'x' * 995) + b'\r\n0\r\n\r\n'
conn.putrequest('POST', '/upload', skip_host=True)
conn.putheader('Host', conn.host)
conn.putheader('Transfer-Encoding', 'chunked')
conn.putheader('Content-Type', 'text/plain')
# Chunked requests don't need a content-length
# conn.putheader("Content-Length", len(body))
conn.endheaders()
conn.send(body)
response = conn.getresponse()
status_line, actual_headers, actual_resp_body = webtest.shb(response)
actual_status = int(status_line[:3])
assert actual_status == 413
conn.close()
def test_Content_Length_in(test_client):
"""Try a non-chunked request where Content-Length exceeds limit.
(server.max_request_body_size).
Assert error before body send.
"""
# Initialize a persistent HTTP connection
conn = test_client.get_connection()
conn.putrequest('POST', '/upload', skip_host=True)
conn.putheader('Host', conn.host)
conn.putheader('Content-Type', 'text/plain')
conn.putheader('Content-Length', '9999')
conn.endheaders()
response = conn.getresponse()
status_line, actual_headers, actual_resp_body = webtest.shb(response)
actual_status = int(status_line[:3])
assert actual_status == 413
expected_resp_body = (
b'The entity sent with the request exceeds '
b'the maximum allowed bytes.'
)
assert actual_resp_body == expected_resp_body
conn.close()
def test_Content_Length_not_int(test_client):
"""Test that malicious Content-Length header returns 400."""
status_line, actual_headers, actual_resp_body = test_client.post(
'/upload',
headers=[
('Content-Type', 'text/plain'),
('Content-Length', 'not-an-integer'),
],
)
actual_status = int(status_line[:3])
assert actual_status == 400
assert actual_resp_body == b'Malformed Content-Length Header.'
@pytest.mark.parametrize(
'uri,expected_resp_status,expected_resp_body',
(
('/wrong_cl_buffered', 500,
(b'The requested resource returned more bytes than the '
b'declared Content-Length.')),
('/wrong_cl_unbuffered', 200, b'I too'),
)
)
def test_Content_Length_out(
test_client,
uri, expected_resp_status, expected_resp_body
):
"""Test response with Content-Length less than the response body.
(non-chunked response)
"""
conn = test_client.get_connection()
conn.putrequest('GET', uri, skip_host=True)
conn.putheader('Host', conn.host)
conn.endheaders()
response = conn.getresponse()
status_line, actual_headers, actual_resp_body = webtest.shb(response)
actual_status = int(status_line[:3])
assert actual_status == expected_resp_status
assert actual_resp_body == expected_resp_body
conn.close()
@pytest.mark.xfail(
reason='Sometimes this test fails due to low timeout. '
'Ref: https://github.com/cherrypy/cherrypy/issues/598'
)
def test_598(test_client):
"""Test serving large file with a read timeout in place."""
# Initialize a persistent HTTP connection
conn = test_client.get_connection()
remote_data_conn = urllib.request.urlopen(
'%s://%s:%s/one_megabyte_of_a'
% ('http', conn.host, conn.port)
)
buf = remote_data_conn.read(512)
time.sleep(timeout * 0.6)
remaining = (1024 * 1024) - 512
while remaining:
data = remote_data_conn.read(remaining)
if not data:
break
buf += data
remaining -= len(data)
assert len(buf) == 1024 * 1024
assert buf == b'a' * 1024 * 1024
assert remaining == 0
remote_data_conn.close()
@pytest.mark.parametrize(
'invalid_terminator',
(
b'\n\n',
b'\r\n\n',
)
)
def test_No_CRLF(test_client, invalid_terminator):
"""Test HTTP queries with no valid CRLF terminators."""
# Initialize a persistent HTTP connection
conn = test_client.get_connection()
# (b'%s' % b'') is not supported in Python 3.4, so just use +
conn.send(b'GET /hello HTTP/1.1' + invalid_terminator)
response = conn.response_class(conn.sock, method='GET')
response.begin()
actual_resp_body = response.read()
expected_resp_body = b'HTTP requires CRLF terminators'
assert actual_resp_body == expected_resp_body
conn.close()

View file

@ -0,0 +1,405 @@
"""Tests for managing HTTP issues (malformed requests, etc)."""
# -*- coding: utf-8 -*-
# vim: set fileencoding=utf-8 :
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import errno
import socket
import pytest
import six
from six.moves import urllib
from cheroot.test import helper
HTTP_BAD_REQUEST = 400
HTTP_LENGTH_REQUIRED = 411
HTTP_NOT_FOUND = 404
HTTP_OK = 200
HTTP_VERSION_NOT_SUPPORTED = 505
class HelloController(helper.Controller):
"""Controller for serving WSGI apps."""
def hello(req, resp):
"""Render Hello world."""
return 'Hello world!'
def body_required(req, resp):
"""Render Hello world or set 411."""
if req.environ.get('Content-Length', None) is None:
resp.status = '411 Length Required'
return
return 'Hello world!'
def query_string(req, resp):
"""Render QUERY_STRING value."""
return req.environ.get('QUERY_STRING', '')
def asterisk(req, resp):
"""Render request method value."""
method = req.environ.get('REQUEST_METHOD', 'NO METHOD FOUND')
tmpl = 'Got asterisk URI path with {method} method'
return tmpl.format(**locals())
def _munge(string):
"""Encode PATH_INFO correctly depending on Python version.
WSGI 1.0 is a mess around unicode. Create endpoints
that match the PATH_INFO that it produces.
"""
if six.PY3:
return string.encode('utf-8').decode('latin-1')
return string
handlers = {
'/hello': hello,
'/no_body': hello,
'/body_required': body_required,
'/query_string': query_string,
_munge('/привіт'): hello,
_munge('/Юххууу'): hello,
'/\xa0Ðblah key 0 900 4 data': hello,
'/*': asterisk,
}
def _get_http_response(connection, method='GET'):
c = connection
kwargs = {'strict': c.strict} if hasattr(c, 'strict') else {}
# Python 3.2 removed the 'strict' feature, saying:
# "http.client now always assumes HTTP/1.x compliant servers."
return c.response_class(c.sock, method=method, **kwargs)
@pytest.fixture
def testing_server(wsgi_server_client):
"""Attach a WSGI app to the given server and pre-configure it."""
wsgi_server = wsgi_server_client.server_instance
wsgi_server.wsgi_app = HelloController()
wsgi_server.max_request_body_size = 30000000
wsgi_server.server_client = wsgi_server_client
return wsgi_server
@pytest.fixture
def test_client(testing_server):
"""Get and return a test client out of the given server."""
return testing_server.server_client
def test_http_connect_request(test_client):
"""Check that CONNECT query results in Method Not Allowed status."""
status_line = test_client.connect('/anything')[0]
actual_status = int(status_line[:3])
assert actual_status == 405
def test_normal_request(test_client):
"""Check that normal GET query succeeds."""
status_line, _, actual_resp_body = test_client.get('/hello')
actual_status = int(status_line[:3])
assert actual_status == HTTP_OK
assert actual_resp_body == b'Hello world!'
def test_query_string_request(test_client):
"""Check that GET param is parsed well."""
status_line, _, actual_resp_body = test_client.get(
'/query_string?test=True'
)
actual_status = int(status_line[:3])
assert actual_status == HTTP_OK
assert actual_resp_body == b'test=True'
@pytest.mark.parametrize(
'uri',
(
'/hello', # plain
'/query_string?test=True', # query
'/{0}?{1}={2}'.format( # quoted unicode
*map(urllib.parse.quote, ('Юххууу', 'ї', 'йо'))
),
)
)
def test_parse_acceptable_uri(test_client, uri):
"""Check that server responds with OK to valid GET queries."""
status_line = test_client.get(uri)[0]
actual_status = int(status_line[:3])
assert actual_status == HTTP_OK
@pytest.mark.xfail(six.PY2, reason='Fails on Python 2')
def test_parse_uri_unsafe_uri(test_client):
"""Test that malicious URI does not allow HTTP injection.
This effectively checks that sending GET request with URL
/%A0%D0blah%20key%200%20900%204%20data
is not converted into
GET /
blah key 0 900 4 data
HTTP/1.1
which would be a security issue otherwise.
"""
c = test_client.get_connection()
resource = '/\xa0Ðblah key 0 900 4 data'.encode('latin-1')
quoted = urllib.parse.quote(resource)
assert quoted == '/%A0%D0blah%20key%200%20900%204%20data'
request = 'GET {quoted} HTTP/1.1'.format(**locals())
c._output(request.encode('utf-8'))
c._send_output()
response = _get_http_response(c, method='GET')
response.begin()
assert response.status == HTTP_OK
assert response.fp.read(12) == b'Hello world!'
c.close()
def test_parse_uri_invalid_uri(test_client):
"""Check that server responds with Bad Request to invalid GET queries.
Invalid request line test case: it should only contain US-ASCII.
"""
c = test_client.get_connection()
c._output(u'GET /йопта! HTTP/1.1'.encode('utf-8'))
c._send_output()
response = _get_http_response(c, method='GET')
response.begin()
assert response.status == HTTP_BAD_REQUEST
assert response.fp.read(21) == b'Malformed Request-URI'
c.close()
@pytest.mark.parametrize(
'uri',
(
'hello', # ascii
'привіт', # non-ascii
)
)
def test_parse_no_leading_slash_invalid(test_client, uri):
"""Check that server responds with Bad Request to invalid GET queries.
Invalid request line test case: it should have leading slash (be absolute).
"""
status_line, _, actual_resp_body = test_client.get(
urllib.parse.quote(uri)
)
actual_status = int(status_line[:3])
assert actual_status == HTTP_BAD_REQUEST
assert b'starting with a slash' in actual_resp_body
def test_parse_uri_absolute_uri(test_client):
"""Check that server responds with Bad Request to Absolute URI.
Only proxy servers should allow this.
"""
status_line, _, actual_resp_body = test_client.get('http://google.com/')
actual_status = int(status_line[:3])
assert actual_status == HTTP_BAD_REQUEST
expected_body = b'Absolute URI not allowed if server is not a proxy.'
assert actual_resp_body == expected_body
def test_parse_uri_asterisk_uri(test_client):
"""Check that server responds with OK to OPTIONS with "*" Absolute URI."""
status_line, _, actual_resp_body = test_client.options('*')
actual_status = int(status_line[:3])
assert actual_status == HTTP_OK
expected_body = b'Got asterisk URI path with OPTIONS method'
assert actual_resp_body == expected_body
def test_parse_uri_fragment_uri(test_client):
"""Check that server responds with Bad Request to URI with fragment."""
status_line, _, actual_resp_body = test_client.get(
'/hello?test=something#fake',
)
actual_status = int(status_line[:3])
assert actual_status == HTTP_BAD_REQUEST
expected_body = b'Illegal #fragment in Request-URI.'
assert actual_resp_body == expected_body
def test_no_content_length(test_client):
"""Test POST query with an empty body being successful."""
# "The presence of a message-body in a request is signaled by the
# inclusion of a Content-Length or Transfer-Encoding header field in
# the request's message-headers."
#
# Send a message with neither header and no body.
c = test_client.get_connection()
c.request('POST', '/no_body')
response = c.getresponse()
actual_resp_body = response.fp.read()
actual_status = response.status
assert actual_status == HTTP_OK
assert actual_resp_body == b'Hello world!'
def test_content_length_required(test_client):
"""Test POST query with body failing because of missing Content-Length."""
# Now send a message that has no Content-Length, but does send a body.
# Verify that CP times out the socket and responds
# with 411 Length Required.
c = test_client.get_connection()
c.request('POST', '/body_required')
response = c.getresponse()
response.fp.read()
actual_status = response.status
assert actual_status == HTTP_LENGTH_REQUIRED
@pytest.mark.parametrize(
'request_line,status_code,expected_body',
(
(b'GET /', # missing proto
HTTP_BAD_REQUEST, b'Malformed Request-Line'),
(b'GET / HTTPS/1.1', # invalid proto
HTTP_BAD_REQUEST, b'Malformed Request-Line: bad protocol'),
(b'GET / HTTP/2.15', # invalid ver
HTTP_VERSION_NOT_SUPPORTED, b'Cannot fulfill request'),
)
)
def test_malformed_request_line(
test_client, request_line,
status_code, expected_body
):
"""Test missing or invalid HTTP version in Request-Line."""
c = test_client.get_connection()
c._output(request_line)
c._send_output()
response = _get_http_response(c, method='GET')
response.begin()
assert response.status == status_code
assert response.fp.read(len(expected_body)) == expected_body
c.close()
def test_malformed_http_method(test_client):
"""Test non-uppercase HTTP method."""
c = test_client.get_connection()
c.putrequest('GeT', '/malformed_method_case')
c.putheader('Content-Type', 'text/plain')
c.endheaders()
response = c.getresponse()
actual_status = response.status
assert actual_status == HTTP_BAD_REQUEST
actual_resp_body = response.fp.read(21)
assert actual_resp_body == b'Malformed method name'
def test_malformed_header(test_client):
"""Check that broken HTTP header results in Bad Request."""
c = test_client.get_connection()
c.putrequest('GET', '/')
c.putheader('Content-Type', 'text/plain')
# See https://www.bitbucket.org/cherrypy/cherrypy/issue/941
c._output(b'Re, 1.2.3.4#015#012')
c.endheaders()
response = c.getresponse()
actual_status = response.status
assert actual_status == HTTP_BAD_REQUEST
actual_resp_body = response.fp.read(20)
assert actual_resp_body == b'Illegal header line.'
def test_request_line_split_issue_1220(test_client):
"""Check that HTTP request line of exactly 256 chars length is OK."""
Request_URI = (
'/hello?'
'intervenant-entreprise-evenement_classaction='
'evenement-mailremerciements'
'&_path=intervenant-entreprise-evenement'
'&intervenant-entreprise-evenement_action-id=19404'
'&intervenant-entreprise-evenement_id=19404'
'&intervenant-entreprise_id=28092'
)
assert len('GET %s HTTP/1.1\r\n' % Request_URI) == 256
actual_resp_body = test_client.get(Request_URI)[2]
assert actual_resp_body == b'Hello world!'
def test_garbage_in(test_client):
"""Test that server sends an error for garbage received over TCP."""
# Connect without SSL regardless of server.scheme
c = test_client.get_connection()
c._output(b'gjkgjklsgjklsgjkljklsg')
c._send_output()
response = c.response_class(c.sock, method='GET')
try:
response.begin()
actual_status = response.status
assert actual_status == HTTP_BAD_REQUEST
actual_resp_body = response.fp.read(22)
assert actual_resp_body == b'Malformed Request-Line'
c.close()
except socket.error as ex:
# "Connection reset by peer" is also acceptable.
if ex.errno != errno.ECONNRESET:
raise
class CloseController:
"""Controller for testing the close callback."""
def __call__(self, environ, start_response):
"""Get the req to know header sent status."""
self.req = start_response.__self__.req
resp = CloseResponse(self.close)
start_response(resp.status, resp.headers.items())
return resp
def close(self):
"""Close, writing hello."""
self.req.write(b'hello')
class CloseResponse:
"""Dummy empty response to trigger the no body status."""
def __init__(self, close):
"""Use some defaults to ensure we have a header."""
self.status = '200 OK'
self.headers = {'Content-Type': 'text/html'}
self.close = close
def __getitem__(self, index):
"""Ensure we don't have a body."""
raise IndexError()
def output(self):
"""Return self to hook the close method."""
return self
@pytest.fixture
def testing_server_close(wsgi_server_client):
"""Attach a WSGI app to the given server and pre-configure it."""
wsgi_server = wsgi_server_client.server_instance
wsgi_server.wsgi_app = CloseController()
wsgi_server.max_request_body_size = 30000000
wsgi_server.server_client = wsgi_server_client
return wsgi_server
def test_send_header_before_closing(testing_server_close):
"""Test we are actually sending the headers before calling 'close'."""
_, _, resp_body = testing_server_close.server_client.get('/')
assert resp_body == b'hello'

View file

@ -0,0 +1,193 @@
"""Tests for the HTTP server."""
# -*- coding: utf-8 -*-
# vim: set fileencoding=utf-8 :
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import os
import socket
import tempfile
import threading
import time
import pytest
from .._compat import bton
from ..server import Gateway, HTTPServer
from ..testing import (
ANY_INTERFACE_IPV4,
ANY_INTERFACE_IPV6,
EPHEMERAL_PORT,
get_server_client,
)
def make_http_server(bind_addr):
"""Create and start an HTTP server bound to bind_addr."""
httpserver = HTTPServer(
bind_addr=bind_addr,
gateway=Gateway,
)
threading.Thread(target=httpserver.safe_start).start()
while not httpserver.ready:
time.sleep(0.1)
return httpserver
non_windows_sock_test = pytest.mark.skipif(
not hasattr(socket, 'AF_UNIX'),
reason='UNIX domain sockets are only available under UNIX-based OS',
)
@pytest.fixture
def http_server():
"""Provision a server creator as a fixture."""
def start_srv():
bind_addr = yield
httpserver = make_http_server(bind_addr)
yield httpserver
yield httpserver
srv_creator = iter(start_srv())
next(srv_creator)
yield srv_creator
try:
while True:
httpserver = next(srv_creator)
if httpserver is not None:
httpserver.stop()
except StopIteration:
pass
@pytest.fixture
def unix_sock_file():
"""Check that bound UNIX socket address is stored in server."""
tmp_sock_fh, tmp_sock_fname = tempfile.mkstemp()
yield tmp_sock_fname
os.close(tmp_sock_fh)
os.unlink(tmp_sock_fname)
def test_prepare_makes_server_ready():
"""Check that prepare() makes the server ready, and stop() clears it."""
httpserver = HTTPServer(
bind_addr=(ANY_INTERFACE_IPV4, EPHEMERAL_PORT),
gateway=Gateway,
)
assert not httpserver.ready
assert not httpserver.requests._threads
httpserver.prepare()
assert httpserver.ready
assert httpserver.requests._threads
for thr in httpserver.requests._threads:
assert thr.ready
httpserver.stop()
assert not httpserver.requests._threads
assert not httpserver.ready
def test_stop_interrupts_serve():
"""Check that stop() interrupts running of serve()."""
httpserver = HTTPServer(
bind_addr=(ANY_INTERFACE_IPV4, EPHEMERAL_PORT),
gateway=Gateway,
)
httpserver.prepare()
serve_thread = threading.Thread(target=httpserver.serve)
serve_thread.start()
serve_thread.join(0.5)
assert serve_thread.is_alive()
httpserver.stop()
serve_thread.join(0.5)
assert not serve_thread.is_alive()
@pytest.mark.parametrize(
'ip_addr',
(
ANY_INTERFACE_IPV4,
ANY_INTERFACE_IPV6,
)
)
def test_bind_addr_inet(http_server, ip_addr):
"""Check that bound IP address is stored in server."""
httpserver = http_server.send((ip_addr, EPHEMERAL_PORT))
assert httpserver.bind_addr[0] == ip_addr
assert httpserver.bind_addr[1] != EPHEMERAL_PORT
@non_windows_sock_test
def test_bind_addr_unix(http_server, unix_sock_file):
"""Check that bound UNIX socket address is stored in server."""
httpserver = http_server.send(unix_sock_file)
assert httpserver.bind_addr == unix_sock_file
@pytest.mark.skip(reason="Abstract sockets don't work currently")
@non_windows_sock_test
def test_bind_addr_unix_abstract(http_server):
"""Check that bound UNIX socket address is stored in server."""
unix_abstract_sock = b'\x00cheroot/test/socket/here.sock'
httpserver = http_server.send(unix_abstract_sock)
assert httpserver.bind_addr == unix_abstract_sock
PEERCRED_IDS_URI = '/peer_creds/ids'
PEERCRED_TEXTS_URI = '/peer_creds/texts'
class _TestGateway(Gateway):
def respond(self):
req = self.req
conn = req.conn
req_uri = bton(req.uri)
if req_uri == PEERCRED_IDS_URI:
peer_creds = conn.peer_pid, conn.peer_uid, conn.peer_gid
return ['|'.join(map(str, peer_creds))]
elif req_uri == PEERCRED_TEXTS_URI:
return ['!'.join((conn.peer_user, conn.peer_group))]
return super(_TestGateway, self).respond()
@pytest.mark.skip(
reason='Test HTTP client is not able to work through UNIX socket currently'
)
@non_windows_sock_test
def test_peercreds_unix_sock(http_server, unix_sock_file):
"""Check that peercred lookup and resolution work when enabled."""
httpserver = http_server.send(unix_sock_file)
httpserver.gateway = _TestGateway
httpserver.peercreds_enabled = True
testclient = get_server_client(httpserver)
expected_peercreds = os.getpid(), os.getuid(), os.getgid()
expected_peercreds = '|'.join(map(str, expected_peercreds))
assert testclient.get(PEERCRED_IDS_URI) == expected_peercreds
assert 'RuntimeError' in testclient.get(PEERCRED_TEXTS_URI)
httpserver.peercreds_resolve_enabled = True
import grp
expected_textcreds = os.getlogin(), grp.getgrgid(os.getgid()).gr_name
expected_textcreds = '!'.join(map(str, expected_textcreds))
assert testclient.get(PEERCRED_TEXTS_URI) == expected_textcreds

View file

@ -0,0 +1,581 @@
"""Extensions to unittest for web frameworks.
Use the WebCase.getPage method to request a page from your HTTP server.
Framework Integration
=====================
If you have control over your server process, you can handle errors
in the server-side of the HTTP conversation a bit better. You must run
both the client (your WebCase tests) and the server in the same process
(but in separate threads, obviously).
When an error occurs in the framework, call server_error. It will print
the traceback to stdout, and keep any assertions you have from running
(the assumption is that, if the server errors, the page output will not
be of further significance to your tests).
"""
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import pprint
import re
import socket
import sys
import time
import traceback
import os
import json
import unittest
import warnings
from six.moves import range, http_client, map, urllib_parse
import six
from more_itertools.more import always_iterable
def interface(host):
"""Return an IP address for a client connection given the server host.
If the server is listening on '0.0.0.0' (INADDR_ANY)
or '::' (IN6ADDR_ANY), this will return the proper localhost.
"""
if host == '0.0.0.0':
# INADDR_ANY, which should respond on localhost.
return '127.0.0.1'
if host == '::':
# IN6ADDR_ANY, which should respond on localhost.
return '::1'
return host
try:
# Jython support
if sys.platform[:4] == 'java':
def getchar():
"""Get a key press."""
# Hopefully this is enough
return sys.stdin.read(1)
else:
# On Windows, msvcrt.getch reads a single char without output.
import msvcrt
def getchar():
"""Get a key press."""
return msvcrt.getch()
except ImportError:
# Unix getchr
import tty
import termios
def getchar():
"""Get a key press."""
fd = sys.stdin.fileno()
old_settings = termios.tcgetattr(fd)
try:
tty.setraw(sys.stdin.fileno())
ch = sys.stdin.read(1)
finally:
termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
return ch
# from jaraco.properties
class NonDataProperty:
"""Non-data property decorator."""
def __init__(self, fget):
"""Initialize a non-data property."""
assert fget is not None, 'fget cannot be none'
assert callable(fget), 'fget must be callable'
self.fget = fget
def __get__(self, obj, objtype=None):
"""Return a class property."""
if obj is None:
return self
return self.fget(obj)
class WebCase(unittest.TestCase):
"""Helper web test suite base."""
HOST = '127.0.0.1'
PORT = 8000
HTTP_CONN = http_client.HTTPConnection
PROTOCOL = 'HTTP/1.1'
scheme = 'http'
url = None
status = None
headers = None
body = None
encoding = 'utf-8'
time = None
@property
def _Conn(self):
"""Return HTTPConnection or HTTPSConnection based on self.scheme.
* from http.client.
"""
cls_name = '{scheme}Connection'.format(scheme=self.scheme.upper())
return getattr(http_client, cls_name)
def get_conn(self, auto_open=False):
"""Return a connection to our HTTP server."""
conn = self._Conn(self.interface(), self.PORT)
# Automatically re-connect?
conn.auto_open = auto_open
conn.connect()
return conn
def set_persistent(self, on=True, auto_open=False):
"""Make our HTTP_CONN persistent (or not).
If the 'on' argument is True (the default), then self.HTTP_CONN
will be set to an instance of HTTP(S)?Connection
to persist across requests.
As this class only allows for a single open connection, if
self already has an open connection, it will be closed.
"""
try:
self.HTTP_CONN.close()
except (TypeError, AttributeError):
pass
self.HTTP_CONN = (
self.get_conn(auto_open=auto_open)
if on
else self._Conn
)
@property
def persistent(self): # noqa: D401; irrelevant for properties
"""Presense of the persistent HTTP connection."""
return hasattr(self.HTTP_CONN, '__class__')
@persistent.setter
def persistent(self, on):
self.set_persistent(on)
def interface(self):
"""Return an IP address for a client connection.
If the server is listening on '0.0.0.0' (INADDR_ANY)
or '::' (IN6ADDR_ANY), this will return the proper localhost.
"""
return interface(self.HOST)
def getPage(self, url, headers=None, method='GET', body=None,
protocol=None, raise_subcls=None):
"""Open the url with debugging support. Return status, headers, body.
url should be the identifier passed to the server, typically a
server-absolute path and query string (sent between method and
protocol), and should only be an absolute URI if proxy support is
enabled in the server.
If the application under test generates absolute URIs, be sure
to wrap them first with strip_netloc::
class MyAppWebCase(WebCase):
def getPage(url, *args, **kwargs):
super(MyAppWebCase, self).getPage(
cheroot.test.webtest.strip_netloc(url),
*args, **kwargs
)
`raise_subcls` must be a tuple with the exceptions classes
or a single exception class that are not going to be considered
a socket.error regardless that they were are subclass of a
socket.error and therefore not considered for a connection retry.
"""
ServerError.on = False
if isinstance(url, six.text_type):
url = url.encode('utf-8')
if isinstance(body, six.text_type):
body = body.encode('utf-8')
self.url = url
self.time = None
start = time.time()
result = openURL(url, headers, method, body, self.HOST, self.PORT,
self.HTTP_CONN, protocol or self.PROTOCOL,
raise_subcls)
self.time = time.time() - start
self.status, self.headers, self.body = result
# Build a list of request cookies from the previous response cookies.
self.cookies = [('Cookie', v) for k, v in self.headers
if k.lower() == 'set-cookie']
if ServerError.on:
raise ServerError()
return result
@NonDataProperty
def interactive(self):
"""Determine whether tests are run in interactive mode.
Load interactivity setting from environment, where
the value can be numeric or a string like true or
False or 1 or 0.
"""
env_str = os.environ.get('WEBTEST_INTERACTIVE', 'True')
is_interactive = bool(json.loads(env_str.lower()))
if is_interactive:
warnings.warn(
'Interactive test failure interceptor support via '
'WEBTEST_INTERACTIVE environment variable is deprecated.',
DeprecationWarning
)
return is_interactive
console_height = 30
def _handlewebError(self, msg):
print('')
print(' ERROR: %s' % msg)
if not self.interactive:
raise self.failureException(msg)
p = (' Show: '
'[B]ody [H]eaders [S]tatus [U]RL; '
'[I]gnore, [R]aise, or sys.e[X]it >> ')
sys.stdout.write(p)
sys.stdout.flush()
while True:
i = getchar().upper()
if not isinstance(i, type('')):
i = i.decode('ascii')
if i not in 'BHSUIRX':
continue
print(i.upper()) # Also prints new line
if i == 'B':
for x, line in enumerate(self.body.splitlines()):
if (x + 1) % self.console_height == 0:
# The \r and comma should make the next line overwrite
sys.stdout.write('<-- More -->\r')
m = getchar().lower()
# Erase our "More" prompt
sys.stdout.write(' \r')
if m == 'q':
break
print(line)
elif i == 'H':
pprint.pprint(self.headers)
elif i == 'S':
print(self.status)
elif i == 'U':
print(self.url)
elif i == 'I':
# return without raising the normal exception
return
elif i == 'R':
raise self.failureException(msg)
elif i == 'X':
sys.exit()
sys.stdout.write(p)
sys.stdout.flush()
@property
def status_code(self): # noqa: D401; irrelevant for properties
"""Integer HTTP status code."""
return int(self.status[:3])
def status_matches(self, expected):
"""Check whether actual status matches expected."""
actual = (
self.status_code
if isinstance(expected, int) else
self.status
)
return expected == actual
def assertStatus(self, status, msg=None):
"""Fail if self.status != status.
status may be integer code, exact string status, or
iterable of allowed possibilities.
"""
if any(map(self.status_matches, always_iterable(status))):
return
tmpl = 'Status {self.status} does not match {status}'
msg = msg or tmpl.format(**locals())
self._handlewebError(msg)
def assertHeader(self, key, value=None, msg=None):
"""Fail if (key, [value]) not in self.headers."""
lowkey = key.lower()
for k, v in self.headers:
if k.lower() == lowkey:
if value is None or str(value) == v:
return v
if msg is None:
if value is None:
msg = '%r not in headers' % key
else:
msg = '%r:%r not in headers' % (key, value)
self._handlewebError(msg)
def assertHeaderIn(self, key, values, msg=None):
"""Fail if header indicated by key doesn't have one of the values."""
lowkey = key.lower()
for k, v in self.headers:
if k.lower() == lowkey:
matches = [value for value in values if str(value) == v]
if matches:
return matches
if msg is None:
msg = '%(key)r not in %(values)r' % vars()
self._handlewebError(msg)
def assertHeaderItemValue(self, key, value, msg=None):
"""Fail if the header does not contain the specified value."""
actual_value = self.assertHeader(key, msg=msg)
header_values = map(str.strip, actual_value.split(','))
if value in header_values:
return value
if msg is None:
msg = '%r not in %r' % (value, header_values)
self._handlewebError(msg)
def assertNoHeader(self, key, msg=None):
"""Fail if key in self.headers."""
lowkey = key.lower()
matches = [k for k, v in self.headers if k.lower() == lowkey]
if matches:
if msg is None:
msg = '%r in headers' % key
self._handlewebError(msg)
def assertNoHeaderItemValue(self, key, value, msg=None):
"""Fail if the header contains the specified value."""
lowkey = key.lower()
hdrs = self.headers
matches = [k for k, v in hdrs if k.lower() == lowkey and v == value]
if matches:
if msg is None:
msg = '%r:%r in %r' % (key, value, hdrs)
self._handlewebError(msg)
def assertBody(self, value, msg=None):
"""Fail if value != self.body."""
if isinstance(value, six.text_type):
value = value.encode(self.encoding)
if value != self.body:
if msg is None:
msg = 'expected body:\n%r\n\nactual body:\n%r' % (
value, self.body)
self._handlewebError(msg)
def assertInBody(self, value, msg=None):
"""Fail if value not in self.body."""
if isinstance(value, six.text_type):
value = value.encode(self.encoding)
if value not in self.body:
if msg is None:
msg = '%r not in body: %s' % (value, self.body)
self._handlewebError(msg)
def assertNotInBody(self, value, msg=None):
"""Fail if value in self.body."""
if isinstance(value, six.text_type):
value = value.encode(self.encoding)
if value in self.body:
if msg is None:
msg = '%r found in body' % value
self._handlewebError(msg)
def assertMatchesBody(self, pattern, msg=None, flags=0):
"""Fail if value (a regex pattern) is not in self.body."""
if isinstance(pattern, six.text_type):
pattern = pattern.encode(self.encoding)
if re.search(pattern, self.body, flags) is None:
if msg is None:
msg = 'No match for %r in body' % pattern
self._handlewebError(msg)
methods_with_bodies = ('POST', 'PUT', 'PATCH')
def cleanHeaders(headers, method, body, host, port):
"""Return request headers, with required headers added (if missing)."""
if headers is None:
headers = []
# Add the required Host request header if not present.
# [This specifies the host:port of the server, not the client.]
found = False
for k, v in headers:
if k.lower() == 'host':
found = True
break
if not found:
if port == 80:
headers.append(('Host', host))
else:
headers.append(('Host', '%s:%s' % (host, port)))
if method in methods_with_bodies:
# Stick in default type and length headers if not present
found = False
for k, v in headers:
if k.lower() == 'content-type':
found = True
break
if not found:
headers.append(
('Content-Type', 'application/x-www-form-urlencoded'))
headers.append(('Content-Length', str(len(body or ''))))
return headers
def shb(response):
"""Return status, headers, body the way we like from a response."""
if six.PY3:
h = response.getheaders()
else:
h = []
key, value = None, None
for line in response.msg.headers:
if line:
if line[0] in ' \t':
value += line.strip()
else:
if key and value:
h.append((key, value))
key, value = line.split(':', 1)
key = key.strip()
value = value.strip()
if key and value:
h.append((key, value))
return '%s %s' % (response.status, response.reason), h, response.read()
def openURL(url, headers=None, method='GET', body=None,
host='127.0.0.1', port=8000, http_conn=http_client.HTTPConnection,
protocol='HTTP/1.1', raise_subcls=None):
"""
Open the given HTTP resource and return status, headers, and body.
`raise_subcls` must be a tuple with the exceptions classes
or a single exception class that are not going to be considered
a socket.error regardless that they were are subclass of a
socket.error and therefore not considered for a connection retry.
"""
headers = cleanHeaders(headers, method, body, host, port)
# Trying 10 times is simply in case of socket errors.
# Normal case--it should run once.
for trial in range(10):
try:
# Allow http_conn to be a class or an instance
if hasattr(http_conn, 'host'):
conn = http_conn
else:
conn = http_conn(interface(host), port)
conn._http_vsn_str = protocol
conn._http_vsn = int(''.join([x for x in protocol if x.isdigit()]))
if six.PY3 and isinstance(url, bytes):
url = url.decode()
conn.putrequest(method.upper(), url, skip_host=True,
skip_accept_encoding=True)
for key, value in headers:
conn.putheader(key, value.encode('Latin-1'))
conn.endheaders()
if body is not None:
conn.send(body)
# Handle response
response = conn.getresponse()
s, h, b = shb(response)
if not hasattr(http_conn, 'host'):
# We made our own conn instance. Close it.
conn.close()
return s, h, b
except socket.error as e:
if raise_subcls is not None and isinstance(e, raise_subcls):
raise
else:
time.sleep(0.5)
if trial == 9:
raise
def strip_netloc(url):
"""Return absolute-URI path from URL.
Strip the scheme and host from the URL, returning the
server-absolute portion.
Useful for wrapping an absolute-URI for which only the
path is expected (such as in calls to getPage).
>>> strip_netloc('https://google.com/foo/bar?bing#baz')
'/foo/bar?bing'
>>> strip_netloc('//google.com/foo/bar?bing#baz')
'/foo/bar?bing'
>>> strip_netloc('/foo/bar?bing#baz')
'/foo/bar?bing'
"""
parsed = urllib_parse.urlparse(url)
scheme, netloc, path, params, query, fragment = parsed
stripped = '', '', path, params, query, ''
return urllib_parse.urlunparse(stripped)
# Add any exceptions which your web framework handles
# normally (that you don't want server_error to trap).
ignored_exceptions = []
# You'll want set this to True when you can't guarantee
# that each response will immediately follow each request;
# for example, when handling requests via multiple threads.
ignore_all = False
class ServerError(Exception):
"""Exception for signalling server error."""
on = False
def server_error(exc=None):
"""Server debug hook.
Return True if exception handled, False if ignored.
You probably want to wrap this, so you can still handle an error using
your framework when it's ignored.
"""
if exc is None:
exc = sys.exc_info()
if ignore_all or exc[0] in ignored_exceptions:
return False
else:
ServerError.on = True
print('')
print(''.join(traceback.format_exception(*exc)))
return True

View file

@ -0,0 +1,144 @@
"""Pytest fixtures and other helpers for doing testing by end-users."""
from __future__ import absolute_import, division, print_function
__metaclass__ = type
from contextlib import closing
import errno
import socket
import threading
import time
import pytest
from six.moves import http_client
import cheroot.server
from cheroot.test import webtest
import cheroot.wsgi
EPHEMERAL_PORT = 0
NO_INTERFACE = None # Using this or '' will cause an exception
ANY_INTERFACE_IPV4 = '0.0.0.0'
ANY_INTERFACE_IPV6 = '::'
config = {
cheroot.wsgi.Server: {
'bind_addr': (NO_INTERFACE, EPHEMERAL_PORT),
'wsgi_app': None,
},
cheroot.server.HTTPServer: {
'bind_addr': (NO_INTERFACE, EPHEMERAL_PORT),
'gateway': cheroot.server.Gateway,
},
}
def cheroot_server(server_factory):
"""Set up and tear down a Cheroot server instance."""
conf = config[server_factory].copy()
bind_port = conf.pop('bind_addr')[-1]
for interface in ANY_INTERFACE_IPV6, ANY_INTERFACE_IPV4:
try:
actual_bind_addr = (interface, bind_port)
httpserver = server_factory( # create it
bind_addr=actual_bind_addr,
**conf
)
except OSError:
pass
else:
break
threading.Thread(target=httpserver.safe_start).start() # spawn it
while not httpserver.ready: # wait until fully initialized and bound
time.sleep(0.1)
yield httpserver
httpserver.stop() # destroy it
@pytest.fixture(scope='module')
def wsgi_server():
"""Set up and tear down a Cheroot WSGI server instance."""
for srv in cheroot_server(cheroot.wsgi.Server):
yield srv
@pytest.fixture(scope='module')
def native_server():
"""Set up and tear down a Cheroot HTTP server instance."""
for srv in cheroot_server(cheroot.server.HTTPServer):
yield srv
class _TestClient:
def __init__(self, server):
self._interface, self._host, self._port = _get_conn_data(server)
self._http_connection = self.get_connection()
self.server_instance = server
def get_connection(self):
name = '{interface}:{port}'.format(
interface=self._interface,
port=self._port,
)
return http_client.HTTPConnection(name)
def request(
self, uri, method='GET', headers=None, http_conn=None,
protocol='HTTP/1.1',
):
return webtest.openURL(
uri, method=method,
headers=headers,
host=self._host, port=self._port,
http_conn=http_conn or self._http_connection,
protocol=protocol,
)
def __getattr__(self, attr_name):
def _wrapper(uri, **kwargs):
http_method = attr_name.upper()
return self.request(uri, method=http_method, **kwargs)
return _wrapper
def _probe_ipv6_sock(interface):
# Alternate way is to check IPs on interfaces using glibc, like:
# github.com/Gautier/minifail/blob/master/minifail/getifaddrs.py
try:
with closing(socket.socket(family=socket.AF_INET6)) as sock:
sock.bind((interface, 0))
except (OSError, socket.error) as sock_err:
# In Python 3 socket.error is an alias for OSError
# In Python 2 socket.error is a subclass of IOError
if sock_err.errno != errno.EADDRNOTAVAIL:
raise
else:
return True
return False
def _get_conn_data(server):
if isinstance(server.bind_addr, tuple):
host, port = server.bind_addr
else:
host, port = server.bind_addr, 0
interface = webtest.interface(host)
if ':' in interface and not _probe_ipv6_sock(interface):
interface = '127.0.0.1'
if ':' in host:
host = interface
return interface, host, port
def get_server_client(server):
"""Create and return a test client for the given server."""
return _TestClient(server)

View file

@ -0,0 +1 @@
"""HTTP workers pool."""

View file

@ -0,0 +1,271 @@
"""A thread-based worker pool."""
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import threading
import time
import socket
from six.moves import queue
__all__ = ('WorkerThread', 'ThreadPool')
class TrueyZero:
"""Object which equals and does math like the integer 0 but evals True."""
def __add__(self, other):
return other
def __radd__(self, other):
return other
trueyzero = TrueyZero()
_SHUTDOWNREQUEST = None
class WorkerThread(threading.Thread):
"""Thread which continuously polls a Queue for Connection objects.
Due to the timing issues of polling a Queue, a WorkerThread does not
check its own 'ready' flag after it has started. To stop the thread,
it is necessary to stick a _SHUTDOWNREQUEST object onto the Queue
(one for each running WorkerThread).
"""
conn = None
"""The current connection pulled off the Queue, or None."""
server = None
"""The HTTP Server which spawned this thread, and which owns the
Queue and is placing active connections into it."""
ready = False
"""A simple flag for the calling server to know when this thread
has begun polling the Queue."""
def __init__(self, server):
"""Initialize WorkerThread instance.
Args:
server (cheroot.server.HTTPServer): web server object
receiving this request
"""
self.ready = False
self.server = server
self.requests_seen = 0
self.bytes_read = 0
self.bytes_written = 0
self.start_time = None
self.work_time = 0
self.stats = {
'Requests': lambda s: self.requests_seen + (
(self.start_time is None) and
trueyzero or
self.conn.requests_seen
),
'Bytes Read': lambda s: self.bytes_read + (
(self.start_time is None) and
trueyzero or
self.conn.rfile.bytes_read
),
'Bytes Written': lambda s: self.bytes_written + (
(self.start_time is None) and
trueyzero or
self.conn.wfile.bytes_written
),
'Work Time': lambda s: self.work_time + (
(self.start_time is None) and
trueyzero or
time.time() - self.start_time
),
'Read Throughput': lambda s: s['Bytes Read'](s) / (
s['Work Time'](s) or 1e-6),
'Write Throughput': lambda s: s['Bytes Written'](s) / (
s['Work Time'](s) or 1e-6),
}
threading.Thread.__init__(self)
def run(self):
"""Process incoming HTTP connections.
Retrieves incoming connections from thread pool.
"""
self.server.stats['Worker Threads'][self.getName()] = self.stats
try:
self.ready = True
while True:
conn = self.server.requests.get()
if conn is _SHUTDOWNREQUEST:
return
self.conn = conn
if self.server.stats['Enabled']:
self.start_time = time.time()
try:
conn.communicate()
finally:
conn.close()
if self.server.stats['Enabled']:
self.requests_seen += self.conn.requests_seen
self.bytes_read += self.conn.rfile.bytes_read
self.bytes_written += self.conn.wfile.bytes_written
self.work_time += time.time() - self.start_time
self.start_time = None
self.conn = None
except (KeyboardInterrupt, SystemExit) as ex:
self.server.interrupt = ex
class ThreadPool:
"""A Request Queue for an HTTPServer which pools threads.
ThreadPool objects must provide min, get(), put(obj), start()
and stop(timeout) attributes.
"""
def __init__(
self, server, min=10, max=-1, accepted_queue_size=-1,
accepted_queue_timeout=10):
"""Initialize HTTP requests queue instance.
Args:
server (cheroot.server.HTTPServer): web server object
receiving this request
min (int): minimum number of worker threads
max (int): maximum number of worker threads
accepted_queue_size (int): maximum number of active
requests in queue
accepted_queue_timeout (int): timeout for putting request
into queue
"""
self.server = server
self.min = min
self.max = max
self._threads = []
self._queue = queue.Queue(maxsize=accepted_queue_size)
self._queue_put_timeout = accepted_queue_timeout
self.get = self._queue.get
def start(self):
"""Start the pool of threads."""
for i in range(self.min):
self._threads.append(WorkerThread(self.server))
for worker in self._threads:
worker.setName('CP Server ' + worker.getName())
worker.start()
for worker in self._threads:
while not worker.ready:
time.sleep(.1)
@property
def idle(self): # noqa: D401; irrelevant for properties
"""Number of worker threads which are idle. Read-only."""
return len([t for t in self._threads if t.conn is None])
def put(self, obj):
"""Put request into queue.
Args:
obj (cheroot.server.HTTPConnection): HTTP connection
waiting to be processed
"""
self._queue.put(obj, block=True, timeout=self._queue_put_timeout)
if obj is _SHUTDOWNREQUEST:
return
def grow(self, amount):
"""Spawn new worker threads (not above self.max)."""
if self.max > 0:
budget = max(self.max - len(self._threads), 0)
else:
# self.max <= 0 indicates no maximum
budget = float('inf')
n_new = min(amount, budget)
workers = [self._spawn_worker() for i in range(n_new)]
while not all(worker.ready for worker in workers):
time.sleep(.1)
self._threads.extend(workers)
def _spawn_worker(self):
worker = WorkerThread(self.server)
worker.setName('CP Server ' + worker.getName())
worker.start()
return worker
def shrink(self, amount):
"""Kill off worker threads (not below self.min)."""
# Grow/shrink the pool if necessary.
# Remove any dead threads from our list
for t in self._threads:
if not t.isAlive():
self._threads.remove(t)
amount -= 1
# calculate the number of threads above the minimum
n_extra = max(len(self._threads) - self.min, 0)
# don't remove more than amount
n_to_remove = min(amount, n_extra)
# put shutdown requests on the queue equal to the number of threads
# to remove. As each request is processed by a worker, that worker
# will terminate and be culled from the list.
for n in range(n_to_remove):
self._queue.put(_SHUTDOWNREQUEST)
def stop(self, timeout=5):
"""Terminate all worker threads.
Args:
timeout (int): time to wait for threads to stop gracefully
"""
# Must shut down threads here so the code that calls
# this method can know when all threads are stopped.
for worker in self._threads:
self._queue.put(_SHUTDOWNREQUEST)
# Don't join currentThread (when stop is called inside a request).
current = threading.currentThread()
if timeout is not None and timeout >= 0:
endtime = time.time() + timeout
while self._threads:
worker = self._threads.pop()
if worker is not current and worker.isAlive():
try:
if timeout is None or timeout < 0:
worker.join()
else:
remaining_time = endtime - time.time()
if remaining_time > 0:
worker.join(remaining_time)
if worker.isAlive():
# We exhausted the timeout.
# Forcibly shut down the socket.
c = worker.conn
if c and not c.rfile.closed:
try:
c.socket.shutdown(socket.SHUT_RD)
except TypeError:
# pyOpenSSL sockets don't take an arg
c.socket.shutdown()
worker.join()
except (AssertionError,
# Ignore repeated Ctrl-C.
# See
# https://github.com/cherrypy/cherrypy/issues/691.
KeyboardInterrupt):
pass
@property
def qsize(self):
"""Return the queue size."""
return self._queue.qsize()

423
libraries/cheroot/wsgi.py Normal file
View file

@ -0,0 +1,423 @@
"""This class holds Cheroot WSGI server implementation.
Simplest example on how to use this server::
from cheroot import wsgi
def my_crazy_app(environ, start_response):
status = '200 OK'
response_headers = [('Content-type','text/plain')]
start_response(status, response_headers)
return [b'Hello world!']
addr = '0.0.0.0', 8070
server = wsgi.Server(addr, my_crazy_app)
server.start()
The Cheroot WSGI server can serve as many WSGI applications
as you want in one instance by using a PathInfoDispatcher::
path_map = {
'/': my_crazy_app,
'/blog': my_blog_app,
}
d = wsgi.PathInfoDispatcher(path_map)
server = wsgi.Server(addr, d)
"""
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import sys
import six
from six.moves import filter
from . import server
from .workers import threadpool
from ._compat import ntob, bton
class Server(server.HTTPServer):
"""A subclass of HTTPServer which calls a WSGI application."""
wsgi_version = (1, 0)
"""The version of WSGI to produce."""
def __init__(
self, bind_addr, wsgi_app, numthreads=10, server_name=None,
max=-1, request_queue_size=5, timeout=10, shutdown_timeout=5,
accepted_queue_size=-1, accepted_queue_timeout=10,
peercreds_enabled=False, peercreds_resolve_enabled=False,
):
"""Initialize WSGI Server instance.
Args:
bind_addr (tuple): network interface to listen to
wsgi_app (callable): WSGI application callable
numthreads (int): number of threads for WSGI thread pool
server_name (str): web server name to be advertised via
Server HTTP header
max (int): maximum number of worker threads
request_queue_size (int): the 'backlog' arg to
socket.listen(); max queued connections
timeout (int): the timeout in seconds for accepted connections
shutdown_timeout (int): the total time, in seconds, to
wait for worker threads to cleanly exit
accepted_queue_size (int): maximum number of active
requests in queue
accepted_queue_timeout (int): timeout for putting request
into queue
"""
super(Server, self).__init__(
bind_addr,
gateway=wsgi_gateways[self.wsgi_version],
server_name=server_name,
peercreds_enabled=peercreds_enabled,
peercreds_resolve_enabled=peercreds_resolve_enabled,
)
self.wsgi_app = wsgi_app
self.request_queue_size = request_queue_size
self.timeout = timeout
self.shutdown_timeout = shutdown_timeout
self.requests = threadpool.ThreadPool(
self, min=numthreads or 1, max=max,
accepted_queue_size=accepted_queue_size,
accepted_queue_timeout=accepted_queue_timeout)
@property
def numthreads(self):
"""Set minimum number of threads."""
return self.requests.min
@numthreads.setter
def numthreads(self, value):
self.requests.min = value
class Gateway(server.Gateway):
"""A base class to interface HTTPServer with WSGI."""
def __init__(self, req):
"""Initialize WSGI Gateway instance with request.
Args:
req (HTTPRequest): current HTTP request
"""
super(Gateway, self).__init__(req)
self.started_response = False
self.env = self.get_environ()
self.remaining_bytes_out = None
@classmethod
def gateway_map(cls):
"""Create a mapping of gateways and their versions.
Returns:
dict[tuple[int,int],class]: map of gateway version and
corresponding class
"""
return dict(
(gw.version, gw)
for gw in cls.__subclasses__()
)
def get_environ(self):
"""Return a new environ dict targeting the given wsgi.version."""
raise NotImplementedError
def respond(self):
"""Process the current request.
From :pep:`333`:
The start_response callable must not actually transmit
the response headers. Instead, it must store them for the
server or gateway to transmit only after the first
iteration of the application return value that yields
a NON-EMPTY string, or upon the application's first
invocation of the write() callable.
"""
response = self.req.server.wsgi_app(self.env, self.start_response)
try:
for chunk in filter(None, response):
if not isinstance(chunk, six.binary_type):
raise ValueError('WSGI Applications must yield bytes')
self.write(chunk)
finally:
# Send headers if not already sent
self.req.ensure_headers_sent()
if hasattr(response, 'close'):
response.close()
def start_response(self, status, headers, exc_info=None):
"""WSGI callable to begin the HTTP response."""
# "The application may call start_response more than once,
# if and only if the exc_info argument is provided."
if self.started_response and not exc_info:
raise AssertionError('WSGI start_response called a second '
'time with no exc_info.')
self.started_response = True
# "if exc_info is provided, and the HTTP headers have already been
# sent, start_response must raise an error, and should raise the
# exc_info tuple."
if self.req.sent_headers:
try:
six.reraise(*exc_info)
finally:
exc_info = None
self.req.status = self._encode_status(status)
for k, v in headers:
if not isinstance(k, str):
raise TypeError(
'WSGI response header key %r is not of type str.' % k)
if not isinstance(v, str):
raise TypeError(
'WSGI response header value %r is not of type str.' % v)
if k.lower() == 'content-length':
self.remaining_bytes_out = int(v)
out_header = ntob(k), ntob(v)
self.req.outheaders.append(out_header)
return self.write
@staticmethod
def _encode_status(status):
"""Cast status to bytes representation of current Python version.
According to :pep:`3333`, when using Python 3, the response status
and headers must be bytes masquerading as unicode; that is, they
must be of type "str" but are restricted to code points in the
"latin-1" set.
"""
if six.PY2:
return status
if not isinstance(status, str):
raise TypeError('WSGI response status is not of type str.')
return status.encode('ISO-8859-1')
def write(self, chunk):
"""WSGI callable to write unbuffered data to the client.
This method is also used internally by start_response (to write
data from the iterable returned by the WSGI application).
"""
if not self.started_response:
raise AssertionError('WSGI write called before start_response.')
chunklen = len(chunk)
rbo = self.remaining_bytes_out
if rbo is not None and chunklen > rbo:
if not self.req.sent_headers:
# Whew. We can send a 500 to the client.
self.req.simple_response(
'500 Internal Server Error',
'The requested resource returned more bytes than the '
'declared Content-Length.')
else:
# Dang. We have probably already sent data. Truncate the chunk
# to fit (so the client doesn't hang) and raise an error later.
chunk = chunk[:rbo]
self.req.ensure_headers_sent()
self.req.write(chunk)
if rbo is not None:
rbo -= chunklen
if rbo < 0:
raise ValueError(
'Response body exceeds the declared Content-Length.')
class Gateway_10(Gateway):
"""A Gateway class to interface HTTPServer with WSGI 1.0.x."""
version = 1, 0
def get_environ(self):
"""Return a new environ dict targeting the given wsgi.version."""
req = self.req
req_conn = req.conn
env = {
# set a non-standard environ entry so the WSGI app can know what
# the *real* server protocol is (and what features to support).
# See http://www.faqs.org/rfcs/rfc2145.html.
'ACTUAL_SERVER_PROTOCOL': req.server.protocol,
'PATH_INFO': bton(req.path),
'QUERY_STRING': bton(req.qs),
'REMOTE_ADDR': req_conn.remote_addr or '',
'REMOTE_PORT': str(req_conn.remote_port or ''),
'REQUEST_METHOD': bton(req.method),
'REQUEST_URI': bton(req.uri),
'SCRIPT_NAME': '',
'SERVER_NAME': req.server.server_name,
# Bah. "SERVER_PROTOCOL" is actually the REQUEST protocol.
'SERVER_PROTOCOL': bton(req.request_protocol),
'SERVER_SOFTWARE': req.server.software,
'wsgi.errors': sys.stderr,
'wsgi.input': req.rfile,
'wsgi.input_terminated': bool(req.chunked_read),
'wsgi.multiprocess': False,
'wsgi.multithread': True,
'wsgi.run_once': False,
'wsgi.url_scheme': bton(req.scheme),
'wsgi.version': self.version,
}
if isinstance(req.server.bind_addr, six.string_types):
# AF_UNIX. This isn't really allowed by WSGI, which doesn't
# address unix domain sockets. But it's better than nothing.
env['SERVER_PORT'] = ''
try:
env['X_REMOTE_PID'] = str(req_conn.peer_pid)
env['X_REMOTE_UID'] = str(req_conn.peer_uid)
env['X_REMOTE_GID'] = str(req_conn.peer_gid)
env['X_REMOTE_USER'] = str(req_conn.peer_user)
env['X_REMOTE_GROUP'] = str(req_conn.peer_group)
env['REMOTE_USER'] = env['X_REMOTE_USER']
except RuntimeError:
"""Unable to retrieve peer creds data.
Unsupported by current kernel or socket error happened, or
unsupported socket type, or disabled.
"""
else:
env['SERVER_PORT'] = str(req.server.bind_addr[1])
# Request headers
env.update(
('HTTP_' + bton(k).upper().replace('-', '_'), bton(v))
for k, v in req.inheaders.items()
)
# CONTENT_TYPE/CONTENT_LENGTH
ct = env.pop('HTTP_CONTENT_TYPE', None)
if ct is not None:
env['CONTENT_TYPE'] = ct
cl = env.pop('HTTP_CONTENT_LENGTH', None)
if cl is not None:
env['CONTENT_LENGTH'] = cl
if req.conn.ssl_env:
env.update(req.conn.ssl_env)
return env
class Gateway_u0(Gateway_10):
"""A Gateway class to interface HTTPServer with WSGI u.0.
WSGI u.0 is an experimental protocol, which uses unicode for keys
and values in both Python 2 and Python 3.
"""
version = 'u', 0
def get_environ(self):
"""Return a new environ dict targeting the given wsgi.version."""
req = self.req
env_10 = super(Gateway_u0, self).get_environ()
env = dict(map(self._decode_key, env_10.items()))
# Request-URI
enc = env.setdefault(six.u('wsgi.url_encoding'), six.u('utf-8'))
try:
env['PATH_INFO'] = req.path.decode(enc)
env['QUERY_STRING'] = req.qs.decode(enc)
except UnicodeDecodeError:
# Fall back to latin 1 so apps can transcode if needed.
env['wsgi.url_encoding'] = 'ISO-8859-1'
env['PATH_INFO'] = env_10['PATH_INFO']
env['QUERY_STRING'] = env_10['QUERY_STRING']
env.update(map(self._decode_value, env.items()))
return env
@staticmethod
def _decode_key(item):
k, v = item
if six.PY2:
k = k.decode('ISO-8859-1')
return k, v
@staticmethod
def _decode_value(item):
k, v = item
skip_keys = 'REQUEST_URI', 'wsgi.input'
if six.PY3 or not isinstance(v, bytes) or k in skip_keys:
return k, v
return k, v.decode('ISO-8859-1')
wsgi_gateways = Gateway.gateway_map()
class PathInfoDispatcher:
"""A WSGI dispatcher for dispatch based on the PATH_INFO."""
def __init__(self, apps):
"""Initialize path info WSGI app dispatcher.
Args:
apps (dict[str,object]|list[tuple[str,object]]): URI prefix
and WSGI app pairs
"""
try:
apps = list(apps.items())
except AttributeError:
pass
# Sort the apps by len(path), descending
def by_path_len(app):
return len(app[0])
apps.sort(key=by_path_len, reverse=True)
# The path_prefix strings must start, but not end, with a slash.
# Use "" instead of "/".
self.apps = [(p.rstrip('/'), a) for p, a in apps]
def __call__(self, environ, start_response):
"""Process incoming WSGI request.
Ref: :pep:`3333`
Args:
environ (Mapping): a dict containing WSGI environment variables
start_response (callable): function, which sets response
status and headers
Returns:
list[bytes]: iterable containing bytes to be returned in
HTTP response body
"""
path = environ['PATH_INFO'] or '/'
for p, app in self.apps:
# The apps list should be sorted by length, descending.
if path.startswith(p + '/') or path == p:
environ = environ.copy()
environ['SCRIPT_NAME'] = environ['SCRIPT_NAME'] + p
environ['PATH_INFO'] = path[len(p):]
return app(environ, start_response)
start_response('404 Not Found', [('Content-Type', 'text/plain'),
('Content-Length', '0')])
return ['']
# compatibility aliases
globals().update(
WSGIServer=Server,
WSGIGateway=Gateway,
WSGIGateway_u0=Gateway_u0,
WSGIGateway_10=Gateway_10,
WSGIPathInfoDispatcher=PathInfoDispatcher,
)