mirror of
https://github.com/jellyfin/jellyfin-kodi.git
synced 2025-06-25 17:40:31 +00:00
Update webservice with cherrypy
Fix playback issues that was causing Kodi to hang up
This commit is contained in:
parent
b2bc90cb06
commit
158a736360
164 changed files with 42855 additions and 174 deletions
6
libraries/cheroot/__init__.py
Normal file
6
libraries/cheroot/__init__.py
Normal file
|
@ -0,0 +1,6 @@
|
|||
"""High-performance, pure-Python HTTP server used by CherryPy."""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
__metaclass__ = type
|
||||
|
||||
__version__ = '6.4.0'
|
6
libraries/cheroot/__main__.py
Normal file
6
libraries/cheroot/__main__.py
Normal file
|
@ -0,0 +1,6 @@
|
|||
"""Stub for accessing the Cheroot CLI tool."""
|
||||
|
||||
from .cli import main
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
66
libraries/cheroot/_compat.py
Normal file
66
libraries/cheroot/_compat.py
Normal file
|
@ -0,0 +1,66 @@
|
|||
"""Compatibility code for using Cheroot with various versions of Python."""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
__metaclass__ = type
|
||||
|
||||
import re
|
||||
|
||||
import six
|
||||
|
||||
if six.PY3:
|
||||
def ntob(n, encoding='ISO-8859-1'):
|
||||
"""Return the native string as bytes in the given encoding."""
|
||||
assert_native(n)
|
||||
# In Python 3, the native string type is unicode
|
||||
return n.encode(encoding)
|
||||
|
||||
def ntou(n, encoding='ISO-8859-1'):
|
||||
"""Return the native string as unicode with the given encoding."""
|
||||
assert_native(n)
|
||||
# In Python 3, the native string type is unicode
|
||||
return n
|
||||
|
||||
def bton(b, encoding='ISO-8859-1'):
|
||||
"""Return the byte string as native string in the given encoding."""
|
||||
return b.decode(encoding)
|
||||
else:
|
||||
# Python 2
|
||||
def ntob(n, encoding='ISO-8859-1'):
|
||||
"""Return the native string as bytes in the given encoding."""
|
||||
assert_native(n)
|
||||
# In Python 2, the native string type is bytes. Assume it's already
|
||||
# in the given encoding, which for ISO-8859-1 is almost always what
|
||||
# was intended.
|
||||
return n
|
||||
|
||||
def ntou(n, encoding='ISO-8859-1'):
|
||||
"""Return the native string as unicode with the given encoding."""
|
||||
assert_native(n)
|
||||
# In Python 2, the native string type is bytes.
|
||||
# First, check for the special encoding 'escape'. The test suite uses
|
||||
# this to signal that it wants to pass a string with embedded \uXXXX
|
||||
# escapes, but without having to prefix it with u'' for Python 2,
|
||||
# but no prefix for Python 3.
|
||||
if encoding == 'escape':
|
||||
return six.u(
|
||||
re.sub(r'\\u([0-9a-zA-Z]{4})',
|
||||
lambda m: six.unichr(int(m.group(1), 16)),
|
||||
n.decode('ISO-8859-1')))
|
||||
# Assume it's already in the given encoding, which for ISO-8859-1
|
||||
# is almost always what was intended.
|
||||
return n.decode(encoding)
|
||||
|
||||
def bton(b, encoding='ISO-8859-1'):
|
||||
"""Return the byte string as native string in the given encoding."""
|
||||
return b
|
||||
|
||||
|
||||
def assert_native(n):
|
||||
"""Check whether the input is of nativ ``str`` type.
|
||||
|
||||
Raises:
|
||||
TypeError: in case of failed check
|
||||
|
||||
"""
|
||||
if not isinstance(n, str):
|
||||
raise TypeError('n must be a native str (got %s)' % type(n).__name__)
|
233
libraries/cheroot/cli.py
Normal file
233
libraries/cheroot/cli.py
Normal file
|
@ -0,0 +1,233 @@
|
|||
"""Command line tool for starting a Cheroot WSGI/HTTP server instance.
|
||||
|
||||
Basic usage::
|
||||
|
||||
# Start a server on 127.0.0.1:8000 with the default settings
|
||||
# for the WSGI app myapp/wsgi.py:application()
|
||||
cheroot myapp.wsgi
|
||||
|
||||
# Start a server on 0.0.0.0:9000 with 8 threads
|
||||
# for the WSGI app myapp/wsgi.py:main_app()
|
||||
cheroot myapp.wsgi:main_app --bind 0.0.0.0:9000 --threads 8
|
||||
|
||||
# Start a server for the cheroot.server.Gateway subclass
|
||||
# myapp/gateway.py:HTTPGateway
|
||||
cheroot myapp.gateway:HTTPGateway
|
||||
|
||||
# Start a server on the UNIX socket /var/spool/myapp.sock
|
||||
cheroot myapp.wsgi --bind /var/spool/myapp.sock
|
||||
|
||||
# Start a server on the abstract UNIX socket CherootServer
|
||||
cheroot myapp.wsgi --bind @CherootServer
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from importlib import import_module
|
||||
import os
|
||||
import sys
|
||||
import contextlib
|
||||
|
||||
import six
|
||||
|
||||
from . import server
|
||||
from . import wsgi
|
||||
|
||||
|
||||
__metaclass__ = type
|
||||
|
||||
|
||||
class BindLocation:
|
||||
"""A class for storing the bind location for a Cheroot instance."""
|
||||
|
||||
|
||||
class TCPSocket(BindLocation):
|
||||
"""TCPSocket."""
|
||||
|
||||
def __init__(self, address, port):
|
||||
"""Initialize.
|
||||
|
||||
Args:
|
||||
address (str): Host name or IP address
|
||||
port (int): TCP port number
|
||||
"""
|
||||
self.bind_addr = address, port
|
||||
|
||||
|
||||
class UnixSocket(BindLocation):
|
||||
"""UnixSocket."""
|
||||
|
||||
def __init__(self, path):
|
||||
"""Initialize."""
|
||||
self.bind_addr = path
|
||||
|
||||
|
||||
class AbstractSocket(BindLocation):
|
||||
"""AbstractSocket."""
|
||||
|
||||
def __init__(self, addr):
|
||||
"""Initialize."""
|
||||
self.bind_addr = '\0{}'.format(self.abstract_socket)
|
||||
|
||||
|
||||
class Application:
|
||||
"""Application."""
|
||||
|
||||
@classmethod
|
||||
def resolve(cls, full_path):
|
||||
"""Read WSGI app/Gateway path string and import application module."""
|
||||
mod_path, _, app_path = full_path.partition(':')
|
||||
app = getattr(import_module(mod_path), app_path or 'application')
|
||||
|
||||
with contextlib.suppress(TypeError):
|
||||
if issubclass(app, server.Gateway):
|
||||
return GatewayYo(app)
|
||||
|
||||
return cls(app)
|
||||
|
||||
def __init__(self, wsgi_app):
|
||||
"""Initialize."""
|
||||
if not callable(wsgi_app):
|
||||
raise TypeError(
|
||||
'Application must be a callable object or '
|
||||
'cheroot.server.Gateway subclass'
|
||||
)
|
||||
self.wsgi_app = wsgi_app
|
||||
|
||||
def server_args(self, parsed_args):
|
||||
"""Return keyword args for Server class."""
|
||||
args = {
|
||||
arg: value
|
||||
for arg, value in vars(parsed_args).items()
|
||||
if not arg.startswith('_') and value is not None
|
||||
}
|
||||
args.update(vars(self))
|
||||
return args
|
||||
|
||||
def server(self, parsed_args):
|
||||
"""Server."""
|
||||
return wsgi.Server(**self.server_args(parsed_args))
|
||||
|
||||
|
||||
class GatewayYo:
|
||||
"""Gateway."""
|
||||
|
||||
def __init__(self, gateway):
|
||||
"""Init."""
|
||||
self.gateway = gateway
|
||||
|
||||
def server(self, parsed_args):
|
||||
"""Server."""
|
||||
server_args = vars(self)
|
||||
server_args['bind_addr'] = parsed_args['bind_addr']
|
||||
if parsed_args.max is not None:
|
||||
server_args['maxthreads'] = parsed_args.max
|
||||
if parsed_args.numthreads is not None:
|
||||
server_args['minthreads'] = parsed_args.numthreads
|
||||
return server.HTTPServer(**server_args)
|
||||
|
||||
|
||||
def parse_wsgi_bind_location(bind_addr_string):
|
||||
"""Convert bind address string to a BindLocation."""
|
||||
# try and match for an IP/hostname and port
|
||||
match = six.moves.urllib.parse.urlparse('//{}'.format(bind_addr_string))
|
||||
try:
|
||||
addr = match.hostname
|
||||
port = match.port
|
||||
if addr is not None or port is not None:
|
||||
return TCPSocket(addr, port)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# else, assume a UNIX socket path
|
||||
# if the string begins with an @ symbol, use an abstract socket
|
||||
if bind_addr_string.startswith('@'):
|
||||
return AbstractSocket(bind_addr_string[1:])
|
||||
return UnixSocket(path=bind_addr_string)
|
||||
|
||||
|
||||
def parse_wsgi_bind_addr(bind_addr_string):
|
||||
"""Convert bind address string to bind address parameter."""
|
||||
return parse_wsgi_bind_location(bind_addr_string).bind_addr
|
||||
|
||||
|
||||
_arg_spec = {
|
||||
'_wsgi_app': dict(
|
||||
metavar='APP_MODULE',
|
||||
type=Application.resolve,
|
||||
help='WSGI application callable or cheroot.server.Gateway subclass',
|
||||
),
|
||||
'--bind': dict(
|
||||
metavar='ADDRESS',
|
||||
dest='bind_addr',
|
||||
type=parse_wsgi_bind_addr,
|
||||
default='[::1]:8000',
|
||||
help='Network interface to listen on (default: [::1]:8000)',
|
||||
),
|
||||
'--chdir': dict(
|
||||
metavar='PATH',
|
||||
type=os.chdir,
|
||||
help='Set the working directory',
|
||||
),
|
||||
'--server-name': dict(
|
||||
dest='server_name',
|
||||
type=str,
|
||||
help='Web server name to be advertised via Server HTTP header',
|
||||
),
|
||||
'--threads': dict(
|
||||
metavar='INT',
|
||||
dest='numthreads',
|
||||
type=int,
|
||||
help='Minimum number of worker threads',
|
||||
),
|
||||
'--max-threads': dict(
|
||||
metavar='INT',
|
||||
dest='max',
|
||||
type=int,
|
||||
help='Maximum number of worker threads',
|
||||
),
|
||||
'--timeout': dict(
|
||||
metavar='INT',
|
||||
dest='timeout',
|
||||
type=int,
|
||||
help='Timeout in seconds for accepted connections',
|
||||
),
|
||||
'--shutdown-timeout': dict(
|
||||
metavar='INT',
|
||||
dest='shutdown_timeout',
|
||||
type=int,
|
||||
help='Time in seconds to wait for worker threads to cleanly exit',
|
||||
),
|
||||
'--request-queue-size': dict(
|
||||
metavar='INT',
|
||||
dest='request_queue_size',
|
||||
type=int,
|
||||
help='Maximum number of queued connections',
|
||||
),
|
||||
'--accepted-queue-size': dict(
|
||||
metavar='INT',
|
||||
dest='accepted_queue_size',
|
||||
type=int,
|
||||
help='Maximum number of active requests in queue',
|
||||
),
|
||||
'--accepted-queue-timeout': dict(
|
||||
metavar='INT',
|
||||
dest='accepted_queue_timeout',
|
||||
type=int,
|
||||
help='Timeout in seconds for putting requests into queue',
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
"""Create a new Cheroot instance with arguments from the command line."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Start an instance of the Cheroot WSGI/HTTP server.')
|
||||
for arg, spec in _arg_spec.items():
|
||||
parser.add_argument(arg, **spec)
|
||||
raw_args = parser.parse_args()
|
||||
|
||||
# ensure cwd in sys.path
|
||||
'' in sys.path or sys.path.insert(0, '')
|
||||
|
||||
# create a server based on the arguments provided
|
||||
raw_args._wsgi_app.server(raw_args).safe_start()
|
58
libraries/cheroot/errors.py
Normal file
58
libraries/cheroot/errors.py
Normal file
|
@ -0,0 +1,58 @@
|
|||
"""Collection of exceptions raised and/or processed by Cheroot."""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
__metaclass__ = type
|
||||
|
||||
import errno
|
||||
import sys
|
||||
|
||||
|
||||
class MaxSizeExceeded(Exception):
|
||||
"""Exception raised when a client sends more data then acceptable within limit.
|
||||
|
||||
Depends on ``request.body.maxbytes`` config option if used within CherryPy
|
||||
"""
|
||||
|
||||
|
||||
class NoSSLError(Exception):
|
||||
"""Exception raised when a client speaks HTTP to an HTTPS socket."""
|
||||
|
||||
|
||||
class FatalSSLAlert(Exception):
|
||||
"""Exception raised when the SSL implementation signals a fatal alert."""
|
||||
|
||||
|
||||
def plat_specific_errors(*errnames):
|
||||
"""Return error numbers for all errors in errnames on this platform.
|
||||
|
||||
The 'errno' module contains different global constants depending on
|
||||
the specific platform (OS). This function will return the list of
|
||||
numeric values for a given list of potential names.
|
||||
"""
|
||||
errno_names = dir(errno)
|
||||
nums = [getattr(errno, k) for k in errnames if k in errno_names]
|
||||
# de-dupe the list
|
||||
return list(dict.fromkeys(nums).keys())
|
||||
|
||||
|
||||
socket_error_eintr = plat_specific_errors('EINTR', 'WSAEINTR')
|
||||
|
||||
socket_errors_to_ignore = plat_specific_errors(
|
||||
'EPIPE',
|
||||
'EBADF', 'WSAEBADF',
|
||||
'ENOTSOCK', 'WSAENOTSOCK',
|
||||
'ETIMEDOUT', 'WSAETIMEDOUT',
|
||||
'ECONNREFUSED', 'WSAECONNREFUSED',
|
||||
'ECONNRESET', 'WSAECONNRESET',
|
||||
'ECONNABORTED', 'WSAECONNABORTED',
|
||||
'ENETRESET', 'WSAENETRESET',
|
||||
'EHOSTDOWN', 'EHOSTUNREACH',
|
||||
)
|
||||
socket_errors_to_ignore.append('timed out')
|
||||
socket_errors_to_ignore.append('The read operation timed out')
|
||||
socket_errors_nonblocking = plat_specific_errors(
|
||||
'EAGAIN', 'EWOULDBLOCK', 'WSAEWOULDBLOCK')
|
||||
|
||||
if sys.platform == 'darwin':
|
||||
socket_errors_to_ignore.extend(plat_specific_errors('EPROTOTYPE'))
|
||||
socket_errors_nonblocking.extend(plat_specific_errors('EPROTOTYPE'))
|
387
libraries/cheroot/makefile.py
Normal file
387
libraries/cheroot/makefile.py
Normal file
|
@ -0,0 +1,387 @@
|
|||
"""Socket file object."""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
__metaclass__ = type
|
||||
|
||||
import socket
|
||||
|
||||
try:
|
||||
# prefer slower Python-based io module
|
||||
import _pyio as io
|
||||
except ImportError:
|
||||
# Python 2.6
|
||||
import io
|
||||
|
||||
import six
|
||||
|
||||
from . import errors
|
||||
|
||||
|
||||
class BufferedWriter(io.BufferedWriter):
|
||||
"""Faux file object attached to a socket object."""
|
||||
|
||||
def write(self, b):
|
||||
"""Write bytes to buffer."""
|
||||
self._checkClosed()
|
||||
if isinstance(b, str):
|
||||
raise TypeError("can't write str to binary stream")
|
||||
|
||||
with self._write_lock:
|
||||
self._write_buf.extend(b)
|
||||
self._flush_unlocked()
|
||||
return len(b)
|
||||
|
||||
def _flush_unlocked(self):
|
||||
self._checkClosed('flush of closed file')
|
||||
while self._write_buf:
|
||||
try:
|
||||
# ssl sockets only except 'bytes', not bytearrays
|
||||
# so perhaps we should conditionally wrap this for perf?
|
||||
n = self.raw.write(bytes(self._write_buf))
|
||||
except io.BlockingIOError as e:
|
||||
n = e.characters_written
|
||||
del self._write_buf[:n]
|
||||
|
||||
|
||||
def MakeFile_PY3(sock, mode='r', bufsize=io.DEFAULT_BUFFER_SIZE):
|
||||
"""File object attached to a socket object."""
|
||||
if 'r' in mode:
|
||||
return io.BufferedReader(socket.SocketIO(sock, mode), bufsize)
|
||||
else:
|
||||
return BufferedWriter(socket.SocketIO(sock, mode), bufsize)
|
||||
|
||||
|
||||
class MakeFile_PY2(getattr(socket, '_fileobject', object)):
|
||||
"""Faux file object attached to a socket object."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""Initialize faux file object."""
|
||||
self.bytes_read = 0
|
||||
self.bytes_written = 0
|
||||
socket._fileobject.__init__(self, *args, **kwargs)
|
||||
|
||||
def write(self, data):
|
||||
"""Sendall for non-blocking sockets."""
|
||||
while data:
|
||||
try:
|
||||
bytes_sent = self.send(data)
|
||||
data = data[bytes_sent:]
|
||||
except socket.error as e:
|
||||
if e.args[0] not in errors.socket_errors_nonblocking:
|
||||
raise
|
||||
|
||||
def send(self, data):
|
||||
"""Send some part of message to the socket."""
|
||||
bytes_sent = self._sock.send(data)
|
||||
self.bytes_written += bytes_sent
|
||||
return bytes_sent
|
||||
|
||||
def flush(self):
|
||||
"""Write all data from buffer to socket and reset write buffer."""
|
||||
if self._wbuf:
|
||||
buffer = ''.join(self._wbuf)
|
||||
self._wbuf = []
|
||||
self.write(buffer)
|
||||
|
||||
def recv(self, size):
|
||||
"""Receive message of a size from the socket."""
|
||||
while True:
|
||||
try:
|
||||
data = self._sock.recv(size)
|
||||
self.bytes_read += len(data)
|
||||
return data
|
||||
except socket.error as e:
|
||||
what = (
|
||||
e.args[0] not in errors.socket_errors_nonblocking
|
||||
and e.args[0] not in errors.socket_error_eintr
|
||||
)
|
||||
if what:
|
||||
raise
|
||||
|
||||
class FauxSocket:
|
||||
"""Faux socket with the minimal interface required by pypy."""
|
||||
|
||||
def _reuse(self):
|
||||
pass
|
||||
|
||||
_fileobject_uses_str_type = six.PY2 and isinstance(
|
||||
socket._fileobject(FauxSocket())._rbuf, six.string_types)
|
||||
|
||||
# FauxSocket is no longer needed
|
||||
del FauxSocket
|
||||
|
||||
if not _fileobject_uses_str_type:
|
||||
def read(self, size=-1):
|
||||
"""Read data from the socket to buffer."""
|
||||
# Use max, disallow tiny reads in a loop as they are very
|
||||
# inefficient.
|
||||
# We never leave read() with any leftover data from a new recv()
|
||||
# call in our internal buffer.
|
||||
rbufsize = max(self._rbufsize, self.default_bufsize)
|
||||
# Our use of StringIO rather than lists of string objects returned
|
||||
# by recv() minimizes memory usage and fragmentation that occurs
|
||||
# when rbufsize is large compared to the typical return value of
|
||||
# recv().
|
||||
buf = self._rbuf
|
||||
buf.seek(0, 2) # seek end
|
||||
if size < 0:
|
||||
# Read until EOF
|
||||
# reset _rbuf. we consume it via buf.
|
||||
self._rbuf = io.BytesIO()
|
||||
while True:
|
||||
data = self.recv(rbufsize)
|
||||
if not data:
|
||||
break
|
||||
buf.write(data)
|
||||
return buf.getvalue()
|
||||
else:
|
||||
# Read until size bytes or EOF seen, whichever comes first
|
||||
buf_len = buf.tell()
|
||||
if buf_len >= size:
|
||||
# Already have size bytes in our buffer? Extract and
|
||||
# return.
|
||||
buf.seek(0)
|
||||
rv = buf.read(size)
|
||||
self._rbuf = io.BytesIO()
|
||||
self._rbuf.write(buf.read())
|
||||
return rv
|
||||
|
||||
# reset _rbuf. we consume it via buf.
|
||||
self._rbuf = io.BytesIO()
|
||||
while True:
|
||||
left = size - buf_len
|
||||
# recv() will malloc the amount of memory given as its
|
||||
# parameter even though it often returns much less data
|
||||
# than that. The returned data string is short lived
|
||||
# as we copy it into a StringIO and free it. This avoids
|
||||
# fragmentation issues on many platforms.
|
||||
data = self.recv(left)
|
||||
if not data:
|
||||
break
|
||||
n = len(data)
|
||||
if n == size and not buf_len:
|
||||
# Shortcut. Avoid buffer data copies when:
|
||||
# - We have no data in our buffer.
|
||||
# AND
|
||||
# - Our call to recv returned exactly the
|
||||
# number of bytes we were asked to read.
|
||||
return data
|
||||
if n == left:
|
||||
buf.write(data)
|
||||
del data # explicit free
|
||||
break
|
||||
assert n <= left, 'recv(%d) returned %d bytes' % (left, n)
|
||||
buf.write(data)
|
||||
buf_len += n
|
||||
del data # explicit free
|
||||
# assert buf_len == buf.tell()
|
||||
return buf.getvalue()
|
||||
|
||||
def readline(self, size=-1):
|
||||
"""Read line from the socket to buffer."""
|
||||
buf = self._rbuf
|
||||
buf.seek(0, 2) # seek end
|
||||
if buf.tell() > 0:
|
||||
# check if we already have it in our buffer
|
||||
buf.seek(0)
|
||||
bline = buf.readline(size)
|
||||
if bline.endswith('\n') or len(bline) == size:
|
||||
self._rbuf = io.BytesIO()
|
||||
self._rbuf.write(buf.read())
|
||||
return bline
|
||||
del bline
|
||||
if size < 0:
|
||||
# Read until \n or EOF, whichever comes first
|
||||
if self._rbufsize <= 1:
|
||||
# Speed up unbuffered case
|
||||
buf.seek(0)
|
||||
buffers = [buf.read()]
|
||||
# reset _rbuf. we consume it via buf.
|
||||
self._rbuf = io.BytesIO()
|
||||
data = None
|
||||
recv = self.recv
|
||||
while data != '\n':
|
||||
data = recv(1)
|
||||
if not data:
|
||||
break
|
||||
buffers.append(data)
|
||||
return ''.join(buffers)
|
||||
|
||||
buf.seek(0, 2) # seek end
|
||||
# reset _rbuf. we consume it via buf.
|
||||
self._rbuf = io.BytesIO()
|
||||
while True:
|
||||
data = self.recv(self._rbufsize)
|
||||
if not data:
|
||||
break
|
||||
nl = data.find('\n')
|
||||
if nl >= 0:
|
||||
nl += 1
|
||||
buf.write(data[:nl])
|
||||
self._rbuf.write(data[nl:])
|
||||
del data
|
||||
break
|
||||
buf.write(data)
|
||||
return buf.getvalue()
|
||||
else:
|
||||
# Read until size bytes or \n or EOF seen, whichever comes
|
||||
# first
|
||||
buf.seek(0, 2) # seek end
|
||||
buf_len = buf.tell()
|
||||
if buf_len >= size:
|
||||
buf.seek(0)
|
||||
rv = buf.read(size)
|
||||
self._rbuf = io.BytesIO()
|
||||
self._rbuf.write(buf.read())
|
||||
return rv
|
||||
# reset _rbuf. we consume it via buf.
|
||||
self._rbuf = io.BytesIO()
|
||||
while True:
|
||||
data = self.recv(self._rbufsize)
|
||||
if not data:
|
||||
break
|
||||
left = size - buf_len
|
||||
# did we just receive a newline?
|
||||
nl = data.find('\n', 0, left)
|
||||
if nl >= 0:
|
||||
nl += 1
|
||||
# save the excess data to _rbuf
|
||||
self._rbuf.write(data[nl:])
|
||||
if buf_len:
|
||||
buf.write(data[:nl])
|
||||
break
|
||||
else:
|
||||
# Shortcut. Avoid data copy through buf when
|
||||
# returning a substring of our first recv().
|
||||
return data[:nl]
|
||||
n = len(data)
|
||||
if n == size and not buf_len:
|
||||
# Shortcut. Avoid data copy through buf when
|
||||
# returning exactly all of our first recv().
|
||||
return data
|
||||
if n >= left:
|
||||
buf.write(data[:left])
|
||||
self._rbuf.write(data[left:])
|
||||
break
|
||||
buf.write(data)
|
||||
buf_len += n
|
||||
# assert buf_len == buf.tell()
|
||||
return buf.getvalue()
|
||||
else:
|
||||
def read(self, size=-1):
|
||||
"""Read data from the socket to buffer."""
|
||||
if size < 0:
|
||||
# Read until EOF
|
||||
buffers = [self._rbuf]
|
||||
self._rbuf = ''
|
||||
if self._rbufsize <= 1:
|
||||
recv_size = self.default_bufsize
|
||||
else:
|
||||
recv_size = self._rbufsize
|
||||
|
||||
while True:
|
||||
data = self.recv(recv_size)
|
||||
if not data:
|
||||
break
|
||||
buffers.append(data)
|
||||
return ''.join(buffers)
|
||||
else:
|
||||
# Read until size bytes or EOF seen, whichever comes first
|
||||
data = self._rbuf
|
||||
buf_len = len(data)
|
||||
if buf_len >= size:
|
||||
self._rbuf = data[size:]
|
||||
return data[:size]
|
||||
buffers = []
|
||||
if data:
|
||||
buffers.append(data)
|
||||
self._rbuf = ''
|
||||
while True:
|
||||
left = size - buf_len
|
||||
recv_size = max(self._rbufsize, left)
|
||||
data = self.recv(recv_size)
|
||||
if not data:
|
||||
break
|
||||
buffers.append(data)
|
||||
n = len(data)
|
||||
if n >= left:
|
||||
self._rbuf = data[left:]
|
||||
buffers[-1] = data[:left]
|
||||
break
|
||||
buf_len += n
|
||||
return ''.join(buffers)
|
||||
|
||||
def readline(self, size=-1):
|
||||
"""Read line from the socket to buffer."""
|
||||
data = self._rbuf
|
||||
if size < 0:
|
||||
# Read until \n or EOF, whichever comes first
|
||||
if self._rbufsize <= 1:
|
||||
# Speed up unbuffered case
|
||||
assert data == ''
|
||||
buffers = []
|
||||
while data != '\n':
|
||||
data = self.recv(1)
|
||||
if not data:
|
||||
break
|
||||
buffers.append(data)
|
||||
return ''.join(buffers)
|
||||
nl = data.find('\n')
|
||||
if nl >= 0:
|
||||
nl += 1
|
||||
self._rbuf = data[nl:]
|
||||
return data[:nl]
|
||||
buffers = []
|
||||
if data:
|
||||
buffers.append(data)
|
||||
self._rbuf = ''
|
||||
while True:
|
||||
data = self.recv(self._rbufsize)
|
||||
if not data:
|
||||
break
|
||||
buffers.append(data)
|
||||
nl = data.find('\n')
|
||||
if nl >= 0:
|
||||
nl += 1
|
||||
self._rbuf = data[nl:]
|
||||
buffers[-1] = data[:nl]
|
||||
break
|
||||
return ''.join(buffers)
|
||||
else:
|
||||
# Read until size bytes or \n or EOF seen, whichever comes
|
||||
# first
|
||||
nl = data.find('\n', 0, size)
|
||||
if nl >= 0:
|
||||
nl += 1
|
||||
self._rbuf = data[nl:]
|
||||
return data[:nl]
|
||||
buf_len = len(data)
|
||||
if buf_len >= size:
|
||||
self._rbuf = data[size:]
|
||||
return data[:size]
|
||||
buffers = []
|
||||
if data:
|
||||
buffers.append(data)
|
||||
self._rbuf = ''
|
||||
while True:
|
||||
data = self.recv(self._rbufsize)
|
||||
if not data:
|
||||
break
|
||||
buffers.append(data)
|
||||
left = size - buf_len
|
||||
nl = data.find('\n', 0, left)
|
||||
if nl >= 0:
|
||||
nl += 1
|
||||
self._rbuf = data[nl:]
|
||||
buffers[-1] = data[:nl]
|
||||
break
|
||||
n = len(data)
|
||||
if n >= left:
|
||||
self._rbuf = data[left:]
|
||||
buffers[-1] = data[:left]
|
||||
break
|
||||
buf_len += n
|
||||
return ''.join(buffers)
|
||||
|
||||
|
||||
MakeFile = MakeFile_PY2 if six.PY2 else MakeFile_PY3
|
2001
libraries/cheroot/server.py
Normal file
2001
libraries/cheroot/server.py
Normal file
File diff suppressed because it is too large
Load diff
51
libraries/cheroot/ssl/__init__.py
Normal file
51
libraries/cheroot/ssl/__init__.py
Normal file
|
@ -0,0 +1,51 @@
|
|||
"""Implementation of the SSL adapter base interface."""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
__metaclass__ = type
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
|
||||
from six import add_metaclass
|
||||
|
||||
|
||||
@add_metaclass(ABCMeta)
|
||||
class Adapter:
|
||||
"""Base class for SSL driver library adapters.
|
||||
|
||||
Required methods:
|
||||
|
||||
* ``wrap(sock) -> (wrapped socket, ssl environ dict)``
|
||||
* ``makefile(sock, mode='r', bufsize=DEFAULT_BUFFER_SIZE) ->
|
||||
socket file object``
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self, certificate, private_key, certificate_chain=None,
|
||||
ciphers=None):
|
||||
"""Set up certificates, private key ciphers and reset context."""
|
||||
self.certificate = certificate
|
||||
self.private_key = private_key
|
||||
self.certificate_chain = certificate_chain
|
||||
self.ciphers = ciphers
|
||||
self.context = None
|
||||
|
||||
@abstractmethod
|
||||
def bind(self, sock):
|
||||
"""Wrap and return the given socket."""
|
||||
return sock
|
||||
|
||||
@abstractmethod
|
||||
def wrap(self, sock):
|
||||
"""Wrap and return the given socket, plus WSGI environ entries."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_environ(self):
|
||||
"""Return WSGI environ entries to be merged into each request."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def makefile(self, sock, mode='r', bufsize=-1):
|
||||
"""Return socket file object."""
|
||||
raise NotImplementedError
|
162
libraries/cheroot/ssl/builtin.py
Normal file
162
libraries/cheroot/ssl/builtin.py
Normal file
|
@ -0,0 +1,162 @@
|
|||
"""
|
||||
A library for integrating Python's builtin ``ssl`` library with Cheroot.
|
||||
|
||||
The ssl module must be importable for SSL functionality.
|
||||
|
||||
To use this module, set ``HTTPServer.ssl_adapter`` to an instance of
|
||||
``BuiltinSSLAdapter``.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
__metaclass__ = type
|
||||
|
||||
try:
|
||||
import ssl
|
||||
except ImportError:
|
||||
ssl = None
|
||||
|
||||
try:
|
||||
from _pyio import DEFAULT_BUFFER_SIZE
|
||||
except ImportError:
|
||||
try:
|
||||
from io import DEFAULT_BUFFER_SIZE
|
||||
except ImportError:
|
||||
DEFAULT_BUFFER_SIZE = -1
|
||||
|
||||
import six
|
||||
|
||||
from . import Adapter
|
||||
from .. import errors
|
||||
from ..makefile import MakeFile
|
||||
|
||||
|
||||
if six.PY3:
|
||||
generic_socket_error = OSError
|
||||
else:
|
||||
import socket
|
||||
generic_socket_error = socket.error
|
||||
del socket
|
||||
|
||||
|
||||
def _assert_ssl_exc_contains(exc, *msgs):
|
||||
"""Check whether SSL exception contains either of messages provided."""
|
||||
if len(msgs) < 1:
|
||||
raise TypeError(
|
||||
'_assert_ssl_exc_contains() requires '
|
||||
'at least one message to be passed.'
|
||||
)
|
||||
err_msg_lower = exc.args[1].lower()
|
||||
return any(m.lower() in err_msg_lower for m in msgs)
|
||||
|
||||
|
||||
class BuiltinSSLAdapter(Adapter):
|
||||
"""A wrapper for integrating Python's builtin ssl module with Cheroot."""
|
||||
|
||||
certificate = None
|
||||
"""The filename of the server SSL certificate."""
|
||||
|
||||
private_key = None
|
||||
"""The filename of the server's private key file."""
|
||||
|
||||
certificate_chain = None
|
||||
"""The filename of the certificate chain file."""
|
||||
|
||||
context = None
|
||||
"""The ssl.SSLContext that will be used to wrap sockets."""
|
||||
|
||||
ciphers = None
|
||||
"""The ciphers list of SSL."""
|
||||
|
||||
def __init__(
|
||||
self, certificate, private_key, certificate_chain=None,
|
||||
ciphers=None):
|
||||
"""Set up context in addition to base class properties if available."""
|
||||
if ssl is None:
|
||||
raise ImportError('You must install the ssl module to use HTTPS.')
|
||||
|
||||
super(BuiltinSSLAdapter, self).__init__(
|
||||
certificate, private_key, certificate_chain, ciphers)
|
||||
|
||||
self.context = ssl.create_default_context(
|
||||
purpose=ssl.Purpose.CLIENT_AUTH,
|
||||
cafile=certificate_chain
|
||||
)
|
||||
self.context.load_cert_chain(certificate, private_key)
|
||||
if self.ciphers is not None:
|
||||
self.context.set_ciphers(ciphers)
|
||||
|
||||
def bind(self, sock):
|
||||
"""Wrap and return the given socket."""
|
||||
return super(BuiltinSSLAdapter, self).bind(sock)
|
||||
|
||||
def wrap(self, sock):
|
||||
"""Wrap and return the given socket, plus WSGI environ entries."""
|
||||
EMPTY_RESULT = None, {}
|
||||
try:
|
||||
s = self.context.wrap_socket(
|
||||
sock, do_handshake_on_connect=True, server_side=True,
|
||||
)
|
||||
except ssl.SSLError as ex:
|
||||
if ex.errno == ssl.SSL_ERROR_EOF:
|
||||
# This is almost certainly due to the cherrypy engine
|
||||
# 'pinging' the socket to assert it's connectable;
|
||||
# the 'ping' isn't SSL.
|
||||
return EMPTY_RESULT
|
||||
elif ex.errno == ssl.SSL_ERROR_SSL:
|
||||
if _assert_ssl_exc_contains(ex, 'http request'):
|
||||
# The client is speaking HTTP to an HTTPS server.
|
||||
raise errors.NoSSLError
|
||||
|
||||
# Check if it's one of the known errors
|
||||
# Errors that are caught by PyOpenSSL, but thrown by
|
||||
# built-in ssl
|
||||
_block_errors = (
|
||||
'unknown protocol', 'unknown ca', 'unknown_ca',
|
||||
'unknown error',
|
||||
'https proxy request', 'inappropriate fallback',
|
||||
'wrong version number',
|
||||
'no shared cipher', 'certificate unknown',
|
||||
'ccs received early',
|
||||
)
|
||||
if _assert_ssl_exc_contains(ex, *_block_errors):
|
||||
# Accepted error, let's pass
|
||||
return EMPTY_RESULT
|
||||
elif _assert_ssl_exc_contains(ex, 'handshake operation timed out'):
|
||||
# This error is thrown by builtin SSL after a timeout
|
||||
# when client is speaking HTTP to an HTTPS server.
|
||||
# The connection can safely be dropped.
|
||||
return EMPTY_RESULT
|
||||
raise
|
||||
except generic_socket_error as exc:
|
||||
"""It is unclear why exactly this happens.
|
||||
|
||||
It's reproducible only under Python 2 with openssl>1.0 and stdlib
|
||||
``ssl`` wrapper, and only with CherryPy.
|
||||
So it looks like some healthcheck tries to connect to this socket
|
||||
during startup (from the same process).
|
||||
|
||||
|
||||
Ref: https://github.com/cherrypy/cherrypy/issues/1618
|
||||
"""
|
||||
if six.PY2 and exc.args == (0, 'Error'):
|
||||
return EMPTY_RESULT
|
||||
raise
|
||||
return s, self.get_environ(s)
|
||||
|
||||
# TODO: fill this out more with mod ssl env
|
||||
def get_environ(self, sock):
|
||||
"""Create WSGI environ entries to be merged into each request."""
|
||||
cipher = sock.cipher()
|
||||
ssl_environ = {
|
||||
'wsgi.url_scheme': 'https',
|
||||
'HTTPS': 'on',
|
||||
'SSL_PROTOCOL': cipher[1],
|
||||
'SSL_CIPHER': cipher[0]
|
||||
# SSL_VERSION_INTERFACE string The mod_ssl program version
|
||||
# SSL_VERSION_LIBRARY string The OpenSSL program version
|
||||
}
|
||||
return ssl_environ
|
||||
|
||||
def makefile(self, sock, mode='r', bufsize=DEFAULT_BUFFER_SIZE):
|
||||
"""Return socket file object."""
|
||||
return MakeFile(sock, mode, bufsize)
|
267
libraries/cheroot/ssl/pyopenssl.py
Normal file
267
libraries/cheroot/ssl/pyopenssl.py
Normal file
|
@ -0,0 +1,267 @@
|
|||
"""
|
||||
A library for integrating pyOpenSSL with Cheroot.
|
||||
|
||||
The OpenSSL module must be importable for SSL functionality.
|
||||
You can obtain it from `here <https://launchpad.net/pyopenssl>`_.
|
||||
|
||||
To use this module, set HTTPServer.ssl_adapter to an instance of
|
||||
ssl.Adapter. There are two ways to use SSL:
|
||||
|
||||
Method One
|
||||
----------
|
||||
|
||||
* ``ssl_adapter.context``: an instance of SSL.Context.
|
||||
|
||||
If this is not None, it is assumed to be an SSL.Context instance,
|
||||
and will be passed to SSL.Connection on bind(). The developer is
|
||||
responsible for forming a valid Context object. This approach is
|
||||
to be preferred for more flexibility, e.g. if the cert and key are
|
||||
streams instead of files, or need decryption, or SSL.SSLv3_METHOD
|
||||
is desired instead of the default SSL.SSLv23_METHOD, etc. Consult
|
||||
the pyOpenSSL documentation for complete options.
|
||||
|
||||
Method Two (shortcut)
|
||||
---------------------
|
||||
|
||||
* ``ssl_adapter.certificate``: the filename of the server SSL certificate.
|
||||
* ``ssl_adapter.private_key``: the filename of the server's private key file.
|
||||
|
||||
Both are None by default. If ssl_adapter.context is None, but .private_key
|
||||
and .certificate are both given and valid, they will be read, and the
|
||||
context will be automatically created from them.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
__metaclass__ = type
|
||||
|
||||
import socket
|
||||
import threading
|
||||
import time
|
||||
|
||||
try:
|
||||
from OpenSSL import SSL
|
||||
from OpenSSL import crypto
|
||||
except ImportError:
|
||||
SSL = None
|
||||
|
||||
from . import Adapter
|
||||
from .. import errors, server as cheroot_server
|
||||
from ..makefile import MakeFile
|
||||
|
||||
|
||||
class SSL_fileobject(MakeFile):
|
||||
"""SSL file object attached to a socket object."""
|
||||
|
||||
ssl_timeout = 3
|
||||
ssl_retry = .01
|
||||
|
||||
def _safe_call(self, is_reader, call, *args, **kwargs):
|
||||
"""Wrap the given call with SSL error-trapping.
|
||||
|
||||
is_reader: if False EOF errors will be raised. If True, EOF errors
|
||||
will return "" (to emulate normal sockets).
|
||||
"""
|
||||
start = time.time()
|
||||
while True:
|
||||
try:
|
||||
return call(*args, **kwargs)
|
||||
except SSL.WantReadError:
|
||||
# Sleep and try again. This is dangerous, because it means
|
||||
# the rest of the stack has no way of differentiating
|
||||
# between a "new handshake" error and "client dropped".
|
||||
# Note this isn't an endless loop: there's a timeout below.
|
||||
time.sleep(self.ssl_retry)
|
||||
except SSL.WantWriteError:
|
||||
time.sleep(self.ssl_retry)
|
||||
except SSL.SysCallError as e:
|
||||
if is_reader and e.args == (-1, 'Unexpected EOF'):
|
||||
return ''
|
||||
|
||||
errnum = e.args[0]
|
||||
if is_reader and errnum in errors.socket_errors_to_ignore:
|
||||
return ''
|
||||
raise socket.error(errnum)
|
||||
except SSL.Error as e:
|
||||
if is_reader and e.args == (-1, 'Unexpected EOF'):
|
||||
return ''
|
||||
|
||||
thirdarg = None
|
||||
try:
|
||||
thirdarg = e.args[0][0][2]
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
if thirdarg == 'http request':
|
||||
# The client is talking HTTP to an HTTPS server.
|
||||
raise errors.NoSSLError()
|
||||
|
||||
raise errors.FatalSSLAlert(*e.args)
|
||||
|
||||
if time.time() - start > self.ssl_timeout:
|
||||
raise socket.timeout('timed out')
|
||||
|
||||
def recv(self, size):
|
||||
"""Receive message of a size from the socket."""
|
||||
return self._safe_call(True, super(SSL_fileobject, self).recv, size)
|
||||
|
||||
def sendall(self, *args, **kwargs):
|
||||
"""Send whole message to the socket."""
|
||||
return self._safe_call(False, super(SSL_fileobject, self).sendall,
|
||||
*args, **kwargs)
|
||||
|
||||
def send(self, *args, **kwargs):
|
||||
"""Send some part of message to the socket."""
|
||||
return self._safe_call(False, super(SSL_fileobject, self).send,
|
||||
*args, **kwargs)
|
||||
|
||||
|
||||
class SSLConnection:
|
||||
"""A thread-safe wrapper for an SSL.Connection.
|
||||
|
||||
``*args``: the arguments to create the wrapped ``SSL.Connection(*args)``.
|
||||
"""
|
||||
|
||||
def __init__(self, *args):
|
||||
"""Initialize SSLConnection instance."""
|
||||
self._ssl_conn = SSL.Connection(*args)
|
||||
self._lock = threading.RLock()
|
||||
|
||||
for f in ('get_context', 'pending', 'send', 'write', 'recv', 'read',
|
||||
'renegotiate', 'bind', 'listen', 'connect', 'accept',
|
||||
'setblocking', 'fileno', 'close', 'get_cipher_list',
|
||||
'getpeername', 'getsockname', 'getsockopt', 'setsockopt',
|
||||
'makefile', 'get_app_data', 'set_app_data', 'state_string',
|
||||
'sock_shutdown', 'get_peer_certificate', 'want_read',
|
||||
'want_write', 'set_connect_state', 'set_accept_state',
|
||||
'connect_ex', 'sendall', 'settimeout', 'gettimeout'):
|
||||
exec("""def %s(self, *args):
|
||||
self._lock.acquire()
|
||||
try:
|
||||
return self._ssl_conn.%s(*args)
|
||||
finally:
|
||||
self._lock.release()
|
||||
""" % (f, f))
|
||||
|
||||
def shutdown(self, *args):
|
||||
"""Shutdown the SSL connection.
|
||||
|
||||
Ignore all incoming args since pyOpenSSL.socket.shutdown takes no args.
|
||||
"""
|
||||
self._lock.acquire()
|
||||
try:
|
||||
return self._ssl_conn.shutdown()
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
|
||||
class pyOpenSSLAdapter(Adapter):
|
||||
"""A wrapper for integrating pyOpenSSL with Cheroot."""
|
||||
|
||||
certificate = None
|
||||
"""The filename of the server SSL certificate."""
|
||||
|
||||
private_key = None
|
||||
"""The filename of the server's private key file."""
|
||||
|
||||
certificate_chain = None
|
||||
"""Optional. The filename of CA's intermediate certificate bundle.
|
||||
|
||||
This is needed for cheaper "chained root" SSL certificates, and should be
|
||||
left as None if not required."""
|
||||
|
||||
context = None
|
||||
"""An instance of SSL.Context."""
|
||||
|
||||
ciphers = None
|
||||
"""The ciphers list of SSL."""
|
||||
|
||||
def __init__(
|
||||
self, certificate, private_key, certificate_chain=None,
|
||||
ciphers=None):
|
||||
"""Initialize OpenSSL Adapter instance."""
|
||||
if SSL is None:
|
||||
raise ImportError('You must install pyOpenSSL to use HTTPS.')
|
||||
|
||||
super(pyOpenSSLAdapter, self).__init__(
|
||||
certificate, private_key, certificate_chain, ciphers)
|
||||
|
||||
self._environ = None
|
||||
|
||||
def bind(self, sock):
|
||||
"""Wrap and return the given socket."""
|
||||
if self.context is None:
|
||||
self.context = self.get_context()
|
||||
conn = SSLConnection(self.context, sock)
|
||||
self._environ = self.get_environ()
|
||||
return conn
|
||||
|
||||
def wrap(self, sock):
|
||||
"""Wrap and return the given socket, plus WSGI environ entries."""
|
||||
return sock, self._environ.copy()
|
||||
|
||||
def get_context(self):
|
||||
"""Return an SSL.Context from self attributes."""
|
||||
# See https://code.activestate.com/recipes/442473/
|
||||
c = SSL.Context(SSL.SSLv23_METHOD)
|
||||
c.use_privatekey_file(self.private_key)
|
||||
if self.certificate_chain:
|
||||
c.load_verify_locations(self.certificate_chain)
|
||||
c.use_certificate_file(self.certificate)
|
||||
return c
|
||||
|
||||
def get_environ(self):
|
||||
"""Return WSGI environ entries to be merged into each request."""
|
||||
ssl_environ = {
|
||||
'HTTPS': 'on',
|
||||
# pyOpenSSL doesn't provide access to any of these AFAICT
|
||||
# 'SSL_PROTOCOL': 'SSLv2',
|
||||
# SSL_CIPHER string The cipher specification name
|
||||
# SSL_VERSION_INTERFACE string The mod_ssl program version
|
||||
# SSL_VERSION_LIBRARY string The OpenSSL program version
|
||||
}
|
||||
|
||||
if self.certificate:
|
||||
# Server certificate attributes
|
||||
cert = open(self.certificate, 'rb').read()
|
||||
cert = crypto.load_certificate(crypto.FILETYPE_PEM, cert)
|
||||
ssl_environ.update({
|
||||
'SSL_SERVER_M_VERSION': cert.get_version(),
|
||||
'SSL_SERVER_M_SERIAL': cert.get_serial_number(),
|
||||
# 'SSL_SERVER_V_START':
|
||||
# Validity of server's certificate (start time),
|
||||
# 'SSL_SERVER_V_END':
|
||||
# Validity of server's certificate (end time),
|
||||
})
|
||||
|
||||
for prefix, dn in [('I', cert.get_issuer()),
|
||||
('S', cert.get_subject())]:
|
||||
# X509Name objects don't seem to have a way to get the
|
||||
# complete DN string. Use str() and slice it instead,
|
||||
# because str(dn) == "<X509Name object '/C=US/ST=...'>"
|
||||
dnstr = str(dn)[18:-2]
|
||||
|
||||
wsgikey = 'SSL_SERVER_%s_DN' % prefix
|
||||
ssl_environ[wsgikey] = dnstr
|
||||
|
||||
# The DN should be of the form: /k1=v1/k2=v2, but we must allow
|
||||
# for any value to contain slashes itself (in a URL).
|
||||
while dnstr:
|
||||
pos = dnstr.rfind('=')
|
||||
dnstr, value = dnstr[:pos], dnstr[pos + 1:]
|
||||
pos = dnstr.rfind('/')
|
||||
dnstr, key = dnstr[:pos], dnstr[pos + 1:]
|
||||
if key and value:
|
||||
wsgikey = 'SSL_SERVER_%s_DN_%s' % (prefix, key)
|
||||
ssl_environ[wsgikey] = value
|
||||
|
||||
return ssl_environ
|
||||
|
||||
def makefile(self, sock, mode='r', bufsize=-1):
|
||||
"""Return socket file object."""
|
||||
if SSL and isinstance(sock, SSL.ConnectionType):
|
||||
timeout = sock.gettimeout()
|
||||
f = SSL_fileobject(sock, mode, bufsize)
|
||||
f.ssl_timeout = timeout
|
||||
return f
|
||||
else:
|
||||
return cheroot_server.CP_fileobject(sock, mode, bufsize)
|
1
libraries/cheroot/test/__init__.py
Normal file
1
libraries/cheroot/test/__init__.py
Normal file
|
@ -0,0 +1 @@
|
|||
"""Cheroot test suite."""
|
27
libraries/cheroot/test/conftest.py
Normal file
27
libraries/cheroot/test/conftest.py
Normal file
|
@ -0,0 +1,27 @@
|
|||
"""Pytest configuration module.
|
||||
|
||||
Contains fixtures, which are tightly bound to the Cheroot framework
|
||||
itself, useless for end-users' app testing.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
__metaclass__ = type
|
||||
|
||||
import pytest
|
||||
|
||||
from ..testing import ( # noqa: F401
|
||||
native_server, wsgi_server,
|
||||
)
|
||||
from ..testing import get_server_client
|
||||
|
||||
|
||||
@pytest.fixture # noqa: F811
|
||||
def wsgi_server_client(wsgi_server):
|
||||
"""Create a test client out of given WSGI server."""
|
||||
return get_server_client(wsgi_server)
|
||||
|
||||
|
||||
@pytest.fixture # noqa: F811
|
||||
def native_server_client(native_server):
|
||||
"""Create a test client out of given HTTP server."""
|
||||
return get_server_client(native_server)
|
169
libraries/cheroot/test/helper.py
Normal file
169
libraries/cheroot/test/helper.py
Normal file
|
@ -0,0 +1,169 @@
|
|||
"""A library of helper functions for the Cheroot test suite."""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
__metaclass__ = type
|
||||
|
||||
import datetime
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import threading
|
||||
import types
|
||||
|
||||
from six.moves import http_client
|
||||
|
||||
import six
|
||||
|
||||
import cheroot.server
|
||||
import cheroot.wsgi
|
||||
|
||||
from cheroot.test import webtest
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
thisdir = os.path.abspath(os.path.dirname(__file__))
|
||||
serverpem = os.path.join(os.getcwd(), thisdir, 'test.pem')
|
||||
|
||||
|
||||
config = {
|
||||
'bind_addr': ('127.0.0.1', 54583),
|
||||
'server': 'wsgi',
|
||||
'wsgi_app': None,
|
||||
}
|
||||
|
||||
|
||||
class CherootWebCase(webtest.WebCase):
|
||||
"""Helper class for a web app test suite."""
|
||||
|
||||
script_name = ''
|
||||
scheme = 'http'
|
||||
|
||||
available_servers = {
|
||||
'wsgi': cheroot.wsgi.Server,
|
||||
'native': cheroot.server.HTTPServer,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def setup_class(cls):
|
||||
"""Create and run one HTTP server per class."""
|
||||
conf = config.copy()
|
||||
conf.update(getattr(cls, 'config', {}))
|
||||
|
||||
s_class = conf.pop('server', 'wsgi')
|
||||
server_factory = cls.available_servers.get(s_class)
|
||||
if server_factory is None:
|
||||
raise RuntimeError('Unknown server in config: %s' % conf['server'])
|
||||
cls.httpserver = server_factory(**conf)
|
||||
|
||||
cls.HOST, cls.PORT = cls.httpserver.bind_addr
|
||||
if cls.httpserver.ssl_adapter is None:
|
||||
ssl = ''
|
||||
cls.scheme = 'http'
|
||||
else:
|
||||
ssl = ' (ssl)'
|
||||
cls.HTTP_CONN = http_client.HTTPSConnection
|
||||
cls.scheme = 'https'
|
||||
|
||||
v = sys.version.split()[0]
|
||||
log.info('Python version used to run this test script: %s' % v)
|
||||
log.info('Cheroot version: %s' % cheroot.__version__)
|
||||
log.info('HTTP server version: %s%s' % (cls.httpserver.protocol, ssl))
|
||||
log.info('PID: %s' % os.getpid())
|
||||
|
||||
if hasattr(cls, 'setup_server'):
|
||||
# Clear the wsgi server so that
|
||||
# it can be updated with the new root
|
||||
cls.setup_server()
|
||||
cls.start()
|
||||
|
||||
@classmethod
|
||||
def teardown_class(cls):
|
||||
"""Cleanup HTTP server."""
|
||||
if hasattr(cls, 'setup_server'):
|
||||
cls.stop()
|
||||
|
||||
@classmethod
|
||||
def start(cls):
|
||||
"""Load and start the HTTP server."""
|
||||
threading.Thread(target=cls.httpserver.safe_start).start()
|
||||
while not cls.httpserver.ready:
|
||||
time.sleep(0.1)
|
||||
|
||||
@classmethod
|
||||
def stop(cls):
|
||||
"""Terminate HTTP server."""
|
||||
cls.httpserver.stop()
|
||||
td = getattr(cls, 'teardown', None)
|
||||
if td:
|
||||
td()
|
||||
|
||||
date_tolerance = 2
|
||||
|
||||
def assertEqualDates(self, dt1, dt2, seconds=None):
|
||||
"""Assert abs(dt1 - dt2) is within Y seconds."""
|
||||
if seconds is None:
|
||||
seconds = self.date_tolerance
|
||||
|
||||
if dt1 > dt2:
|
||||
diff = dt1 - dt2
|
||||
else:
|
||||
diff = dt2 - dt1
|
||||
if not diff < datetime.timedelta(seconds=seconds):
|
||||
raise AssertionError('%r and %r are not within %r seconds.' %
|
||||
(dt1, dt2, seconds))
|
||||
|
||||
|
||||
class Request:
|
||||
"""HTTP request container."""
|
||||
|
||||
def __init__(self, environ):
|
||||
"""Initialize HTTP request."""
|
||||
self.environ = environ
|
||||
|
||||
|
||||
class Response:
|
||||
"""HTTP response container."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize HTTP response."""
|
||||
self.status = '200 OK'
|
||||
self.headers = {'Content-Type': 'text/html'}
|
||||
self.body = None
|
||||
|
||||
def output(self):
|
||||
"""Generate iterable response body object."""
|
||||
if self.body is None:
|
||||
return []
|
||||
elif isinstance(self.body, six.text_type):
|
||||
return [self.body.encode('iso-8859-1')]
|
||||
elif isinstance(self.body, six.binary_type):
|
||||
return [self.body]
|
||||
else:
|
||||
return [x.encode('iso-8859-1') for x in self.body]
|
||||
|
||||
|
||||
class Controller:
|
||||
"""WSGI app for tests."""
|
||||
|
||||
def __call__(self, environ, start_response):
|
||||
"""WSGI request handler."""
|
||||
req, resp = Request(environ), Response()
|
||||
try:
|
||||
# Python 3 supports unicode attribute names
|
||||
# Python 2 encodes them
|
||||
handler = self.handlers[environ['PATH_INFO']]
|
||||
except KeyError:
|
||||
resp.status = '404 Not Found'
|
||||
else:
|
||||
output = handler(req, resp)
|
||||
if (output is not None and
|
||||
not any(resp.status.startswith(status_code)
|
||||
for status_code in ('204', '304'))):
|
||||
resp.body = output
|
||||
try:
|
||||
resp.headers.setdefault('Content-Length', str(len(output)))
|
||||
except TypeError:
|
||||
if not isinstance(output, types.GeneratorType):
|
||||
raise
|
||||
start_response(resp.status, resp.headers.items())
|
||||
return resp.output()
|
38
libraries/cheroot/test/test.pem
Normal file
38
libraries/cheroot/test/test.pem
Normal file
|
@ -0,0 +1,38 @@
|
|||
-----BEGIN RSA PRIVATE KEY-----
|
||||
MIICXAIBAAKBgQDBKo554mzIMY+AByUNpaUOP9bJnQ7ZLQe9XgHwoLJR4VzpyZZZ
|
||||
R9L4WtImEew05FY3Izerfm3MN3+MC0tJ6yQU9sOiU3vBW6RrLIMlfKsnRwBRZ0Kn
|
||||
da+O6xldVSosu8Ev3z9VZ94iC/ZgKzrH7Mjj/U8/MQO7RBS/LAqee8bFNQIDAQAB
|
||||
AoGAWOCF0ZrWxn3XMucWq2LNwPKqlvVGwbIwX3cDmX22zmnM4Fy6arXbYh4XlyCj
|
||||
9+ofqRrxIFz5k/7tFriTmZ0xag5+Jdx+Kwg0/twiP7XCNKipFogwe1Hznw8OFAoT
|
||||
enKBdj2+/n2o0Bvo/tDB59m9L/538d46JGQUmJlzMyqYikECQQDyoq+8CtMNvE18
|
||||
8VgHcR/KtApxWAjj4HpaHYL637ATjThetUZkW92mgDgowyplthusxdNqhHWyv7E8
|
||||
tWNdYErZAkEAy85ShTR0M5aWmrE7o0r0SpWInAkNBH9aXQRRARFYsdBtNfRu6I0i
|
||||
0lvU9wiu3eF57FMEC86yViZ5UBnQfTu7vQJAVesj/Zt7pwaCDfdMa740OsxMUlyR
|
||||
MVhhGx4OLpYdPJ8qUecxGQKq13XZ7R1HGyNEY4bd2X80Smq08UFuATfC6QJAH8UB
|
||||
yBHtKz2GLIcELOg6PIYizW/7v3+6rlVF60yw7sb2vzpjL40QqIn4IKoR2DSVtOkb
|
||||
8FtAIX3N21aq0VrGYQJBAIPiaEc2AZ8Bq2GC4F3wOz/BxJ/izvnkiotR12QK4fh5
|
||||
yjZMhTjWCas5zwHR5PDjlD88AWGDMsZ1PicD4348xJQ=
|
||||
-----END RSA PRIVATE KEY-----
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIDxTCCAy6gAwIBAgIJAI18BD7eQxlGMA0GCSqGSIb3DQEBBAUAMIGeMQswCQYD
|
||||
VQQGEwJVUzETMBEGA1UECBMKQ2FsaWZvcm5pYTESMBAGA1UEBxMJU2FuIERpZWdv
|
||||
MRkwFwYDVQQKExBDaGVycnlQeSBQcm9qZWN0MREwDwYDVQQLEwhkZXYtdGVzdDEW
|
||||
MBQGA1UEAxMNQ2hlcnJ5UHkgVGVhbTEgMB4GCSqGSIb3DQEJARYRcmVtaUBjaGVy
|
||||
cnlweS5vcmcwHhcNMDYwOTA5MTkyMDIwWhcNMzQwMTI0MTkyMDIwWjCBnjELMAkG
|
||||
A1UEBhMCVVMxEzARBgNVBAgTCkNhbGlmb3JuaWExEjAQBgNVBAcTCVNhbiBEaWVn
|
||||
bzEZMBcGA1UEChMQQ2hlcnJ5UHkgUHJvamVjdDERMA8GA1UECxMIZGV2LXRlc3Qx
|
||||
FjAUBgNVBAMTDUNoZXJyeVB5IFRlYW0xIDAeBgkqhkiG9w0BCQEWEXJlbWlAY2hl
|
||||
cnJ5cHkub3JnMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDBKo554mzIMY+A
|
||||
ByUNpaUOP9bJnQ7ZLQe9XgHwoLJR4VzpyZZZR9L4WtImEew05FY3Izerfm3MN3+M
|
||||
C0tJ6yQU9sOiU3vBW6RrLIMlfKsnRwBRZ0Knda+O6xldVSosu8Ev3z9VZ94iC/Zg
|
||||
KzrH7Mjj/U8/MQO7RBS/LAqee8bFNQIDAQABo4IBBzCCAQMwHQYDVR0OBBYEFDIQ
|
||||
2feb71tVZCWpU0qJ/Tw+wdtoMIHTBgNVHSMEgcswgciAFDIQ2feb71tVZCWpU0qJ
|
||||
/Tw+wdtooYGkpIGhMIGeMQswCQYDVQQGEwJVUzETMBEGA1UECBMKQ2FsaWZvcm5p
|
||||
YTESMBAGA1UEBxMJU2FuIERpZWdvMRkwFwYDVQQKExBDaGVycnlQeSBQcm9qZWN0
|
||||
MREwDwYDVQQLEwhkZXYtdGVzdDEWMBQGA1UEAxMNQ2hlcnJ5UHkgVGVhbTEgMB4G
|
||||
CSqGSIb3DQEJARYRcmVtaUBjaGVycnlweS5vcmeCCQCNfAQ+3kMZRjAMBgNVHRME
|
||||
BTADAQH/MA0GCSqGSIb3DQEBBAUAA4GBAL7AAQz7IePV48ZTAFHKr88ntPALsL5S
|
||||
8vHCZPNMevNkLTj3DYUw2BcnENxMjm1kou2F2BkvheBPNZKIhc6z4hAml3ed1xa2
|
||||
D7w6e6OTcstdK/+KrPDDHeOP1dhMWNs2JE1bNlfF1LiXzYKSXpe88eCKjCXsCT/T
|
||||
NluCaWQys3MS
|
||||
-----END CERTIFICATE-----
|
49
libraries/cheroot/test/test__compat.py
Normal file
49
libraries/cheroot/test/test__compat.py
Normal file
|
@ -0,0 +1,49 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
"""Test suite for cross-python compatibility helpers."""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
__metaclass__ = type
|
||||
|
||||
import pytest
|
||||
import six
|
||||
|
||||
from cheroot._compat import ntob, ntou, bton
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'func,inp,out',
|
||||
[
|
||||
(ntob, 'bar', b'bar'),
|
||||
(ntou, 'bar', u'bar'),
|
||||
(bton, b'bar', 'bar'),
|
||||
],
|
||||
)
|
||||
def test_compat_functions_positive(func, inp, out):
|
||||
"""Check that compat functions work with correct input."""
|
||||
assert func(inp, encoding='utf-8') == out
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'func',
|
||||
[
|
||||
ntob,
|
||||
ntou,
|
||||
],
|
||||
)
|
||||
def test_compat_functions_negative_nonnative(func):
|
||||
"""Check that compat functions fail loudly for incorrect input."""
|
||||
non_native_test_str = b'bar' if six.PY3 else u'bar'
|
||||
with pytest.raises(TypeError):
|
||||
func(non_native_test_str, encoding='utf-8')
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='This test does not work now')
|
||||
@pytest.mark.skipif(
|
||||
six.PY3,
|
||||
reason='This code path only appears in Python 2 version.',
|
||||
)
|
||||
def test_ntou_escape():
|
||||
"""Check that ntou supports escape-encoding under Python 2."""
|
||||
expected = u''
|
||||
actual = ntou('hi'.encode('ISO-8859-1'), encoding='escape')
|
||||
assert actual == expected
|
897
libraries/cheroot/test/test_conn.py
Normal file
897
libraries/cheroot/test/test_conn.py
Normal file
|
@ -0,0 +1,897 @@
|
|||
"""Tests for TCP connection handling, including proper and timely close."""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
__metaclass__ = type
|
||||
|
||||
import socket
|
||||
import time
|
||||
|
||||
from six.moves import range, http_client, urllib
|
||||
|
||||
import six
|
||||
import pytest
|
||||
|
||||
from cheroot.test import helper, webtest
|
||||
|
||||
|
||||
timeout = 1
|
||||
pov = 'pPeErRsSiIsStTeEnNcCeE oOfF vViIsSiIoOnN'
|
||||
|
||||
|
||||
class Controller(helper.Controller):
|
||||
"""Controller for serving WSGI apps."""
|
||||
|
||||
def hello(req, resp):
|
||||
"""Render Hello world."""
|
||||
return 'Hello, world!'
|
||||
|
||||
def pov(req, resp):
|
||||
"""Render pov value."""
|
||||
return pov
|
||||
|
||||
def stream(req, resp):
|
||||
"""Render streaming response."""
|
||||
if 'set_cl' in req.environ['QUERY_STRING']:
|
||||
resp.headers['Content-Length'] = str(10)
|
||||
|
||||
def content():
|
||||
for x in range(10):
|
||||
yield str(x)
|
||||
|
||||
return content()
|
||||
|
||||
def upload(req, resp):
|
||||
"""Process file upload and render thank."""
|
||||
if not req.environ['REQUEST_METHOD'] == 'POST':
|
||||
raise AssertionError("'POST' != request.method %r" %
|
||||
req.environ['REQUEST_METHOD'])
|
||||
return "thanks for '%s'" % req.environ['wsgi.input'].read()
|
||||
|
||||
def custom_204(req, resp):
|
||||
"""Render response with status 204."""
|
||||
resp.status = '204'
|
||||
return 'Code = 204'
|
||||
|
||||
def custom_304(req, resp):
|
||||
"""Render response with status 304."""
|
||||
resp.status = '304'
|
||||
return 'Code = 304'
|
||||
|
||||
def err_before_read(req, resp):
|
||||
"""Render response with status 500."""
|
||||
resp.status = '500 Internal Server Error'
|
||||
return 'ok'
|
||||
|
||||
def one_megabyte_of_a(req, resp):
|
||||
"""Render 1MB response."""
|
||||
return ['a' * 1024] * 1024
|
||||
|
||||
def wrong_cl_buffered(req, resp):
|
||||
"""Render buffered response with invalid length value."""
|
||||
resp.headers['Content-Length'] = '5'
|
||||
return 'I have too many bytes'
|
||||
|
||||
def wrong_cl_unbuffered(req, resp):
|
||||
"""Render unbuffered response with invalid length value."""
|
||||
resp.headers['Content-Length'] = '5'
|
||||
return ['I too', ' have too many bytes']
|
||||
|
||||
def _munge(string):
|
||||
"""Encode PATH_INFO correctly depending on Python version.
|
||||
|
||||
WSGI 1.0 is a mess around unicode. Create endpoints
|
||||
that match the PATH_INFO that it produces.
|
||||
"""
|
||||
if six.PY3:
|
||||
return string.encode('utf-8').decode('latin-1')
|
||||
return string
|
||||
|
||||
handlers = {
|
||||
'/hello': hello,
|
||||
'/pov': pov,
|
||||
'/page1': pov,
|
||||
'/page2': pov,
|
||||
'/page3': pov,
|
||||
'/stream': stream,
|
||||
'/upload': upload,
|
||||
'/custom/204': custom_204,
|
||||
'/custom/304': custom_304,
|
||||
'/err_before_read': err_before_read,
|
||||
'/one_megabyte_of_a': one_megabyte_of_a,
|
||||
'/wrong_cl_buffered': wrong_cl_buffered,
|
||||
'/wrong_cl_unbuffered': wrong_cl_unbuffered,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def testing_server(wsgi_server_client):
|
||||
"""Attach a WSGI app to the given server and pre-configure it."""
|
||||
app = Controller()
|
||||
|
||||
def _timeout(req, resp):
|
||||
return str(wsgi_server.timeout)
|
||||
app.handlers['/timeout'] = _timeout
|
||||
wsgi_server = wsgi_server_client.server_instance
|
||||
wsgi_server.wsgi_app = app
|
||||
wsgi_server.max_request_body_size = 1001
|
||||
wsgi_server.timeout = timeout
|
||||
wsgi_server.server_client = wsgi_server_client
|
||||
return wsgi_server
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_client(testing_server):
|
||||
"""Get and return a test client out of the given server."""
|
||||
return testing_server.server_client
|
||||
|
||||
|
||||
def header_exists(header_name, headers):
|
||||
"""Check that a header is present."""
|
||||
return header_name.lower() in (k.lower() for (k, _) in headers)
|
||||
|
||||
|
||||
def header_has_value(header_name, header_value, headers):
|
||||
"""Check that a header with a given value is present."""
|
||||
return header_name.lower() in (k.lower() for (k, v) in headers
|
||||
if v == header_value)
|
||||
|
||||
|
||||
def test_HTTP11_persistent_connections(test_client):
|
||||
"""Test persistent HTTP/1.1 connections."""
|
||||
# Initialize a persistent HTTP connection
|
||||
http_connection = test_client.get_connection()
|
||||
http_connection.auto_open = False
|
||||
http_connection.connect()
|
||||
|
||||
# Make the first request and assert there's no "Connection: close".
|
||||
status_line, actual_headers, actual_resp_body = test_client.get(
|
||||
'/pov', http_conn=http_connection
|
||||
)
|
||||
actual_status = int(status_line[:3])
|
||||
assert actual_status == 200
|
||||
assert status_line[4:] == 'OK'
|
||||
assert actual_resp_body == pov.encode()
|
||||
assert not header_exists('Connection', actual_headers)
|
||||
|
||||
# Make another request on the same connection.
|
||||
status_line, actual_headers, actual_resp_body = test_client.get(
|
||||
'/page1', http_conn=http_connection
|
||||
)
|
||||
actual_status = int(status_line[:3])
|
||||
assert actual_status == 200
|
||||
assert status_line[4:] == 'OK'
|
||||
assert actual_resp_body == pov.encode()
|
||||
assert not header_exists('Connection', actual_headers)
|
||||
|
||||
# Test client-side close.
|
||||
status_line, actual_headers, actual_resp_body = test_client.get(
|
||||
'/page2', http_conn=http_connection,
|
||||
headers=[('Connection', 'close')]
|
||||
)
|
||||
actual_status = int(status_line[:3])
|
||||
assert actual_status == 200
|
||||
assert status_line[4:] == 'OK'
|
||||
assert actual_resp_body == pov.encode()
|
||||
assert header_has_value('Connection', 'close', actual_headers)
|
||||
|
||||
# Make another request on the same connection, which should error.
|
||||
with pytest.raises(http_client.NotConnected):
|
||||
test_client.get('/pov', http_conn=http_connection)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'set_cl',
|
||||
(
|
||||
False, # Without Content-Length
|
||||
True, # With Content-Length
|
||||
)
|
||||
)
|
||||
def test_streaming_11(test_client, set_cl):
|
||||
"""Test serving of streaming responses with HTTP/1.1 protocol."""
|
||||
# Initialize a persistent HTTP connection
|
||||
http_connection = test_client.get_connection()
|
||||
http_connection.auto_open = False
|
||||
http_connection.connect()
|
||||
|
||||
# Make the first request and assert there's no "Connection: close".
|
||||
status_line, actual_headers, actual_resp_body = test_client.get(
|
||||
'/pov', http_conn=http_connection
|
||||
)
|
||||
actual_status = int(status_line[:3])
|
||||
assert actual_status == 200
|
||||
assert status_line[4:] == 'OK'
|
||||
assert actual_resp_body == pov.encode()
|
||||
assert not header_exists('Connection', actual_headers)
|
||||
|
||||
# Make another, streamed request on the same connection.
|
||||
if set_cl:
|
||||
# When a Content-Length is provided, the content should stream
|
||||
# without closing the connection.
|
||||
status_line, actual_headers, actual_resp_body = test_client.get(
|
||||
'/stream?set_cl=Yes', http_conn=http_connection
|
||||
)
|
||||
assert header_exists('Content-Length', actual_headers)
|
||||
assert not header_has_value('Connection', 'close', actual_headers)
|
||||
assert not header_exists('Transfer-Encoding', actual_headers)
|
||||
|
||||
assert actual_status == 200
|
||||
assert status_line[4:] == 'OK'
|
||||
assert actual_resp_body == b'0123456789'
|
||||
else:
|
||||
# When no Content-Length response header is provided,
|
||||
# streamed output will either close the connection, or use
|
||||
# chunked encoding, to determine transfer-length.
|
||||
status_line, actual_headers, actual_resp_body = test_client.get(
|
||||
'/stream', http_conn=http_connection
|
||||
)
|
||||
assert not header_exists('Content-Length', actual_headers)
|
||||
assert actual_status == 200
|
||||
assert status_line[4:] == 'OK'
|
||||
assert actual_resp_body == b'0123456789'
|
||||
|
||||
chunked_response = False
|
||||
for k, v in actual_headers:
|
||||
if k.lower() == 'transfer-encoding':
|
||||
if str(v) == 'chunked':
|
||||
chunked_response = True
|
||||
|
||||
if chunked_response:
|
||||
assert not header_has_value('Connection', 'close', actual_headers)
|
||||
else:
|
||||
assert header_has_value('Connection', 'close', actual_headers)
|
||||
|
||||
# Make another request on the same connection, which should
|
||||
# error.
|
||||
with pytest.raises(http_client.NotConnected):
|
||||
test_client.get('/pov', http_conn=http_connection)
|
||||
|
||||
# Try HEAD.
|
||||
# See https://www.bitbucket.org/cherrypy/cherrypy/issue/864.
|
||||
# TODO: figure out how can this be possible on an closed connection
|
||||
# (chunked_response case)
|
||||
status_line, actual_headers, actual_resp_body = test_client.head(
|
||||
'/stream', http_conn=http_connection
|
||||
)
|
||||
assert actual_status == 200
|
||||
assert status_line[4:] == 'OK'
|
||||
assert actual_resp_body == b''
|
||||
assert not header_exists('Transfer-Encoding', actual_headers)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'set_cl',
|
||||
(
|
||||
False, # Without Content-Length
|
||||
True, # With Content-Length
|
||||
)
|
||||
)
|
||||
def test_streaming_10(test_client, set_cl):
|
||||
"""Test serving of streaming responses with HTTP/1.0 protocol."""
|
||||
original_server_protocol = test_client.server_instance.protocol
|
||||
test_client.server_instance.protocol = 'HTTP/1.0'
|
||||
|
||||
# Initialize a persistent HTTP connection
|
||||
http_connection = test_client.get_connection()
|
||||
http_connection.auto_open = False
|
||||
http_connection.connect()
|
||||
|
||||
# Make the first request and assert Keep-Alive.
|
||||
status_line, actual_headers, actual_resp_body = test_client.get(
|
||||
'/pov', http_conn=http_connection,
|
||||
headers=[('Connection', 'Keep-Alive')],
|
||||
protocol='HTTP/1.0',
|
||||
)
|
||||
actual_status = int(status_line[:3])
|
||||
assert actual_status == 200
|
||||
assert status_line[4:] == 'OK'
|
||||
assert actual_resp_body == pov.encode()
|
||||
assert header_has_value('Connection', 'Keep-Alive', actual_headers)
|
||||
|
||||
# Make another, streamed request on the same connection.
|
||||
if set_cl:
|
||||
# When a Content-Length is provided, the content should
|
||||
# stream without closing the connection.
|
||||
status_line, actual_headers, actual_resp_body = test_client.get(
|
||||
'/stream?set_cl=Yes', http_conn=http_connection,
|
||||
headers=[('Connection', 'Keep-Alive')],
|
||||
protocol='HTTP/1.0',
|
||||
)
|
||||
actual_status = int(status_line[:3])
|
||||
assert actual_status == 200
|
||||
assert status_line[4:] == 'OK'
|
||||
assert actual_resp_body == b'0123456789'
|
||||
|
||||
assert header_exists('Content-Length', actual_headers)
|
||||
assert header_has_value('Connection', 'Keep-Alive', actual_headers)
|
||||
assert not header_exists('Transfer-Encoding', actual_headers)
|
||||
else:
|
||||
# When a Content-Length is not provided,
|
||||
# the server should close the connection.
|
||||
status_line, actual_headers, actual_resp_body = test_client.get(
|
||||
'/stream', http_conn=http_connection,
|
||||
headers=[('Connection', 'Keep-Alive')],
|
||||
protocol='HTTP/1.0',
|
||||
)
|
||||
actual_status = int(status_line[:3])
|
||||
assert actual_status == 200
|
||||
assert status_line[4:] == 'OK'
|
||||
assert actual_resp_body == b'0123456789'
|
||||
|
||||
assert not header_exists('Content-Length', actual_headers)
|
||||
assert not header_has_value('Connection', 'Keep-Alive', actual_headers)
|
||||
assert not header_exists('Transfer-Encoding', actual_headers)
|
||||
|
||||
# Make another request on the same connection, which should error.
|
||||
with pytest.raises(http_client.NotConnected):
|
||||
test_client.get(
|
||||
'/pov', http_conn=http_connection,
|
||||
protocol='HTTP/1.0',
|
||||
)
|
||||
|
||||
test_client.server_instance.protocol = original_server_protocol
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'http_server_protocol',
|
||||
(
|
||||
'HTTP/1.0',
|
||||
'HTTP/1.1',
|
||||
)
|
||||
)
|
||||
def test_keepalive(test_client, http_server_protocol):
|
||||
"""Test Keep-Alive enabled connections."""
|
||||
original_server_protocol = test_client.server_instance.protocol
|
||||
test_client.server_instance.protocol = http_server_protocol
|
||||
|
||||
http_client_protocol = 'HTTP/1.0'
|
||||
|
||||
# Initialize a persistent HTTP connection
|
||||
http_connection = test_client.get_connection()
|
||||
http_connection.auto_open = False
|
||||
http_connection.connect()
|
||||
|
||||
# Test a normal HTTP/1.0 request.
|
||||
status_line, actual_headers, actual_resp_body = test_client.get(
|
||||
'/page2',
|
||||
protocol=http_client_protocol,
|
||||
)
|
||||
actual_status = int(status_line[:3])
|
||||
assert actual_status == 200
|
||||
assert status_line[4:] == 'OK'
|
||||
assert actual_resp_body == pov.encode()
|
||||
assert not header_exists('Connection', actual_headers)
|
||||
|
||||
# Test a keep-alive HTTP/1.0 request.
|
||||
|
||||
status_line, actual_headers, actual_resp_body = test_client.get(
|
||||
'/page3', headers=[('Connection', 'Keep-Alive')],
|
||||
http_conn=http_connection, protocol=http_client_protocol,
|
||||
)
|
||||
actual_status = int(status_line[:3])
|
||||
assert actual_status == 200
|
||||
assert status_line[4:] == 'OK'
|
||||
assert actual_resp_body == pov.encode()
|
||||
assert header_has_value('Connection', 'Keep-Alive', actual_headers)
|
||||
|
||||
# Remove the keep-alive header again.
|
||||
status_line, actual_headers, actual_resp_body = test_client.get(
|
||||
'/page3', http_conn=http_connection,
|
||||
protocol=http_client_protocol,
|
||||
)
|
||||
actual_status = int(status_line[:3])
|
||||
assert actual_status == 200
|
||||
assert status_line[4:] == 'OK'
|
||||
assert actual_resp_body == pov.encode()
|
||||
assert not header_exists('Connection', actual_headers)
|
||||
|
||||
test_client.server_instance.protocol = original_server_protocol
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'timeout_before_headers',
|
||||
(
|
||||
True,
|
||||
False,
|
||||
)
|
||||
)
|
||||
def test_HTTP11_Timeout(test_client, timeout_before_headers):
|
||||
"""Check timeout without sending any data.
|
||||
|
||||
The server will close the conn with a 408.
|
||||
"""
|
||||
conn = test_client.get_connection()
|
||||
conn.auto_open = False
|
||||
conn.connect()
|
||||
|
||||
if not timeout_before_headers:
|
||||
# Connect but send half the headers only.
|
||||
conn.send(b'GET /hello HTTP/1.1')
|
||||
conn.send(('Host: %s' % conn.host).encode('ascii'))
|
||||
# else: Connect but send nothing.
|
||||
|
||||
# Wait for our socket timeout
|
||||
time.sleep(timeout * 2)
|
||||
|
||||
# The request should have returned 408 already.
|
||||
response = conn.response_class(conn.sock, method='GET')
|
||||
response.begin()
|
||||
assert response.status == 408
|
||||
conn.close()
|
||||
|
||||
|
||||
def test_HTTP11_Timeout_after_request(test_client):
|
||||
"""Check timeout after at least one request has succeeded.
|
||||
|
||||
The server should close the connection without 408.
|
||||
"""
|
||||
fail_msg = "Writing to timed out socket didn't fail as it should have: %s"
|
||||
|
||||
# Make an initial request
|
||||
conn = test_client.get_connection()
|
||||
conn.putrequest('GET', '/timeout?t=%s' % timeout, skip_host=True)
|
||||
conn.putheader('Host', conn.host)
|
||||
conn.endheaders()
|
||||
response = conn.response_class(conn.sock, method='GET')
|
||||
response.begin()
|
||||
assert response.status == 200
|
||||
actual_body = response.read()
|
||||
expected_body = str(timeout).encode()
|
||||
assert actual_body == expected_body
|
||||
|
||||
# Make a second request on the same socket
|
||||
conn._output(b'GET /hello HTTP/1.1')
|
||||
conn._output(('Host: %s' % conn.host).encode('ascii'))
|
||||
conn._send_output()
|
||||
response = conn.response_class(conn.sock, method='GET')
|
||||
response.begin()
|
||||
assert response.status == 200
|
||||
actual_body = response.read()
|
||||
expected_body = b'Hello, world!'
|
||||
assert actual_body == expected_body
|
||||
|
||||
# Wait for our socket timeout
|
||||
time.sleep(timeout * 2)
|
||||
|
||||
# Make another request on the same socket, which should error
|
||||
conn._output(b'GET /hello HTTP/1.1')
|
||||
conn._output(('Host: %s' % conn.host).encode('ascii'))
|
||||
conn._send_output()
|
||||
response = conn.response_class(conn.sock, method='GET')
|
||||
try:
|
||||
response.begin()
|
||||
except (socket.error, http_client.BadStatusLine):
|
||||
pass
|
||||
except Exception as ex:
|
||||
pytest.fail(fail_msg % ex)
|
||||
else:
|
||||
if response.status != 408:
|
||||
pytest.fail(fail_msg % response.read())
|
||||
|
||||
conn.close()
|
||||
|
||||
# Make another request on a new socket, which should work
|
||||
conn = test_client.get_connection()
|
||||
conn.putrequest('GET', '/pov', skip_host=True)
|
||||
conn.putheader('Host', conn.host)
|
||||
conn.endheaders()
|
||||
response = conn.response_class(conn.sock, method='GET')
|
||||
response.begin()
|
||||
assert response.status == 200
|
||||
actual_body = response.read()
|
||||
expected_body = pov.encode()
|
||||
assert actual_body == expected_body
|
||||
|
||||
# Make another request on the same socket,
|
||||
# but timeout on the headers
|
||||
conn.send(b'GET /hello HTTP/1.1')
|
||||
# Wait for our socket timeout
|
||||
time.sleep(timeout * 2)
|
||||
response = conn.response_class(conn.sock, method='GET')
|
||||
try:
|
||||
response.begin()
|
||||
except (socket.error, http_client.BadStatusLine):
|
||||
pass
|
||||
except Exception as ex:
|
||||
pytest.fail(fail_msg % ex)
|
||||
else:
|
||||
if response.status != 408:
|
||||
pytest.fail(fail_msg % response.read())
|
||||
|
||||
conn.close()
|
||||
|
||||
# Retry the request on a new connection, which should work
|
||||
conn = test_client.get_connection()
|
||||
conn.putrequest('GET', '/pov', skip_host=True)
|
||||
conn.putheader('Host', conn.host)
|
||||
conn.endheaders()
|
||||
response = conn.response_class(conn.sock, method='GET')
|
||||
response.begin()
|
||||
assert response.status == 200
|
||||
actual_body = response.read()
|
||||
expected_body = pov.encode()
|
||||
assert actual_body == expected_body
|
||||
conn.close()
|
||||
|
||||
|
||||
def test_HTTP11_pipelining(test_client):
|
||||
"""Test HTTP/1.1 pipelining.
|
||||
|
||||
httplib doesn't support this directly.
|
||||
"""
|
||||
conn = test_client.get_connection()
|
||||
|
||||
# Put request 1
|
||||
conn.putrequest('GET', '/hello', skip_host=True)
|
||||
conn.putheader('Host', conn.host)
|
||||
conn.endheaders()
|
||||
|
||||
for trial in range(5):
|
||||
# Put next request
|
||||
conn._output(
|
||||
('GET /hello?%s HTTP/1.1' % trial).encode('iso-8859-1')
|
||||
)
|
||||
conn._output(('Host: %s' % conn.host).encode('ascii'))
|
||||
conn._send_output()
|
||||
|
||||
# Retrieve previous response
|
||||
response = conn.response_class(conn.sock, method='GET')
|
||||
# there is a bug in python3 regarding the buffering of
|
||||
# ``conn.sock``. Until that bug get's fixed we will
|
||||
# monkey patch the ``reponse`` instance.
|
||||
# https://bugs.python.org/issue23377
|
||||
if six.PY3:
|
||||
response.fp = conn.sock.makefile('rb', 0)
|
||||
response.begin()
|
||||
body = response.read(13)
|
||||
assert response.status == 200
|
||||
assert body == b'Hello, world!'
|
||||
|
||||
# Retrieve final response
|
||||
response = conn.response_class(conn.sock, method='GET')
|
||||
response.begin()
|
||||
body = response.read()
|
||||
assert response.status == 200
|
||||
assert body == b'Hello, world!'
|
||||
|
||||
conn.close()
|
||||
|
||||
|
||||
def test_100_Continue(test_client):
|
||||
"""Test 100-continue header processing."""
|
||||
conn = test_client.get_connection()
|
||||
|
||||
# Try a page without an Expect request header first.
|
||||
# Note that httplib's response.begin automatically ignores
|
||||
# 100 Continue responses, so we must manually check for it.
|
||||
conn.putrequest('POST', '/upload', skip_host=True)
|
||||
conn.putheader('Host', conn.host)
|
||||
conn.putheader('Content-Type', 'text/plain')
|
||||
conn.putheader('Content-Length', '4')
|
||||
conn.endheaders()
|
||||
conn.send(b"d'oh")
|
||||
response = conn.response_class(conn.sock, method='POST')
|
||||
version, status, reason = response._read_status()
|
||||
assert status != 100
|
||||
conn.close()
|
||||
|
||||
# Now try a page with an Expect header...
|
||||
conn.connect()
|
||||
conn.putrequest('POST', '/upload', skip_host=True)
|
||||
conn.putheader('Host', conn.host)
|
||||
conn.putheader('Content-Type', 'text/plain')
|
||||
conn.putheader('Content-Length', '17')
|
||||
conn.putheader('Expect', '100-continue')
|
||||
conn.endheaders()
|
||||
response = conn.response_class(conn.sock, method='POST')
|
||||
|
||||
# ...assert and then skip the 100 response
|
||||
version, status, reason = response._read_status()
|
||||
assert status == 100
|
||||
while True:
|
||||
line = response.fp.readline().strip()
|
||||
if line:
|
||||
pytest.fail(
|
||||
'100 Continue should not output any headers. Got %r' %
|
||||
line)
|
||||
else:
|
||||
break
|
||||
|
||||
# ...send the body
|
||||
body = b'I am a small file'
|
||||
conn.send(body)
|
||||
|
||||
# ...get the final response
|
||||
response.begin()
|
||||
status_line, actual_headers, actual_resp_body = webtest.shb(response)
|
||||
actual_status = int(status_line[:3])
|
||||
assert actual_status == 200
|
||||
expected_resp_body = ("thanks for '%s'" % body).encode()
|
||||
assert actual_resp_body == expected_resp_body
|
||||
conn.close()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'max_request_body_size',
|
||||
(
|
||||
0,
|
||||
1001,
|
||||
)
|
||||
)
|
||||
def test_readall_or_close(test_client, max_request_body_size):
|
||||
"""Test a max_request_body_size of 0 (the default) and 1001."""
|
||||
old_max = test_client.server_instance.max_request_body_size
|
||||
|
||||
test_client.server_instance.max_request_body_size = max_request_body_size
|
||||
|
||||
conn = test_client.get_connection()
|
||||
|
||||
# Get a POST page with an error
|
||||
conn.putrequest('POST', '/err_before_read', skip_host=True)
|
||||
conn.putheader('Host', conn.host)
|
||||
conn.putheader('Content-Type', 'text/plain')
|
||||
conn.putheader('Content-Length', '1000')
|
||||
conn.putheader('Expect', '100-continue')
|
||||
conn.endheaders()
|
||||
response = conn.response_class(conn.sock, method='POST')
|
||||
|
||||
# ...assert and then skip the 100 response
|
||||
version, status, reason = response._read_status()
|
||||
assert status == 100
|
||||
skip = True
|
||||
while skip:
|
||||
skip = response.fp.readline().strip()
|
||||
|
||||
# ...send the body
|
||||
conn.send(b'x' * 1000)
|
||||
|
||||
# ...get the final response
|
||||
response.begin()
|
||||
status_line, actual_headers, actual_resp_body = webtest.shb(response)
|
||||
actual_status = int(status_line[:3])
|
||||
assert actual_status == 500
|
||||
|
||||
# Now try a working page with an Expect header...
|
||||
conn._output(b'POST /upload HTTP/1.1')
|
||||
conn._output(('Host: %s' % conn.host).encode('ascii'))
|
||||
conn._output(b'Content-Type: text/plain')
|
||||
conn._output(b'Content-Length: 17')
|
||||
conn._output(b'Expect: 100-continue')
|
||||
conn._send_output()
|
||||
response = conn.response_class(conn.sock, method='POST')
|
||||
|
||||
# ...assert and then skip the 100 response
|
||||
version, status, reason = response._read_status()
|
||||
assert status == 100
|
||||
skip = True
|
||||
while skip:
|
||||
skip = response.fp.readline().strip()
|
||||
|
||||
# ...send the body
|
||||
body = b'I am a small file'
|
||||
conn.send(body)
|
||||
|
||||
# ...get the final response
|
||||
response.begin()
|
||||
status_line, actual_headers, actual_resp_body = webtest.shb(response)
|
||||
actual_status = int(status_line[:3])
|
||||
assert actual_status == 200
|
||||
expected_resp_body = ("thanks for '%s'" % body).encode()
|
||||
assert actual_resp_body == expected_resp_body
|
||||
conn.close()
|
||||
|
||||
test_client.server_instance.max_request_body_size = old_max
|
||||
|
||||
|
||||
def test_No_Message_Body(test_client):
|
||||
"""Test HTTP queries with an empty response body."""
|
||||
# Initialize a persistent HTTP connection
|
||||
http_connection = test_client.get_connection()
|
||||
http_connection.auto_open = False
|
||||
http_connection.connect()
|
||||
|
||||
# Make the first request and assert there's no "Connection: close".
|
||||
status_line, actual_headers, actual_resp_body = test_client.get(
|
||||
'/pov', http_conn=http_connection
|
||||
)
|
||||
actual_status = int(status_line[:3])
|
||||
assert actual_status == 200
|
||||
assert status_line[4:] == 'OK'
|
||||
assert actual_resp_body == pov.encode()
|
||||
assert not header_exists('Connection', actual_headers)
|
||||
|
||||
# Make a 204 request on the same connection.
|
||||
status_line, actual_headers, actual_resp_body = test_client.get(
|
||||
'/custom/204', http_conn=http_connection
|
||||
)
|
||||
actual_status = int(status_line[:3])
|
||||
assert actual_status == 204
|
||||
assert not header_exists('Content-Length', actual_headers)
|
||||
assert actual_resp_body == b''
|
||||
assert not header_exists('Connection', actual_headers)
|
||||
|
||||
# Make a 304 request on the same connection.
|
||||
status_line, actual_headers, actual_resp_body = test_client.get(
|
||||
'/custom/304', http_conn=http_connection
|
||||
)
|
||||
actual_status = int(status_line[:3])
|
||||
assert actual_status == 304
|
||||
assert not header_exists('Content-Length', actual_headers)
|
||||
assert actual_resp_body == b''
|
||||
assert not header_exists('Connection', actual_headers)
|
||||
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason='Server does not correctly read trailers/ending of the previous '
|
||||
'HTTP request, thus the second request fails as the server tries '
|
||||
r"to parse b'Content-Type: application/json\r\n' as a "
|
||||
'Request-Line. This results in HTTP status code 400, instead of 413'
|
||||
'Ref: https://github.com/cherrypy/cheroot/issues/69'
|
||||
)
|
||||
def test_Chunked_Encoding(test_client):
|
||||
"""Test HTTP uploads with chunked transfer-encoding."""
|
||||
# Initialize a persistent HTTP connection
|
||||
conn = test_client.get_connection()
|
||||
|
||||
# Try a normal chunked request (with extensions)
|
||||
body = (
|
||||
b'8;key=value\r\nxx\r\nxxxx\r\n5\r\nyyyyy\r\n0\r\n'
|
||||
b'Content-Type: application/json\r\n'
|
||||
b'\r\n'
|
||||
)
|
||||
conn.putrequest('POST', '/upload', skip_host=True)
|
||||
conn.putheader('Host', conn.host)
|
||||
conn.putheader('Transfer-Encoding', 'chunked')
|
||||
conn.putheader('Trailer', 'Content-Type')
|
||||
# Note that this is somewhat malformed:
|
||||
# we shouldn't be sending Content-Length.
|
||||
# RFC 2616 says the server should ignore it.
|
||||
conn.putheader('Content-Length', '3')
|
||||
conn.endheaders()
|
||||
conn.send(body)
|
||||
response = conn.getresponse()
|
||||
status_line, actual_headers, actual_resp_body = webtest.shb(response)
|
||||
actual_status = int(status_line[:3])
|
||||
assert actual_status == 200
|
||||
assert status_line[4:] == 'OK'
|
||||
expected_resp_body = ("thanks for '%s'" % b'xx\r\nxxxxyyyyy').encode()
|
||||
assert actual_resp_body == expected_resp_body
|
||||
|
||||
# Try a chunked request that exceeds server.max_request_body_size.
|
||||
# Note that the delimiters and trailer are included.
|
||||
body = b'3e3\r\n' + (b'x' * 995) + b'\r\n0\r\n\r\n'
|
||||
conn.putrequest('POST', '/upload', skip_host=True)
|
||||
conn.putheader('Host', conn.host)
|
||||
conn.putheader('Transfer-Encoding', 'chunked')
|
||||
conn.putheader('Content-Type', 'text/plain')
|
||||
# Chunked requests don't need a content-length
|
||||
# conn.putheader("Content-Length", len(body))
|
||||
conn.endheaders()
|
||||
conn.send(body)
|
||||
response = conn.getresponse()
|
||||
status_line, actual_headers, actual_resp_body = webtest.shb(response)
|
||||
actual_status = int(status_line[:3])
|
||||
assert actual_status == 413
|
||||
conn.close()
|
||||
|
||||
|
||||
def test_Content_Length_in(test_client):
|
||||
"""Try a non-chunked request where Content-Length exceeds limit.
|
||||
|
||||
(server.max_request_body_size).
|
||||
Assert error before body send.
|
||||
"""
|
||||
# Initialize a persistent HTTP connection
|
||||
conn = test_client.get_connection()
|
||||
|
||||
conn.putrequest('POST', '/upload', skip_host=True)
|
||||
conn.putheader('Host', conn.host)
|
||||
conn.putheader('Content-Type', 'text/plain')
|
||||
conn.putheader('Content-Length', '9999')
|
||||
conn.endheaders()
|
||||
response = conn.getresponse()
|
||||
status_line, actual_headers, actual_resp_body = webtest.shb(response)
|
||||
actual_status = int(status_line[:3])
|
||||
assert actual_status == 413
|
||||
expected_resp_body = (
|
||||
b'The entity sent with the request exceeds '
|
||||
b'the maximum allowed bytes.'
|
||||
)
|
||||
assert actual_resp_body == expected_resp_body
|
||||
conn.close()
|
||||
|
||||
|
||||
def test_Content_Length_not_int(test_client):
|
||||
"""Test that malicious Content-Length header returns 400."""
|
||||
status_line, actual_headers, actual_resp_body = test_client.post(
|
||||
'/upload',
|
||||
headers=[
|
||||
('Content-Type', 'text/plain'),
|
||||
('Content-Length', 'not-an-integer'),
|
||||
],
|
||||
)
|
||||
actual_status = int(status_line[:3])
|
||||
|
||||
assert actual_status == 400
|
||||
assert actual_resp_body == b'Malformed Content-Length Header.'
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'uri,expected_resp_status,expected_resp_body',
|
||||
(
|
||||
('/wrong_cl_buffered', 500,
|
||||
(b'The requested resource returned more bytes than the '
|
||||
b'declared Content-Length.')),
|
||||
('/wrong_cl_unbuffered', 200, b'I too'),
|
||||
)
|
||||
)
|
||||
def test_Content_Length_out(
|
||||
test_client,
|
||||
uri, expected_resp_status, expected_resp_body
|
||||
):
|
||||
"""Test response with Content-Length less than the response body.
|
||||
|
||||
(non-chunked response)
|
||||
"""
|
||||
conn = test_client.get_connection()
|
||||
conn.putrequest('GET', uri, skip_host=True)
|
||||
conn.putheader('Host', conn.host)
|
||||
conn.endheaders()
|
||||
|
||||
response = conn.getresponse()
|
||||
status_line, actual_headers, actual_resp_body = webtest.shb(response)
|
||||
actual_status = int(status_line[:3])
|
||||
|
||||
assert actual_status == expected_resp_status
|
||||
assert actual_resp_body == expected_resp_body
|
||||
|
||||
conn.close()
|
||||
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason='Sometimes this test fails due to low timeout. '
|
||||
'Ref: https://github.com/cherrypy/cherrypy/issues/598'
|
||||
)
|
||||
def test_598(test_client):
|
||||
"""Test serving large file with a read timeout in place."""
|
||||
# Initialize a persistent HTTP connection
|
||||
conn = test_client.get_connection()
|
||||
remote_data_conn = urllib.request.urlopen(
|
||||
'%s://%s:%s/one_megabyte_of_a'
|
||||
% ('http', conn.host, conn.port)
|
||||
)
|
||||
buf = remote_data_conn.read(512)
|
||||
time.sleep(timeout * 0.6)
|
||||
remaining = (1024 * 1024) - 512
|
||||
while remaining:
|
||||
data = remote_data_conn.read(remaining)
|
||||
if not data:
|
||||
break
|
||||
buf += data
|
||||
remaining -= len(data)
|
||||
|
||||
assert len(buf) == 1024 * 1024
|
||||
assert buf == b'a' * 1024 * 1024
|
||||
assert remaining == 0
|
||||
remote_data_conn.close()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'invalid_terminator',
|
||||
(
|
||||
b'\n\n',
|
||||
b'\r\n\n',
|
||||
)
|
||||
)
|
||||
def test_No_CRLF(test_client, invalid_terminator):
|
||||
"""Test HTTP queries with no valid CRLF terminators."""
|
||||
# Initialize a persistent HTTP connection
|
||||
conn = test_client.get_connection()
|
||||
|
||||
# (b'%s' % b'') is not supported in Python 3.4, so just use +
|
||||
conn.send(b'GET /hello HTTP/1.1' + invalid_terminator)
|
||||
response = conn.response_class(conn.sock, method='GET')
|
||||
response.begin()
|
||||
actual_resp_body = response.read()
|
||||
expected_resp_body = b'HTTP requires CRLF terminators'
|
||||
assert actual_resp_body == expected_resp_body
|
||||
conn.close()
|
405
libraries/cheroot/test/test_core.py
Normal file
405
libraries/cheroot/test/test_core.py
Normal file
|
@ -0,0 +1,405 @@
|
|||
"""Tests for managing HTTP issues (malformed requests, etc)."""
|
||||
# -*- coding: utf-8 -*-
|
||||
# vim: set fileencoding=utf-8 :
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
__metaclass__ = type
|
||||
|
||||
import errno
|
||||
import socket
|
||||
|
||||
import pytest
|
||||
import six
|
||||
from six.moves import urllib
|
||||
|
||||
from cheroot.test import helper
|
||||
|
||||
|
||||
HTTP_BAD_REQUEST = 400
|
||||
HTTP_LENGTH_REQUIRED = 411
|
||||
HTTP_NOT_FOUND = 404
|
||||
HTTP_OK = 200
|
||||
HTTP_VERSION_NOT_SUPPORTED = 505
|
||||
|
||||
|
||||
class HelloController(helper.Controller):
|
||||
"""Controller for serving WSGI apps."""
|
||||
|
||||
def hello(req, resp):
|
||||
"""Render Hello world."""
|
||||
return 'Hello world!'
|
||||
|
||||
def body_required(req, resp):
|
||||
"""Render Hello world or set 411."""
|
||||
if req.environ.get('Content-Length', None) is None:
|
||||
resp.status = '411 Length Required'
|
||||
return
|
||||
return 'Hello world!'
|
||||
|
||||
def query_string(req, resp):
|
||||
"""Render QUERY_STRING value."""
|
||||
return req.environ.get('QUERY_STRING', '')
|
||||
|
||||
def asterisk(req, resp):
|
||||
"""Render request method value."""
|
||||
method = req.environ.get('REQUEST_METHOD', 'NO METHOD FOUND')
|
||||
tmpl = 'Got asterisk URI path with {method} method'
|
||||
return tmpl.format(**locals())
|
||||
|
||||
def _munge(string):
|
||||
"""Encode PATH_INFO correctly depending on Python version.
|
||||
|
||||
WSGI 1.0 is a mess around unicode. Create endpoints
|
||||
that match the PATH_INFO that it produces.
|
||||
"""
|
||||
if six.PY3:
|
||||
return string.encode('utf-8').decode('latin-1')
|
||||
return string
|
||||
|
||||
handlers = {
|
||||
'/hello': hello,
|
||||
'/no_body': hello,
|
||||
'/body_required': body_required,
|
||||
'/query_string': query_string,
|
||||
_munge('/привіт'): hello,
|
||||
_munge('/Юххууу'): hello,
|
||||
'/\xa0Ðblah key 0 900 4 data': hello,
|
||||
'/*': asterisk,
|
||||
}
|
||||
|
||||
|
||||
def _get_http_response(connection, method='GET'):
|
||||
c = connection
|
||||
kwargs = {'strict': c.strict} if hasattr(c, 'strict') else {}
|
||||
# Python 3.2 removed the 'strict' feature, saying:
|
||||
# "http.client now always assumes HTTP/1.x compliant servers."
|
||||
return c.response_class(c.sock, method=method, **kwargs)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def testing_server(wsgi_server_client):
|
||||
"""Attach a WSGI app to the given server and pre-configure it."""
|
||||
wsgi_server = wsgi_server_client.server_instance
|
||||
wsgi_server.wsgi_app = HelloController()
|
||||
wsgi_server.max_request_body_size = 30000000
|
||||
wsgi_server.server_client = wsgi_server_client
|
||||
return wsgi_server
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_client(testing_server):
|
||||
"""Get and return a test client out of the given server."""
|
||||
return testing_server.server_client
|
||||
|
||||
|
||||
def test_http_connect_request(test_client):
|
||||
"""Check that CONNECT query results in Method Not Allowed status."""
|
||||
status_line = test_client.connect('/anything')[0]
|
||||
actual_status = int(status_line[:3])
|
||||
assert actual_status == 405
|
||||
|
||||
|
||||
def test_normal_request(test_client):
|
||||
"""Check that normal GET query succeeds."""
|
||||
status_line, _, actual_resp_body = test_client.get('/hello')
|
||||
actual_status = int(status_line[:3])
|
||||
assert actual_status == HTTP_OK
|
||||
assert actual_resp_body == b'Hello world!'
|
||||
|
||||
|
||||
def test_query_string_request(test_client):
|
||||
"""Check that GET param is parsed well."""
|
||||
status_line, _, actual_resp_body = test_client.get(
|
||||
'/query_string?test=True'
|
||||
)
|
||||
actual_status = int(status_line[:3])
|
||||
assert actual_status == HTTP_OK
|
||||
assert actual_resp_body == b'test=True'
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'uri',
|
||||
(
|
||||
'/hello', # plain
|
||||
'/query_string?test=True', # query
|
||||
'/{0}?{1}={2}'.format( # quoted unicode
|
||||
*map(urllib.parse.quote, ('Юххууу', 'ї', 'йо'))
|
||||
),
|
||||
)
|
||||
)
|
||||
def test_parse_acceptable_uri(test_client, uri):
|
||||
"""Check that server responds with OK to valid GET queries."""
|
||||
status_line = test_client.get(uri)[0]
|
||||
actual_status = int(status_line[:3])
|
||||
assert actual_status == HTTP_OK
|
||||
|
||||
|
||||
@pytest.mark.xfail(six.PY2, reason='Fails on Python 2')
|
||||
def test_parse_uri_unsafe_uri(test_client):
|
||||
"""Test that malicious URI does not allow HTTP injection.
|
||||
|
||||
This effectively checks that sending GET request with URL
|
||||
|
||||
/%A0%D0blah%20key%200%20900%204%20data
|
||||
|
||||
is not converted into
|
||||
|
||||
GET /
|
||||
blah key 0 900 4 data
|
||||
HTTP/1.1
|
||||
|
||||
which would be a security issue otherwise.
|
||||
"""
|
||||
c = test_client.get_connection()
|
||||
resource = '/\xa0Ðblah key 0 900 4 data'.encode('latin-1')
|
||||
quoted = urllib.parse.quote(resource)
|
||||
assert quoted == '/%A0%D0blah%20key%200%20900%204%20data'
|
||||
request = 'GET {quoted} HTTP/1.1'.format(**locals())
|
||||
c._output(request.encode('utf-8'))
|
||||
c._send_output()
|
||||
response = _get_http_response(c, method='GET')
|
||||
response.begin()
|
||||
assert response.status == HTTP_OK
|
||||
assert response.fp.read(12) == b'Hello world!'
|
||||
c.close()
|
||||
|
||||
|
||||
def test_parse_uri_invalid_uri(test_client):
|
||||
"""Check that server responds with Bad Request to invalid GET queries.
|
||||
|
||||
Invalid request line test case: it should only contain US-ASCII.
|
||||
"""
|
||||
c = test_client.get_connection()
|
||||
c._output(u'GET /йопта! HTTP/1.1'.encode('utf-8'))
|
||||
c._send_output()
|
||||
response = _get_http_response(c, method='GET')
|
||||
response.begin()
|
||||
assert response.status == HTTP_BAD_REQUEST
|
||||
assert response.fp.read(21) == b'Malformed Request-URI'
|
||||
c.close()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'uri',
|
||||
(
|
||||
'hello', # ascii
|
||||
'привіт', # non-ascii
|
||||
)
|
||||
)
|
||||
def test_parse_no_leading_slash_invalid(test_client, uri):
|
||||
"""Check that server responds with Bad Request to invalid GET queries.
|
||||
|
||||
Invalid request line test case: it should have leading slash (be absolute).
|
||||
"""
|
||||
status_line, _, actual_resp_body = test_client.get(
|
||||
urllib.parse.quote(uri)
|
||||
)
|
||||
actual_status = int(status_line[:3])
|
||||
assert actual_status == HTTP_BAD_REQUEST
|
||||
assert b'starting with a slash' in actual_resp_body
|
||||
|
||||
|
||||
def test_parse_uri_absolute_uri(test_client):
|
||||
"""Check that server responds with Bad Request to Absolute URI.
|
||||
|
||||
Only proxy servers should allow this.
|
||||
"""
|
||||
status_line, _, actual_resp_body = test_client.get('http://google.com/')
|
||||
actual_status = int(status_line[:3])
|
||||
assert actual_status == HTTP_BAD_REQUEST
|
||||
expected_body = b'Absolute URI not allowed if server is not a proxy.'
|
||||
assert actual_resp_body == expected_body
|
||||
|
||||
|
||||
def test_parse_uri_asterisk_uri(test_client):
|
||||
"""Check that server responds with OK to OPTIONS with "*" Absolute URI."""
|
||||
status_line, _, actual_resp_body = test_client.options('*')
|
||||
actual_status = int(status_line[:3])
|
||||
assert actual_status == HTTP_OK
|
||||
expected_body = b'Got asterisk URI path with OPTIONS method'
|
||||
assert actual_resp_body == expected_body
|
||||
|
||||
|
||||
def test_parse_uri_fragment_uri(test_client):
|
||||
"""Check that server responds with Bad Request to URI with fragment."""
|
||||
status_line, _, actual_resp_body = test_client.get(
|
||||
'/hello?test=something#fake',
|
||||
)
|
||||
actual_status = int(status_line[:3])
|
||||
assert actual_status == HTTP_BAD_REQUEST
|
||||
expected_body = b'Illegal #fragment in Request-URI.'
|
||||
assert actual_resp_body == expected_body
|
||||
|
||||
|
||||
def test_no_content_length(test_client):
|
||||
"""Test POST query with an empty body being successful."""
|
||||
# "The presence of a message-body in a request is signaled by the
|
||||
# inclusion of a Content-Length or Transfer-Encoding header field in
|
||||
# the request's message-headers."
|
||||
#
|
||||
# Send a message with neither header and no body.
|
||||
c = test_client.get_connection()
|
||||
c.request('POST', '/no_body')
|
||||
response = c.getresponse()
|
||||
actual_resp_body = response.fp.read()
|
||||
actual_status = response.status
|
||||
assert actual_status == HTTP_OK
|
||||
assert actual_resp_body == b'Hello world!'
|
||||
|
||||
|
||||
def test_content_length_required(test_client):
|
||||
"""Test POST query with body failing because of missing Content-Length."""
|
||||
# Now send a message that has no Content-Length, but does send a body.
|
||||
# Verify that CP times out the socket and responds
|
||||
# with 411 Length Required.
|
||||
|
||||
c = test_client.get_connection()
|
||||
c.request('POST', '/body_required')
|
||||
response = c.getresponse()
|
||||
response.fp.read()
|
||||
|
||||
actual_status = response.status
|
||||
assert actual_status == HTTP_LENGTH_REQUIRED
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'request_line,status_code,expected_body',
|
||||
(
|
||||
(b'GET /', # missing proto
|
||||
HTTP_BAD_REQUEST, b'Malformed Request-Line'),
|
||||
(b'GET / HTTPS/1.1', # invalid proto
|
||||
HTTP_BAD_REQUEST, b'Malformed Request-Line: bad protocol'),
|
||||
(b'GET / HTTP/2.15', # invalid ver
|
||||
HTTP_VERSION_NOT_SUPPORTED, b'Cannot fulfill request'),
|
||||
)
|
||||
)
|
||||
def test_malformed_request_line(
|
||||
test_client, request_line,
|
||||
status_code, expected_body
|
||||
):
|
||||
"""Test missing or invalid HTTP version in Request-Line."""
|
||||
c = test_client.get_connection()
|
||||
c._output(request_line)
|
||||
c._send_output()
|
||||
response = _get_http_response(c, method='GET')
|
||||
response.begin()
|
||||
assert response.status == status_code
|
||||
assert response.fp.read(len(expected_body)) == expected_body
|
||||
c.close()
|
||||
|
||||
|
||||
def test_malformed_http_method(test_client):
|
||||
"""Test non-uppercase HTTP method."""
|
||||
c = test_client.get_connection()
|
||||
c.putrequest('GeT', '/malformed_method_case')
|
||||
c.putheader('Content-Type', 'text/plain')
|
||||
c.endheaders()
|
||||
|
||||
response = c.getresponse()
|
||||
actual_status = response.status
|
||||
assert actual_status == HTTP_BAD_REQUEST
|
||||
actual_resp_body = response.fp.read(21)
|
||||
assert actual_resp_body == b'Malformed method name'
|
||||
|
||||
|
||||
def test_malformed_header(test_client):
|
||||
"""Check that broken HTTP header results in Bad Request."""
|
||||
c = test_client.get_connection()
|
||||
c.putrequest('GET', '/')
|
||||
c.putheader('Content-Type', 'text/plain')
|
||||
# See https://www.bitbucket.org/cherrypy/cherrypy/issue/941
|
||||
c._output(b'Re, 1.2.3.4#015#012')
|
||||
c.endheaders()
|
||||
|
||||
response = c.getresponse()
|
||||
actual_status = response.status
|
||||
assert actual_status == HTTP_BAD_REQUEST
|
||||
actual_resp_body = response.fp.read(20)
|
||||
assert actual_resp_body == b'Illegal header line.'
|
||||
|
||||
|
||||
def test_request_line_split_issue_1220(test_client):
|
||||
"""Check that HTTP request line of exactly 256 chars length is OK."""
|
||||
Request_URI = (
|
||||
'/hello?'
|
||||
'intervenant-entreprise-evenement_classaction='
|
||||
'evenement-mailremerciements'
|
||||
'&_path=intervenant-entreprise-evenement'
|
||||
'&intervenant-entreprise-evenement_action-id=19404'
|
||||
'&intervenant-entreprise-evenement_id=19404'
|
||||
'&intervenant-entreprise_id=28092'
|
||||
)
|
||||
assert len('GET %s HTTP/1.1\r\n' % Request_URI) == 256
|
||||
|
||||
actual_resp_body = test_client.get(Request_URI)[2]
|
||||
assert actual_resp_body == b'Hello world!'
|
||||
|
||||
|
||||
def test_garbage_in(test_client):
|
||||
"""Test that server sends an error for garbage received over TCP."""
|
||||
# Connect without SSL regardless of server.scheme
|
||||
|
||||
c = test_client.get_connection()
|
||||
c._output(b'gjkgjklsgjklsgjkljklsg')
|
||||
c._send_output()
|
||||
response = c.response_class(c.sock, method='GET')
|
||||
try:
|
||||
response.begin()
|
||||
actual_status = response.status
|
||||
assert actual_status == HTTP_BAD_REQUEST
|
||||
actual_resp_body = response.fp.read(22)
|
||||
assert actual_resp_body == b'Malformed Request-Line'
|
||||
c.close()
|
||||
except socket.error as ex:
|
||||
# "Connection reset by peer" is also acceptable.
|
||||
if ex.errno != errno.ECONNRESET:
|
||||
raise
|
||||
|
||||
|
||||
class CloseController:
|
||||
"""Controller for testing the close callback."""
|
||||
|
||||
def __call__(self, environ, start_response):
|
||||
"""Get the req to know header sent status."""
|
||||
self.req = start_response.__self__.req
|
||||
resp = CloseResponse(self.close)
|
||||
start_response(resp.status, resp.headers.items())
|
||||
return resp
|
||||
|
||||
def close(self):
|
||||
"""Close, writing hello."""
|
||||
self.req.write(b'hello')
|
||||
|
||||
|
||||
class CloseResponse:
|
||||
"""Dummy empty response to trigger the no body status."""
|
||||
|
||||
def __init__(self, close):
|
||||
"""Use some defaults to ensure we have a header."""
|
||||
self.status = '200 OK'
|
||||
self.headers = {'Content-Type': 'text/html'}
|
||||
self.close = close
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""Ensure we don't have a body."""
|
||||
raise IndexError()
|
||||
|
||||
def output(self):
|
||||
"""Return self to hook the close method."""
|
||||
return self
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def testing_server_close(wsgi_server_client):
|
||||
"""Attach a WSGI app to the given server and pre-configure it."""
|
||||
wsgi_server = wsgi_server_client.server_instance
|
||||
wsgi_server.wsgi_app = CloseController()
|
||||
wsgi_server.max_request_body_size = 30000000
|
||||
wsgi_server.server_client = wsgi_server_client
|
||||
return wsgi_server
|
||||
|
||||
|
||||
def test_send_header_before_closing(testing_server_close):
|
||||
"""Test we are actually sending the headers before calling 'close'."""
|
||||
_, _, resp_body = testing_server_close.server_client.get('/')
|
||||
assert resp_body == b'hello'
|
193
libraries/cheroot/test/test_server.py
Normal file
193
libraries/cheroot/test/test_server.py
Normal file
|
@ -0,0 +1,193 @@
|
|||
"""Tests for the HTTP server."""
|
||||
# -*- coding: utf-8 -*-
|
||||
# vim: set fileencoding=utf-8 :
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
__metaclass__ = type
|
||||
|
||||
import os
|
||||
import socket
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from .._compat import bton
|
||||
from ..server import Gateway, HTTPServer
|
||||
from ..testing import (
|
||||
ANY_INTERFACE_IPV4,
|
||||
ANY_INTERFACE_IPV6,
|
||||
EPHEMERAL_PORT,
|
||||
get_server_client,
|
||||
)
|
||||
|
||||
|
||||
def make_http_server(bind_addr):
|
||||
"""Create and start an HTTP server bound to bind_addr."""
|
||||
httpserver = HTTPServer(
|
||||
bind_addr=bind_addr,
|
||||
gateway=Gateway,
|
||||
)
|
||||
|
||||
threading.Thread(target=httpserver.safe_start).start()
|
||||
|
||||
while not httpserver.ready:
|
||||
time.sleep(0.1)
|
||||
|
||||
return httpserver
|
||||
|
||||
|
||||
non_windows_sock_test = pytest.mark.skipif(
|
||||
not hasattr(socket, 'AF_UNIX'),
|
||||
reason='UNIX domain sockets are only available under UNIX-based OS',
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def http_server():
|
||||
"""Provision a server creator as a fixture."""
|
||||
def start_srv():
|
||||
bind_addr = yield
|
||||
httpserver = make_http_server(bind_addr)
|
||||
yield httpserver
|
||||
yield httpserver
|
||||
|
||||
srv_creator = iter(start_srv())
|
||||
next(srv_creator)
|
||||
yield srv_creator
|
||||
try:
|
||||
while True:
|
||||
httpserver = next(srv_creator)
|
||||
if httpserver is not None:
|
||||
httpserver.stop()
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def unix_sock_file():
|
||||
"""Check that bound UNIX socket address is stored in server."""
|
||||
tmp_sock_fh, tmp_sock_fname = tempfile.mkstemp()
|
||||
|
||||
yield tmp_sock_fname
|
||||
|
||||
os.close(tmp_sock_fh)
|
||||
os.unlink(tmp_sock_fname)
|
||||
|
||||
|
||||
def test_prepare_makes_server_ready():
|
||||
"""Check that prepare() makes the server ready, and stop() clears it."""
|
||||
httpserver = HTTPServer(
|
||||
bind_addr=(ANY_INTERFACE_IPV4, EPHEMERAL_PORT),
|
||||
gateway=Gateway,
|
||||
)
|
||||
|
||||
assert not httpserver.ready
|
||||
assert not httpserver.requests._threads
|
||||
|
||||
httpserver.prepare()
|
||||
|
||||
assert httpserver.ready
|
||||
assert httpserver.requests._threads
|
||||
for thr in httpserver.requests._threads:
|
||||
assert thr.ready
|
||||
|
||||
httpserver.stop()
|
||||
|
||||
assert not httpserver.requests._threads
|
||||
assert not httpserver.ready
|
||||
|
||||
|
||||
def test_stop_interrupts_serve():
|
||||
"""Check that stop() interrupts running of serve()."""
|
||||
httpserver = HTTPServer(
|
||||
bind_addr=(ANY_INTERFACE_IPV4, EPHEMERAL_PORT),
|
||||
gateway=Gateway,
|
||||
)
|
||||
|
||||
httpserver.prepare()
|
||||
serve_thread = threading.Thread(target=httpserver.serve)
|
||||
serve_thread.start()
|
||||
|
||||
serve_thread.join(0.5)
|
||||
assert serve_thread.is_alive()
|
||||
|
||||
httpserver.stop()
|
||||
|
||||
serve_thread.join(0.5)
|
||||
assert not serve_thread.is_alive()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'ip_addr',
|
||||
(
|
||||
ANY_INTERFACE_IPV4,
|
||||
ANY_INTERFACE_IPV6,
|
||||
)
|
||||
)
|
||||
def test_bind_addr_inet(http_server, ip_addr):
|
||||
"""Check that bound IP address is stored in server."""
|
||||
httpserver = http_server.send((ip_addr, EPHEMERAL_PORT))
|
||||
|
||||
assert httpserver.bind_addr[0] == ip_addr
|
||||
assert httpserver.bind_addr[1] != EPHEMERAL_PORT
|
||||
|
||||
|
||||
@non_windows_sock_test
|
||||
def test_bind_addr_unix(http_server, unix_sock_file):
|
||||
"""Check that bound UNIX socket address is stored in server."""
|
||||
httpserver = http_server.send(unix_sock_file)
|
||||
|
||||
assert httpserver.bind_addr == unix_sock_file
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Abstract sockets don't work currently")
|
||||
@non_windows_sock_test
|
||||
def test_bind_addr_unix_abstract(http_server):
|
||||
"""Check that bound UNIX socket address is stored in server."""
|
||||
unix_abstract_sock = b'\x00cheroot/test/socket/here.sock'
|
||||
httpserver = http_server.send(unix_abstract_sock)
|
||||
|
||||
assert httpserver.bind_addr == unix_abstract_sock
|
||||
|
||||
|
||||
PEERCRED_IDS_URI = '/peer_creds/ids'
|
||||
PEERCRED_TEXTS_URI = '/peer_creds/texts'
|
||||
|
||||
|
||||
class _TestGateway(Gateway):
|
||||
def respond(self):
|
||||
req = self.req
|
||||
conn = req.conn
|
||||
req_uri = bton(req.uri)
|
||||
if req_uri == PEERCRED_IDS_URI:
|
||||
peer_creds = conn.peer_pid, conn.peer_uid, conn.peer_gid
|
||||
return ['|'.join(map(str, peer_creds))]
|
||||
elif req_uri == PEERCRED_TEXTS_URI:
|
||||
return ['!'.join((conn.peer_user, conn.peer_group))]
|
||||
return super(_TestGateway, self).respond()
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason='Test HTTP client is not able to work through UNIX socket currently'
|
||||
)
|
||||
@non_windows_sock_test
|
||||
def test_peercreds_unix_sock(http_server, unix_sock_file):
|
||||
"""Check that peercred lookup and resolution work when enabled."""
|
||||
httpserver = http_server.send(unix_sock_file)
|
||||
httpserver.gateway = _TestGateway
|
||||
httpserver.peercreds_enabled = True
|
||||
|
||||
testclient = get_server_client(httpserver)
|
||||
|
||||
expected_peercreds = os.getpid(), os.getuid(), os.getgid()
|
||||
expected_peercreds = '|'.join(map(str, expected_peercreds))
|
||||
assert testclient.get(PEERCRED_IDS_URI) == expected_peercreds
|
||||
assert 'RuntimeError' in testclient.get(PEERCRED_TEXTS_URI)
|
||||
|
||||
httpserver.peercreds_resolve_enabled = True
|
||||
import grp
|
||||
expected_textcreds = os.getlogin(), grp.getgrgid(os.getgid()).gr_name
|
||||
expected_textcreds = '!'.join(map(str, expected_textcreds))
|
||||
assert testclient.get(PEERCRED_TEXTS_URI) == expected_textcreds
|
581
libraries/cheroot/test/webtest.py
Normal file
581
libraries/cheroot/test/webtest.py
Normal file
|
@ -0,0 +1,581 @@
|
|||
"""Extensions to unittest for web frameworks.
|
||||
|
||||
Use the WebCase.getPage method to request a page from your HTTP server.
|
||||
Framework Integration
|
||||
=====================
|
||||
If you have control over your server process, you can handle errors
|
||||
in the server-side of the HTTP conversation a bit better. You must run
|
||||
both the client (your WebCase tests) and the server in the same process
|
||||
(but in separate threads, obviously).
|
||||
When an error occurs in the framework, call server_error. It will print
|
||||
the traceback to stdout, and keep any assertions you have from running
|
||||
(the assumption is that, if the server errors, the page output will not
|
||||
be of further significance to your tests).
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
__metaclass__ = type
|
||||
|
||||
import pprint
|
||||
import re
|
||||
import socket
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
import os
|
||||
import json
|
||||
import unittest
|
||||
import warnings
|
||||
|
||||
from six.moves import range, http_client, map, urllib_parse
|
||||
import six
|
||||
|
||||
from more_itertools.more import always_iterable
|
||||
|
||||
|
||||
def interface(host):
|
||||
"""Return an IP address for a client connection given the server host.
|
||||
|
||||
If the server is listening on '0.0.0.0' (INADDR_ANY)
|
||||
or '::' (IN6ADDR_ANY), this will return the proper localhost.
|
||||
"""
|
||||
if host == '0.0.0.0':
|
||||
# INADDR_ANY, which should respond on localhost.
|
||||
return '127.0.0.1'
|
||||
if host == '::':
|
||||
# IN6ADDR_ANY, which should respond on localhost.
|
||||
return '::1'
|
||||
return host
|
||||
|
||||
|
||||
try:
|
||||
# Jython support
|
||||
if sys.platform[:4] == 'java':
|
||||
def getchar():
|
||||
"""Get a key press."""
|
||||
# Hopefully this is enough
|
||||
return sys.stdin.read(1)
|
||||
else:
|
||||
# On Windows, msvcrt.getch reads a single char without output.
|
||||
import msvcrt
|
||||
|
||||
def getchar():
|
||||
"""Get a key press."""
|
||||
return msvcrt.getch()
|
||||
except ImportError:
|
||||
# Unix getchr
|
||||
import tty
|
||||
import termios
|
||||
|
||||
def getchar():
|
||||
"""Get a key press."""
|
||||
fd = sys.stdin.fileno()
|
||||
old_settings = termios.tcgetattr(fd)
|
||||
try:
|
||||
tty.setraw(sys.stdin.fileno())
|
||||
ch = sys.stdin.read(1)
|
||||
finally:
|
||||
termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
|
||||
return ch
|
||||
|
||||
|
||||
# from jaraco.properties
|
||||
class NonDataProperty:
|
||||
"""Non-data property decorator."""
|
||||
|
||||
def __init__(self, fget):
|
||||
"""Initialize a non-data property."""
|
||||
assert fget is not None, 'fget cannot be none'
|
||||
assert callable(fget), 'fget must be callable'
|
||||
self.fget = fget
|
||||
|
||||
def __get__(self, obj, objtype=None):
|
||||
"""Return a class property."""
|
||||
if obj is None:
|
||||
return self
|
||||
return self.fget(obj)
|
||||
|
||||
|
||||
class WebCase(unittest.TestCase):
|
||||
"""Helper web test suite base."""
|
||||
|
||||
HOST = '127.0.0.1'
|
||||
PORT = 8000
|
||||
HTTP_CONN = http_client.HTTPConnection
|
||||
PROTOCOL = 'HTTP/1.1'
|
||||
|
||||
scheme = 'http'
|
||||
url = None
|
||||
|
||||
status = None
|
||||
headers = None
|
||||
body = None
|
||||
|
||||
encoding = 'utf-8'
|
||||
|
||||
time = None
|
||||
|
||||
@property
|
||||
def _Conn(self):
|
||||
"""Return HTTPConnection or HTTPSConnection based on self.scheme.
|
||||
|
||||
* from http.client.
|
||||
"""
|
||||
cls_name = '{scheme}Connection'.format(scheme=self.scheme.upper())
|
||||
return getattr(http_client, cls_name)
|
||||
|
||||
def get_conn(self, auto_open=False):
|
||||
"""Return a connection to our HTTP server."""
|
||||
conn = self._Conn(self.interface(), self.PORT)
|
||||
# Automatically re-connect?
|
||||
conn.auto_open = auto_open
|
||||
conn.connect()
|
||||
return conn
|
||||
|
||||
def set_persistent(self, on=True, auto_open=False):
|
||||
"""Make our HTTP_CONN persistent (or not).
|
||||
|
||||
If the 'on' argument is True (the default), then self.HTTP_CONN
|
||||
will be set to an instance of HTTP(S)?Connection
|
||||
to persist across requests.
|
||||
As this class only allows for a single open connection, if
|
||||
self already has an open connection, it will be closed.
|
||||
"""
|
||||
try:
|
||||
self.HTTP_CONN.close()
|
||||
except (TypeError, AttributeError):
|
||||
pass
|
||||
|
||||
self.HTTP_CONN = (
|
||||
self.get_conn(auto_open=auto_open)
|
||||
if on
|
||||
else self._Conn
|
||||
)
|
||||
|
||||
@property
|
||||
def persistent(self): # noqa: D401; irrelevant for properties
|
||||
"""Presense of the persistent HTTP connection."""
|
||||
return hasattr(self.HTTP_CONN, '__class__')
|
||||
|
||||
@persistent.setter
|
||||
def persistent(self, on):
|
||||
self.set_persistent(on)
|
||||
|
||||
def interface(self):
|
||||
"""Return an IP address for a client connection.
|
||||
|
||||
If the server is listening on '0.0.0.0' (INADDR_ANY)
|
||||
or '::' (IN6ADDR_ANY), this will return the proper localhost.
|
||||
"""
|
||||
return interface(self.HOST)
|
||||
|
||||
def getPage(self, url, headers=None, method='GET', body=None,
|
||||
protocol=None, raise_subcls=None):
|
||||
"""Open the url with debugging support. Return status, headers, body.
|
||||
|
||||
url should be the identifier passed to the server, typically a
|
||||
server-absolute path and query string (sent between method and
|
||||
protocol), and should only be an absolute URI if proxy support is
|
||||
enabled in the server.
|
||||
|
||||
If the application under test generates absolute URIs, be sure
|
||||
to wrap them first with strip_netloc::
|
||||
|
||||
class MyAppWebCase(WebCase):
|
||||
def getPage(url, *args, **kwargs):
|
||||
super(MyAppWebCase, self).getPage(
|
||||
cheroot.test.webtest.strip_netloc(url),
|
||||
*args, **kwargs
|
||||
)
|
||||
|
||||
`raise_subcls` must be a tuple with the exceptions classes
|
||||
or a single exception class that are not going to be considered
|
||||
a socket.error regardless that they were are subclass of a
|
||||
socket.error and therefore not considered for a connection retry.
|
||||
"""
|
||||
ServerError.on = False
|
||||
|
||||
if isinstance(url, six.text_type):
|
||||
url = url.encode('utf-8')
|
||||
if isinstance(body, six.text_type):
|
||||
body = body.encode('utf-8')
|
||||
|
||||
self.url = url
|
||||
self.time = None
|
||||
start = time.time()
|
||||
result = openURL(url, headers, method, body, self.HOST, self.PORT,
|
||||
self.HTTP_CONN, protocol or self.PROTOCOL,
|
||||
raise_subcls)
|
||||
self.time = time.time() - start
|
||||
self.status, self.headers, self.body = result
|
||||
|
||||
# Build a list of request cookies from the previous response cookies.
|
||||
self.cookies = [('Cookie', v) for k, v in self.headers
|
||||
if k.lower() == 'set-cookie']
|
||||
|
||||
if ServerError.on:
|
||||
raise ServerError()
|
||||
return result
|
||||
|
||||
@NonDataProperty
|
||||
def interactive(self):
|
||||
"""Determine whether tests are run in interactive mode.
|
||||
|
||||
Load interactivity setting from environment, where
|
||||
the value can be numeric or a string like true or
|
||||
False or 1 or 0.
|
||||
"""
|
||||
env_str = os.environ.get('WEBTEST_INTERACTIVE', 'True')
|
||||
is_interactive = bool(json.loads(env_str.lower()))
|
||||
if is_interactive:
|
||||
warnings.warn(
|
||||
'Interactive test failure interceptor support via '
|
||||
'WEBTEST_INTERACTIVE environment variable is deprecated.',
|
||||
DeprecationWarning
|
||||
)
|
||||
return is_interactive
|
||||
|
||||
console_height = 30
|
||||
|
||||
def _handlewebError(self, msg):
|
||||
print('')
|
||||
print(' ERROR: %s' % msg)
|
||||
|
||||
if not self.interactive:
|
||||
raise self.failureException(msg)
|
||||
|
||||
p = (' Show: '
|
||||
'[B]ody [H]eaders [S]tatus [U]RL; '
|
||||
'[I]gnore, [R]aise, or sys.e[X]it >> ')
|
||||
sys.stdout.write(p)
|
||||
sys.stdout.flush()
|
||||
while True:
|
||||
i = getchar().upper()
|
||||
if not isinstance(i, type('')):
|
||||
i = i.decode('ascii')
|
||||
if i not in 'BHSUIRX':
|
||||
continue
|
||||
print(i.upper()) # Also prints new line
|
||||
if i == 'B':
|
||||
for x, line in enumerate(self.body.splitlines()):
|
||||
if (x + 1) % self.console_height == 0:
|
||||
# The \r and comma should make the next line overwrite
|
||||
sys.stdout.write('<-- More -->\r')
|
||||
m = getchar().lower()
|
||||
# Erase our "More" prompt
|
||||
sys.stdout.write(' \r')
|
||||
if m == 'q':
|
||||
break
|
||||
print(line)
|
||||
elif i == 'H':
|
||||
pprint.pprint(self.headers)
|
||||
elif i == 'S':
|
||||
print(self.status)
|
||||
elif i == 'U':
|
||||
print(self.url)
|
||||
elif i == 'I':
|
||||
# return without raising the normal exception
|
||||
return
|
||||
elif i == 'R':
|
||||
raise self.failureException(msg)
|
||||
elif i == 'X':
|
||||
sys.exit()
|
||||
sys.stdout.write(p)
|
||||
sys.stdout.flush()
|
||||
|
||||
@property
|
||||
def status_code(self): # noqa: D401; irrelevant for properties
|
||||
"""Integer HTTP status code."""
|
||||
return int(self.status[:3])
|
||||
|
||||
def status_matches(self, expected):
|
||||
"""Check whether actual status matches expected."""
|
||||
actual = (
|
||||
self.status_code
|
||||
if isinstance(expected, int) else
|
||||
self.status
|
||||
)
|
||||
return expected == actual
|
||||
|
||||
def assertStatus(self, status, msg=None):
|
||||
"""Fail if self.status != status.
|
||||
|
||||
status may be integer code, exact string status, or
|
||||
iterable of allowed possibilities.
|
||||
"""
|
||||
if any(map(self.status_matches, always_iterable(status))):
|
||||
return
|
||||
|
||||
tmpl = 'Status {self.status} does not match {status}'
|
||||
msg = msg or tmpl.format(**locals())
|
||||
self._handlewebError(msg)
|
||||
|
||||
def assertHeader(self, key, value=None, msg=None):
|
||||
"""Fail if (key, [value]) not in self.headers."""
|
||||
lowkey = key.lower()
|
||||
for k, v in self.headers:
|
||||
if k.lower() == lowkey:
|
||||
if value is None or str(value) == v:
|
||||
return v
|
||||
|
||||
if msg is None:
|
||||
if value is None:
|
||||
msg = '%r not in headers' % key
|
||||
else:
|
||||
msg = '%r:%r not in headers' % (key, value)
|
||||
self._handlewebError(msg)
|
||||
|
||||
def assertHeaderIn(self, key, values, msg=None):
|
||||
"""Fail if header indicated by key doesn't have one of the values."""
|
||||
lowkey = key.lower()
|
||||
for k, v in self.headers:
|
||||
if k.lower() == lowkey:
|
||||
matches = [value for value in values if str(value) == v]
|
||||
if matches:
|
||||
return matches
|
||||
|
||||
if msg is None:
|
||||
msg = '%(key)r not in %(values)r' % vars()
|
||||
self._handlewebError(msg)
|
||||
|
||||
def assertHeaderItemValue(self, key, value, msg=None):
|
||||
"""Fail if the header does not contain the specified value."""
|
||||
actual_value = self.assertHeader(key, msg=msg)
|
||||
header_values = map(str.strip, actual_value.split(','))
|
||||
if value in header_values:
|
||||
return value
|
||||
|
||||
if msg is None:
|
||||
msg = '%r not in %r' % (value, header_values)
|
||||
self._handlewebError(msg)
|
||||
|
||||
def assertNoHeader(self, key, msg=None):
|
||||
"""Fail if key in self.headers."""
|
||||
lowkey = key.lower()
|
||||
matches = [k for k, v in self.headers if k.lower() == lowkey]
|
||||
if matches:
|
||||
if msg is None:
|
||||
msg = '%r in headers' % key
|
||||
self._handlewebError(msg)
|
||||
|
||||
def assertNoHeaderItemValue(self, key, value, msg=None):
|
||||
"""Fail if the header contains the specified value."""
|
||||
lowkey = key.lower()
|
||||
hdrs = self.headers
|
||||
matches = [k for k, v in hdrs if k.lower() == lowkey and v == value]
|
||||
if matches:
|
||||
if msg is None:
|
||||
msg = '%r:%r in %r' % (key, value, hdrs)
|
||||
self._handlewebError(msg)
|
||||
|
||||
def assertBody(self, value, msg=None):
|
||||
"""Fail if value != self.body."""
|
||||
if isinstance(value, six.text_type):
|
||||
value = value.encode(self.encoding)
|
||||
if value != self.body:
|
||||
if msg is None:
|
||||
msg = 'expected body:\n%r\n\nactual body:\n%r' % (
|
||||
value, self.body)
|
||||
self._handlewebError(msg)
|
||||
|
||||
def assertInBody(self, value, msg=None):
|
||||
"""Fail if value not in self.body."""
|
||||
if isinstance(value, six.text_type):
|
||||
value = value.encode(self.encoding)
|
||||
if value not in self.body:
|
||||
if msg is None:
|
||||
msg = '%r not in body: %s' % (value, self.body)
|
||||
self._handlewebError(msg)
|
||||
|
||||
def assertNotInBody(self, value, msg=None):
|
||||
"""Fail if value in self.body."""
|
||||
if isinstance(value, six.text_type):
|
||||
value = value.encode(self.encoding)
|
||||
if value in self.body:
|
||||
if msg is None:
|
||||
msg = '%r found in body' % value
|
||||
self._handlewebError(msg)
|
||||
|
||||
def assertMatchesBody(self, pattern, msg=None, flags=0):
|
||||
"""Fail if value (a regex pattern) is not in self.body."""
|
||||
if isinstance(pattern, six.text_type):
|
||||
pattern = pattern.encode(self.encoding)
|
||||
if re.search(pattern, self.body, flags) is None:
|
||||
if msg is None:
|
||||
msg = 'No match for %r in body' % pattern
|
||||
self._handlewebError(msg)
|
||||
|
||||
|
||||
methods_with_bodies = ('POST', 'PUT', 'PATCH')
|
||||
|
||||
|
||||
def cleanHeaders(headers, method, body, host, port):
|
||||
"""Return request headers, with required headers added (if missing)."""
|
||||
if headers is None:
|
||||
headers = []
|
||||
|
||||
# Add the required Host request header if not present.
|
||||
# [This specifies the host:port of the server, not the client.]
|
||||
found = False
|
||||
for k, v in headers:
|
||||
if k.lower() == 'host':
|
||||
found = True
|
||||
break
|
||||
if not found:
|
||||
if port == 80:
|
||||
headers.append(('Host', host))
|
||||
else:
|
||||
headers.append(('Host', '%s:%s' % (host, port)))
|
||||
|
||||
if method in methods_with_bodies:
|
||||
# Stick in default type and length headers if not present
|
||||
found = False
|
||||
for k, v in headers:
|
||||
if k.lower() == 'content-type':
|
||||
found = True
|
||||
break
|
||||
if not found:
|
||||
headers.append(
|
||||
('Content-Type', 'application/x-www-form-urlencoded'))
|
||||
headers.append(('Content-Length', str(len(body or ''))))
|
||||
|
||||
return headers
|
||||
|
||||
|
||||
def shb(response):
|
||||
"""Return status, headers, body the way we like from a response."""
|
||||
if six.PY3:
|
||||
h = response.getheaders()
|
||||
else:
|
||||
h = []
|
||||
key, value = None, None
|
||||
for line in response.msg.headers:
|
||||
if line:
|
||||
if line[0] in ' \t':
|
||||
value += line.strip()
|
||||
else:
|
||||
if key and value:
|
||||
h.append((key, value))
|
||||
key, value = line.split(':', 1)
|
||||
key = key.strip()
|
||||
value = value.strip()
|
||||
if key and value:
|
||||
h.append((key, value))
|
||||
|
||||
return '%s %s' % (response.status, response.reason), h, response.read()
|
||||
|
||||
|
||||
def openURL(url, headers=None, method='GET', body=None,
|
||||
host='127.0.0.1', port=8000, http_conn=http_client.HTTPConnection,
|
||||
protocol='HTTP/1.1', raise_subcls=None):
|
||||
"""
|
||||
Open the given HTTP resource and return status, headers, and body.
|
||||
|
||||
`raise_subcls` must be a tuple with the exceptions classes
|
||||
or a single exception class that are not going to be considered
|
||||
a socket.error regardless that they were are subclass of a
|
||||
socket.error and therefore not considered for a connection retry.
|
||||
"""
|
||||
headers = cleanHeaders(headers, method, body, host, port)
|
||||
|
||||
# Trying 10 times is simply in case of socket errors.
|
||||
# Normal case--it should run once.
|
||||
for trial in range(10):
|
||||
try:
|
||||
# Allow http_conn to be a class or an instance
|
||||
if hasattr(http_conn, 'host'):
|
||||
conn = http_conn
|
||||
else:
|
||||
conn = http_conn(interface(host), port)
|
||||
|
||||
conn._http_vsn_str = protocol
|
||||
conn._http_vsn = int(''.join([x for x in protocol if x.isdigit()]))
|
||||
|
||||
if six.PY3 and isinstance(url, bytes):
|
||||
url = url.decode()
|
||||
conn.putrequest(method.upper(), url, skip_host=True,
|
||||
skip_accept_encoding=True)
|
||||
|
||||
for key, value in headers:
|
||||
conn.putheader(key, value.encode('Latin-1'))
|
||||
conn.endheaders()
|
||||
|
||||
if body is not None:
|
||||
conn.send(body)
|
||||
|
||||
# Handle response
|
||||
response = conn.getresponse()
|
||||
|
||||
s, h, b = shb(response)
|
||||
|
||||
if not hasattr(http_conn, 'host'):
|
||||
# We made our own conn instance. Close it.
|
||||
conn.close()
|
||||
|
||||
return s, h, b
|
||||
except socket.error as e:
|
||||
if raise_subcls is not None and isinstance(e, raise_subcls):
|
||||
raise
|
||||
else:
|
||||
time.sleep(0.5)
|
||||
if trial == 9:
|
||||
raise
|
||||
|
||||
|
||||
def strip_netloc(url):
|
||||
"""Return absolute-URI path from URL.
|
||||
|
||||
Strip the scheme and host from the URL, returning the
|
||||
server-absolute portion.
|
||||
|
||||
Useful for wrapping an absolute-URI for which only the
|
||||
path is expected (such as in calls to getPage).
|
||||
|
||||
>>> strip_netloc('https://google.com/foo/bar?bing#baz')
|
||||
'/foo/bar?bing'
|
||||
|
||||
>>> strip_netloc('//google.com/foo/bar?bing#baz')
|
||||
'/foo/bar?bing'
|
||||
|
||||
>>> strip_netloc('/foo/bar?bing#baz')
|
||||
'/foo/bar?bing'
|
||||
"""
|
||||
parsed = urllib_parse.urlparse(url)
|
||||
scheme, netloc, path, params, query, fragment = parsed
|
||||
stripped = '', '', path, params, query, ''
|
||||
return urllib_parse.urlunparse(stripped)
|
||||
|
||||
|
||||
# Add any exceptions which your web framework handles
|
||||
# normally (that you don't want server_error to trap).
|
||||
ignored_exceptions = []
|
||||
|
||||
# You'll want set this to True when you can't guarantee
|
||||
# that each response will immediately follow each request;
|
||||
# for example, when handling requests via multiple threads.
|
||||
ignore_all = False
|
||||
|
||||
|
||||
class ServerError(Exception):
|
||||
"""Exception for signalling server error."""
|
||||
|
||||
on = False
|
||||
|
||||
|
||||
def server_error(exc=None):
|
||||
"""Server debug hook.
|
||||
|
||||
Return True if exception handled, False if ignored.
|
||||
You probably want to wrap this, so you can still handle an error using
|
||||
your framework when it's ignored.
|
||||
"""
|
||||
if exc is None:
|
||||
exc = sys.exc_info()
|
||||
|
||||
if ignore_all or exc[0] in ignored_exceptions:
|
||||
return False
|
||||
else:
|
||||
ServerError.on = True
|
||||
print('')
|
||||
print(''.join(traceback.format_exception(*exc)))
|
||||
return True
|
144
libraries/cheroot/testing.py
Normal file
144
libraries/cheroot/testing.py
Normal file
|
@ -0,0 +1,144 @@
|
|||
"""Pytest fixtures and other helpers for doing testing by end-users."""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
__metaclass__ = type
|
||||
|
||||
from contextlib import closing
|
||||
import errno
|
||||
import socket
|
||||
import threading
|
||||
import time
|
||||
|
||||
import pytest
|
||||
from six.moves import http_client
|
||||
|
||||
import cheroot.server
|
||||
from cheroot.test import webtest
|
||||
import cheroot.wsgi
|
||||
|
||||
EPHEMERAL_PORT = 0
|
||||
NO_INTERFACE = None # Using this or '' will cause an exception
|
||||
ANY_INTERFACE_IPV4 = '0.0.0.0'
|
||||
ANY_INTERFACE_IPV6 = '::'
|
||||
|
||||
config = {
|
||||
cheroot.wsgi.Server: {
|
||||
'bind_addr': (NO_INTERFACE, EPHEMERAL_PORT),
|
||||
'wsgi_app': None,
|
||||
},
|
||||
cheroot.server.HTTPServer: {
|
||||
'bind_addr': (NO_INTERFACE, EPHEMERAL_PORT),
|
||||
'gateway': cheroot.server.Gateway,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def cheroot_server(server_factory):
|
||||
"""Set up and tear down a Cheroot server instance."""
|
||||
conf = config[server_factory].copy()
|
||||
bind_port = conf.pop('bind_addr')[-1]
|
||||
|
||||
for interface in ANY_INTERFACE_IPV6, ANY_INTERFACE_IPV4:
|
||||
try:
|
||||
actual_bind_addr = (interface, bind_port)
|
||||
httpserver = server_factory( # create it
|
||||
bind_addr=actual_bind_addr,
|
||||
**conf
|
||||
)
|
||||
except OSError:
|
||||
pass
|
||||
else:
|
||||
break
|
||||
|
||||
threading.Thread(target=httpserver.safe_start).start() # spawn it
|
||||
while not httpserver.ready: # wait until fully initialized and bound
|
||||
time.sleep(0.1)
|
||||
|
||||
yield httpserver
|
||||
|
||||
httpserver.stop() # destroy it
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
def wsgi_server():
|
||||
"""Set up and tear down a Cheroot WSGI server instance."""
|
||||
for srv in cheroot_server(cheroot.wsgi.Server):
|
||||
yield srv
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
def native_server():
|
||||
"""Set up and tear down a Cheroot HTTP server instance."""
|
||||
for srv in cheroot_server(cheroot.server.HTTPServer):
|
||||
yield srv
|
||||
|
||||
|
||||
class _TestClient:
|
||||
def __init__(self, server):
|
||||
self._interface, self._host, self._port = _get_conn_data(server)
|
||||
self._http_connection = self.get_connection()
|
||||
self.server_instance = server
|
||||
|
||||
def get_connection(self):
|
||||
name = '{interface}:{port}'.format(
|
||||
interface=self._interface,
|
||||
port=self._port,
|
||||
)
|
||||
return http_client.HTTPConnection(name)
|
||||
|
||||
def request(
|
||||
self, uri, method='GET', headers=None, http_conn=None,
|
||||
protocol='HTTP/1.1',
|
||||
):
|
||||
return webtest.openURL(
|
||||
uri, method=method,
|
||||
headers=headers,
|
||||
host=self._host, port=self._port,
|
||||
http_conn=http_conn or self._http_connection,
|
||||
protocol=protocol,
|
||||
)
|
||||
|
||||
def __getattr__(self, attr_name):
|
||||
def _wrapper(uri, **kwargs):
|
||||
http_method = attr_name.upper()
|
||||
return self.request(uri, method=http_method, **kwargs)
|
||||
|
||||
return _wrapper
|
||||
|
||||
|
||||
def _probe_ipv6_sock(interface):
|
||||
# Alternate way is to check IPs on interfaces using glibc, like:
|
||||
# github.com/Gautier/minifail/blob/master/minifail/getifaddrs.py
|
||||
try:
|
||||
with closing(socket.socket(family=socket.AF_INET6)) as sock:
|
||||
sock.bind((interface, 0))
|
||||
except (OSError, socket.error) as sock_err:
|
||||
# In Python 3 socket.error is an alias for OSError
|
||||
# In Python 2 socket.error is a subclass of IOError
|
||||
if sock_err.errno != errno.EADDRNOTAVAIL:
|
||||
raise
|
||||
else:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _get_conn_data(server):
|
||||
if isinstance(server.bind_addr, tuple):
|
||||
host, port = server.bind_addr
|
||||
else:
|
||||
host, port = server.bind_addr, 0
|
||||
|
||||
interface = webtest.interface(host)
|
||||
|
||||
if ':' in interface and not _probe_ipv6_sock(interface):
|
||||
interface = '127.0.0.1'
|
||||
if ':' in host:
|
||||
host = interface
|
||||
|
||||
return interface, host, port
|
||||
|
||||
|
||||
def get_server_client(server):
|
||||
"""Create and return a test client for the given server."""
|
||||
return _TestClient(server)
|
1
libraries/cheroot/workers/__init__.py
Normal file
1
libraries/cheroot/workers/__init__.py
Normal file
|
@ -0,0 +1 @@
|
|||
"""HTTP workers pool."""
|
271
libraries/cheroot/workers/threadpool.py
Normal file
271
libraries/cheroot/workers/threadpool.py
Normal file
|
@ -0,0 +1,271 @@
|
|||
"""A thread-based worker pool."""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
__metaclass__ = type
|
||||
|
||||
|
||||
import threading
|
||||
import time
|
||||
import socket
|
||||
|
||||
from six.moves import queue
|
||||
|
||||
|
||||
__all__ = ('WorkerThread', 'ThreadPool')
|
||||
|
||||
|
||||
class TrueyZero:
|
||||
"""Object which equals and does math like the integer 0 but evals True."""
|
||||
|
||||
def __add__(self, other):
|
||||
return other
|
||||
|
||||
def __radd__(self, other):
|
||||
return other
|
||||
|
||||
|
||||
trueyzero = TrueyZero()
|
||||
|
||||
_SHUTDOWNREQUEST = None
|
||||
|
||||
|
||||
class WorkerThread(threading.Thread):
|
||||
"""Thread which continuously polls a Queue for Connection objects.
|
||||
|
||||
Due to the timing issues of polling a Queue, a WorkerThread does not
|
||||
check its own 'ready' flag after it has started. To stop the thread,
|
||||
it is necessary to stick a _SHUTDOWNREQUEST object onto the Queue
|
||||
(one for each running WorkerThread).
|
||||
"""
|
||||
|
||||
conn = None
|
||||
"""The current connection pulled off the Queue, or None."""
|
||||
|
||||
server = None
|
||||
"""The HTTP Server which spawned this thread, and which owns the
|
||||
Queue and is placing active connections into it."""
|
||||
|
||||
ready = False
|
||||
"""A simple flag for the calling server to know when this thread
|
||||
has begun polling the Queue."""
|
||||
|
||||
def __init__(self, server):
|
||||
"""Initialize WorkerThread instance.
|
||||
|
||||
Args:
|
||||
server (cheroot.server.HTTPServer): web server object
|
||||
receiving this request
|
||||
"""
|
||||
self.ready = False
|
||||
self.server = server
|
||||
|
||||
self.requests_seen = 0
|
||||
self.bytes_read = 0
|
||||
self.bytes_written = 0
|
||||
self.start_time = None
|
||||
self.work_time = 0
|
||||
self.stats = {
|
||||
'Requests': lambda s: self.requests_seen + (
|
||||
(self.start_time is None) and
|
||||
trueyzero or
|
||||
self.conn.requests_seen
|
||||
),
|
||||
'Bytes Read': lambda s: self.bytes_read + (
|
||||
(self.start_time is None) and
|
||||
trueyzero or
|
||||
self.conn.rfile.bytes_read
|
||||
),
|
||||
'Bytes Written': lambda s: self.bytes_written + (
|
||||
(self.start_time is None) and
|
||||
trueyzero or
|
||||
self.conn.wfile.bytes_written
|
||||
),
|
||||
'Work Time': lambda s: self.work_time + (
|
||||
(self.start_time is None) and
|
||||
trueyzero or
|
||||
time.time() - self.start_time
|
||||
),
|
||||
'Read Throughput': lambda s: s['Bytes Read'](s) / (
|
||||
s['Work Time'](s) or 1e-6),
|
||||
'Write Throughput': lambda s: s['Bytes Written'](s) / (
|
||||
s['Work Time'](s) or 1e-6),
|
||||
}
|
||||
threading.Thread.__init__(self)
|
||||
|
||||
def run(self):
|
||||
"""Process incoming HTTP connections.
|
||||
|
||||
Retrieves incoming connections from thread pool.
|
||||
"""
|
||||
self.server.stats['Worker Threads'][self.getName()] = self.stats
|
||||
try:
|
||||
self.ready = True
|
||||
while True:
|
||||
conn = self.server.requests.get()
|
||||
if conn is _SHUTDOWNREQUEST:
|
||||
return
|
||||
|
||||
self.conn = conn
|
||||
if self.server.stats['Enabled']:
|
||||
self.start_time = time.time()
|
||||
try:
|
||||
conn.communicate()
|
||||
finally:
|
||||
conn.close()
|
||||
if self.server.stats['Enabled']:
|
||||
self.requests_seen += self.conn.requests_seen
|
||||
self.bytes_read += self.conn.rfile.bytes_read
|
||||
self.bytes_written += self.conn.wfile.bytes_written
|
||||
self.work_time += time.time() - self.start_time
|
||||
self.start_time = None
|
||||
self.conn = None
|
||||
except (KeyboardInterrupt, SystemExit) as ex:
|
||||
self.server.interrupt = ex
|
||||
|
||||
|
||||
class ThreadPool:
|
||||
"""A Request Queue for an HTTPServer which pools threads.
|
||||
|
||||
ThreadPool objects must provide min, get(), put(obj), start()
|
||||
and stop(timeout) attributes.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, server, min=10, max=-1, accepted_queue_size=-1,
|
||||
accepted_queue_timeout=10):
|
||||
"""Initialize HTTP requests queue instance.
|
||||
|
||||
Args:
|
||||
server (cheroot.server.HTTPServer): web server object
|
||||
receiving this request
|
||||
min (int): minimum number of worker threads
|
||||
max (int): maximum number of worker threads
|
||||
accepted_queue_size (int): maximum number of active
|
||||
requests in queue
|
||||
accepted_queue_timeout (int): timeout for putting request
|
||||
into queue
|
||||
"""
|
||||
self.server = server
|
||||
self.min = min
|
||||
self.max = max
|
||||
self._threads = []
|
||||
self._queue = queue.Queue(maxsize=accepted_queue_size)
|
||||
self._queue_put_timeout = accepted_queue_timeout
|
||||
self.get = self._queue.get
|
||||
|
||||
def start(self):
|
||||
"""Start the pool of threads."""
|
||||
for i in range(self.min):
|
||||
self._threads.append(WorkerThread(self.server))
|
||||
for worker in self._threads:
|
||||
worker.setName('CP Server ' + worker.getName())
|
||||
worker.start()
|
||||
for worker in self._threads:
|
||||
while not worker.ready:
|
||||
time.sleep(.1)
|
||||
|
||||
@property
|
||||
def idle(self): # noqa: D401; irrelevant for properties
|
||||
"""Number of worker threads which are idle. Read-only."""
|
||||
return len([t for t in self._threads if t.conn is None])
|
||||
|
||||
def put(self, obj):
|
||||
"""Put request into queue.
|
||||
|
||||
Args:
|
||||
obj (cheroot.server.HTTPConnection): HTTP connection
|
||||
waiting to be processed
|
||||
"""
|
||||
self._queue.put(obj, block=True, timeout=self._queue_put_timeout)
|
||||
if obj is _SHUTDOWNREQUEST:
|
||||
return
|
||||
|
||||
def grow(self, amount):
|
||||
"""Spawn new worker threads (not above self.max)."""
|
||||
if self.max > 0:
|
||||
budget = max(self.max - len(self._threads), 0)
|
||||
else:
|
||||
# self.max <= 0 indicates no maximum
|
||||
budget = float('inf')
|
||||
|
||||
n_new = min(amount, budget)
|
||||
|
||||
workers = [self._spawn_worker() for i in range(n_new)]
|
||||
while not all(worker.ready for worker in workers):
|
||||
time.sleep(.1)
|
||||
self._threads.extend(workers)
|
||||
|
||||
def _spawn_worker(self):
|
||||
worker = WorkerThread(self.server)
|
||||
worker.setName('CP Server ' + worker.getName())
|
||||
worker.start()
|
||||
return worker
|
||||
|
||||
def shrink(self, amount):
|
||||
"""Kill off worker threads (not below self.min)."""
|
||||
# Grow/shrink the pool if necessary.
|
||||
# Remove any dead threads from our list
|
||||
for t in self._threads:
|
||||
if not t.isAlive():
|
||||
self._threads.remove(t)
|
||||
amount -= 1
|
||||
|
||||
# calculate the number of threads above the minimum
|
||||
n_extra = max(len(self._threads) - self.min, 0)
|
||||
|
||||
# don't remove more than amount
|
||||
n_to_remove = min(amount, n_extra)
|
||||
|
||||
# put shutdown requests on the queue equal to the number of threads
|
||||
# to remove. As each request is processed by a worker, that worker
|
||||
# will terminate and be culled from the list.
|
||||
for n in range(n_to_remove):
|
||||
self._queue.put(_SHUTDOWNREQUEST)
|
||||
|
||||
def stop(self, timeout=5):
|
||||
"""Terminate all worker threads.
|
||||
|
||||
Args:
|
||||
timeout (int): time to wait for threads to stop gracefully
|
||||
"""
|
||||
# Must shut down threads here so the code that calls
|
||||
# this method can know when all threads are stopped.
|
||||
for worker in self._threads:
|
||||
self._queue.put(_SHUTDOWNREQUEST)
|
||||
|
||||
# Don't join currentThread (when stop is called inside a request).
|
||||
current = threading.currentThread()
|
||||
if timeout is not None and timeout >= 0:
|
||||
endtime = time.time() + timeout
|
||||
while self._threads:
|
||||
worker = self._threads.pop()
|
||||
if worker is not current and worker.isAlive():
|
||||
try:
|
||||
if timeout is None or timeout < 0:
|
||||
worker.join()
|
||||
else:
|
||||
remaining_time = endtime - time.time()
|
||||
if remaining_time > 0:
|
||||
worker.join(remaining_time)
|
||||
if worker.isAlive():
|
||||
# We exhausted the timeout.
|
||||
# Forcibly shut down the socket.
|
||||
c = worker.conn
|
||||
if c and not c.rfile.closed:
|
||||
try:
|
||||
c.socket.shutdown(socket.SHUT_RD)
|
||||
except TypeError:
|
||||
# pyOpenSSL sockets don't take an arg
|
||||
c.socket.shutdown()
|
||||
worker.join()
|
||||
except (AssertionError,
|
||||
# Ignore repeated Ctrl-C.
|
||||
# See
|
||||
# https://github.com/cherrypy/cherrypy/issues/691.
|
||||
KeyboardInterrupt):
|
||||
pass
|
||||
|
||||
@property
|
||||
def qsize(self):
|
||||
"""Return the queue size."""
|
||||
return self._queue.qsize()
|
423
libraries/cheroot/wsgi.py
Normal file
423
libraries/cheroot/wsgi.py
Normal file
|
@ -0,0 +1,423 @@
|
|||
"""This class holds Cheroot WSGI server implementation.
|
||||
|
||||
Simplest example on how to use this server::
|
||||
|
||||
from cheroot import wsgi
|
||||
|
||||
def my_crazy_app(environ, start_response):
|
||||
status = '200 OK'
|
||||
response_headers = [('Content-type','text/plain')]
|
||||
start_response(status, response_headers)
|
||||
return [b'Hello world!']
|
||||
|
||||
addr = '0.0.0.0', 8070
|
||||
server = wsgi.Server(addr, my_crazy_app)
|
||||
server.start()
|
||||
|
||||
The Cheroot WSGI server can serve as many WSGI applications
|
||||
as you want in one instance by using a PathInfoDispatcher::
|
||||
|
||||
path_map = {
|
||||
'/': my_crazy_app,
|
||||
'/blog': my_blog_app,
|
||||
}
|
||||
d = wsgi.PathInfoDispatcher(path_map)
|
||||
server = wsgi.Server(addr, d)
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
__metaclass__ = type
|
||||
|
||||
import sys
|
||||
|
||||
import six
|
||||
from six.moves import filter
|
||||
|
||||
from . import server
|
||||
from .workers import threadpool
|
||||
from ._compat import ntob, bton
|
||||
|
||||
|
||||
class Server(server.HTTPServer):
|
||||
"""A subclass of HTTPServer which calls a WSGI application."""
|
||||
|
||||
wsgi_version = (1, 0)
|
||||
"""The version of WSGI to produce."""
|
||||
|
||||
def __init__(
|
||||
self, bind_addr, wsgi_app, numthreads=10, server_name=None,
|
||||
max=-1, request_queue_size=5, timeout=10, shutdown_timeout=5,
|
||||
accepted_queue_size=-1, accepted_queue_timeout=10,
|
||||
peercreds_enabled=False, peercreds_resolve_enabled=False,
|
||||
):
|
||||
"""Initialize WSGI Server instance.
|
||||
|
||||
Args:
|
||||
bind_addr (tuple): network interface to listen to
|
||||
wsgi_app (callable): WSGI application callable
|
||||
numthreads (int): number of threads for WSGI thread pool
|
||||
server_name (str): web server name to be advertised via
|
||||
Server HTTP header
|
||||
max (int): maximum number of worker threads
|
||||
request_queue_size (int): the 'backlog' arg to
|
||||
socket.listen(); max queued connections
|
||||
timeout (int): the timeout in seconds for accepted connections
|
||||
shutdown_timeout (int): the total time, in seconds, to
|
||||
wait for worker threads to cleanly exit
|
||||
accepted_queue_size (int): maximum number of active
|
||||
requests in queue
|
||||
accepted_queue_timeout (int): timeout for putting request
|
||||
into queue
|
||||
"""
|
||||
super(Server, self).__init__(
|
||||
bind_addr,
|
||||
gateway=wsgi_gateways[self.wsgi_version],
|
||||
server_name=server_name,
|
||||
peercreds_enabled=peercreds_enabled,
|
||||
peercreds_resolve_enabled=peercreds_resolve_enabled,
|
||||
)
|
||||
self.wsgi_app = wsgi_app
|
||||
self.request_queue_size = request_queue_size
|
||||
self.timeout = timeout
|
||||
self.shutdown_timeout = shutdown_timeout
|
||||
self.requests = threadpool.ThreadPool(
|
||||
self, min=numthreads or 1, max=max,
|
||||
accepted_queue_size=accepted_queue_size,
|
||||
accepted_queue_timeout=accepted_queue_timeout)
|
||||
|
||||
@property
|
||||
def numthreads(self):
|
||||
"""Set minimum number of threads."""
|
||||
return self.requests.min
|
||||
|
||||
@numthreads.setter
|
||||
def numthreads(self, value):
|
||||
self.requests.min = value
|
||||
|
||||
|
||||
class Gateway(server.Gateway):
|
||||
"""A base class to interface HTTPServer with WSGI."""
|
||||
|
||||
def __init__(self, req):
|
||||
"""Initialize WSGI Gateway instance with request.
|
||||
|
||||
Args:
|
||||
req (HTTPRequest): current HTTP request
|
||||
"""
|
||||
super(Gateway, self).__init__(req)
|
||||
self.started_response = False
|
||||
self.env = self.get_environ()
|
||||
self.remaining_bytes_out = None
|
||||
|
||||
@classmethod
|
||||
def gateway_map(cls):
|
||||
"""Create a mapping of gateways and their versions.
|
||||
|
||||
Returns:
|
||||
dict[tuple[int,int],class]: map of gateway version and
|
||||
corresponding class
|
||||
|
||||
"""
|
||||
return dict(
|
||||
(gw.version, gw)
|
||||
for gw in cls.__subclasses__()
|
||||
)
|
||||
|
||||
def get_environ(self):
|
||||
"""Return a new environ dict targeting the given wsgi.version."""
|
||||
raise NotImplementedError
|
||||
|
||||
def respond(self):
|
||||
"""Process the current request.
|
||||
|
||||
From :pep:`333`:
|
||||
|
||||
The start_response callable must not actually transmit
|
||||
the response headers. Instead, it must store them for the
|
||||
server or gateway to transmit only after the first
|
||||
iteration of the application return value that yields
|
||||
a NON-EMPTY string, or upon the application's first
|
||||
invocation of the write() callable.
|
||||
"""
|
||||
response = self.req.server.wsgi_app(self.env, self.start_response)
|
||||
try:
|
||||
for chunk in filter(None, response):
|
||||
if not isinstance(chunk, six.binary_type):
|
||||
raise ValueError('WSGI Applications must yield bytes')
|
||||
self.write(chunk)
|
||||
finally:
|
||||
# Send headers if not already sent
|
||||
self.req.ensure_headers_sent()
|
||||
if hasattr(response, 'close'):
|
||||
response.close()
|
||||
|
||||
def start_response(self, status, headers, exc_info=None):
|
||||
"""WSGI callable to begin the HTTP response."""
|
||||
# "The application may call start_response more than once,
|
||||
# if and only if the exc_info argument is provided."
|
||||
if self.started_response and not exc_info:
|
||||
raise AssertionError('WSGI start_response called a second '
|
||||
'time with no exc_info.')
|
||||
self.started_response = True
|
||||
|
||||
# "if exc_info is provided, and the HTTP headers have already been
|
||||
# sent, start_response must raise an error, and should raise the
|
||||
# exc_info tuple."
|
||||
if self.req.sent_headers:
|
||||
try:
|
||||
six.reraise(*exc_info)
|
||||
finally:
|
||||
exc_info = None
|
||||
|
||||
self.req.status = self._encode_status(status)
|
||||
|
||||
for k, v in headers:
|
||||
if not isinstance(k, str):
|
||||
raise TypeError(
|
||||
'WSGI response header key %r is not of type str.' % k)
|
||||
if not isinstance(v, str):
|
||||
raise TypeError(
|
||||
'WSGI response header value %r is not of type str.' % v)
|
||||
if k.lower() == 'content-length':
|
||||
self.remaining_bytes_out = int(v)
|
||||
out_header = ntob(k), ntob(v)
|
||||
self.req.outheaders.append(out_header)
|
||||
|
||||
return self.write
|
||||
|
||||
@staticmethod
|
||||
def _encode_status(status):
|
||||
"""Cast status to bytes representation of current Python version.
|
||||
|
||||
According to :pep:`3333`, when using Python 3, the response status
|
||||
and headers must be bytes masquerading as unicode; that is, they
|
||||
must be of type "str" but are restricted to code points in the
|
||||
"latin-1" set.
|
||||
"""
|
||||
if six.PY2:
|
||||
return status
|
||||
if not isinstance(status, str):
|
||||
raise TypeError('WSGI response status is not of type str.')
|
||||
return status.encode('ISO-8859-1')
|
||||
|
||||
def write(self, chunk):
|
||||
"""WSGI callable to write unbuffered data to the client.
|
||||
|
||||
This method is also used internally by start_response (to write
|
||||
data from the iterable returned by the WSGI application).
|
||||
"""
|
||||
if not self.started_response:
|
||||
raise AssertionError('WSGI write called before start_response.')
|
||||
|
||||
chunklen = len(chunk)
|
||||
rbo = self.remaining_bytes_out
|
||||
if rbo is not None and chunklen > rbo:
|
||||
if not self.req.sent_headers:
|
||||
# Whew. We can send a 500 to the client.
|
||||
self.req.simple_response(
|
||||
'500 Internal Server Error',
|
||||
'The requested resource returned more bytes than the '
|
||||
'declared Content-Length.')
|
||||
else:
|
||||
# Dang. We have probably already sent data. Truncate the chunk
|
||||
# to fit (so the client doesn't hang) and raise an error later.
|
||||
chunk = chunk[:rbo]
|
||||
|
||||
self.req.ensure_headers_sent()
|
||||
|
||||
self.req.write(chunk)
|
||||
|
||||
if rbo is not None:
|
||||
rbo -= chunklen
|
||||
if rbo < 0:
|
||||
raise ValueError(
|
||||
'Response body exceeds the declared Content-Length.')
|
||||
|
||||
|
||||
class Gateway_10(Gateway):
|
||||
"""A Gateway class to interface HTTPServer with WSGI 1.0.x."""
|
||||
|
||||
version = 1, 0
|
||||
|
||||
def get_environ(self):
|
||||
"""Return a new environ dict targeting the given wsgi.version."""
|
||||
req = self.req
|
||||
req_conn = req.conn
|
||||
env = {
|
||||
# set a non-standard environ entry so the WSGI app can know what
|
||||
# the *real* server protocol is (and what features to support).
|
||||
# See http://www.faqs.org/rfcs/rfc2145.html.
|
||||
'ACTUAL_SERVER_PROTOCOL': req.server.protocol,
|
||||
'PATH_INFO': bton(req.path),
|
||||
'QUERY_STRING': bton(req.qs),
|
||||
'REMOTE_ADDR': req_conn.remote_addr or '',
|
||||
'REMOTE_PORT': str(req_conn.remote_port or ''),
|
||||
'REQUEST_METHOD': bton(req.method),
|
||||
'REQUEST_URI': bton(req.uri),
|
||||
'SCRIPT_NAME': '',
|
||||
'SERVER_NAME': req.server.server_name,
|
||||
# Bah. "SERVER_PROTOCOL" is actually the REQUEST protocol.
|
||||
'SERVER_PROTOCOL': bton(req.request_protocol),
|
||||
'SERVER_SOFTWARE': req.server.software,
|
||||
'wsgi.errors': sys.stderr,
|
||||
'wsgi.input': req.rfile,
|
||||
'wsgi.input_terminated': bool(req.chunked_read),
|
||||
'wsgi.multiprocess': False,
|
||||
'wsgi.multithread': True,
|
||||
'wsgi.run_once': False,
|
||||
'wsgi.url_scheme': bton(req.scheme),
|
||||
'wsgi.version': self.version,
|
||||
}
|
||||
|
||||
if isinstance(req.server.bind_addr, six.string_types):
|
||||
# AF_UNIX. This isn't really allowed by WSGI, which doesn't
|
||||
# address unix domain sockets. But it's better than nothing.
|
||||
env['SERVER_PORT'] = ''
|
||||
try:
|
||||
env['X_REMOTE_PID'] = str(req_conn.peer_pid)
|
||||
env['X_REMOTE_UID'] = str(req_conn.peer_uid)
|
||||
env['X_REMOTE_GID'] = str(req_conn.peer_gid)
|
||||
|
||||
env['X_REMOTE_USER'] = str(req_conn.peer_user)
|
||||
env['X_REMOTE_GROUP'] = str(req_conn.peer_group)
|
||||
|
||||
env['REMOTE_USER'] = env['X_REMOTE_USER']
|
||||
except RuntimeError:
|
||||
"""Unable to retrieve peer creds data.
|
||||
|
||||
Unsupported by current kernel or socket error happened, or
|
||||
unsupported socket type, or disabled.
|
||||
"""
|
||||
else:
|
||||
env['SERVER_PORT'] = str(req.server.bind_addr[1])
|
||||
|
||||
# Request headers
|
||||
env.update(
|
||||
('HTTP_' + bton(k).upper().replace('-', '_'), bton(v))
|
||||
for k, v in req.inheaders.items()
|
||||
)
|
||||
|
||||
# CONTENT_TYPE/CONTENT_LENGTH
|
||||
ct = env.pop('HTTP_CONTENT_TYPE', None)
|
||||
if ct is not None:
|
||||
env['CONTENT_TYPE'] = ct
|
||||
cl = env.pop('HTTP_CONTENT_LENGTH', None)
|
||||
if cl is not None:
|
||||
env['CONTENT_LENGTH'] = cl
|
||||
|
||||
if req.conn.ssl_env:
|
||||
env.update(req.conn.ssl_env)
|
||||
|
||||
return env
|
||||
|
||||
|
||||
class Gateway_u0(Gateway_10):
|
||||
"""A Gateway class to interface HTTPServer with WSGI u.0.
|
||||
|
||||
WSGI u.0 is an experimental protocol, which uses unicode for keys
|
||||
and values in both Python 2 and Python 3.
|
||||
"""
|
||||
|
||||
version = 'u', 0
|
||||
|
||||
def get_environ(self):
|
||||
"""Return a new environ dict targeting the given wsgi.version."""
|
||||
req = self.req
|
||||
env_10 = super(Gateway_u0, self).get_environ()
|
||||
env = dict(map(self._decode_key, env_10.items()))
|
||||
|
||||
# Request-URI
|
||||
enc = env.setdefault(six.u('wsgi.url_encoding'), six.u('utf-8'))
|
||||
try:
|
||||
env['PATH_INFO'] = req.path.decode(enc)
|
||||
env['QUERY_STRING'] = req.qs.decode(enc)
|
||||
except UnicodeDecodeError:
|
||||
# Fall back to latin 1 so apps can transcode if needed.
|
||||
env['wsgi.url_encoding'] = 'ISO-8859-1'
|
||||
env['PATH_INFO'] = env_10['PATH_INFO']
|
||||
env['QUERY_STRING'] = env_10['QUERY_STRING']
|
||||
|
||||
env.update(map(self._decode_value, env.items()))
|
||||
|
||||
return env
|
||||
|
||||
@staticmethod
|
||||
def _decode_key(item):
|
||||
k, v = item
|
||||
if six.PY2:
|
||||
k = k.decode('ISO-8859-1')
|
||||
return k, v
|
||||
|
||||
@staticmethod
|
||||
def _decode_value(item):
|
||||
k, v = item
|
||||
skip_keys = 'REQUEST_URI', 'wsgi.input'
|
||||
if six.PY3 or not isinstance(v, bytes) or k in skip_keys:
|
||||
return k, v
|
||||
return k, v.decode('ISO-8859-1')
|
||||
|
||||
|
||||
wsgi_gateways = Gateway.gateway_map()
|
||||
|
||||
|
||||
class PathInfoDispatcher:
|
||||
"""A WSGI dispatcher for dispatch based on the PATH_INFO."""
|
||||
|
||||
def __init__(self, apps):
|
||||
"""Initialize path info WSGI app dispatcher.
|
||||
|
||||
Args:
|
||||
apps (dict[str,object]|list[tuple[str,object]]): URI prefix
|
||||
and WSGI app pairs
|
||||
"""
|
||||
try:
|
||||
apps = list(apps.items())
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
# Sort the apps by len(path), descending
|
||||
def by_path_len(app):
|
||||
return len(app[0])
|
||||
apps.sort(key=by_path_len, reverse=True)
|
||||
|
||||
# The path_prefix strings must start, but not end, with a slash.
|
||||
# Use "" instead of "/".
|
||||
self.apps = [(p.rstrip('/'), a) for p, a in apps]
|
||||
|
||||
def __call__(self, environ, start_response):
|
||||
"""Process incoming WSGI request.
|
||||
|
||||
Ref: :pep:`3333`
|
||||
|
||||
Args:
|
||||
environ (Mapping): a dict containing WSGI environment variables
|
||||
start_response (callable): function, which sets response
|
||||
status and headers
|
||||
|
||||
Returns:
|
||||
list[bytes]: iterable containing bytes to be returned in
|
||||
HTTP response body
|
||||
|
||||
"""
|
||||
path = environ['PATH_INFO'] or '/'
|
||||
for p, app in self.apps:
|
||||
# The apps list should be sorted by length, descending.
|
||||
if path.startswith(p + '/') or path == p:
|
||||
environ = environ.copy()
|
||||
environ['SCRIPT_NAME'] = environ['SCRIPT_NAME'] + p
|
||||
environ['PATH_INFO'] = path[len(p):]
|
||||
return app(environ, start_response)
|
||||
|
||||
start_response('404 Not Found', [('Content-Type', 'text/plain'),
|
||||
('Content-Length', '0')])
|
||||
return ['']
|
||||
|
||||
|
||||
# compatibility aliases
|
||||
globals().update(
|
||||
WSGIServer=Server,
|
||||
WSGIGateway=Gateway,
|
||||
WSGIGateway_u0=Gateway_u0,
|
||||
WSGIGateway_10=Gateway_10,
|
||||
WSGIPathInfoDispatcher=PathInfoDispatcher,
|
||||
)
|
Loading…
Add table
Add a link
Reference in a new issue