diff --git a/libraries/backports/__init__.py b/libraries/backports/__init__.py new file mode 100644 index 00000000..69e3be50 --- /dev/null +++ b/libraries/backports/__init__.py @@ -0,0 +1 @@ +__path__ = __import__('pkgutil').extend_path(__path__, __name__) diff --git a/libraries/backports/functools_lru_cache.py b/libraries/backports/functools_lru_cache.py new file mode 100644 index 00000000..707c6c76 --- /dev/null +++ b/libraries/backports/functools_lru_cache.py @@ -0,0 +1,184 @@ +from __future__ import absolute_import + +import functools +from collections import namedtuple +from threading import RLock + +_CacheInfo = namedtuple("CacheInfo", ["hits", "misses", "maxsize", "currsize"]) + + +@functools.wraps(functools.update_wrapper) +def update_wrapper(wrapper, + wrapped, + assigned = functools.WRAPPER_ASSIGNMENTS, + updated = functools.WRAPPER_UPDATES): + """ + Patch two bugs in functools.update_wrapper. + """ + # workaround for http://bugs.python.org/issue3445 + assigned = tuple(attr for attr in assigned if hasattr(wrapped, attr)) + wrapper = functools.update_wrapper(wrapper, wrapped, assigned, updated) + # workaround for https://bugs.python.org/issue17482 + wrapper.__wrapped__ = wrapped + return wrapper + + +class _HashedSeq(list): + __slots__ = 'hashvalue' + + def __init__(self, tup, hash=hash): + self[:] = tup + self.hashvalue = hash(tup) + + def __hash__(self): + return self.hashvalue + + +def _make_key(args, kwds, typed, + kwd_mark=(object(),), + fasttypes=set([int, str, frozenset, type(None)]), + sorted=sorted, tuple=tuple, type=type, len=len): + 'Make a cache key from optionally typed positional and keyword arguments' + key = args + if kwds: + sorted_items = sorted(kwds.items()) + key += kwd_mark + for item in sorted_items: + key += item + if typed: + key += tuple(type(v) for v in args) + if kwds: + key += tuple(type(v) for k, v in sorted_items) + elif len(key) == 1 and type(key[0]) in fasttypes: + return key[0] + return _HashedSeq(key) + + +def lru_cache(maxsize=100, typed=False): + """Least-recently-used cache decorator. + + If *maxsize* is set to None, the LRU features are disabled and the cache + can grow without bound. + + If *typed* is True, arguments of different types will be cached separately. + For example, f(3.0) and f(3) will be treated as distinct calls with + distinct results. + + Arguments to the cached function must be hashable. + + View the cache statistics named tuple (hits, misses, maxsize, currsize) with + f.cache_info(). Clear the cache and statistics with f.cache_clear(). + Access the underlying function with f.__wrapped__. + + See: http://en.wikipedia.org/wiki/Cache_algorithms#Least_Recently_Used + + """ + + # Users should only access the lru_cache through its public API: + # cache_info, cache_clear, and f.__wrapped__ + # The internals of the lru_cache are encapsulated for thread safety and + # to allow the implementation to change (including a possible C version). + + def decorating_function(user_function): + + cache = dict() + stats = [0, 0] # make statistics updateable non-locally + HITS, MISSES = 0, 1 # names for the stats fields + make_key = _make_key + cache_get = cache.get # bound method to lookup key or return None + _len = len # localize the global len() function + lock = RLock() # because linkedlist updates aren't threadsafe + root = [] # root of the circular doubly linked list + root[:] = [root, root, None, None] # initialize by pointing to self + nonlocal_root = [root] # make updateable non-locally + PREV, NEXT, KEY, RESULT = 0, 1, 2, 3 # names for the link fields + + if maxsize == 0: + + def wrapper(*args, **kwds): + # no caching, just do a statistics update after a successful call + result = user_function(*args, **kwds) + stats[MISSES] += 1 + return result + + elif maxsize is None: + + def wrapper(*args, **kwds): + # simple caching without ordering or size limit + key = make_key(args, kwds, typed) + result = cache_get(key, root) # root used here as a unique not-found sentinel + if result is not root: + stats[HITS] += 1 + return result + result = user_function(*args, **kwds) + cache[key] = result + stats[MISSES] += 1 + return result + + else: + + def wrapper(*args, **kwds): + # size limited caching that tracks accesses by recency + key = make_key(args, kwds, typed) if kwds or typed else args + with lock: + link = cache_get(key) + if link is not None: + # record recent use of the key by moving it to the front of the list + root, = nonlocal_root + link_prev, link_next, key, result = link + link_prev[NEXT] = link_next + link_next[PREV] = link_prev + last = root[PREV] + last[NEXT] = root[PREV] = link + link[PREV] = last + link[NEXT] = root + stats[HITS] += 1 + return result + result = user_function(*args, **kwds) + with lock: + root, = nonlocal_root + if key in cache: + # getting here means that this same key was added to the + # cache while the lock was released. since the link + # update is already done, we need only return the + # computed result and update the count of misses. + pass + elif _len(cache) >= maxsize: + # use the old root to store the new key and result + oldroot = root + oldroot[KEY] = key + oldroot[RESULT] = result + # empty the oldest link and make it the new root + root = nonlocal_root[0] = oldroot[NEXT] + oldkey = root[KEY] + root[KEY] = root[RESULT] = None + # now update the cache dictionary for the new links + del cache[oldkey] + cache[key] = oldroot + else: + # put result in a new link at the front of the list + last = root[PREV] + link = [last, root, key, result] + last[NEXT] = root[PREV] = cache[key] = link + stats[MISSES] += 1 + return result + + def cache_info(): + """Report cache statistics""" + with lock: + return _CacheInfo(stats[HITS], stats[MISSES], maxsize, len(cache)) + + def cache_clear(): + """Clear the cache and cache statistics""" + with lock: + cache.clear() + root = nonlocal_root[0] + root[:] = [root, root, None, None] + stats[:] = [0, 0] + + wrapper.__wrapped__ = user_function + wrapper.cache_info = cache_info + wrapper.cache_clear = cache_clear + return update_wrapper(wrapper, user_function) + + return decorating_function diff --git a/libraries/cheroot/__init__.py b/libraries/cheroot/__init__.py new file mode 100644 index 00000000..a313660e --- /dev/null +++ b/libraries/cheroot/__init__.py @@ -0,0 +1,6 @@ +"""High-performance, pure-Python HTTP server used by CherryPy.""" + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +__version__ = '6.4.0' diff --git a/libraries/cheroot/__main__.py b/libraries/cheroot/__main__.py new file mode 100644 index 00000000..d2e27c10 --- /dev/null +++ b/libraries/cheroot/__main__.py @@ -0,0 +1,6 @@ +"""Stub for accessing the Cheroot CLI tool.""" + +from .cli import main + +if __name__ == '__main__': + main() diff --git a/libraries/cheroot/_compat.py b/libraries/cheroot/_compat.py new file mode 100644 index 00000000..e98f91f9 --- /dev/null +++ b/libraries/cheroot/_compat.py @@ -0,0 +1,66 @@ +"""Compatibility code for using Cheroot with various versions of Python.""" + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +import re + +import six + +if six.PY3: + def ntob(n, encoding='ISO-8859-1'): + """Return the native string as bytes in the given encoding.""" + assert_native(n) + # In Python 3, the native string type is unicode + return n.encode(encoding) + + def ntou(n, encoding='ISO-8859-1'): + """Return the native string as unicode with the given encoding.""" + assert_native(n) + # In Python 3, the native string type is unicode + return n + + def bton(b, encoding='ISO-8859-1'): + """Return the byte string as native string in the given encoding.""" + return b.decode(encoding) +else: + # Python 2 + def ntob(n, encoding='ISO-8859-1'): + """Return the native string as bytes in the given encoding.""" + assert_native(n) + # In Python 2, the native string type is bytes. Assume it's already + # in the given encoding, which for ISO-8859-1 is almost always what + # was intended. + return n + + def ntou(n, encoding='ISO-8859-1'): + """Return the native string as unicode with the given encoding.""" + assert_native(n) + # In Python 2, the native string type is bytes. + # First, check for the special encoding 'escape'. The test suite uses + # this to signal that it wants to pass a string with embedded \uXXXX + # escapes, but without having to prefix it with u'' for Python 2, + # but no prefix for Python 3. + if encoding == 'escape': + return six.u( + re.sub(r'\\u([0-9a-zA-Z]{4})', + lambda m: six.unichr(int(m.group(1), 16)), + n.decode('ISO-8859-1'))) + # Assume it's already in the given encoding, which for ISO-8859-1 + # is almost always what was intended. + return n.decode(encoding) + + def bton(b, encoding='ISO-8859-1'): + """Return the byte string as native string in the given encoding.""" + return b + + +def assert_native(n): + """Check whether the input is of nativ ``str`` type. + + Raises: + TypeError: in case of failed check + + """ + if not isinstance(n, str): + raise TypeError('n must be a native str (got %s)' % type(n).__name__) diff --git a/libraries/cheroot/cli.py b/libraries/cheroot/cli.py new file mode 100644 index 00000000..6d59fb5c --- /dev/null +++ b/libraries/cheroot/cli.py @@ -0,0 +1,233 @@ +"""Command line tool for starting a Cheroot WSGI/HTTP server instance. + +Basic usage:: + + # Start a server on 127.0.0.1:8000 with the default settings + # for the WSGI app myapp/wsgi.py:application() + cheroot myapp.wsgi + + # Start a server on 0.0.0.0:9000 with 8 threads + # for the WSGI app myapp/wsgi.py:main_app() + cheroot myapp.wsgi:main_app --bind 0.0.0.0:9000 --threads 8 + + # Start a server for the cheroot.server.Gateway subclass + # myapp/gateway.py:HTTPGateway + cheroot myapp.gateway:HTTPGateway + + # Start a server on the UNIX socket /var/spool/myapp.sock + cheroot myapp.wsgi --bind /var/spool/myapp.sock + + # Start a server on the abstract UNIX socket CherootServer + cheroot myapp.wsgi --bind @CherootServer +""" + +import argparse +from importlib import import_module +import os +import sys +import contextlib + +import six + +from . import server +from . import wsgi + + +__metaclass__ = type + + +class BindLocation: + """A class for storing the bind location for a Cheroot instance.""" + + +class TCPSocket(BindLocation): + """TCPSocket.""" + + def __init__(self, address, port): + """Initialize. + + Args: + address (str): Host name or IP address + port (int): TCP port number + """ + self.bind_addr = address, port + + +class UnixSocket(BindLocation): + """UnixSocket.""" + + def __init__(self, path): + """Initialize.""" + self.bind_addr = path + + +class AbstractSocket(BindLocation): + """AbstractSocket.""" + + def __init__(self, addr): + """Initialize.""" + self.bind_addr = '\0{}'.format(self.abstract_socket) + + +class Application: + """Application.""" + + @classmethod + def resolve(cls, full_path): + """Read WSGI app/Gateway path string and import application module.""" + mod_path, _, app_path = full_path.partition(':') + app = getattr(import_module(mod_path), app_path or 'application') + + with contextlib.suppress(TypeError): + if issubclass(app, server.Gateway): + return GatewayYo(app) + + return cls(app) + + def __init__(self, wsgi_app): + """Initialize.""" + if not callable(wsgi_app): + raise TypeError( + 'Application must be a callable object or ' + 'cheroot.server.Gateway subclass' + ) + self.wsgi_app = wsgi_app + + def server_args(self, parsed_args): + """Return keyword args for Server class.""" + args = { + arg: value + for arg, value in vars(parsed_args).items() + if not arg.startswith('_') and value is not None + } + args.update(vars(self)) + return args + + def server(self, parsed_args): + """Server.""" + return wsgi.Server(**self.server_args(parsed_args)) + + +class GatewayYo: + """Gateway.""" + + def __init__(self, gateway): + """Init.""" + self.gateway = gateway + + def server(self, parsed_args): + """Server.""" + server_args = vars(self) + server_args['bind_addr'] = parsed_args['bind_addr'] + if parsed_args.max is not None: + server_args['maxthreads'] = parsed_args.max + if parsed_args.numthreads is not None: + server_args['minthreads'] = parsed_args.numthreads + return server.HTTPServer(**server_args) + + +def parse_wsgi_bind_location(bind_addr_string): + """Convert bind address string to a BindLocation.""" + # try and match for an IP/hostname and port + match = six.moves.urllib.parse.urlparse('//{}'.format(bind_addr_string)) + try: + addr = match.hostname + port = match.port + if addr is not None or port is not None: + return TCPSocket(addr, port) + except ValueError: + pass + + # else, assume a UNIX socket path + # if the string begins with an @ symbol, use an abstract socket + if bind_addr_string.startswith('@'): + return AbstractSocket(bind_addr_string[1:]) + return UnixSocket(path=bind_addr_string) + + +def parse_wsgi_bind_addr(bind_addr_string): + """Convert bind address string to bind address parameter.""" + return parse_wsgi_bind_location(bind_addr_string).bind_addr + + +_arg_spec = { + '_wsgi_app': dict( + metavar='APP_MODULE', + type=Application.resolve, + help='WSGI application callable or cheroot.server.Gateway subclass', + ), + '--bind': dict( + metavar='ADDRESS', + dest='bind_addr', + type=parse_wsgi_bind_addr, + default='[::1]:8000', + help='Network interface to listen on (default: [::1]:8000)', + ), + '--chdir': dict( + metavar='PATH', + type=os.chdir, + help='Set the working directory', + ), + '--server-name': dict( + dest='server_name', + type=str, + help='Web server name to be advertised via Server HTTP header', + ), + '--threads': dict( + metavar='INT', + dest='numthreads', + type=int, + help='Minimum number of worker threads', + ), + '--max-threads': dict( + metavar='INT', + dest='max', + type=int, + help='Maximum number of worker threads', + ), + '--timeout': dict( + metavar='INT', + dest='timeout', + type=int, + help='Timeout in seconds for accepted connections', + ), + '--shutdown-timeout': dict( + metavar='INT', + dest='shutdown_timeout', + type=int, + help='Time in seconds to wait for worker threads to cleanly exit', + ), + '--request-queue-size': dict( + metavar='INT', + dest='request_queue_size', + type=int, + help='Maximum number of queued connections', + ), + '--accepted-queue-size': dict( + metavar='INT', + dest='accepted_queue_size', + type=int, + help='Maximum number of active requests in queue', + ), + '--accepted-queue-timeout': dict( + metavar='INT', + dest='accepted_queue_timeout', + type=int, + help='Timeout in seconds for putting requests into queue', + ), +} + + +def main(): + """Create a new Cheroot instance with arguments from the command line.""" + parser = argparse.ArgumentParser( + description='Start an instance of the Cheroot WSGI/HTTP server.') + for arg, spec in _arg_spec.items(): + parser.add_argument(arg, **spec) + raw_args = parser.parse_args() + + # ensure cwd in sys.path + '' in sys.path or sys.path.insert(0, '') + + # create a server based on the arguments provided + raw_args._wsgi_app.server(raw_args).safe_start() diff --git a/libraries/cheroot/errors.py b/libraries/cheroot/errors.py new file mode 100644 index 00000000..82412b42 --- /dev/null +++ b/libraries/cheroot/errors.py @@ -0,0 +1,58 @@ +"""Collection of exceptions raised and/or processed by Cheroot.""" + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +import errno +import sys + + +class MaxSizeExceeded(Exception): + """Exception raised when a client sends more data then acceptable within limit. + + Depends on ``request.body.maxbytes`` config option if used within CherryPy + """ + + +class NoSSLError(Exception): + """Exception raised when a client speaks HTTP to an HTTPS socket.""" + + +class FatalSSLAlert(Exception): + """Exception raised when the SSL implementation signals a fatal alert.""" + + +def plat_specific_errors(*errnames): + """Return error numbers for all errors in errnames on this platform. + + The 'errno' module contains different global constants depending on + the specific platform (OS). This function will return the list of + numeric values for a given list of potential names. + """ + errno_names = dir(errno) + nums = [getattr(errno, k) for k in errnames if k in errno_names] + # de-dupe the list + return list(dict.fromkeys(nums).keys()) + + +socket_error_eintr = plat_specific_errors('EINTR', 'WSAEINTR') + +socket_errors_to_ignore = plat_specific_errors( + 'EPIPE', + 'EBADF', 'WSAEBADF', + 'ENOTSOCK', 'WSAENOTSOCK', + 'ETIMEDOUT', 'WSAETIMEDOUT', + 'ECONNREFUSED', 'WSAECONNREFUSED', + 'ECONNRESET', 'WSAECONNRESET', + 'ECONNABORTED', 'WSAECONNABORTED', + 'ENETRESET', 'WSAENETRESET', + 'EHOSTDOWN', 'EHOSTUNREACH', +) +socket_errors_to_ignore.append('timed out') +socket_errors_to_ignore.append('The read operation timed out') +socket_errors_nonblocking = plat_specific_errors( + 'EAGAIN', 'EWOULDBLOCK', 'WSAEWOULDBLOCK') + +if sys.platform == 'darwin': + socket_errors_to_ignore.extend(plat_specific_errors('EPROTOTYPE')) + socket_errors_nonblocking.extend(plat_specific_errors('EPROTOTYPE')) diff --git a/libraries/cheroot/makefile.py b/libraries/cheroot/makefile.py new file mode 100644 index 00000000..a76f2eda --- /dev/null +++ b/libraries/cheroot/makefile.py @@ -0,0 +1,387 @@ +"""Socket file object.""" + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +import socket + +try: + # prefer slower Python-based io module + import _pyio as io +except ImportError: + # Python 2.6 + import io + +import six + +from . import errors + + +class BufferedWriter(io.BufferedWriter): + """Faux file object attached to a socket object.""" + + def write(self, b): + """Write bytes to buffer.""" + self._checkClosed() + if isinstance(b, str): + raise TypeError("can't write str to binary stream") + + with self._write_lock: + self._write_buf.extend(b) + self._flush_unlocked() + return len(b) + + def _flush_unlocked(self): + self._checkClosed('flush of closed file') + while self._write_buf: + try: + # ssl sockets only except 'bytes', not bytearrays + # so perhaps we should conditionally wrap this for perf? + n = self.raw.write(bytes(self._write_buf)) + except io.BlockingIOError as e: + n = e.characters_written + del self._write_buf[:n] + + +def MakeFile_PY3(sock, mode='r', bufsize=io.DEFAULT_BUFFER_SIZE): + """File object attached to a socket object.""" + if 'r' in mode: + return io.BufferedReader(socket.SocketIO(sock, mode), bufsize) + else: + return BufferedWriter(socket.SocketIO(sock, mode), bufsize) + + +class MakeFile_PY2(getattr(socket, '_fileobject', object)): + """Faux file object attached to a socket object.""" + + def __init__(self, *args, **kwargs): + """Initialize faux file object.""" + self.bytes_read = 0 + self.bytes_written = 0 + socket._fileobject.__init__(self, *args, **kwargs) + + def write(self, data): + """Sendall for non-blocking sockets.""" + while data: + try: + bytes_sent = self.send(data) + data = data[bytes_sent:] + except socket.error as e: + if e.args[0] not in errors.socket_errors_nonblocking: + raise + + def send(self, data): + """Send some part of message to the socket.""" + bytes_sent = self._sock.send(data) + self.bytes_written += bytes_sent + return bytes_sent + + def flush(self): + """Write all data from buffer to socket and reset write buffer.""" + if self._wbuf: + buffer = ''.join(self._wbuf) + self._wbuf = [] + self.write(buffer) + + def recv(self, size): + """Receive message of a size from the socket.""" + while True: + try: + data = self._sock.recv(size) + self.bytes_read += len(data) + return data + except socket.error as e: + what = ( + e.args[0] not in errors.socket_errors_nonblocking + and e.args[0] not in errors.socket_error_eintr + ) + if what: + raise + + class FauxSocket: + """Faux socket with the minimal interface required by pypy.""" + + def _reuse(self): + pass + + _fileobject_uses_str_type = six.PY2 and isinstance( + socket._fileobject(FauxSocket())._rbuf, six.string_types) + + # FauxSocket is no longer needed + del FauxSocket + + if not _fileobject_uses_str_type: + def read(self, size=-1): + """Read data from the socket to buffer.""" + # Use max, disallow tiny reads in a loop as they are very + # inefficient. + # We never leave read() with any leftover data from a new recv() + # call in our internal buffer. + rbufsize = max(self._rbufsize, self.default_bufsize) + # Our use of StringIO rather than lists of string objects returned + # by recv() minimizes memory usage and fragmentation that occurs + # when rbufsize is large compared to the typical return value of + # recv(). + buf = self._rbuf + buf.seek(0, 2) # seek end + if size < 0: + # Read until EOF + # reset _rbuf. we consume it via buf. + self._rbuf = io.BytesIO() + while True: + data = self.recv(rbufsize) + if not data: + break + buf.write(data) + return buf.getvalue() + else: + # Read until size bytes or EOF seen, whichever comes first + buf_len = buf.tell() + if buf_len >= size: + # Already have size bytes in our buffer? Extract and + # return. + buf.seek(0) + rv = buf.read(size) + self._rbuf = io.BytesIO() + self._rbuf.write(buf.read()) + return rv + + # reset _rbuf. we consume it via buf. + self._rbuf = io.BytesIO() + while True: + left = size - buf_len + # recv() will malloc the amount of memory given as its + # parameter even though it often returns much less data + # than that. The returned data string is short lived + # as we copy it into a StringIO and free it. This avoids + # fragmentation issues on many platforms. + data = self.recv(left) + if not data: + break + n = len(data) + if n == size and not buf_len: + # Shortcut. Avoid buffer data copies when: + # - We have no data in our buffer. + # AND + # - Our call to recv returned exactly the + # number of bytes we were asked to read. + return data + if n == left: + buf.write(data) + del data # explicit free + break + assert n <= left, 'recv(%d) returned %d bytes' % (left, n) + buf.write(data) + buf_len += n + del data # explicit free + # assert buf_len == buf.tell() + return buf.getvalue() + + def readline(self, size=-1): + """Read line from the socket to buffer.""" + buf = self._rbuf + buf.seek(0, 2) # seek end + if buf.tell() > 0: + # check if we already have it in our buffer + buf.seek(0) + bline = buf.readline(size) + if bline.endswith('\n') or len(bline) == size: + self._rbuf = io.BytesIO() + self._rbuf.write(buf.read()) + return bline + del bline + if size < 0: + # Read until \n or EOF, whichever comes first + if self._rbufsize <= 1: + # Speed up unbuffered case + buf.seek(0) + buffers = [buf.read()] + # reset _rbuf. we consume it via buf. + self._rbuf = io.BytesIO() + data = None + recv = self.recv + while data != '\n': + data = recv(1) + if not data: + break + buffers.append(data) + return ''.join(buffers) + + buf.seek(0, 2) # seek end + # reset _rbuf. we consume it via buf. + self._rbuf = io.BytesIO() + while True: + data = self.recv(self._rbufsize) + if not data: + break + nl = data.find('\n') + if nl >= 0: + nl += 1 + buf.write(data[:nl]) + self._rbuf.write(data[nl:]) + del data + break + buf.write(data) + return buf.getvalue() + else: + # Read until size bytes or \n or EOF seen, whichever comes + # first + buf.seek(0, 2) # seek end + buf_len = buf.tell() + if buf_len >= size: + buf.seek(0) + rv = buf.read(size) + self._rbuf = io.BytesIO() + self._rbuf.write(buf.read()) + return rv + # reset _rbuf. we consume it via buf. + self._rbuf = io.BytesIO() + while True: + data = self.recv(self._rbufsize) + if not data: + break + left = size - buf_len + # did we just receive a newline? + nl = data.find('\n', 0, left) + if nl >= 0: + nl += 1 + # save the excess data to _rbuf + self._rbuf.write(data[nl:]) + if buf_len: + buf.write(data[:nl]) + break + else: + # Shortcut. Avoid data copy through buf when + # returning a substring of our first recv(). + return data[:nl] + n = len(data) + if n == size and not buf_len: + # Shortcut. Avoid data copy through buf when + # returning exactly all of our first recv(). + return data + if n >= left: + buf.write(data[:left]) + self._rbuf.write(data[left:]) + break + buf.write(data) + buf_len += n + # assert buf_len == buf.tell() + return buf.getvalue() + else: + def read(self, size=-1): + """Read data from the socket to buffer.""" + if size < 0: + # Read until EOF + buffers = [self._rbuf] + self._rbuf = '' + if self._rbufsize <= 1: + recv_size = self.default_bufsize + else: + recv_size = self._rbufsize + + while True: + data = self.recv(recv_size) + if not data: + break + buffers.append(data) + return ''.join(buffers) + else: + # Read until size bytes or EOF seen, whichever comes first + data = self._rbuf + buf_len = len(data) + if buf_len >= size: + self._rbuf = data[size:] + return data[:size] + buffers = [] + if data: + buffers.append(data) + self._rbuf = '' + while True: + left = size - buf_len + recv_size = max(self._rbufsize, left) + data = self.recv(recv_size) + if not data: + break + buffers.append(data) + n = len(data) + if n >= left: + self._rbuf = data[left:] + buffers[-1] = data[:left] + break + buf_len += n + return ''.join(buffers) + + def readline(self, size=-1): + """Read line from the socket to buffer.""" + data = self._rbuf + if size < 0: + # Read until \n or EOF, whichever comes first + if self._rbufsize <= 1: + # Speed up unbuffered case + assert data == '' + buffers = [] + while data != '\n': + data = self.recv(1) + if not data: + break + buffers.append(data) + return ''.join(buffers) + nl = data.find('\n') + if nl >= 0: + nl += 1 + self._rbuf = data[nl:] + return data[:nl] + buffers = [] + if data: + buffers.append(data) + self._rbuf = '' + while True: + data = self.recv(self._rbufsize) + if not data: + break + buffers.append(data) + nl = data.find('\n') + if nl >= 0: + nl += 1 + self._rbuf = data[nl:] + buffers[-1] = data[:nl] + break + return ''.join(buffers) + else: + # Read until size bytes or \n or EOF seen, whichever comes + # first + nl = data.find('\n', 0, size) + if nl >= 0: + nl += 1 + self._rbuf = data[nl:] + return data[:nl] + buf_len = len(data) + if buf_len >= size: + self._rbuf = data[size:] + return data[:size] + buffers = [] + if data: + buffers.append(data) + self._rbuf = '' + while True: + data = self.recv(self._rbufsize) + if not data: + break + buffers.append(data) + left = size - buf_len + nl = data.find('\n', 0, left) + if nl >= 0: + nl += 1 + self._rbuf = data[nl:] + buffers[-1] = data[:nl] + break + n = len(data) + if n >= left: + self._rbuf = data[left:] + buffers[-1] = data[:left] + break + buf_len += n + return ''.join(buffers) + + +MakeFile = MakeFile_PY2 if six.PY2 else MakeFile_PY3 diff --git a/libraries/cheroot/server.py b/libraries/cheroot/server.py new file mode 100644 index 00000000..44070490 --- /dev/null +++ b/libraries/cheroot/server.py @@ -0,0 +1,2001 @@ +""" +A high-speed, production ready, thread pooled, generic HTTP server. + +For those of you wanting to understand internals of this module, here's the +basic call flow. The server's listening thread runs a very tight loop, +sticking incoming connections onto a Queue:: + + server = HTTPServer(...) + server.start() + -> while True: + tick() + # This blocks until a request comes in: + child = socket.accept() + conn = HTTPConnection(child, ...) + server.requests.put(conn) + +Worker threads are kept in a pool and poll the Queue, popping off and then +handling each connection in turn. Each connection can consist of an arbitrary +number of requests and their responses, so we run a nested loop:: + + while True: + conn = server.requests.get() + conn.communicate() + -> while True: + req = HTTPRequest(...) + req.parse_request() + -> # Read the Request-Line, e.g. "GET /page HTTP/1.1" + req.rfile.readline() + read_headers(req.rfile, req.inheaders) + req.respond() + -> response = app(...) + try: + for chunk in response: + if chunk: + req.write(chunk) + finally: + if hasattr(response, "close"): + response.close() + if req.close_connection: + return + +For running a server you can invoke :func:`start() <HTTPServer.start()>` (it +will run the server forever) or use invoking :func:`prepare() +<HTTPServer.prepare()>` and :func:`serve() <HTTPServer.serve()>` like this:: + + server = HTTPServer(...) + server.prepare() + try: + threading.Thread(target=server.serve).start() + + # waiting/detecting some appropriate stop condition here + ... + + finally: + server.stop() + +And now for a trivial doctest to exercise the test suite + +>>> 'HTTPServer' in globals() +True + +""" + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +import os +import io +import re +import email.utils +import socket +import sys +import time +import traceback as traceback_ +import logging +import platform +import xbmc + +try: + from functools import lru_cache +except ImportError: + from backports.functools_lru_cache import lru_cache + +import six +from six.moves import queue +from six.moves import urllib + +from . import errors, __version__ +from ._compat import bton, ntou +from .workers import threadpool +from .makefile import MakeFile + + +__all__ = ('HTTPRequest', 'HTTPConnection', 'HTTPServer', + 'SizeCheckWrapper', 'KnownLengthRFile', 'ChunkedRFile', + 'Gateway', 'get_ssl_adapter_class') + +""" +Special KODI case: +Android does not have support for grp and pwd +But Python has issues reporting that this is running on Android (it shows as Linux2). +We're instead using xbmc library to detect that. +""" +IS_WINDOWS = platform.system() == 'Windows' +IS_ANDROID = xbmc.getCondVisibility('system.platform.linux') and xbmc.getCondVisibility('system.platform.android') + +if not (IS_WINDOWS or IS_ANDROID): + import grp + import pwd + import struct + + +if IS_WINDOWS and hasattr(socket, 'AF_INET6'): + if not hasattr(socket, 'IPPROTO_IPV6'): + socket.IPPROTO_IPV6 = 41 + if not hasattr(socket, 'IPV6_V6ONLY'): + socket.IPV6_V6ONLY = 27 + + +if not hasattr(socket, 'SO_PEERCRED'): + """ + NOTE: the value for SO_PEERCRED can be architecture specific, in + which case the getsockopt() will hopefully fail. The arch + specific value could be derived from platform.processor() + """ + socket.SO_PEERCRED = 17 + + +LF = b'\n' +CRLF = b'\r\n' +TAB = b'\t' +SPACE = b' ' +COLON = b':' +SEMICOLON = b';' +EMPTY = b'' +ASTERISK = b'*' +FORWARD_SLASH = b'/' +QUOTED_SLASH = b'%2F' +QUOTED_SLASH_REGEX = re.compile(b'(?i)' + QUOTED_SLASH) + + +comma_separated_headers = [ + b'Accept', b'Accept-Charset', b'Accept-Encoding', + b'Accept-Language', b'Accept-Ranges', b'Allow', b'Cache-Control', + b'Connection', b'Content-Encoding', b'Content-Language', b'Expect', + b'If-Match', b'If-None-Match', b'Pragma', b'Proxy-Authenticate', b'TE', + b'Trailer', b'Transfer-Encoding', b'Upgrade', b'Vary', b'Via', b'Warning', + b'WWW-Authenticate', +] + + +if not hasattr(logging, 'statistics'): + logging.statistics = {} + + +class HeaderReader: + """Object for reading headers from an HTTP request. + + Interface and default implementation. + """ + + def __call__(self, rfile, hdict=None): + """ + Read headers from the given stream into the given header dict. + + If hdict is None, a new header dict is created. Returns the populated + header dict. + + Headers which are repeated are folded together using a comma if their + specification so dictates. + + This function raises ValueError when the read bytes violate the HTTP + spec. + You should probably return "400 Bad Request" if this happens. + """ + if hdict is None: + hdict = {} + + while True: + line = rfile.readline() + if not line: + # No more data--illegal end of headers + raise ValueError('Illegal end of headers.') + + if line == CRLF: + # Normal end of headers + break + if not line.endswith(CRLF): + raise ValueError('HTTP requires CRLF terminators') + + if line[0] in (SPACE, TAB): + # It's a continuation line. + v = line.strip() + else: + try: + k, v = line.split(COLON, 1) + except ValueError: + raise ValueError('Illegal header line.') + v = v.strip() + k = self._transform_key(k) + hname = k + + if not self._allow_header(k): + continue + + if k in comma_separated_headers: + existing = hdict.get(hname) + if existing: + v = b', '.join((existing, v)) + hdict[hname] = v + + return hdict + + def _allow_header(self, key_name): + return True + + def _transform_key(self, key_name): + # TODO: what about TE and WWW-Authenticate? + return key_name.strip().title() + + +class DropUnderscoreHeaderReader(HeaderReader): + """Custom HeaderReader to exclude any headers with underscores in them.""" + + def _allow_header(self, key_name): + orig = super(DropUnderscoreHeaderReader, self)._allow_header(key_name) + return orig and '_' not in key_name + + +class SizeCheckWrapper: + """Wraps a file-like object, raising MaxSizeExceeded if too large.""" + + def __init__(self, rfile, maxlen): + """Initialize SizeCheckWrapper instance. + + Args: + rfile (file): file of a limited size + maxlen (int): maximum length of the file being read + """ + self.rfile = rfile + self.maxlen = maxlen + self.bytes_read = 0 + + def _check_length(self): + if self.maxlen and self.bytes_read > self.maxlen: + raise errors.MaxSizeExceeded() + + def read(self, size=None): + """Read a chunk from rfile buffer and return it. + + Args: + size (int): amount of data to read + + Returns: + bytes: Chunk from rfile, limited by size if specified. + + """ + data = self.rfile.read(size) + self.bytes_read += len(data) + self._check_length() + return data + + def readline(self, size=None): + """Read a single line from rfile buffer and return it. + + Args: + size (int): minimum amount of data to read + + Returns: + bytes: One line from rfile. + + """ + if size is not None: + data = self.rfile.readline(size) + self.bytes_read += len(data) + self._check_length() + return data + + # User didn't specify a size ... + # We read the line in chunks to make sure it's not a 100MB line ! + res = [] + while True: + data = self.rfile.readline(256) + self.bytes_read += len(data) + self._check_length() + res.append(data) + # See https://github.com/cherrypy/cherrypy/issues/421 + if len(data) < 256 or data[-1:] == LF: + return EMPTY.join(res) + + def readlines(self, sizehint=0): + """Read all lines from rfile buffer and return them. + + Args: + sizehint (int): hint of minimum amount of data to read + + Returns: + list[bytes]: Lines of bytes read from rfile. + + """ + # Shamelessly stolen from StringIO + total = 0 + lines = [] + line = self.readline(sizehint) + while line: + lines.append(line) + total += len(line) + if 0 < sizehint <= total: + break + line = self.readline(sizehint) + return lines + + def close(self): + """Release resources allocated for rfile.""" + self.rfile.close() + + def __iter__(self): + """Return file iterator.""" + return self + + def __next__(self): + """Generate next file chunk.""" + data = next(self.rfile) + self.bytes_read += len(data) + self._check_length() + return data + + next = __next__ + + +class KnownLengthRFile: + """Wraps a file-like object, returning an empty string when exhausted.""" + + def __init__(self, rfile, content_length): + """Initialize KnownLengthRFile instance. + + Args: + rfile (file): file of a known size + content_length (int): length of the file being read + + """ + self.rfile = rfile + self.remaining = content_length + + def read(self, size=None): + """Read a chunk from rfile buffer and return it. + + Args: + size (int): amount of data to read + + Returns: + bytes: Chunk from rfile, limited by size if specified. + + """ + if self.remaining == 0: + return b'' + if size is None: + size = self.remaining + else: + size = min(size, self.remaining) + + data = self.rfile.read(size) + self.remaining -= len(data) + return data + + def readline(self, size=None): + """Read a single line from rfile buffer and return it. + + Args: + size (int): minimum amount of data to read + + Returns: + bytes: One line from rfile. + + """ + if self.remaining == 0: + return b'' + if size is None: + size = self.remaining + else: + size = min(size, self.remaining) + + data = self.rfile.readline(size) + self.remaining -= len(data) + return data + + def readlines(self, sizehint=0): + """Read all lines from rfile buffer and return them. + + Args: + sizehint (int): hint of minimum amount of data to read + + Returns: + list[bytes]: Lines of bytes read from rfile. + + """ + # Shamelessly stolen from StringIO + total = 0 + lines = [] + line = self.readline(sizehint) + while line: + lines.append(line) + total += len(line) + if 0 < sizehint <= total: + break + line = self.readline(sizehint) + return lines + + def close(self): + """Release resources allocated for rfile.""" + self.rfile.close() + + def __iter__(self): + """Return file iterator.""" + return self + + def __next__(self): + """Generate next file chunk.""" + data = next(self.rfile) + self.remaining -= len(data) + return data + + next = __next__ + + +class ChunkedRFile: + """Wraps a file-like object, returning an empty string when exhausted. + + This class is intended to provide a conforming wsgi.input value for + request entities that have been encoded with the 'chunked' transfer + encoding. + """ + + def __init__(self, rfile, maxlen, bufsize=8192): + """Initialize ChunkedRFile instance. + + Args: + rfile (file): file encoded with the 'chunked' transfer encoding + maxlen (int): maximum length of the file being read + bufsize (int): size of the buffer used to read the file + """ + self.rfile = rfile + self.maxlen = maxlen + self.bytes_read = 0 + self.buffer = EMPTY + self.bufsize = bufsize + self.closed = False + + def _fetch(self): + if self.closed: + return + + line = self.rfile.readline() + self.bytes_read += len(line) + + if self.maxlen and self.bytes_read > self.maxlen: + raise errors.MaxSizeExceeded( + 'Request Entity Too Large', self.maxlen) + + line = line.strip().split(SEMICOLON, 1) + + try: + chunk_size = line.pop(0) + chunk_size = int(chunk_size, 16) + except ValueError: + raise ValueError('Bad chunked transfer size: ' + repr(chunk_size)) + + if chunk_size <= 0: + self.closed = True + return + +# if line: chunk_extension = line[0] + + if self.maxlen and self.bytes_read + chunk_size > self.maxlen: + raise IOError('Request Entity Too Large') + + chunk = self.rfile.read(chunk_size) + self.bytes_read += len(chunk) + self.buffer += chunk + + crlf = self.rfile.read(2) + if crlf != CRLF: + raise ValueError( + "Bad chunked transfer coding (expected '\\r\\n', " + 'got ' + repr(crlf) + ')') + + def read(self, size=None): + """Read a chunk from rfile buffer and return it. + + Args: + size (int): amount of data to read + + Returns: + bytes: Chunk from rfile, limited by size if specified. + + """ + data = EMPTY + + if size == 0: + return data + + while True: + if size and len(data) >= size: + return data + + if not self.buffer: + self._fetch() + if not self.buffer: + # EOF + return data + + if size: + remaining = size - len(data) + data += self.buffer[:remaining] + self.buffer = self.buffer[remaining:] + else: + data += self.buffer + self.buffer = EMPTY + + def readline(self, size=None): + """Read a single line from rfile buffer and return it. + + Args: + size (int): minimum amount of data to read + + Returns: + bytes: One line from rfile. + + """ + data = EMPTY + + if size == 0: + return data + + while True: + if size and len(data) >= size: + return data + + if not self.buffer: + self._fetch() + if not self.buffer: + # EOF + return data + + newline_pos = self.buffer.find(LF) + if size: + if newline_pos == -1: + remaining = size - len(data) + data += self.buffer[:remaining] + self.buffer = self.buffer[remaining:] + else: + remaining = min(size - len(data), newline_pos) + data += self.buffer[:remaining] + self.buffer = self.buffer[remaining:] + else: + if newline_pos == -1: + data += self.buffer + self.buffer = EMPTY + else: + data += self.buffer[:newline_pos] + self.buffer = self.buffer[newline_pos:] + + def readlines(self, sizehint=0): + """Read all lines from rfile buffer and return them. + + Args: + sizehint (int): hint of minimum amount of data to read + + Returns: + list[bytes]: Lines of bytes read from rfile. + + """ + # Shamelessly stolen from StringIO + total = 0 + lines = [] + line = self.readline(sizehint) + while line: + lines.append(line) + total += len(line) + if 0 < sizehint <= total: + break + line = self.readline(sizehint) + return lines + + def read_trailer_lines(self): + """Read HTTP headers and yield them. + + Returns: + Generator: yields CRLF separated lines. + + """ + if not self.closed: + raise ValueError( + 'Cannot read trailers until the request body has been read.') + + while True: + line = self.rfile.readline() + if not line: + # No more data--illegal end of headers + raise ValueError('Illegal end of headers.') + + self.bytes_read += len(line) + if self.maxlen and self.bytes_read > self.maxlen: + raise IOError('Request Entity Too Large') + + if line == CRLF: + # Normal end of headers + break + if not line.endswith(CRLF): + raise ValueError('HTTP requires CRLF terminators') + + yield line + + def close(self): + """Release resources allocated for rfile.""" + self.rfile.close() + + +class HTTPRequest: + """An HTTP Request (and response). + + A single HTTP connection may consist of multiple request/response pairs. + """ + + server = None + """The HTTPServer object which is receiving this request.""" + + conn = None + """The HTTPConnection object on which this request connected.""" + + inheaders = {} + """A dict of request headers.""" + + outheaders = [] + """A list of header tuples to write in the response.""" + + ready = False + """When True, the request has been parsed and is ready to begin generating + the response. When False, signals the calling Connection that the response + should not be generated and the connection should close.""" + + close_connection = False + """Signals the calling Connection that the request should close. This does + not imply an error! The client and/or server may each request that the + connection be closed.""" + + chunked_write = False + """If True, output will be encoded with the "chunked" transfer-coding. + + This value is set automatically inside send_headers.""" + + header_reader = HeaderReader() + """ + A HeaderReader instance or compatible reader. + """ + + def __init__(self, server, conn, proxy_mode=False, strict_mode=True): + """Initialize HTTP request container instance. + + Args: + server (HTTPServer): web server object receiving this request + conn (HTTPConnection): HTTP connection object for this request + proxy_mode (bool): whether this HTTPServer should behave as a PROXY + server for certain requests + strict_mode (bool): whether we should return a 400 Bad Request when + we encounter a request that a HTTP compliant client should not be + making + """ + self.server = server + self.conn = conn + + self.ready = False + self.started_request = False + self.scheme = b'http' + if self.server.ssl_adapter is not None: + self.scheme = b'https' + # Use the lowest-common protocol in case read_request_line errors. + self.response_protocol = 'HTTP/1.0' + self.inheaders = {} + + self.status = '' + self.outheaders = [] + self.sent_headers = False + self.close_connection = self.__class__.close_connection + self.chunked_read = False + self.chunked_write = self.__class__.chunked_write + self.proxy_mode = proxy_mode + self.strict_mode = strict_mode + + def parse_request(self): + """Parse the next HTTP request start-line and message-headers.""" + self.rfile = SizeCheckWrapper(self.conn.rfile, + self.server.max_request_header_size) + try: + success = self.read_request_line() + except errors.MaxSizeExceeded: + self.simple_response( + '414 Request-URI Too Long', + 'The Request-URI sent with the request exceeds the maximum ' + 'allowed bytes.') + return + else: + if not success: + return + + try: + success = self.read_request_headers() + except errors.MaxSizeExceeded: + self.simple_response( + '413 Request Entity Too Large', + 'The headers sent with the request exceed the maximum ' + 'allowed bytes.') + return + else: + if not success: + return + + self.ready = True + + def read_request_line(self): + """Read and parse first line of the HTTP request. + + Returns: + bool: True if the request line is valid or False if it's malformed. + + """ + # HTTP/1.1 connections are persistent by default. If a client + # requests a page, then idles (leaves the connection open), + # then rfile.readline() will raise socket.error("timed out"). + # Note that it does this based on the value given to settimeout(), + # and doesn't need the client to request or acknowledge the close + # (although your TCP stack might suffer for it: cf Apache's history + # with FIN_WAIT_2). + request_line = self.rfile.readline() + + # Set started_request to True so communicate() knows to send 408 + # from here on out. + self.started_request = True + if not request_line: + return False + + if request_line == CRLF: + # RFC 2616 sec 4.1: "...if the server is reading the protocol + # stream at the beginning of a message and receives a CRLF + # first, it should ignore the CRLF." + # But only ignore one leading line! else we enable a DoS. + request_line = self.rfile.readline() + if not request_line: + return False + + if not request_line.endswith(CRLF): + self.simple_response( + '400 Bad Request', 'HTTP requires CRLF terminators') + return False + + try: + method, uri, req_protocol = request_line.strip().split(SPACE, 2) + if not req_protocol.startswith(b'HTTP/'): + self.simple_response( + '400 Bad Request', 'Malformed Request-Line: bad protocol' + ) + return False + rp = req_protocol[5:].split(b'.', 1) + rp = tuple(map(int, rp)) # Minor.Major must be threat as integers + if rp > (1, 1): + self.simple_response( + '505 HTTP Version Not Supported', 'Cannot fulfill request' + ) + return False + except (ValueError, IndexError): + self.simple_response('400 Bad Request', 'Malformed Request-Line') + return False + + self.uri = uri + self.method = method.upper() + + if self.strict_mode and method != self.method: + resp = ( + 'Malformed method name: According to RFC 2616 ' + '(section 5.1.1) and its successors ' + 'RFC 7230 (section 3.1.1) and RFC 7231 (section 4.1) ' + 'method names are case-sensitive and uppercase.' + ) + self.simple_response('400 Bad Request', resp) + return False + + try: + if six.PY2: # FIXME: Figure out better way to do this + # Ref: https://stackoverflow.com/a/196392/595220 (like this?) + """This is a dummy check for unicode in URI.""" + ntou(bton(uri, 'ascii'), 'ascii') + scheme, authority, path, qs, fragment = urllib.parse.urlsplit(uri) + except UnicodeError: + self.simple_response('400 Bad Request', 'Malformed Request-URI') + return False + + if self.method == b'OPTIONS': + # TODO: cover this branch with tests + path = (uri + # https://tools.ietf.org/html/rfc7230#section-5.3.4 + if self.proxy_mode or uri == ASTERISK + else path) + elif self.method == b'CONNECT': + # TODO: cover this branch with tests + if not self.proxy_mode: + self.simple_response('405 Method Not Allowed') + return False + + # `urlsplit()` above parses "example.com:3128" as path part of URI. + # this is a workaround, which makes it detect netloc correctly + uri_split = urllib.parse.urlsplit(b'//' + uri) + _scheme, _authority, _path, _qs, _fragment = uri_split + _port = EMPTY + try: + _port = uri_split.port + except ValueError: + pass + + # FIXME: use third-party validation to make checks against RFC + # the validation doesn't take into account, that urllib parses + # invalid URIs without raising errors + # https://tools.ietf.org/html/rfc7230#section-5.3.3 + invalid_path = ( + _authority != uri + or not _port + or any((_scheme, _path, _qs, _fragment)) + ) + if invalid_path: + self.simple_response('400 Bad Request', + 'Invalid path in Request-URI: request-' + 'target must match authority-form.') + return False + + authority = path = _authority + scheme = qs = fragment = EMPTY + else: + uri_is_absolute_form = (scheme or authority) + + disallowed_absolute = ( + self.strict_mode + and not self.proxy_mode + and uri_is_absolute_form + ) + if disallowed_absolute: + # https://tools.ietf.org/html/rfc7230#section-5.3.2 + # (absolute form) + """Absolute URI is only allowed within proxies.""" + self.simple_response( + '400 Bad Request', + 'Absolute URI not allowed if server is not a proxy.', + ) + return False + + invalid_path = ( + self.strict_mode + and not uri.startswith(FORWARD_SLASH) + and not uri_is_absolute_form + ) + if invalid_path: + # https://tools.ietf.org/html/rfc7230#section-5.3.1 + # (origin_form) and + """Path should start with a forward slash.""" + resp = ( + 'Invalid path in Request-URI: request-target must contain ' + 'origin-form which starts with absolute-path (URI ' + 'starting with a slash "/").' + ) + self.simple_response('400 Bad Request', resp) + return False + + if fragment: + self.simple_response('400 Bad Request', + 'Illegal #fragment in Request-URI.') + return False + + if path is None: + # FIXME: It looks like this case cannot happen + self.simple_response('400 Bad Request', + 'Invalid path in Request-URI.') + return False + + # Unquote the path+params (e.g. "/this%20path" -> "/this path"). + # https://www.w3.org/Protocols/rfc2616/rfc2616-sec5.html#sec5.1.2 + # + # But note that "...a URI must be separated into its components + # before the escaped characters within those components can be + # safely decoded." https://www.ietf.org/rfc/rfc2396.txt, sec 2.4.2 + # Therefore, "/this%2Fpath" becomes "/this%2Fpath", not + # "/this/path". + try: + # TODO: Figure out whether exception can really happen here. + # It looks like it's caught on urlsplit() call above. + atoms = [ + urllib.parse.unquote_to_bytes(x) + for x in QUOTED_SLASH_REGEX.split(path) + ] + except ValueError as ex: + self.simple_response('400 Bad Request', ex.args[0]) + return False + path = QUOTED_SLASH.join(atoms) + + if not path.startswith(FORWARD_SLASH): + path = FORWARD_SLASH + path + + if scheme is not EMPTY: + self.scheme = scheme + self.authority = authority + self.path = path + + # Note that, like wsgiref and most other HTTP servers, + # we "% HEX HEX"-unquote the path but not the query string. + self.qs = qs + + # Compare request and server HTTP protocol versions, in case our + # server does not support the requested protocol. Limit our output + # to min(req, server). We want the following output: + # request server actual written supported response + # protocol protocol response protocol feature set + # a 1.0 1.0 1.0 1.0 + # b 1.0 1.1 1.1 1.0 + # c 1.1 1.0 1.0 1.0 + # d 1.1 1.1 1.1 1.1 + # Notice that, in (b), the response will be "HTTP/1.1" even though + # the client only understands 1.0. RFC 2616 10.5.6 says we should + # only return 505 if the _major_ version is different. + sp = int(self.server.protocol[5]), int(self.server.protocol[7]) + + if sp[0] != rp[0]: + self.simple_response('505 HTTP Version Not Supported') + return False + + self.request_protocol = req_protocol + self.response_protocol = 'HTTP/%s.%s' % min(rp, sp) + + return True + + def read_request_headers(self): + """Read self.rfile into self.inheaders. Return success.""" + # then all the http headers + try: + self.header_reader(self.rfile, self.inheaders) + except ValueError as ex: + self.simple_response('400 Bad Request', ex.args[0]) + return False + + mrbs = self.server.max_request_body_size + + try: + cl = int(self.inheaders.get(b'Content-Length', 0)) + except ValueError: + self.simple_response( + '400 Bad Request', + 'Malformed Content-Length Header.') + return False + + if mrbs and cl > mrbs: + self.simple_response( + '413 Request Entity Too Large', + 'The entity sent with the request exceeds the maximum ' + 'allowed bytes.') + return False + + # Persistent connection support + if self.response_protocol == 'HTTP/1.1': + # Both server and client are HTTP/1.1 + if self.inheaders.get(b'Connection', b'') == b'close': + self.close_connection = True + else: + # Either the server or client (or both) are HTTP/1.0 + if self.inheaders.get(b'Connection', b'') != b'Keep-Alive': + self.close_connection = True + + # Transfer-Encoding support + te = None + if self.response_protocol == 'HTTP/1.1': + te = self.inheaders.get(b'Transfer-Encoding') + if te: + te = [x.strip().lower() for x in te.split(b',') if x.strip()] + + self.chunked_read = False + + if te: + for enc in te: + if enc == b'chunked': + self.chunked_read = True + else: + # Note that, even if we see "chunked", we must reject + # if there is an extension we don't recognize. + self.simple_response('501 Unimplemented') + self.close_connection = True + return False + + # From PEP 333: + # "Servers and gateways that implement HTTP 1.1 must provide + # transparent support for HTTP 1.1's "expect/continue" mechanism. + # This may be done in any of several ways: + # 1. Respond to requests containing an Expect: 100-continue request + # with an immediate "100 Continue" response, and proceed normally. + # 2. Proceed with the request normally, but provide the application + # with a wsgi.input stream that will send the "100 Continue" + # response if/when the application first attempts to read from + # the input stream. The read request must then remain blocked + # until the client responds. + # 3. Wait until the client decides that the server does not support + # expect/continue, and sends the request body on its own. + # (This is suboptimal, and is not recommended.) + # + # We used to do 3, but are now doing 1. Maybe we'll do 2 someday, + # but it seems like it would be a big slowdown for such a rare case. + if self.inheaders.get(b'Expect', b'') == b'100-continue': + # Don't use simple_response here, because it emits headers + # we don't want. See + # https://github.com/cherrypy/cherrypy/issues/951 + msg = self.server.protocol.encode('ascii') + msg += b' 100 Continue\r\n\r\n' + try: + self.conn.wfile.write(msg) + except socket.error as ex: + if ex.args[0] not in errors.socket_errors_to_ignore: + raise + return True + + def respond(self): + """Call the gateway and write its iterable output.""" + mrbs = self.server.max_request_body_size + if self.chunked_read: + self.rfile = ChunkedRFile(self.conn.rfile, mrbs) + else: + cl = int(self.inheaders.get(b'Content-Length', 0)) + if mrbs and mrbs < cl: + if not self.sent_headers: + self.simple_response( + '413 Request Entity Too Large', + 'The entity sent with the request exceeds the ' + 'maximum allowed bytes.') + return + self.rfile = KnownLengthRFile(self.conn.rfile, cl) + + self.server.gateway(self).respond() + self.ready and self.ensure_headers_sent() + + if self.chunked_write: + self.conn.wfile.write(b'0\r\n\r\n') + + def simple_response(self, status, msg=''): + """Write a simple response back to the client.""" + status = str(status) + proto_status = '%s %s\r\n' % (self.server.protocol, status) + content_length = 'Content-Length: %s\r\n' % len(msg) + content_type = 'Content-Type: text/plain\r\n' + buf = [ + proto_status.encode('ISO-8859-1'), + content_length.encode('ISO-8859-1'), + content_type.encode('ISO-8859-1'), + ] + + if status[:3] in ('413', '414'): + # Request Entity Too Large / Request-URI Too Long + self.close_connection = True + if self.response_protocol == 'HTTP/1.1': + # This will not be true for 414, since read_request_line + # usually raises 414 before reading the whole line, and we + # therefore cannot know the proper response_protocol. + buf.append(b'Connection: close\r\n') + else: + # HTTP/1.0 had no 413/414 status nor Connection header. + # Emit 400 instead and trust the message body is enough. + status = '400 Bad Request' + + buf.append(CRLF) + if msg: + if isinstance(msg, six.text_type): + msg = msg.encode('ISO-8859-1') + buf.append(msg) + + try: + self.conn.wfile.write(EMPTY.join(buf)) + except socket.error as ex: + if ex.args[0] not in errors.socket_errors_to_ignore: + raise + + def ensure_headers_sent(self): + """Ensure headers are sent to the client if not already sent.""" + if not self.sent_headers: + self.sent_headers = True + self.send_headers() + + def write(self, chunk): + """Write unbuffered data to the client.""" + if self.chunked_write and chunk: + chunk_size_hex = hex(len(chunk))[2:].encode('ascii') + buf = [chunk_size_hex, CRLF, chunk, CRLF] + self.conn.wfile.write(EMPTY.join(buf)) + else: + self.conn.wfile.write(chunk) + + def send_headers(self): + """Assert, process, and send the HTTP response message-headers. + + You must set self.status, and self.outheaders before calling this. + """ + hkeys = [key.lower() for key, value in self.outheaders] + status = int(self.status[:3]) + + if status == 413: + # Request Entity Too Large. Close conn to avoid garbage. + self.close_connection = True + elif b'content-length' not in hkeys: + # "All 1xx (informational), 204 (no content), + # and 304 (not modified) responses MUST NOT + # include a message-body." So no point chunking. + if status < 200 or status in (204, 205, 304): + pass + else: + needs_chunked = ( + self.response_protocol == 'HTTP/1.1' + and self.method != b'HEAD' + ) + if needs_chunked: + # Use the chunked transfer-coding + self.chunked_write = True + self.outheaders.append((b'Transfer-Encoding', b'chunked')) + else: + # Closing the conn is the only way to determine len. + self.close_connection = True + + if b'connection' not in hkeys: + if self.response_protocol == 'HTTP/1.1': + # Both server and client are HTTP/1.1 or better + if self.close_connection: + self.outheaders.append((b'Connection', b'close')) + else: + # Server and/or client are HTTP/1.0 + if not self.close_connection: + self.outheaders.append((b'Connection', b'Keep-Alive')) + + if (not self.close_connection) and (not self.chunked_read): + # Read any remaining request body data on the socket. + # "If an origin server receives a request that does not include an + # Expect request-header field with the "100-continue" expectation, + # the request includes a request body, and the server responds + # with a final status code before reading the entire request body + # from the transport connection, then the server SHOULD NOT close + # the transport connection until it has read the entire request, + # or until the client closes the connection. Otherwise, the client + # might not reliably receive the response message. However, this + # requirement is not be construed as preventing a server from + # defending itself against denial-of-service attacks, or from + # badly broken client implementations." + remaining = getattr(self.rfile, 'remaining', 0) + if remaining > 0: + self.rfile.read(remaining) + + if b'date' not in hkeys: + self.outheaders.append(( + b'Date', + email.utils.formatdate(usegmt=True).encode('ISO-8859-1'), + )) + + if b'server' not in hkeys: + self.outheaders.append(( + b'Server', + self.server.server_name.encode('ISO-8859-1'), + )) + + proto = self.server.protocol.encode('ascii') + buf = [proto + SPACE + self.status + CRLF] + for k, v in self.outheaders: + buf.append(k + COLON + SPACE + v + CRLF) + buf.append(CRLF) + self.conn.wfile.write(EMPTY.join(buf)) + + +class HTTPConnection: + """An HTTP connection (active socket).""" + + remote_addr = None + remote_port = None + ssl_env = None + rbufsize = io.DEFAULT_BUFFER_SIZE + wbufsize = io.DEFAULT_BUFFER_SIZE + RequestHandlerClass = HTTPRequest + peercreds_enabled = False + peercreds_resolve_enabled = False + + def __init__(self, server, sock, makefile=MakeFile): + """Initialize HTTPConnection instance. + + Args: + server (HTTPServer): web server object receiving this request + socket (socket._socketobject): the raw socket object (usually + TCP) for this connection + makefile (file): a fileobject class for reading from the socket + """ + self.server = server + self.socket = sock + self.rfile = makefile(sock, 'rb', self.rbufsize) + self.wfile = makefile(sock, 'wb', self.wbufsize) + self.requests_seen = 0 + + self.peercreds_enabled = self.server.peercreds_enabled + self.peercreds_resolve_enabled = self.server.peercreds_resolve_enabled + + # LRU cached methods: + # Ref: https://stackoverflow.com/a/14946506/595220 + self.resolve_peer_creds = ( + lru_cache(maxsize=1)(self.resolve_peer_creds) + ) + self.get_peer_creds = ( + lru_cache(maxsize=1)(self.get_peer_creds) + ) + + def communicate(self): + """Read each request and respond appropriately.""" + request_seen = False + try: + while True: + # (re)set req to None so that if something goes wrong in + # the RequestHandlerClass constructor, the error doesn't + # get written to the previous request. + req = None + req = self.RequestHandlerClass(self.server, self) + + # This order of operations should guarantee correct pipelining. + req.parse_request() + if self.server.stats['Enabled']: + self.requests_seen += 1 + if not req.ready: + # Something went wrong in the parsing (and the server has + # probably already made a simple_response). Return and + # let the conn close. + return + + request_seen = True + req.respond() + if req.close_connection: + return + except socket.error as ex: + errnum = ex.args[0] + # sadly SSL sockets return a different (longer) time out string + timeout_errs = 'timed out', 'The read operation timed out' + if errnum in timeout_errs: + # Don't error if we're between requests; only error + # if 1) no request has been started at all, or 2) we're + # in the middle of a request. + # See https://github.com/cherrypy/cherrypy/issues/853 + if (not request_seen) or (req and req.started_request): + self._conditional_error(req, '408 Request Timeout') + elif errnum not in errors.socket_errors_to_ignore: + self.server.error_log('socket.error %s' % repr(errnum), + level=logging.WARNING, traceback=True) + self._conditional_error(req, '500 Internal Server Error') + except (KeyboardInterrupt, SystemExit): + raise + except errors.FatalSSLAlert: + pass + except errors.NoSSLError: + self._handle_no_ssl(req) + except Exception as ex: + self.server.error_log( + repr(ex), level=logging.ERROR, traceback=True) + self._conditional_error(req, '500 Internal Server Error') + + linger = False + + def _handle_no_ssl(self, req): + if not req or req.sent_headers: + return + # Unwrap wfile + self.wfile = MakeFile(self.socket._sock, 'wb', self.wbufsize) + msg = ( + 'The client sent a plain HTTP request, but ' + 'this server only speaks HTTPS on this port.' + ) + req.simple_response('400 Bad Request', msg) + self.linger = True + + def _conditional_error(self, req, response): + """Respond with an error. + + Don't bother writing if a response + has already started being written. + """ + if not req or req.sent_headers: + return + + try: + req.simple_response(response) + except errors.FatalSSLAlert: + pass + except errors.NoSSLError: + self._handle_no_ssl(req) + + def close(self): + """Close the socket underlying this connection.""" + self.rfile.close() + + if not self.linger: + self._close_kernel_socket() + self.socket.close() + else: + # On the other hand, sometimes we want to hang around for a bit + # to make sure the client has a chance to read our entire + # response. Skipping the close() calls here delays the FIN + # packet until the socket object is garbage-collected later. + # Someday, perhaps, we'll do the full lingering_close that + # Apache does, but not today. + pass + + def get_peer_creds(self): # LRU cached on per-instance basis, see __init__ + """Return the PID/UID/GID tuple of the peer socket for UNIX sockets. + + This function uses SO_PEERCRED to query the UNIX PID, UID, GID + of the peer, which is only available if the bind address is + a UNIX domain socket. + + Raises: + NotImplementedError: in case of unsupported socket type + RuntimeError: in case of SO_PEERCRED lookup unsupported or disabled + + """ + PEERCRED_STRUCT_DEF = '3i' + + if IS_WINDOWS or self.socket.family != socket.AF_UNIX: + raise NotImplementedError( + 'SO_PEERCRED is only supported in Linux kernel and WSL' + ) + elif not self.peercreds_enabled: + raise RuntimeError( + 'Peer creds lookup is disabled within this server' + ) + + try: + peer_creds = self.socket.getsockopt( + socket.SOL_SOCKET, socket.SO_PEERCRED, + struct.calcsize(PEERCRED_STRUCT_DEF) + ) + except socket.error as socket_err: + """Non-Linux kernels don't support SO_PEERCRED. + + Refs: + http://welz.org.za/notes/on-peer-cred.html + https://github.com/daveti/tcpSockHack + msdn.microsoft.com/en-us/commandline/wsl/release_notes#build-15025 + """ + six.raise_from( # 3.6+: raise RuntimeError from socket_err + RuntimeError, + socket_err, + ) + else: + pid, uid, gid = struct.unpack(PEERCRED_STRUCT_DEF, peer_creds) + return pid, uid, gid + + @property + def peer_pid(self): + """Return the id of the connected peer process.""" + pid, _, _ = self.get_peer_creds() + return pid + + @property + def peer_uid(self): + """Return the user id of the connected peer process.""" + _, uid, _ = self.get_peer_creds() + return uid + + @property + def peer_gid(self): + """Return the group id of the connected peer process.""" + _, _, gid = self.get_peer_creds() + return gid + + def resolve_peer_creds(self): # LRU cached on per-instance basis + """Return the username and group tuple of the peercreds if available. + + Raises: + NotImplementedError: in case of unsupported OS + RuntimeError: in case of UID/GID lookup unsupported or disabled + + """ + if (IS_WINDOWS or IS_ANDROID): + raise NotImplementedError( + 'UID/GID lookup can only be done under UNIX-like OS' + ) + elif not self.peercreds_resolve_enabled: + raise RuntimeError( + 'UID/GID lookup is disabled within this server' + ) + + user = pwd.getpwuid(self.peer_uid).pw_name # [0] + group = grp.getgrgid(self.peer_gid).gr_name # [0] + + return user, group + + @property + def peer_user(self): + """Return the username of the connected peer process.""" + user, _ = self.resolve_peer_creds() + return user + + @property + def peer_group(self): + """Return the group of the connected peer process.""" + _, group = self.resolve_peer_creds() + return group + + def _close_kernel_socket(self): + """Close kernel socket in outdated Python versions. + + On old Python versions, + Python's socket module does NOT call close on the kernel + socket when you call socket.close(). We do so manually here + because we want this server to send a FIN TCP segment + immediately. Note this must be called *before* calling + socket.close(), because the latter drops its reference to + the kernel socket. + """ + if six.PY2 and hasattr(self.socket, '_sock'): + self.socket._sock.close() + + +def prevent_socket_inheritance(sock): + """Stub inheritance prevention. + + Dummy function, since neither fcntl nor ctypes are available. + """ + pass + + +class HTTPServer: + """An HTTP server.""" + + _bind_addr = '127.0.0.1' + _interrupt = None + + gateway = None + """A Gateway instance.""" + + minthreads = None + """The minimum number of worker threads to create (default 10).""" + + maxthreads = None + """The maximum number of worker threads to create. + + (default -1 = no limit)""" + + server_name = None + """The name of the server; defaults to ``self.version``.""" + + protocol = 'HTTP/1.1' + """The version string to write in the Status-Line of all HTTP responses. + + For example, "HTTP/1.1" is the default. This also limits the supported + features used in the response.""" + + request_queue_size = 5 + """The 'backlog' arg to socket.listen(); max queued connections. + + (default 5).""" + + shutdown_timeout = 5 + """The total time to wait for worker threads to cleanly exit. + + Specified in seconds.""" + + timeout = 10 + """The timeout in seconds for accepted connections (default 10).""" + + version = 'Cheroot/' + __version__ + """A version string for the HTTPServer.""" + + software = None + """The value to set for the SERVER_SOFTWARE entry in the WSGI environ. + + If None, this defaults to ``'%s Server' % self.version``. + """ + + ready = False + """Internal flag which indicating the socket is accepting connections.""" + + max_request_header_size = 0 + """The maximum size, in bytes, for request headers, or 0 for no limit.""" + + max_request_body_size = 0 + """The maximum size, in bytes, for request bodies, or 0 for no limit.""" + + nodelay = True + """If True (the default since 3.1), sets the TCP_NODELAY socket option.""" + + ConnectionClass = HTTPConnection + """The class to use for handling HTTP connections.""" + + ssl_adapter = None + """An instance of ssl.Adapter (or a subclass). + + You must have the corresponding SSL driver library installed. + """ + + peercreds_enabled = False + """If True, peer cred lookup can be performed via UNIX domain socket.""" + + peercreds_resolve_enabled = False + """If True, username/group will be looked up in the OS from peercreds.""" + + def __init__( + self, bind_addr, gateway, + minthreads=10, maxthreads=-1, server_name=None, + peercreds_enabled=False, peercreds_resolve_enabled=False, + ): + """Initialize HTTPServer instance. + + Args: + bind_addr (tuple): network interface to listen to + gateway (Gateway): gateway for processing HTTP requests + minthreads (int): minimum number of threads for HTTP thread pool + maxthreads (int): maximum number of threads for HTTP thread pool + server_name (str): web server name to be advertised via Server + HTTP header + """ + self.bind_addr = bind_addr + self.gateway = gateway + + self.requests = threadpool.ThreadPool( + self, min=minthreads or 1, max=maxthreads) + + if not server_name: + server_name = self.version + self.server_name = server_name + self.peercreds_enabled = peercreds_enabled + self.peercreds_resolve_enabled = ( + peercreds_resolve_enabled and peercreds_enabled + ) + self.clear_stats() + + def clear_stats(self): + """Reset server stat counters..""" + self._start_time = None + self._run_time = 0 + self.stats = { + 'Enabled': False, + 'Bind Address': lambda s: repr(self.bind_addr), + 'Run time': lambda s: (not s['Enabled']) and -1 or self.runtime(), + 'Accepts': 0, + 'Accepts/sec': lambda s: s['Accepts'] / self.runtime(), + 'Queue': lambda s: getattr(self.requests, 'qsize', None), + 'Threads': lambda s: len(getattr(self.requests, '_threads', [])), + 'Threads Idle': lambda s: getattr(self.requests, 'idle', None), + 'Socket Errors': 0, + 'Requests': lambda s: (not s['Enabled']) and -1 or sum( + [w['Requests'](w) for w in s['Worker Threads'].values()], 0), + 'Bytes Read': lambda s: (not s['Enabled']) and -1 or sum( + [w['Bytes Read'](w) for w in s['Worker Threads'].values()], 0), + 'Bytes Written': lambda s: (not s['Enabled']) and -1 or sum( + [w['Bytes Written'](w) for w in s['Worker Threads'].values()], + 0), + 'Work Time': lambda s: (not s['Enabled']) and -1 or sum( + [w['Work Time'](w) for w in s['Worker Threads'].values()], 0), + 'Read Throughput': lambda s: (not s['Enabled']) and -1 or sum( + [w['Bytes Read'](w) / (w['Work Time'](w) or 1e-6) + for w in s['Worker Threads'].values()], 0), + 'Write Throughput': lambda s: (not s['Enabled']) and -1 or sum( + [w['Bytes Written'](w) / (w['Work Time'](w) or 1e-6) + for w in s['Worker Threads'].values()], 0), + 'Worker Threads': {}, + } + logging.statistics['Cheroot HTTPServer %d' % id(self)] = self.stats + + def runtime(self): + """Return server uptime.""" + if self._start_time is None: + return self._run_time + else: + return self._run_time + (time.time() - self._start_time) + + def __str__(self): + """Render Server instance representing bind address.""" + return '%s.%s(%r)' % (self.__module__, self.__class__.__name__, + self.bind_addr) + + @property + def bind_addr(self): + """Return the interface on which to listen for connections. + + For TCP sockets, a (host, port) tuple. Host values may be any IPv4 + or IPv6 address, or any valid hostname. The string 'localhost' is a + synonym for '127.0.0.1' (or '::1', if your hosts file prefers IPv6). + The string '0.0.0.0' is a special IPv4 entry meaning "any active + interface" (INADDR_ANY), and '::' is the similar IN6ADDR_ANY for + IPv6. The empty string or None are not allowed. + + For UNIX sockets, supply the filename as a string. + + Systemd socket activation is automatic and doesn't require tempering + with this variable. + """ + return self._bind_addr + + @bind_addr.setter + def bind_addr(self, value): + """Set the interface on which to listen for connections.""" + if isinstance(value, tuple) and value[0] in ('', None): + # Despite the socket module docs, using '' does not + # allow AI_PASSIVE to work. Passing None instead + # returns '0.0.0.0' like we want. In other words: + # host AI_PASSIVE result + # '' Y 192.168.x.y + # '' N 192.168.x.y + # None Y 0.0.0.0 + # None N 127.0.0.1 + # But since you can get the same effect with an explicit + # '0.0.0.0', we deny both the empty string and None as values. + raise ValueError("Host values of '' or None are not allowed. " + "Use '0.0.0.0' (IPv4) or '::' (IPv6) instead " + 'to listen on all active interfaces.') + self._bind_addr = value + + def safe_start(self): + """Run the server forever, and stop it cleanly on exit.""" + try: + self.start() + except (KeyboardInterrupt, IOError): + # The time.sleep call might raise + # "IOError: [Errno 4] Interrupted function call" on KBInt. + self.error_log('Keyboard Interrupt: shutting down') + self.stop() + raise + except SystemExit: + self.error_log('SystemExit raised: shutting down') + self.stop() + raise + + def prepare(self): + """Prepare server to serving requests. + + It binds a socket's port, setups the socket to ``listen()`` and does + other preparing things. + """ + self._interrupt = None + + if self.software is None: + self.software = '%s Server' % self.version + + # Select the appropriate socket + self.socket = None + if os.getenv('LISTEN_PID', None): + # systemd socket activation + self.socket = socket.fromfd(3, socket.AF_INET, socket.SOCK_STREAM) + elif isinstance(self.bind_addr, six.string_types): + # AF_UNIX socket + + # So we can reuse the socket... + try: + os.unlink(self.bind_addr) + except Exception: + pass + + # So everyone can access the socket... + try: + os.chmod(self.bind_addr, 0o777) + except Exception: + pass + + info = [ + (socket.AF_UNIX, socket.SOCK_STREAM, 0, '', self.bind_addr)] + else: + # AF_INET or AF_INET6 socket + # Get the correct address family for our host (allows IPv6 + # addresses) + host, port = self.bind_addr + try: + info = socket.getaddrinfo( + host, port, socket.AF_UNSPEC, + socket.SOCK_STREAM, 0, socket.AI_PASSIVE) + except socket.gaierror: + sock_type = socket.AF_INET + bind_addr = self.bind_addr + + if ':' in host: + sock_type = socket.AF_INET6 + bind_addr = bind_addr + (0, 0) + + info = [(sock_type, socket.SOCK_STREAM, 0, '', bind_addr)] + + if not self.socket: + msg = 'No socket could be created' + for res in info: + af, socktype, proto, canonname, sa = res + try: + self.bind(af, socktype, proto) + break + except socket.error as serr: + msg = '%s -- (%s: %s)' % (msg, sa, serr) + if self.socket: + self.socket.close() + self.socket = None + + if not self.socket: + raise socket.error(msg) + + # Timeout so KeyboardInterrupt can be caught on Win32 + self.socket.settimeout(1) + self.socket.listen(self.request_queue_size) + + # Create worker threads + self.requests.start() + + self.ready = True + self._start_time = time.time() + + def serve(self): + """Serve requests, after invoking :func:`prepare()`.""" + while self.ready: + try: + self.tick() + except (KeyboardInterrupt, SystemExit): + raise + except Exception: + self.error_log('Error in HTTPServer.tick', level=logging.ERROR, + traceback=True) + + if self.interrupt: + while self.interrupt is True: + # Wait for self.stop() to complete. See _set_interrupt. + time.sleep(0.1) + if self.interrupt: + raise self.interrupt + + def start(self): + """Run the server forever. + + It is shortcut for invoking :func:`prepare()` then :func:`serve()`. + """ + # We don't have to trap KeyboardInterrupt or SystemExit here, + # because cherrypy.server already does so, calling self.stop() for us. + # If you're using this server with another framework, you should + # trap those exceptions in whatever code block calls start(). + self.prepare() + self.serve() + + def error_log(self, msg='', level=20, traceback=False): + """Write error message to log. + + Args: + msg (str): error message + level (int): logging level + traceback (bool): add traceback to output or not + """ + # Override this in subclasses as desired + sys.stderr.write(msg + '\n') + sys.stderr.flush() + if traceback: + tblines = traceback_.format_exc() + sys.stderr.write(tblines) + sys.stderr.flush() + + def bind(self, family, type, proto=0): + """Create (or recreate) the actual socket object.""" + self.socket = socket.socket(family, type, proto) + prevent_socket_inheritance(self.socket) + if not IS_WINDOWS: + # Windows has different semantics for SO_REUSEADDR, + # so don't set it. + # https://msdn.microsoft.com/en-us/library/ms740621(v=vs.85).aspx + self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + if self.nodelay and not isinstance(self.bind_addr, str): + self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + + if self.ssl_adapter is not None: + self.socket = self.ssl_adapter.bind(self.socket) + + host, port = self.bind_addr[:2] + + # If listening on the IPV6 any address ('::' = IN6ADDR_ANY), + # activate dual-stack. See + # https://github.com/cherrypy/cherrypy/issues/871. + listening_ipv6 = ( + hasattr(socket, 'AF_INET6') + and family == socket.AF_INET6 + and host in ('::', '::0', '::0.0.0.0') + ) + if listening_ipv6: + try: + self.socket.setsockopt( + socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0) + except (AttributeError, socket.error): + # Apparently, the socket option is not available in + # this machine's TCP stack + pass + + self.socket.bind(self.bind_addr) + # TODO: keep requested bind_addr separate real bound_addr (port is + # different in case of ephemeral port 0) + self.bind_addr = self.socket.getsockname() + if family in ( + # Windows doesn't have socket.AF_UNIX, so not using it in check + socket.AF_INET, + socket.AF_INET6, + ): + """UNIX domain sockets are strings or bytes. + + In case of bytes with a leading null-byte it's an abstract socket. + """ + self.bind_addr = self.bind_addr[:2] + + def tick(self): + """Accept a new connection and put it on the Queue.""" + try: + s, addr = self.socket.accept() + if self.stats['Enabled']: + self.stats['Accepts'] += 1 + if not self.ready: + return + + prevent_socket_inheritance(s) + if hasattr(s, 'settimeout'): + s.settimeout(self.timeout) + + mf = MakeFile + ssl_env = {} + # if ssl cert and key are set, we try to be a secure HTTP server + if self.ssl_adapter is not None: + try: + s, ssl_env = self.ssl_adapter.wrap(s) + except errors.NoSSLError: + msg = ('The client sent a plain HTTP request, but ' + 'this server only speaks HTTPS on this port.') + buf = ['%s 400 Bad Request\r\n' % self.protocol, + 'Content-Length: %s\r\n' % len(msg), + 'Content-Type: text/plain\r\n\r\n', + msg] + + sock_to_make = s if six.PY3 else s._sock + wfile = mf(sock_to_make, 'wb', io.DEFAULT_BUFFER_SIZE) + try: + wfile.write(''.join(buf).encode('ISO-8859-1')) + except socket.error as ex: + if ex.args[0] not in errors.socket_errors_to_ignore: + raise + return + if not s: + return + mf = self.ssl_adapter.makefile + # Re-apply our timeout since we may have a new socket object + if hasattr(s, 'settimeout'): + s.settimeout(self.timeout) + + conn = self.ConnectionClass(self, s, mf) + + if not isinstance(self.bind_addr, six.string_types): + # optional values + # Until we do DNS lookups, omit REMOTE_HOST + if addr is None: # sometimes this can happen + # figure out if AF_INET or AF_INET6. + if len(s.getsockname()) == 2: + # AF_INET + addr = ('0.0.0.0', 0) + else: + # AF_INET6 + addr = ('::', 0) + conn.remote_addr = addr[0] + conn.remote_port = addr[1] + + conn.ssl_env = ssl_env + + try: + self.requests.put(conn) + except queue.Full: + # Just drop the conn. TODO: write 503 back? + conn.close() + return + except socket.timeout: + # The only reason for the timeout in start() is so we can + # notice keyboard interrupts on Win32, which don't interrupt + # accept() by default + return + except socket.error as ex: + if self.stats['Enabled']: + self.stats['Socket Errors'] += 1 + if ex.args[0] in errors.socket_error_eintr: + # I *think* this is right. EINTR should occur when a signal + # is received during the accept() call; all docs say retry + # the call, and I *think* I'm reading it right that Python + # will then go ahead and poll for and handle the signal + # elsewhere. See + # https://github.com/cherrypy/cherrypy/issues/707. + return + if ex.args[0] in errors.socket_errors_nonblocking: + # Just try again. See + # https://github.com/cherrypy/cherrypy/issues/479. + return + if ex.args[0] in errors.socket_errors_to_ignore: + # Our socket was closed. + # See https://github.com/cherrypy/cherrypy/issues/686. + return + raise + + @property + def interrupt(self): + """Flag interrupt of the server.""" + return self._interrupt + + @interrupt.setter + def interrupt(self, interrupt): + """Perform the shutdown of this server and save the exception.""" + self._interrupt = True + self.stop() + self._interrupt = interrupt + + def stop(self): + """Gracefully shutdown a server that is serving forever.""" + self.ready = False + if self._start_time is not None: + self._run_time += (time.time() - self._start_time) + self._start_time = None + + sock = getattr(self, 'socket', None) + if sock: + if not isinstance(self.bind_addr, six.string_types): + # Touch our own socket to make accept() return immediately. + try: + host, port = sock.getsockname()[:2] + except socket.error as ex: + if ex.args[0] not in errors.socket_errors_to_ignore: + # Changed to use error code and not message + # See + # https://github.com/cherrypy/cherrypy/issues/860. + raise + else: + # Note that we're explicitly NOT using AI_PASSIVE, + # here, because we want an actual IP to touch. + # localhost won't work if we've bound to a public IP, + # but it will if we bound to '0.0.0.0' (INADDR_ANY). + for res in socket.getaddrinfo(host, port, socket.AF_UNSPEC, + socket.SOCK_STREAM): + af, socktype, proto, canonname, sa = res + s = None + try: + s = socket.socket(af, socktype, proto) + # See + # https://groups.google.com/group/cherrypy-users/ + # browse_frm/thread/bbfe5eb39c904fe0 + s.settimeout(1.0) + s.connect((host, port)) + s.close() + except socket.error: + if s: + s.close() + if hasattr(sock, 'close'): + sock.close() + self.socket = None + + self.requests.stop(self.shutdown_timeout) + + +class Gateway: + """Base class to interface HTTPServer with other systems, such as WSGI.""" + + def __init__(self, req): + """Initialize Gateway instance with request. + + Args: + req (HTTPRequest): current HTTP request + """ + self.req = req + + def respond(self): + """Process the current request. Must be overridden in a subclass.""" + raise NotImplementedError + + +# These may either be ssl.Adapter subclasses or the string names +# of such classes (in which case they will be lazily loaded). +ssl_adapters = { + 'builtin': 'cheroot.ssl.builtin.BuiltinSSLAdapter', + 'pyopenssl': 'cheroot.ssl.pyopenssl.pyOpenSSLAdapter', +} + + +def get_ssl_adapter_class(name='builtin'): + """Return an SSL adapter class for the given name.""" + adapter = ssl_adapters[name.lower()] + if isinstance(adapter, six.string_types): + last_dot = adapter.rfind('.') + attr_name = adapter[last_dot + 1:] + mod_path = adapter[:last_dot] + + try: + mod = sys.modules[mod_path] + if mod is None: + raise KeyError() + except KeyError: + # The last [''] is important. + mod = __import__(mod_path, globals(), locals(), ['']) + + # Let an AttributeError propagate outward. + try: + adapter = getattr(mod, attr_name) + except AttributeError: + raise AttributeError("'%s' object has no attribute '%s'" + % (mod_path, attr_name)) + + return adapter diff --git a/libraries/cheroot/ssl/__init__.py b/libraries/cheroot/ssl/__init__.py new file mode 100644 index 00000000..ec1a0d90 --- /dev/null +++ b/libraries/cheroot/ssl/__init__.py @@ -0,0 +1,51 @@ +"""Implementation of the SSL adapter base interface.""" + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +from abc import ABCMeta, abstractmethod + +from six import add_metaclass + + +@add_metaclass(ABCMeta) +class Adapter: + """Base class for SSL driver library adapters. + + Required methods: + + * ``wrap(sock) -> (wrapped socket, ssl environ dict)`` + * ``makefile(sock, mode='r', bufsize=DEFAULT_BUFFER_SIZE) -> + socket file object`` + """ + + @abstractmethod + def __init__( + self, certificate, private_key, certificate_chain=None, + ciphers=None): + """Set up certificates, private key ciphers and reset context.""" + self.certificate = certificate + self.private_key = private_key + self.certificate_chain = certificate_chain + self.ciphers = ciphers + self.context = None + + @abstractmethod + def bind(self, sock): + """Wrap and return the given socket.""" + return sock + + @abstractmethod + def wrap(self, sock): + """Wrap and return the given socket, plus WSGI environ entries.""" + raise NotImplementedError + + @abstractmethod + def get_environ(self): + """Return WSGI environ entries to be merged into each request.""" + raise NotImplementedError + + @abstractmethod + def makefile(self, sock, mode='r', bufsize=-1): + """Return socket file object.""" + raise NotImplementedError diff --git a/libraries/cheroot/ssl/builtin.py b/libraries/cheroot/ssl/builtin.py new file mode 100644 index 00000000..a19f7eef --- /dev/null +++ b/libraries/cheroot/ssl/builtin.py @@ -0,0 +1,162 @@ +""" +A library for integrating Python's builtin ``ssl`` library with Cheroot. + +The ssl module must be importable for SSL functionality. + +To use this module, set ``HTTPServer.ssl_adapter`` to an instance of +``BuiltinSSLAdapter``. +""" + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +try: + import ssl +except ImportError: + ssl = None + +try: + from _pyio import DEFAULT_BUFFER_SIZE +except ImportError: + try: + from io import DEFAULT_BUFFER_SIZE + except ImportError: + DEFAULT_BUFFER_SIZE = -1 + +import six + +from . import Adapter +from .. import errors +from ..makefile import MakeFile + + +if six.PY3: + generic_socket_error = OSError +else: + import socket + generic_socket_error = socket.error + del socket + + +def _assert_ssl_exc_contains(exc, *msgs): + """Check whether SSL exception contains either of messages provided.""" + if len(msgs) < 1: + raise TypeError( + '_assert_ssl_exc_contains() requires ' + 'at least one message to be passed.' + ) + err_msg_lower = exc.args[1].lower() + return any(m.lower() in err_msg_lower for m in msgs) + + +class BuiltinSSLAdapter(Adapter): + """A wrapper for integrating Python's builtin ssl module with Cheroot.""" + + certificate = None + """The filename of the server SSL certificate.""" + + private_key = None + """The filename of the server's private key file.""" + + certificate_chain = None + """The filename of the certificate chain file.""" + + context = None + """The ssl.SSLContext that will be used to wrap sockets.""" + + ciphers = None + """The ciphers list of SSL.""" + + def __init__( + self, certificate, private_key, certificate_chain=None, + ciphers=None): + """Set up context in addition to base class properties if available.""" + if ssl is None: + raise ImportError('You must install the ssl module to use HTTPS.') + + super(BuiltinSSLAdapter, self).__init__( + certificate, private_key, certificate_chain, ciphers) + + self.context = ssl.create_default_context( + purpose=ssl.Purpose.CLIENT_AUTH, + cafile=certificate_chain + ) + self.context.load_cert_chain(certificate, private_key) + if self.ciphers is not None: + self.context.set_ciphers(ciphers) + + def bind(self, sock): + """Wrap and return the given socket.""" + return super(BuiltinSSLAdapter, self).bind(sock) + + def wrap(self, sock): + """Wrap and return the given socket, plus WSGI environ entries.""" + EMPTY_RESULT = None, {} + try: + s = self.context.wrap_socket( + sock, do_handshake_on_connect=True, server_side=True, + ) + except ssl.SSLError as ex: + if ex.errno == ssl.SSL_ERROR_EOF: + # This is almost certainly due to the cherrypy engine + # 'pinging' the socket to assert it's connectable; + # the 'ping' isn't SSL. + return EMPTY_RESULT + elif ex.errno == ssl.SSL_ERROR_SSL: + if _assert_ssl_exc_contains(ex, 'http request'): + # The client is speaking HTTP to an HTTPS server. + raise errors.NoSSLError + + # Check if it's one of the known errors + # Errors that are caught by PyOpenSSL, but thrown by + # built-in ssl + _block_errors = ( + 'unknown protocol', 'unknown ca', 'unknown_ca', + 'unknown error', + 'https proxy request', 'inappropriate fallback', + 'wrong version number', + 'no shared cipher', 'certificate unknown', + 'ccs received early', + ) + if _assert_ssl_exc_contains(ex, *_block_errors): + # Accepted error, let's pass + return EMPTY_RESULT + elif _assert_ssl_exc_contains(ex, 'handshake operation timed out'): + # This error is thrown by builtin SSL after a timeout + # when client is speaking HTTP to an HTTPS server. + # The connection can safely be dropped. + return EMPTY_RESULT + raise + except generic_socket_error as exc: + """It is unclear why exactly this happens. + + It's reproducible only under Python 2 with openssl>1.0 and stdlib + ``ssl`` wrapper, and only with CherryPy. + So it looks like some healthcheck tries to connect to this socket + during startup (from the same process). + + + Ref: https://github.com/cherrypy/cherrypy/issues/1618 + """ + if six.PY2 and exc.args == (0, 'Error'): + return EMPTY_RESULT + raise + return s, self.get_environ(s) + + # TODO: fill this out more with mod ssl env + def get_environ(self, sock): + """Create WSGI environ entries to be merged into each request.""" + cipher = sock.cipher() + ssl_environ = { + 'wsgi.url_scheme': 'https', + 'HTTPS': 'on', + 'SSL_PROTOCOL': cipher[1], + 'SSL_CIPHER': cipher[0] + # SSL_VERSION_INTERFACE string The mod_ssl program version + # SSL_VERSION_LIBRARY string The OpenSSL program version + } + return ssl_environ + + def makefile(self, sock, mode='r', bufsize=DEFAULT_BUFFER_SIZE): + """Return socket file object.""" + return MakeFile(sock, mode, bufsize) diff --git a/libraries/cheroot/ssl/pyopenssl.py b/libraries/cheroot/ssl/pyopenssl.py new file mode 100644 index 00000000..2185f851 --- /dev/null +++ b/libraries/cheroot/ssl/pyopenssl.py @@ -0,0 +1,267 @@ +""" +A library for integrating pyOpenSSL with Cheroot. + +The OpenSSL module must be importable for SSL functionality. +You can obtain it from `here <https://launchpad.net/pyopenssl>`_. + +To use this module, set HTTPServer.ssl_adapter to an instance of +ssl.Adapter. There are two ways to use SSL: + +Method One +---------- + + * ``ssl_adapter.context``: an instance of SSL.Context. + +If this is not None, it is assumed to be an SSL.Context instance, +and will be passed to SSL.Connection on bind(). The developer is +responsible for forming a valid Context object. This approach is +to be preferred for more flexibility, e.g. if the cert and key are +streams instead of files, or need decryption, or SSL.SSLv3_METHOD +is desired instead of the default SSL.SSLv23_METHOD, etc. Consult +the pyOpenSSL documentation for complete options. + +Method Two (shortcut) +--------------------- + + * ``ssl_adapter.certificate``: the filename of the server SSL certificate. + * ``ssl_adapter.private_key``: the filename of the server's private key file. + +Both are None by default. If ssl_adapter.context is None, but .private_key +and .certificate are both given and valid, they will be read, and the +context will be automatically created from them. +""" + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +import socket +import threading +import time + +try: + from OpenSSL import SSL + from OpenSSL import crypto +except ImportError: + SSL = None + +from . import Adapter +from .. import errors, server as cheroot_server +from ..makefile import MakeFile + + +class SSL_fileobject(MakeFile): + """SSL file object attached to a socket object.""" + + ssl_timeout = 3 + ssl_retry = .01 + + def _safe_call(self, is_reader, call, *args, **kwargs): + """Wrap the given call with SSL error-trapping. + + is_reader: if False EOF errors will be raised. If True, EOF errors + will return "" (to emulate normal sockets). + """ + start = time.time() + while True: + try: + return call(*args, **kwargs) + except SSL.WantReadError: + # Sleep and try again. This is dangerous, because it means + # the rest of the stack has no way of differentiating + # between a "new handshake" error and "client dropped". + # Note this isn't an endless loop: there's a timeout below. + time.sleep(self.ssl_retry) + except SSL.WantWriteError: + time.sleep(self.ssl_retry) + except SSL.SysCallError as e: + if is_reader and e.args == (-1, 'Unexpected EOF'): + return '' + + errnum = e.args[0] + if is_reader and errnum in errors.socket_errors_to_ignore: + return '' + raise socket.error(errnum) + except SSL.Error as e: + if is_reader and e.args == (-1, 'Unexpected EOF'): + return '' + + thirdarg = None + try: + thirdarg = e.args[0][0][2] + except IndexError: + pass + + if thirdarg == 'http request': + # The client is talking HTTP to an HTTPS server. + raise errors.NoSSLError() + + raise errors.FatalSSLAlert(*e.args) + + if time.time() - start > self.ssl_timeout: + raise socket.timeout('timed out') + + def recv(self, size): + """Receive message of a size from the socket.""" + return self._safe_call(True, super(SSL_fileobject, self).recv, size) + + def sendall(self, *args, **kwargs): + """Send whole message to the socket.""" + return self._safe_call(False, super(SSL_fileobject, self).sendall, + *args, **kwargs) + + def send(self, *args, **kwargs): + """Send some part of message to the socket.""" + return self._safe_call(False, super(SSL_fileobject, self).send, + *args, **kwargs) + + +class SSLConnection: + """A thread-safe wrapper for an SSL.Connection. + + ``*args``: the arguments to create the wrapped ``SSL.Connection(*args)``. + """ + + def __init__(self, *args): + """Initialize SSLConnection instance.""" + self._ssl_conn = SSL.Connection(*args) + self._lock = threading.RLock() + + for f in ('get_context', 'pending', 'send', 'write', 'recv', 'read', + 'renegotiate', 'bind', 'listen', 'connect', 'accept', + 'setblocking', 'fileno', 'close', 'get_cipher_list', + 'getpeername', 'getsockname', 'getsockopt', 'setsockopt', + 'makefile', 'get_app_data', 'set_app_data', 'state_string', + 'sock_shutdown', 'get_peer_certificate', 'want_read', + 'want_write', 'set_connect_state', 'set_accept_state', + 'connect_ex', 'sendall', 'settimeout', 'gettimeout'): + exec("""def %s(self, *args): + self._lock.acquire() + try: + return self._ssl_conn.%s(*args) + finally: + self._lock.release() +""" % (f, f)) + + def shutdown(self, *args): + """Shutdown the SSL connection. + + Ignore all incoming args since pyOpenSSL.socket.shutdown takes no args. + """ + self._lock.acquire() + try: + return self._ssl_conn.shutdown() + finally: + self._lock.release() + + +class pyOpenSSLAdapter(Adapter): + """A wrapper for integrating pyOpenSSL with Cheroot.""" + + certificate = None + """The filename of the server SSL certificate.""" + + private_key = None + """The filename of the server's private key file.""" + + certificate_chain = None + """Optional. The filename of CA's intermediate certificate bundle. + + This is needed for cheaper "chained root" SSL certificates, and should be + left as None if not required.""" + + context = None + """An instance of SSL.Context.""" + + ciphers = None + """The ciphers list of SSL.""" + + def __init__( + self, certificate, private_key, certificate_chain=None, + ciphers=None): + """Initialize OpenSSL Adapter instance.""" + if SSL is None: + raise ImportError('You must install pyOpenSSL to use HTTPS.') + + super(pyOpenSSLAdapter, self).__init__( + certificate, private_key, certificate_chain, ciphers) + + self._environ = None + + def bind(self, sock): + """Wrap and return the given socket.""" + if self.context is None: + self.context = self.get_context() + conn = SSLConnection(self.context, sock) + self._environ = self.get_environ() + return conn + + def wrap(self, sock): + """Wrap and return the given socket, plus WSGI environ entries.""" + return sock, self._environ.copy() + + def get_context(self): + """Return an SSL.Context from self attributes.""" + # See https://code.activestate.com/recipes/442473/ + c = SSL.Context(SSL.SSLv23_METHOD) + c.use_privatekey_file(self.private_key) + if self.certificate_chain: + c.load_verify_locations(self.certificate_chain) + c.use_certificate_file(self.certificate) + return c + + def get_environ(self): + """Return WSGI environ entries to be merged into each request.""" + ssl_environ = { + 'HTTPS': 'on', + # pyOpenSSL doesn't provide access to any of these AFAICT + # 'SSL_PROTOCOL': 'SSLv2', + # SSL_CIPHER string The cipher specification name + # SSL_VERSION_INTERFACE string The mod_ssl program version + # SSL_VERSION_LIBRARY string The OpenSSL program version + } + + if self.certificate: + # Server certificate attributes + cert = open(self.certificate, 'rb').read() + cert = crypto.load_certificate(crypto.FILETYPE_PEM, cert) + ssl_environ.update({ + 'SSL_SERVER_M_VERSION': cert.get_version(), + 'SSL_SERVER_M_SERIAL': cert.get_serial_number(), + # 'SSL_SERVER_V_START': + # Validity of server's certificate (start time), + # 'SSL_SERVER_V_END': + # Validity of server's certificate (end time), + }) + + for prefix, dn in [('I', cert.get_issuer()), + ('S', cert.get_subject())]: + # X509Name objects don't seem to have a way to get the + # complete DN string. Use str() and slice it instead, + # because str(dn) == "<X509Name object '/C=US/ST=...'>" + dnstr = str(dn)[18:-2] + + wsgikey = 'SSL_SERVER_%s_DN' % prefix + ssl_environ[wsgikey] = dnstr + + # The DN should be of the form: /k1=v1/k2=v2, but we must allow + # for any value to contain slashes itself (in a URL). + while dnstr: + pos = dnstr.rfind('=') + dnstr, value = dnstr[:pos], dnstr[pos + 1:] + pos = dnstr.rfind('/') + dnstr, key = dnstr[:pos], dnstr[pos + 1:] + if key and value: + wsgikey = 'SSL_SERVER_%s_DN_%s' % (prefix, key) + ssl_environ[wsgikey] = value + + return ssl_environ + + def makefile(self, sock, mode='r', bufsize=-1): + """Return socket file object.""" + if SSL and isinstance(sock, SSL.ConnectionType): + timeout = sock.gettimeout() + f = SSL_fileobject(sock, mode, bufsize) + f.ssl_timeout = timeout + return f + else: + return cheroot_server.CP_fileobject(sock, mode, bufsize) diff --git a/libraries/cheroot/test/__init__.py b/libraries/cheroot/test/__init__.py new file mode 100644 index 00000000..e2a7b348 --- /dev/null +++ b/libraries/cheroot/test/__init__.py @@ -0,0 +1 @@ +"""Cheroot test suite.""" diff --git a/libraries/cheroot/test/conftest.py b/libraries/cheroot/test/conftest.py new file mode 100644 index 00000000..9f5f9284 --- /dev/null +++ b/libraries/cheroot/test/conftest.py @@ -0,0 +1,27 @@ +"""Pytest configuration module. + +Contains fixtures, which are tightly bound to the Cheroot framework +itself, useless for end-users' app testing. +""" + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +import pytest + +from ..testing import ( # noqa: F401 + native_server, wsgi_server, +) +from ..testing import get_server_client + + +@pytest.fixture # noqa: F811 +def wsgi_server_client(wsgi_server): + """Create a test client out of given WSGI server.""" + return get_server_client(wsgi_server) + + +@pytest.fixture # noqa: F811 +def native_server_client(native_server): + """Create a test client out of given HTTP server.""" + return get_server_client(native_server) diff --git a/libraries/cheroot/test/helper.py b/libraries/cheroot/test/helper.py new file mode 100644 index 00000000..38f40b26 --- /dev/null +++ b/libraries/cheroot/test/helper.py @@ -0,0 +1,169 @@ +"""A library of helper functions for the Cheroot test suite.""" + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +import datetime +import logging +import os +import sys +import time +import threading +import types + +from six.moves import http_client + +import six + +import cheroot.server +import cheroot.wsgi + +from cheroot.test import webtest + +log = logging.getLogger(__name__) +thisdir = os.path.abspath(os.path.dirname(__file__)) +serverpem = os.path.join(os.getcwd(), thisdir, 'test.pem') + + +config = { + 'bind_addr': ('127.0.0.1', 54583), + 'server': 'wsgi', + 'wsgi_app': None, +} + + +class CherootWebCase(webtest.WebCase): + """Helper class for a web app test suite.""" + + script_name = '' + scheme = 'http' + + available_servers = { + 'wsgi': cheroot.wsgi.Server, + 'native': cheroot.server.HTTPServer, + } + + @classmethod + def setup_class(cls): + """Create and run one HTTP server per class.""" + conf = config.copy() + conf.update(getattr(cls, 'config', {})) + + s_class = conf.pop('server', 'wsgi') + server_factory = cls.available_servers.get(s_class) + if server_factory is None: + raise RuntimeError('Unknown server in config: %s' % conf['server']) + cls.httpserver = server_factory(**conf) + + cls.HOST, cls.PORT = cls.httpserver.bind_addr + if cls.httpserver.ssl_adapter is None: + ssl = '' + cls.scheme = 'http' + else: + ssl = ' (ssl)' + cls.HTTP_CONN = http_client.HTTPSConnection + cls.scheme = 'https' + + v = sys.version.split()[0] + log.info('Python version used to run this test script: %s' % v) + log.info('Cheroot version: %s' % cheroot.__version__) + log.info('HTTP server version: %s%s' % (cls.httpserver.protocol, ssl)) + log.info('PID: %s' % os.getpid()) + + if hasattr(cls, 'setup_server'): + # Clear the wsgi server so that + # it can be updated with the new root + cls.setup_server() + cls.start() + + @classmethod + def teardown_class(cls): + """Cleanup HTTP server.""" + if hasattr(cls, 'setup_server'): + cls.stop() + + @classmethod + def start(cls): + """Load and start the HTTP server.""" + threading.Thread(target=cls.httpserver.safe_start).start() + while not cls.httpserver.ready: + time.sleep(0.1) + + @classmethod + def stop(cls): + """Terminate HTTP server.""" + cls.httpserver.stop() + td = getattr(cls, 'teardown', None) + if td: + td() + + date_tolerance = 2 + + def assertEqualDates(self, dt1, dt2, seconds=None): + """Assert abs(dt1 - dt2) is within Y seconds.""" + if seconds is None: + seconds = self.date_tolerance + + if dt1 > dt2: + diff = dt1 - dt2 + else: + diff = dt2 - dt1 + if not diff < datetime.timedelta(seconds=seconds): + raise AssertionError('%r and %r are not within %r seconds.' % + (dt1, dt2, seconds)) + + +class Request: + """HTTP request container.""" + + def __init__(self, environ): + """Initialize HTTP request.""" + self.environ = environ + + +class Response: + """HTTP response container.""" + + def __init__(self): + """Initialize HTTP response.""" + self.status = '200 OK' + self.headers = {'Content-Type': 'text/html'} + self.body = None + + def output(self): + """Generate iterable response body object.""" + if self.body is None: + return [] + elif isinstance(self.body, six.text_type): + return [self.body.encode('iso-8859-1')] + elif isinstance(self.body, six.binary_type): + return [self.body] + else: + return [x.encode('iso-8859-1') for x in self.body] + + +class Controller: + """WSGI app for tests.""" + + def __call__(self, environ, start_response): + """WSGI request handler.""" + req, resp = Request(environ), Response() + try: + # Python 3 supports unicode attribute names + # Python 2 encodes them + handler = self.handlers[environ['PATH_INFO']] + except KeyError: + resp.status = '404 Not Found' + else: + output = handler(req, resp) + if (output is not None and + not any(resp.status.startswith(status_code) + for status_code in ('204', '304'))): + resp.body = output + try: + resp.headers.setdefault('Content-Length', str(len(output))) + except TypeError: + if not isinstance(output, types.GeneratorType): + raise + start_response(resp.status, resp.headers.items()) + return resp.output() diff --git a/libraries/cheroot/test/test.pem b/libraries/cheroot/test/test.pem new file mode 100644 index 00000000..47a47042 --- /dev/null +++ b/libraries/cheroot/test/test.pem @@ -0,0 +1,38 @@ +-----BEGIN RSA PRIVATE KEY----- +MIICXAIBAAKBgQDBKo554mzIMY+AByUNpaUOP9bJnQ7ZLQe9XgHwoLJR4VzpyZZZ +R9L4WtImEew05FY3Izerfm3MN3+MC0tJ6yQU9sOiU3vBW6RrLIMlfKsnRwBRZ0Kn +da+O6xldVSosu8Ev3z9VZ94iC/ZgKzrH7Mjj/U8/MQO7RBS/LAqee8bFNQIDAQAB +AoGAWOCF0ZrWxn3XMucWq2LNwPKqlvVGwbIwX3cDmX22zmnM4Fy6arXbYh4XlyCj +9+ofqRrxIFz5k/7tFriTmZ0xag5+Jdx+Kwg0/twiP7XCNKipFogwe1Hznw8OFAoT +enKBdj2+/n2o0Bvo/tDB59m9L/538d46JGQUmJlzMyqYikECQQDyoq+8CtMNvE18 +8VgHcR/KtApxWAjj4HpaHYL637ATjThetUZkW92mgDgowyplthusxdNqhHWyv7E8 +tWNdYErZAkEAy85ShTR0M5aWmrE7o0r0SpWInAkNBH9aXQRRARFYsdBtNfRu6I0i +0lvU9wiu3eF57FMEC86yViZ5UBnQfTu7vQJAVesj/Zt7pwaCDfdMa740OsxMUlyR +MVhhGx4OLpYdPJ8qUecxGQKq13XZ7R1HGyNEY4bd2X80Smq08UFuATfC6QJAH8UB +yBHtKz2GLIcELOg6PIYizW/7v3+6rlVF60yw7sb2vzpjL40QqIn4IKoR2DSVtOkb +8FtAIX3N21aq0VrGYQJBAIPiaEc2AZ8Bq2GC4F3wOz/BxJ/izvnkiotR12QK4fh5 +yjZMhTjWCas5zwHR5PDjlD88AWGDMsZ1PicD4348xJQ= +-----END RSA PRIVATE KEY----- +-----BEGIN CERTIFICATE----- +MIIDxTCCAy6gAwIBAgIJAI18BD7eQxlGMA0GCSqGSIb3DQEBBAUAMIGeMQswCQYD +VQQGEwJVUzETMBEGA1UECBMKQ2FsaWZvcm5pYTESMBAGA1UEBxMJU2FuIERpZWdv +MRkwFwYDVQQKExBDaGVycnlQeSBQcm9qZWN0MREwDwYDVQQLEwhkZXYtdGVzdDEW +MBQGA1UEAxMNQ2hlcnJ5UHkgVGVhbTEgMB4GCSqGSIb3DQEJARYRcmVtaUBjaGVy +cnlweS5vcmcwHhcNMDYwOTA5MTkyMDIwWhcNMzQwMTI0MTkyMDIwWjCBnjELMAkG +A1UEBhMCVVMxEzARBgNVBAgTCkNhbGlmb3JuaWExEjAQBgNVBAcTCVNhbiBEaWVn +bzEZMBcGA1UEChMQQ2hlcnJ5UHkgUHJvamVjdDERMA8GA1UECxMIZGV2LXRlc3Qx +FjAUBgNVBAMTDUNoZXJyeVB5IFRlYW0xIDAeBgkqhkiG9w0BCQEWEXJlbWlAY2hl +cnJ5cHkub3JnMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDBKo554mzIMY+A +ByUNpaUOP9bJnQ7ZLQe9XgHwoLJR4VzpyZZZR9L4WtImEew05FY3Izerfm3MN3+M +C0tJ6yQU9sOiU3vBW6RrLIMlfKsnRwBRZ0Knda+O6xldVSosu8Ev3z9VZ94iC/Zg +KzrH7Mjj/U8/MQO7RBS/LAqee8bFNQIDAQABo4IBBzCCAQMwHQYDVR0OBBYEFDIQ +2feb71tVZCWpU0qJ/Tw+wdtoMIHTBgNVHSMEgcswgciAFDIQ2feb71tVZCWpU0qJ +/Tw+wdtooYGkpIGhMIGeMQswCQYDVQQGEwJVUzETMBEGA1UECBMKQ2FsaWZvcm5p +YTESMBAGA1UEBxMJU2FuIERpZWdvMRkwFwYDVQQKExBDaGVycnlQeSBQcm9qZWN0 +MREwDwYDVQQLEwhkZXYtdGVzdDEWMBQGA1UEAxMNQ2hlcnJ5UHkgVGVhbTEgMB4G +CSqGSIb3DQEJARYRcmVtaUBjaGVycnlweS5vcmeCCQCNfAQ+3kMZRjAMBgNVHRME +BTADAQH/MA0GCSqGSIb3DQEBBAUAA4GBAL7AAQz7IePV48ZTAFHKr88ntPALsL5S +8vHCZPNMevNkLTj3DYUw2BcnENxMjm1kou2F2BkvheBPNZKIhc6z4hAml3ed1xa2 +D7w6e6OTcstdK/+KrPDDHeOP1dhMWNs2JE1bNlfF1LiXzYKSXpe88eCKjCXsCT/T +NluCaWQys3MS +-----END CERTIFICATE----- diff --git a/libraries/cheroot/test/test__compat.py b/libraries/cheroot/test/test__compat.py new file mode 100644 index 00000000..d34e5eb8 --- /dev/null +++ b/libraries/cheroot/test/test__compat.py @@ -0,0 +1,49 @@ +# -*- coding: utf-8 -*- +"""Test suite for cross-python compatibility helpers.""" + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +import pytest +import six + +from cheroot._compat import ntob, ntou, bton + + +@pytest.mark.parametrize( + 'func,inp,out', + [ + (ntob, 'bar', b'bar'), + (ntou, 'bar', u'bar'), + (bton, b'bar', 'bar'), + ], +) +def test_compat_functions_positive(func, inp, out): + """Check that compat functions work with correct input.""" + assert func(inp, encoding='utf-8') == out + + +@pytest.mark.parametrize( + 'func', + [ + ntob, + ntou, + ], +) +def test_compat_functions_negative_nonnative(func): + """Check that compat functions fail loudly for incorrect input.""" + non_native_test_str = b'bar' if six.PY3 else u'bar' + with pytest.raises(TypeError): + func(non_native_test_str, encoding='utf-8') + + +@pytest.mark.skip(reason='This test does not work now') +@pytest.mark.skipif( + six.PY3, + reason='This code path only appears in Python 2 version.', +) +def test_ntou_escape(): + """Check that ntou supports escape-encoding under Python 2.""" + expected = u'' + actual = ntou('hi'.encode('ISO-8859-1'), encoding='escape') + assert actual == expected diff --git a/libraries/cheroot/test/test_conn.py b/libraries/cheroot/test/test_conn.py new file mode 100644 index 00000000..f543dd9b --- /dev/null +++ b/libraries/cheroot/test/test_conn.py @@ -0,0 +1,897 @@ +"""Tests for TCP connection handling, including proper and timely close.""" + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +import socket +import time + +from six.moves import range, http_client, urllib + +import six +import pytest + +from cheroot.test import helper, webtest + + +timeout = 1 +pov = 'pPeErRsSiIsStTeEnNcCeE oOfF vViIsSiIoOnN' + + +class Controller(helper.Controller): + """Controller for serving WSGI apps.""" + + def hello(req, resp): + """Render Hello world.""" + return 'Hello, world!' + + def pov(req, resp): + """Render pov value.""" + return pov + + def stream(req, resp): + """Render streaming response.""" + if 'set_cl' in req.environ['QUERY_STRING']: + resp.headers['Content-Length'] = str(10) + + def content(): + for x in range(10): + yield str(x) + + return content() + + def upload(req, resp): + """Process file upload and render thank.""" + if not req.environ['REQUEST_METHOD'] == 'POST': + raise AssertionError("'POST' != request.method %r" % + req.environ['REQUEST_METHOD']) + return "thanks for '%s'" % req.environ['wsgi.input'].read() + + def custom_204(req, resp): + """Render response with status 204.""" + resp.status = '204' + return 'Code = 204' + + def custom_304(req, resp): + """Render response with status 304.""" + resp.status = '304' + return 'Code = 304' + + def err_before_read(req, resp): + """Render response with status 500.""" + resp.status = '500 Internal Server Error' + return 'ok' + + def one_megabyte_of_a(req, resp): + """Render 1MB response.""" + return ['a' * 1024] * 1024 + + def wrong_cl_buffered(req, resp): + """Render buffered response with invalid length value.""" + resp.headers['Content-Length'] = '5' + return 'I have too many bytes' + + def wrong_cl_unbuffered(req, resp): + """Render unbuffered response with invalid length value.""" + resp.headers['Content-Length'] = '5' + return ['I too', ' have too many bytes'] + + def _munge(string): + """Encode PATH_INFO correctly depending on Python version. + + WSGI 1.0 is a mess around unicode. Create endpoints + that match the PATH_INFO that it produces. + """ + if six.PY3: + return string.encode('utf-8').decode('latin-1') + return string + + handlers = { + '/hello': hello, + '/pov': pov, + '/page1': pov, + '/page2': pov, + '/page3': pov, + '/stream': stream, + '/upload': upload, + '/custom/204': custom_204, + '/custom/304': custom_304, + '/err_before_read': err_before_read, + '/one_megabyte_of_a': one_megabyte_of_a, + '/wrong_cl_buffered': wrong_cl_buffered, + '/wrong_cl_unbuffered': wrong_cl_unbuffered, + } + + +@pytest.fixture +def testing_server(wsgi_server_client): + """Attach a WSGI app to the given server and pre-configure it.""" + app = Controller() + + def _timeout(req, resp): + return str(wsgi_server.timeout) + app.handlers['/timeout'] = _timeout + wsgi_server = wsgi_server_client.server_instance + wsgi_server.wsgi_app = app + wsgi_server.max_request_body_size = 1001 + wsgi_server.timeout = timeout + wsgi_server.server_client = wsgi_server_client + return wsgi_server + + +@pytest.fixture +def test_client(testing_server): + """Get and return a test client out of the given server.""" + return testing_server.server_client + + +def header_exists(header_name, headers): + """Check that a header is present.""" + return header_name.lower() in (k.lower() for (k, _) in headers) + + +def header_has_value(header_name, header_value, headers): + """Check that a header with a given value is present.""" + return header_name.lower() in (k.lower() for (k, v) in headers + if v == header_value) + + +def test_HTTP11_persistent_connections(test_client): + """Test persistent HTTP/1.1 connections.""" + # Initialize a persistent HTTP connection + http_connection = test_client.get_connection() + http_connection.auto_open = False + http_connection.connect() + + # Make the first request and assert there's no "Connection: close". + status_line, actual_headers, actual_resp_body = test_client.get( + '/pov', http_conn=http_connection + ) + actual_status = int(status_line[:3]) + assert actual_status == 200 + assert status_line[4:] == 'OK' + assert actual_resp_body == pov.encode() + assert not header_exists('Connection', actual_headers) + + # Make another request on the same connection. + status_line, actual_headers, actual_resp_body = test_client.get( + '/page1', http_conn=http_connection + ) + actual_status = int(status_line[:3]) + assert actual_status == 200 + assert status_line[4:] == 'OK' + assert actual_resp_body == pov.encode() + assert not header_exists('Connection', actual_headers) + + # Test client-side close. + status_line, actual_headers, actual_resp_body = test_client.get( + '/page2', http_conn=http_connection, + headers=[('Connection', 'close')] + ) + actual_status = int(status_line[:3]) + assert actual_status == 200 + assert status_line[4:] == 'OK' + assert actual_resp_body == pov.encode() + assert header_has_value('Connection', 'close', actual_headers) + + # Make another request on the same connection, which should error. + with pytest.raises(http_client.NotConnected): + test_client.get('/pov', http_conn=http_connection) + + +@pytest.mark.parametrize( + 'set_cl', + ( + False, # Without Content-Length + True, # With Content-Length + ) +) +def test_streaming_11(test_client, set_cl): + """Test serving of streaming responses with HTTP/1.1 protocol.""" + # Initialize a persistent HTTP connection + http_connection = test_client.get_connection() + http_connection.auto_open = False + http_connection.connect() + + # Make the first request and assert there's no "Connection: close". + status_line, actual_headers, actual_resp_body = test_client.get( + '/pov', http_conn=http_connection + ) + actual_status = int(status_line[:3]) + assert actual_status == 200 + assert status_line[4:] == 'OK' + assert actual_resp_body == pov.encode() + assert not header_exists('Connection', actual_headers) + + # Make another, streamed request on the same connection. + if set_cl: + # When a Content-Length is provided, the content should stream + # without closing the connection. + status_line, actual_headers, actual_resp_body = test_client.get( + '/stream?set_cl=Yes', http_conn=http_connection + ) + assert header_exists('Content-Length', actual_headers) + assert not header_has_value('Connection', 'close', actual_headers) + assert not header_exists('Transfer-Encoding', actual_headers) + + assert actual_status == 200 + assert status_line[4:] == 'OK' + assert actual_resp_body == b'0123456789' + else: + # When no Content-Length response header is provided, + # streamed output will either close the connection, or use + # chunked encoding, to determine transfer-length. + status_line, actual_headers, actual_resp_body = test_client.get( + '/stream', http_conn=http_connection + ) + assert not header_exists('Content-Length', actual_headers) + assert actual_status == 200 + assert status_line[4:] == 'OK' + assert actual_resp_body == b'0123456789' + + chunked_response = False + for k, v in actual_headers: + if k.lower() == 'transfer-encoding': + if str(v) == 'chunked': + chunked_response = True + + if chunked_response: + assert not header_has_value('Connection', 'close', actual_headers) + else: + assert header_has_value('Connection', 'close', actual_headers) + + # Make another request on the same connection, which should + # error. + with pytest.raises(http_client.NotConnected): + test_client.get('/pov', http_conn=http_connection) + + # Try HEAD. + # See https://www.bitbucket.org/cherrypy/cherrypy/issue/864. + # TODO: figure out how can this be possible on an closed connection + # (chunked_response case) + status_line, actual_headers, actual_resp_body = test_client.head( + '/stream', http_conn=http_connection + ) + assert actual_status == 200 + assert status_line[4:] == 'OK' + assert actual_resp_body == b'' + assert not header_exists('Transfer-Encoding', actual_headers) + + +@pytest.mark.parametrize( + 'set_cl', + ( + False, # Without Content-Length + True, # With Content-Length + ) +) +def test_streaming_10(test_client, set_cl): + """Test serving of streaming responses with HTTP/1.0 protocol.""" + original_server_protocol = test_client.server_instance.protocol + test_client.server_instance.protocol = 'HTTP/1.0' + + # Initialize a persistent HTTP connection + http_connection = test_client.get_connection() + http_connection.auto_open = False + http_connection.connect() + + # Make the first request and assert Keep-Alive. + status_line, actual_headers, actual_resp_body = test_client.get( + '/pov', http_conn=http_connection, + headers=[('Connection', 'Keep-Alive')], + protocol='HTTP/1.0', + ) + actual_status = int(status_line[:3]) + assert actual_status == 200 + assert status_line[4:] == 'OK' + assert actual_resp_body == pov.encode() + assert header_has_value('Connection', 'Keep-Alive', actual_headers) + + # Make another, streamed request on the same connection. + if set_cl: + # When a Content-Length is provided, the content should + # stream without closing the connection. + status_line, actual_headers, actual_resp_body = test_client.get( + '/stream?set_cl=Yes', http_conn=http_connection, + headers=[('Connection', 'Keep-Alive')], + protocol='HTTP/1.0', + ) + actual_status = int(status_line[:3]) + assert actual_status == 200 + assert status_line[4:] == 'OK' + assert actual_resp_body == b'0123456789' + + assert header_exists('Content-Length', actual_headers) + assert header_has_value('Connection', 'Keep-Alive', actual_headers) + assert not header_exists('Transfer-Encoding', actual_headers) + else: + # When a Content-Length is not provided, + # the server should close the connection. + status_line, actual_headers, actual_resp_body = test_client.get( + '/stream', http_conn=http_connection, + headers=[('Connection', 'Keep-Alive')], + protocol='HTTP/1.0', + ) + actual_status = int(status_line[:3]) + assert actual_status == 200 + assert status_line[4:] == 'OK' + assert actual_resp_body == b'0123456789' + + assert not header_exists('Content-Length', actual_headers) + assert not header_has_value('Connection', 'Keep-Alive', actual_headers) + assert not header_exists('Transfer-Encoding', actual_headers) + + # Make another request on the same connection, which should error. + with pytest.raises(http_client.NotConnected): + test_client.get( + '/pov', http_conn=http_connection, + protocol='HTTP/1.0', + ) + + test_client.server_instance.protocol = original_server_protocol + + +@pytest.mark.parametrize( + 'http_server_protocol', + ( + 'HTTP/1.0', + 'HTTP/1.1', + ) +) +def test_keepalive(test_client, http_server_protocol): + """Test Keep-Alive enabled connections.""" + original_server_protocol = test_client.server_instance.protocol + test_client.server_instance.protocol = http_server_protocol + + http_client_protocol = 'HTTP/1.0' + + # Initialize a persistent HTTP connection + http_connection = test_client.get_connection() + http_connection.auto_open = False + http_connection.connect() + + # Test a normal HTTP/1.0 request. + status_line, actual_headers, actual_resp_body = test_client.get( + '/page2', + protocol=http_client_protocol, + ) + actual_status = int(status_line[:3]) + assert actual_status == 200 + assert status_line[4:] == 'OK' + assert actual_resp_body == pov.encode() + assert not header_exists('Connection', actual_headers) + + # Test a keep-alive HTTP/1.0 request. + + status_line, actual_headers, actual_resp_body = test_client.get( + '/page3', headers=[('Connection', 'Keep-Alive')], + http_conn=http_connection, protocol=http_client_protocol, + ) + actual_status = int(status_line[:3]) + assert actual_status == 200 + assert status_line[4:] == 'OK' + assert actual_resp_body == pov.encode() + assert header_has_value('Connection', 'Keep-Alive', actual_headers) + + # Remove the keep-alive header again. + status_line, actual_headers, actual_resp_body = test_client.get( + '/page3', http_conn=http_connection, + protocol=http_client_protocol, + ) + actual_status = int(status_line[:3]) + assert actual_status == 200 + assert status_line[4:] == 'OK' + assert actual_resp_body == pov.encode() + assert not header_exists('Connection', actual_headers) + + test_client.server_instance.protocol = original_server_protocol + + +@pytest.mark.parametrize( + 'timeout_before_headers', + ( + True, + False, + ) +) +def test_HTTP11_Timeout(test_client, timeout_before_headers): + """Check timeout without sending any data. + + The server will close the conn with a 408. + """ + conn = test_client.get_connection() + conn.auto_open = False + conn.connect() + + if not timeout_before_headers: + # Connect but send half the headers only. + conn.send(b'GET /hello HTTP/1.1') + conn.send(('Host: %s' % conn.host).encode('ascii')) + # else: Connect but send nothing. + + # Wait for our socket timeout + time.sleep(timeout * 2) + + # The request should have returned 408 already. + response = conn.response_class(conn.sock, method='GET') + response.begin() + assert response.status == 408 + conn.close() + + +def test_HTTP11_Timeout_after_request(test_client): + """Check timeout after at least one request has succeeded. + + The server should close the connection without 408. + """ + fail_msg = "Writing to timed out socket didn't fail as it should have: %s" + + # Make an initial request + conn = test_client.get_connection() + conn.putrequest('GET', '/timeout?t=%s' % timeout, skip_host=True) + conn.putheader('Host', conn.host) + conn.endheaders() + response = conn.response_class(conn.sock, method='GET') + response.begin() + assert response.status == 200 + actual_body = response.read() + expected_body = str(timeout).encode() + assert actual_body == expected_body + + # Make a second request on the same socket + conn._output(b'GET /hello HTTP/1.1') + conn._output(('Host: %s' % conn.host).encode('ascii')) + conn._send_output() + response = conn.response_class(conn.sock, method='GET') + response.begin() + assert response.status == 200 + actual_body = response.read() + expected_body = b'Hello, world!' + assert actual_body == expected_body + + # Wait for our socket timeout + time.sleep(timeout * 2) + + # Make another request on the same socket, which should error + conn._output(b'GET /hello HTTP/1.1') + conn._output(('Host: %s' % conn.host).encode('ascii')) + conn._send_output() + response = conn.response_class(conn.sock, method='GET') + try: + response.begin() + except (socket.error, http_client.BadStatusLine): + pass + except Exception as ex: + pytest.fail(fail_msg % ex) + else: + if response.status != 408: + pytest.fail(fail_msg % response.read()) + + conn.close() + + # Make another request on a new socket, which should work + conn = test_client.get_connection() + conn.putrequest('GET', '/pov', skip_host=True) + conn.putheader('Host', conn.host) + conn.endheaders() + response = conn.response_class(conn.sock, method='GET') + response.begin() + assert response.status == 200 + actual_body = response.read() + expected_body = pov.encode() + assert actual_body == expected_body + + # Make another request on the same socket, + # but timeout on the headers + conn.send(b'GET /hello HTTP/1.1') + # Wait for our socket timeout + time.sleep(timeout * 2) + response = conn.response_class(conn.sock, method='GET') + try: + response.begin() + except (socket.error, http_client.BadStatusLine): + pass + except Exception as ex: + pytest.fail(fail_msg % ex) + else: + if response.status != 408: + pytest.fail(fail_msg % response.read()) + + conn.close() + + # Retry the request on a new connection, which should work + conn = test_client.get_connection() + conn.putrequest('GET', '/pov', skip_host=True) + conn.putheader('Host', conn.host) + conn.endheaders() + response = conn.response_class(conn.sock, method='GET') + response.begin() + assert response.status == 200 + actual_body = response.read() + expected_body = pov.encode() + assert actual_body == expected_body + conn.close() + + +def test_HTTP11_pipelining(test_client): + """Test HTTP/1.1 pipelining. + + httplib doesn't support this directly. + """ + conn = test_client.get_connection() + + # Put request 1 + conn.putrequest('GET', '/hello', skip_host=True) + conn.putheader('Host', conn.host) + conn.endheaders() + + for trial in range(5): + # Put next request + conn._output( + ('GET /hello?%s HTTP/1.1' % trial).encode('iso-8859-1') + ) + conn._output(('Host: %s' % conn.host).encode('ascii')) + conn._send_output() + + # Retrieve previous response + response = conn.response_class(conn.sock, method='GET') + # there is a bug in python3 regarding the buffering of + # ``conn.sock``. Until that bug get's fixed we will + # monkey patch the ``reponse`` instance. + # https://bugs.python.org/issue23377 + if six.PY3: + response.fp = conn.sock.makefile('rb', 0) + response.begin() + body = response.read(13) + assert response.status == 200 + assert body == b'Hello, world!' + + # Retrieve final response + response = conn.response_class(conn.sock, method='GET') + response.begin() + body = response.read() + assert response.status == 200 + assert body == b'Hello, world!' + + conn.close() + + +def test_100_Continue(test_client): + """Test 100-continue header processing.""" + conn = test_client.get_connection() + + # Try a page without an Expect request header first. + # Note that httplib's response.begin automatically ignores + # 100 Continue responses, so we must manually check for it. + conn.putrequest('POST', '/upload', skip_host=True) + conn.putheader('Host', conn.host) + conn.putheader('Content-Type', 'text/plain') + conn.putheader('Content-Length', '4') + conn.endheaders() + conn.send(b"d'oh") + response = conn.response_class(conn.sock, method='POST') + version, status, reason = response._read_status() + assert status != 100 + conn.close() + + # Now try a page with an Expect header... + conn.connect() + conn.putrequest('POST', '/upload', skip_host=True) + conn.putheader('Host', conn.host) + conn.putheader('Content-Type', 'text/plain') + conn.putheader('Content-Length', '17') + conn.putheader('Expect', '100-continue') + conn.endheaders() + response = conn.response_class(conn.sock, method='POST') + + # ...assert and then skip the 100 response + version, status, reason = response._read_status() + assert status == 100 + while True: + line = response.fp.readline().strip() + if line: + pytest.fail( + '100 Continue should not output any headers. Got %r' % + line) + else: + break + + # ...send the body + body = b'I am a small file' + conn.send(body) + + # ...get the final response + response.begin() + status_line, actual_headers, actual_resp_body = webtest.shb(response) + actual_status = int(status_line[:3]) + assert actual_status == 200 + expected_resp_body = ("thanks for '%s'" % body).encode() + assert actual_resp_body == expected_resp_body + conn.close() + + +@pytest.mark.parametrize( + 'max_request_body_size', + ( + 0, + 1001, + ) +) +def test_readall_or_close(test_client, max_request_body_size): + """Test a max_request_body_size of 0 (the default) and 1001.""" + old_max = test_client.server_instance.max_request_body_size + + test_client.server_instance.max_request_body_size = max_request_body_size + + conn = test_client.get_connection() + + # Get a POST page with an error + conn.putrequest('POST', '/err_before_read', skip_host=True) + conn.putheader('Host', conn.host) + conn.putheader('Content-Type', 'text/plain') + conn.putheader('Content-Length', '1000') + conn.putheader('Expect', '100-continue') + conn.endheaders() + response = conn.response_class(conn.sock, method='POST') + + # ...assert and then skip the 100 response + version, status, reason = response._read_status() + assert status == 100 + skip = True + while skip: + skip = response.fp.readline().strip() + + # ...send the body + conn.send(b'x' * 1000) + + # ...get the final response + response.begin() + status_line, actual_headers, actual_resp_body = webtest.shb(response) + actual_status = int(status_line[:3]) + assert actual_status == 500 + + # Now try a working page with an Expect header... + conn._output(b'POST /upload HTTP/1.1') + conn._output(('Host: %s' % conn.host).encode('ascii')) + conn._output(b'Content-Type: text/plain') + conn._output(b'Content-Length: 17') + conn._output(b'Expect: 100-continue') + conn._send_output() + response = conn.response_class(conn.sock, method='POST') + + # ...assert and then skip the 100 response + version, status, reason = response._read_status() + assert status == 100 + skip = True + while skip: + skip = response.fp.readline().strip() + + # ...send the body + body = b'I am a small file' + conn.send(body) + + # ...get the final response + response.begin() + status_line, actual_headers, actual_resp_body = webtest.shb(response) + actual_status = int(status_line[:3]) + assert actual_status == 200 + expected_resp_body = ("thanks for '%s'" % body).encode() + assert actual_resp_body == expected_resp_body + conn.close() + + test_client.server_instance.max_request_body_size = old_max + + +def test_No_Message_Body(test_client): + """Test HTTP queries with an empty response body.""" + # Initialize a persistent HTTP connection + http_connection = test_client.get_connection() + http_connection.auto_open = False + http_connection.connect() + + # Make the first request and assert there's no "Connection: close". + status_line, actual_headers, actual_resp_body = test_client.get( + '/pov', http_conn=http_connection + ) + actual_status = int(status_line[:3]) + assert actual_status == 200 + assert status_line[4:] == 'OK' + assert actual_resp_body == pov.encode() + assert not header_exists('Connection', actual_headers) + + # Make a 204 request on the same connection. + status_line, actual_headers, actual_resp_body = test_client.get( + '/custom/204', http_conn=http_connection + ) + actual_status = int(status_line[:3]) + assert actual_status == 204 + assert not header_exists('Content-Length', actual_headers) + assert actual_resp_body == b'' + assert not header_exists('Connection', actual_headers) + + # Make a 304 request on the same connection. + status_line, actual_headers, actual_resp_body = test_client.get( + '/custom/304', http_conn=http_connection + ) + actual_status = int(status_line[:3]) + assert actual_status == 304 + assert not header_exists('Content-Length', actual_headers) + assert actual_resp_body == b'' + assert not header_exists('Connection', actual_headers) + + +@pytest.mark.xfail( + reason='Server does not correctly read trailers/ending of the previous ' + 'HTTP request, thus the second request fails as the server tries ' + r"to parse b'Content-Type: application/json\r\n' as a " + 'Request-Line. This results in HTTP status code 400, instead of 413' + 'Ref: https://github.com/cherrypy/cheroot/issues/69' +) +def test_Chunked_Encoding(test_client): + """Test HTTP uploads with chunked transfer-encoding.""" + # Initialize a persistent HTTP connection + conn = test_client.get_connection() + + # Try a normal chunked request (with extensions) + body = ( + b'8;key=value\r\nxx\r\nxxxx\r\n5\r\nyyyyy\r\n0\r\n' + b'Content-Type: application/json\r\n' + b'\r\n' + ) + conn.putrequest('POST', '/upload', skip_host=True) + conn.putheader('Host', conn.host) + conn.putheader('Transfer-Encoding', 'chunked') + conn.putheader('Trailer', 'Content-Type') + # Note that this is somewhat malformed: + # we shouldn't be sending Content-Length. + # RFC 2616 says the server should ignore it. + conn.putheader('Content-Length', '3') + conn.endheaders() + conn.send(body) + response = conn.getresponse() + status_line, actual_headers, actual_resp_body = webtest.shb(response) + actual_status = int(status_line[:3]) + assert actual_status == 200 + assert status_line[4:] == 'OK' + expected_resp_body = ("thanks for '%s'" % b'xx\r\nxxxxyyyyy').encode() + assert actual_resp_body == expected_resp_body + + # Try a chunked request that exceeds server.max_request_body_size. + # Note that the delimiters and trailer are included. + body = b'3e3\r\n' + (b'x' * 995) + b'\r\n0\r\n\r\n' + conn.putrequest('POST', '/upload', skip_host=True) + conn.putheader('Host', conn.host) + conn.putheader('Transfer-Encoding', 'chunked') + conn.putheader('Content-Type', 'text/plain') + # Chunked requests don't need a content-length + # conn.putheader("Content-Length", len(body)) + conn.endheaders() + conn.send(body) + response = conn.getresponse() + status_line, actual_headers, actual_resp_body = webtest.shb(response) + actual_status = int(status_line[:3]) + assert actual_status == 413 + conn.close() + + +def test_Content_Length_in(test_client): + """Try a non-chunked request where Content-Length exceeds limit. + + (server.max_request_body_size). + Assert error before body send. + """ + # Initialize a persistent HTTP connection + conn = test_client.get_connection() + + conn.putrequest('POST', '/upload', skip_host=True) + conn.putheader('Host', conn.host) + conn.putheader('Content-Type', 'text/plain') + conn.putheader('Content-Length', '9999') + conn.endheaders() + response = conn.getresponse() + status_line, actual_headers, actual_resp_body = webtest.shb(response) + actual_status = int(status_line[:3]) + assert actual_status == 413 + expected_resp_body = ( + b'The entity sent with the request exceeds ' + b'the maximum allowed bytes.' + ) + assert actual_resp_body == expected_resp_body + conn.close() + + +def test_Content_Length_not_int(test_client): + """Test that malicious Content-Length header returns 400.""" + status_line, actual_headers, actual_resp_body = test_client.post( + '/upload', + headers=[ + ('Content-Type', 'text/plain'), + ('Content-Length', 'not-an-integer'), + ], + ) + actual_status = int(status_line[:3]) + + assert actual_status == 400 + assert actual_resp_body == b'Malformed Content-Length Header.' + + +@pytest.mark.parametrize( + 'uri,expected_resp_status,expected_resp_body', + ( + ('/wrong_cl_buffered', 500, + (b'The requested resource returned more bytes than the ' + b'declared Content-Length.')), + ('/wrong_cl_unbuffered', 200, b'I too'), + ) +) +def test_Content_Length_out( + test_client, + uri, expected_resp_status, expected_resp_body +): + """Test response with Content-Length less than the response body. + + (non-chunked response) + """ + conn = test_client.get_connection() + conn.putrequest('GET', uri, skip_host=True) + conn.putheader('Host', conn.host) + conn.endheaders() + + response = conn.getresponse() + status_line, actual_headers, actual_resp_body = webtest.shb(response) + actual_status = int(status_line[:3]) + + assert actual_status == expected_resp_status + assert actual_resp_body == expected_resp_body + + conn.close() + + +@pytest.mark.xfail( + reason='Sometimes this test fails due to low timeout. ' + 'Ref: https://github.com/cherrypy/cherrypy/issues/598' +) +def test_598(test_client): + """Test serving large file with a read timeout in place.""" + # Initialize a persistent HTTP connection + conn = test_client.get_connection() + remote_data_conn = urllib.request.urlopen( + '%s://%s:%s/one_megabyte_of_a' + % ('http', conn.host, conn.port) + ) + buf = remote_data_conn.read(512) + time.sleep(timeout * 0.6) + remaining = (1024 * 1024) - 512 + while remaining: + data = remote_data_conn.read(remaining) + if not data: + break + buf += data + remaining -= len(data) + + assert len(buf) == 1024 * 1024 + assert buf == b'a' * 1024 * 1024 + assert remaining == 0 + remote_data_conn.close() + + +@pytest.mark.parametrize( + 'invalid_terminator', + ( + b'\n\n', + b'\r\n\n', + ) +) +def test_No_CRLF(test_client, invalid_terminator): + """Test HTTP queries with no valid CRLF terminators.""" + # Initialize a persistent HTTP connection + conn = test_client.get_connection() + + # (b'%s' % b'') is not supported in Python 3.4, so just use + + conn.send(b'GET /hello HTTP/1.1' + invalid_terminator) + response = conn.response_class(conn.sock, method='GET') + response.begin() + actual_resp_body = response.read() + expected_resp_body = b'HTTP requires CRLF terminators' + assert actual_resp_body == expected_resp_body + conn.close() diff --git a/libraries/cheroot/test/test_core.py b/libraries/cheroot/test/test_core.py new file mode 100644 index 00000000..7c91b13e --- /dev/null +++ b/libraries/cheroot/test/test_core.py @@ -0,0 +1,405 @@ +"""Tests for managing HTTP issues (malformed requests, etc).""" +# -*- coding: utf-8 -*- +# vim: set fileencoding=utf-8 : + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +import errno +import socket + +import pytest +import six +from six.moves import urllib + +from cheroot.test import helper + + +HTTP_BAD_REQUEST = 400 +HTTP_LENGTH_REQUIRED = 411 +HTTP_NOT_FOUND = 404 +HTTP_OK = 200 +HTTP_VERSION_NOT_SUPPORTED = 505 + + +class HelloController(helper.Controller): + """Controller for serving WSGI apps.""" + + def hello(req, resp): + """Render Hello world.""" + return 'Hello world!' + + def body_required(req, resp): + """Render Hello world or set 411.""" + if req.environ.get('Content-Length', None) is None: + resp.status = '411 Length Required' + return + return 'Hello world!' + + def query_string(req, resp): + """Render QUERY_STRING value.""" + return req.environ.get('QUERY_STRING', '') + + def asterisk(req, resp): + """Render request method value.""" + method = req.environ.get('REQUEST_METHOD', 'NO METHOD FOUND') + tmpl = 'Got asterisk URI path with {method} method' + return tmpl.format(**locals()) + + def _munge(string): + """Encode PATH_INFO correctly depending on Python version. + + WSGI 1.0 is a mess around unicode. Create endpoints + that match the PATH_INFO that it produces. + """ + if six.PY3: + return string.encode('utf-8').decode('latin-1') + return string + + handlers = { + '/hello': hello, + '/no_body': hello, + '/body_required': body_required, + '/query_string': query_string, + _munge('/привіт'): hello, + _munge('/Юххууу'): hello, + '/\xa0Ðblah key 0 900 4 data': hello, + '/*': asterisk, + } + + +def _get_http_response(connection, method='GET'): + c = connection + kwargs = {'strict': c.strict} if hasattr(c, 'strict') else {} + # Python 3.2 removed the 'strict' feature, saying: + # "http.client now always assumes HTTP/1.x compliant servers." + return c.response_class(c.sock, method=method, **kwargs) + + +@pytest.fixture +def testing_server(wsgi_server_client): + """Attach a WSGI app to the given server and pre-configure it.""" + wsgi_server = wsgi_server_client.server_instance + wsgi_server.wsgi_app = HelloController() + wsgi_server.max_request_body_size = 30000000 + wsgi_server.server_client = wsgi_server_client + return wsgi_server + + +@pytest.fixture +def test_client(testing_server): + """Get and return a test client out of the given server.""" + return testing_server.server_client + + +def test_http_connect_request(test_client): + """Check that CONNECT query results in Method Not Allowed status.""" + status_line = test_client.connect('/anything')[0] + actual_status = int(status_line[:3]) + assert actual_status == 405 + + +def test_normal_request(test_client): + """Check that normal GET query succeeds.""" + status_line, _, actual_resp_body = test_client.get('/hello') + actual_status = int(status_line[:3]) + assert actual_status == HTTP_OK + assert actual_resp_body == b'Hello world!' + + +def test_query_string_request(test_client): + """Check that GET param is parsed well.""" + status_line, _, actual_resp_body = test_client.get( + '/query_string?test=True' + ) + actual_status = int(status_line[:3]) + assert actual_status == HTTP_OK + assert actual_resp_body == b'test=True' + + +@pytest.mark.parametrize( + 'uri', + ( + '/hello', # plain + '/query_string?test=True', # query + '/{0}?{1}={2}'.format( # quoted unicode + *map(urllib.parse.quote, ('Юххууу', 'ї', 'йо')) + ), + ) +) +def test_parse_acceptable_uri(test_client, uri): + """Check that server responds with OK to valid GET queries.""" + status_line = test_client.get(uri)[0] + actual_status = int(status_line[:3]) + assert actual_status == HTTP_OK + + +@pytest.mark.xfail(six.PY2, reason='Fails on Python 2') +def test_parse_uri_unsafe_uri(test_client): + """Test that malicious URI does not allow HTTP injection. + + This effectively checks that sending GET request with URL + + /%A0%D0blah%20key%200%20900%204%20data + + is not converted into + + GET / + blah key 0 900 4 data + HTTP/1.1 + + which would be a security issue otherwise. + """ + c = test_client.get_connection() + resource = '/\xa0Ðblah key 0 900 4 data'.encode('latin-1') + quoted = urllib.parse.quote(resource) + assert quoted == '/%A0%D0blah%20key%200%20900%204%20data' + request = 'GET {quoted} HTTP/1.1'.format(**locals()) + c._output(request.encode('utf-8')) + c._send_output() + response = _get_http_response(c, method='GET') + response.begin() + assert response.status == HTTP_OK + assert response.fp.read(12) == b'Hello world!' + c.close() + + +def test_parse_uri_invalid_uri(test_client): + """Check that server responds with Bad Request to invalid GET queries. + + Invalid request line test case: it should only contain US-ASCII. + """ + c = test_client.get_connection() + c._output(u'GET /йопта! HTTP/1.1'.encode('utf-8')) + c._send_output() + response = _get_http_response(c, method='GET') + response.begin() + assert response.status == HTTP_BAD_REQUEST + assert response.fp.read(21) == b'Malformed Request-URI' + c.close() + + +@pytest.mark.parametrize( + 'uri', + ( + 'hello', # ascii + 'привіт', # non-ascii + ) +) +def test_parse_no_leading_slash_invalid(test_client, uri): + """Check that server responds with Bad Request to invalid GET queries. + + Invalid request line test case: it should have leading slash (be absolute). + """ + status_line, _, actual_resp_body = test_client.get( + urllib.parse.quote(uri) + ) + actual_status = int(status_line[:3]) + assert actual_status == HTTP_BAD_REQUEST + assert b'starting with a slash' in actual_resp_body + + +def test_parse_uri_absolute_uri(test_client): + """Check that server responds with Bad Request to Absolute URI. + + Only proxy servers should allow this. + """ + status_line, _, actual_resp_body = test_client.get('http://google.com/') + actual_status = int(status_line[:3]) + assert actual_status == HTTP_BAD_REQUEST + expected_body = b'Absolute URI not allowed if server is not a proxy.' + assert actual_resp_body == expected_body + + +def test_parse_uri_asterisk_uri(test_client): + """Check that server responds with OK to OPTIONS with "*" Absolute URI.""" + status_line, _, actual_resp_body = test_client.options('*') + actual_status = int(status_line[:3]) + assert actual_status == HTTP_OK + expected_body = b'Got asterisk URI path with OPTIONS method' + assert actual_resp_body == expected_body + + +def test_parse_uri_fragment_uri(test_client): + """Check that server responds with Bad Request to URI with fragment.""" + status_line, _, actual_resp_body = test_client.get( + '/hello?test=something#fake', + ) + actual_status = int(status_line[:3]) + assert actual_status == HTTP_BAD_REQUEST + expected_body = b'Illegal #fragment in Request-URI.' + assert actual_resp_body == expected_body + + +def test_no_content_length(test_client): + """Test POST query with an empty body being successful.""" + # "The presence of a message-body in a request is signaled by the + # inclusion of a Content-Length or Transfer-Encoding header field in + # the request's message-headers." + # + # Send a message with neither header and no body. + c = test_client.get_connection() + c.request('POST', '/no_body') + response = c.getresponse() + actual_resp_body = response.fp.read() + actual_status = response.status + assert actual_status == HTTP_OK + assert actual_resp_body == b'Hello world!' + + +def test_content_length_required(test_client): + """Test POST query with body failing because of missing Content-Length.""" + # Now send a message that has no Content-Length, but does send a body. + # Verify that CP times out the socket and responds + # with 411 Length Required. + + c = test_client.get_connection() + c.request('POST', '/body_required') + response = c.getresponse() + response.fp.read() + + actual_status = response.status + assert actual_status == HTTP_LENGTH_REQUIRED + + +@pytest.mark.parametrize( + 'request_line,status_code,expected_body', + ( + (b'GET /', # missing proto + HTTP_BAD_REQUEST, b'Malformed Request-Line'), + (b'GET / HTTPS/1.1', # invalid proto + HTTP_BAD_REQUEST, b'Malformed Request-Line: bad protocol'), + (b'GET / HTTP/2.15', # invalid ver + HTTP_VERSION_NOT_SUPPORTED, b'Cannot fulfill request'), + ) +) +def test_malformed_request_line( + test_client, request_line, + status_code, expected_body +): + """Test missing or invalid HTTP version in Request-Line.""" + c = test_client.get_connection() + c._output(request_line) + c._send_output() + response = _get_http_response(c, method='GET') + response.begin() + assert response.status == status_code + assert response.fp.read(len(expected_body)) == expected_body + c.close() + + +def test_malformed_http_method(test_client): + """Test non-uppercase HTTP method.""" + c = test_client.get_connection() + c.putrequest('GeT', '/malformed_method_case') + c.putheader('Content-Type', 'text/plain') + c.endheaders() + + response = c.getresponse() + actual_status = response.status + assert actual_status == HTTP_BAD_REQUEST + actual_resp_body = response.fp.read(21) + assert actual_resp_body == b'Malformed method name' + + +def test_malformed_header(test_client): + """Check that broken HTTP header results in Bad Request.""" + c = test_client.get_connection() + c.putrequest('GET', '/') + c.putheader('Content-Type', 'text/plain') + # See https://www.bitbucket.org/cherrypy/cherrypy/issue/941 + c._output(b'Re, 1.2.3.4#015#012') + c.endheaders() + + response = c.getresponse() + actual_status = response.status + assert actual_status == HTTP_BAD_REQUEST + actual_resp_body = response.fp.read(20) + assert actual_resp_body == b'Illegal header line.' + + +def test_request_line_split_issue_1220(test_client): + """Check that HTTP request line of exactly 256 chars length is OK.""" + Request_URI = ( + '/hello?' + 'intervenant-entreprise-evenement_classaction=' + 'evenement-mailremerciements' + '&_path=intervenant-entreprise-evenement' + '&intervenant-entreprise-evenement_action-id=19404' + '&intervenant-entreprise-evenement_id=19404' + '&intervenant-entreprise_id=28092' + ) + assert len('GET %s HTTP/1.1\r\n' % Request_URI) == 256 + + actual_resp_body = test_client.get(Request_URI)[2] + assert actual_resp_body == b'Hello world!' + + +def test_garbage_in(test_client): + """Test that server sends an error for garbage received over TCP.""" + # Connect without SSL regardless of server.scheme + + c = test_client.get_connection() + c._output(b'gjkgjklsgjklsgjkljklsg') + c._send_output() + response = c.response_class(c.sock, method='GET') + try: + response.begin() + actual_status = response.status + assert actual_status == HTTP_BAD_REQUEST + actual_resp_body = response.fp.read(22) + assert actual_resp_body == b'Malformed Request-Line' + c.close() + except socket.error as ex: + # "Connection reset by peer" is also acceptable. + if ex.errno != errno.ECONNRESET: + raise + + +class CloseController: + """Controller for testing the close callback.""" + + def __call__(self, environ, start_response): + """Get the req to know header sent status.""" + self.req = start_response.__self__.req + resp = CloseResponse(self.close) + start_response(resp.status, resp.headers.items()) + return resp + + def close(self): + """Close, writing hello.""" + self.req.write(b'hello') + + +class CloseResponse: + """Dummy empty response to trigger the no body status.""" + + def __init__(self, close): + """Use some defaults to ensure we have a header.""" + self.status = '200 OK' + self.headers = {'Content-Type': 'text/html'} + self.close = close + + def __getitem__(self, index): + """Ensure we don't have a body.""" + raise IndexError() + + def output(self): + """Return self to hook the close method.""" + return self + + +@pytest.fixture +def testing_server_close(wsgi_server_client): + """Attach a WSGI app to the given server and pre-configure it.""" + wsgi_server = wsgi_server_client.server_instance + wsgi_server.wsgi_app = CloseController() + wsgi_server.max_request_body_size = 30000000 + wsgi_server.server_client = wsgi_server_client + return wsgi_server + + +def test_send_header_before_closing(testing_server_close): + """Test we are actually sending the headers before calling 'close'.""" + _, _, resp_body = testing_server_close.server_client.get('/') + assert resp_body == b'hello' diff --git a/libraries/cheroot/test/test_server.py b/libraries/cheroot/test/test_server.py new file mode 100644 index 00000000..c53f7a81 --- /dev/null +++ b/libraries/cheroot/test/test_server.py @@ -0,0 +1,193 @@ +"""Tests for the HTTP server.""" +# -*- coding: utf-8 -*- +# vim: set fileencoding=utf-8 : + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +import os +import socket +import tempfile +import threading +import time + +import pytest + +from .._compat import bton +from ..server import Gateway, HTTPServer +from ..testing import ( + ANY_INTERFACE_IPV4, + ANY_INTERFACE_IPV6, + EPHEMERAL_PORT, + get_server_client, +) + + +def make_http_server(bind_addr): + """Create and start an HTTP server bound to bind_addr.""" + httpserver = HTTPServer( + bind_addr=bind_addr, + gateway=Gateway, + ) + + threading.Thread(target=httpserver.safe_start).start() + + while not httpserver.ready: + time.sleep(0.1) + + return httpserver + + +non_windows_sock_test = pytest.mark.skipif( + not hasattr(socket, 'AF_UNIX'), + reason='UNIX domain sockets are only available under UNIX-based OS', +) + + +@pytest.fixture +def http_server(): + """Provision a server creator as a fixture.""" + def start_srv(): + bind_addr = yield + httpserver = make_http_server(bind_addr) + yield httpserver + yield httpserver + + srv_creator = iter(start_srv()) + next(srv_creator) + yield srv_creator + try: + while True: + httpserver = next(srv_creator) + if httpserver is not None: + httpserver.stop() + except StopIteration: + pass + + +@pytest.fixture +def unix_sock_file(): + """Check that bound UNIX socket address is stored in server.""" + tmp_sock_fh, tmp_sock_fname = tempfile.mkstemp() + + yield tmp_sock_fname + + os.close(tmp_sock_fh) + os.unlink(tmp_sock_fname) + + +def test_prepare_makes_server_ready(): + """Check that prepare() makes the server ready, and stop() clears it.""" + httpserver = HTTPServer( + bind_addr=(ANY_INTERFACE_IPV4, EPHEMERAL_PORT), + gateway=Gateway, + ) + + assert not httpserver.ready + assert not httpserver.requests._threads + + httpserver.prepare() + + assert httpserver.ready + assert httpserver.requests._threads + for thr in httpserver.requests._threads: + assert thr.ready + + httpserver.stop() + + assert not httpserver.requests._threads + assert not httpserver.ready + + +def test_stop_interrupts_serve(): + """Check that stop() interrupts running of serve().""" + httpserver = HTTPServer( + bind_addr=(ANY_INTERFACE_IPV4, EPHEMERAL_PORT), + gateway=Gateway, + ) + + httpserver.prepare() + serve_thread = threading.Thread(target=httpserver.serve) + serve_thread.start() + + serve_thread.join(0.5) + assert serve_thread.is_alive() + + httpserver.stop() + + serve_thread.join(0.5) + assert not serve_thread.is_alive() + + +@pytest.mark.parametrize( + 'ip_addr', + ( + ANY_INTERFACE_IPV4, + ANY_INTERFACE_IPV6, + ) +) +def test_bind_addr_inet(http_server, ip_addr): + """Check that bound IP address is stored in server.""" + httpserver = http_server.send((ip_addr, EPHEMERAL_PORT)) + + assert httpserver.bind_addr[0] == ip_addr + assert httpserver.bind_addr[1] != EPHEMERAL_PORT + + +@non_windows_sock_test +def test_bind_addr_unix(http_server, unix_sock_file): + """Check that bound UNIX socket address is stored in server.""" + httpserver = http_server.send(unix_sock_file) + + assert httpserver.bind_addr == unix_sock_file + + +@pytest.mark.skip(reason="Abstract sockets don't work currently") +@non_windows_sock_test +def test_bind_addr_unix_abstract(http_server): + """Check that bound UNIX socket address is stored in server.""" + unix_abstract_sock = b'\x00cheroot/test/socket/here.sock' + httpserver = http_server.send(unix_abstract_sock) + + assert httpserver.bind_addr == unix_abstract_sock + + +PEERCRED_IDS_URI = '/peer_creds/ids' +PEERCRED_TEXTS_URI = '/peer_creds/texts' + + +class _TestGateway(Gateway): + def respond(self): + req = self.req + conn = req.conn + req_uri = bton(req.uri) + if req_uri == PEERCRED_IDS_URI: + peer_creds = conn.peer_pid, conn.peer_uid, conn.peer_gid + return ['|'.join(map(str, peer_creds))] + elif req_uri == PEERCRED_TEXTS_URI: + return ['!'.join((conn.peer_user, conn.peer_group))] + return super(_TestGateway, self).respond() + + +@pytest.mark.skip( + reason='Test HTTP client is not able to work through UNIX socket currently' +) +@non_windows_sock_test +def test_peercreds_unix_sock(http_server, unix_sock_file): + """Check that peercred lookup and resolution work when enabled.""" + httpserver = http_server.send(unix_sock_file) + httpserver.gateway = _TestGateway + httpserver.peercreds_enabled = True + + testclient = get_server_client(httpserver) + + expected_peercreds = os.getpid(), os.getuid(), os.getgid() + expected_peercreds = '|'.join(map(str, expected_peercreds)) + assert testclient.get(PEERCRED_IDS_URI) == expected_peercreds + assert 'RuntimeError' in testclient.get(PEERCRED_TEXTS_URI) + + httpserver.peercreds_resolve_enabled = True + import grp + expected_textcreds = os.getlogin(), grp.getgrgid(os.getgid()).gr_name + expected_textcreds = '!'.join(map(str, expected_textcreds)) + assert testclient.get(PEERCRED_TEXTS_URI) == expected_textcreds diff --git a/libraries/cheroot/test/webtest.py b/libraries/cheroot/test/webtest.py new file mode 100644 index 00000000..43448f5b --- /dev/null +++ b/libraries/cheroot/test/webtest.py @@ -0,0 +1,581 @@ +"""Extensions to unittest for web frameworks. + +Use the WebCase.getPage method to request a page from your HTTP server. +Framework Integration +===================== +If you have control over your server process, you can handle errors +in the server-side of the HTTP conversation a bit better. You must run +both the client (your WebCase tests) and the server in the same process +(but in separate threads, obviously). +When an error occurs in the framework, call server_error. It will print +the traceback to stdout, and keep any assertions you have from running +(the assumption is that, if the server errors, the page output will not +be of further significance to your tests). +""" + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +import pprint +import re +import socket +import sys +import time +import traceback +import os +import json +import unittest +import warnings + +from six.moves import range, http_client, map, urllib_parse +import six + +from more_itertools.more import always_iterable + + +def interface(host): + """Return an IP address for a client connection given the server host. + + If the server is listening on '0.0.0.0' (INADDR_ANY) + or '::' (IN6ADDR_ANY), this will return the proper localhost. + """ + if host == '0.0.0.0': + # INADDR_ANY, which should respond on localhost. + return '127.0.0.1' + if host == '::': + # IN6ADDR_ANY, which should respond on localhost. + return '::1' + return host + + +try: + # Jython support + if sys.platform[:4] == 'java': + def getchar(): + """Get a key press.""" + # Hopefully this is enough + return sys.stdin.read(1) + else: + # On Windows, msvcrt.getch reads a single char without output. + import msvcrt + + def getchar(): + """Get a key press.""" + return msvcrt.getch() +except ImportError: + # Unix getchr + import tty + import termios + + def getchar(): + """Get a key press.""" + fd = sys.stdin.fileno() + old_settings = termios.tcgetattr(fd) + try: + tty.setraw(sys.stdin.fileno()) + ch = sys.stdin.read(1) + finally: + termios.tcsetattr(fd, termios.TCSADRAIN, old_settings) + return ch + + +# from jaraco.properties +class NonDataProperty: + """Non-data property decorator.""" + + def __init__(self, fget): + """Initialize a non-data property.""" + assert fget is not None, 'fget cannot be none' + assert callable(fget), 'fget must be callable' + self.fget = fget + + def __get__(self, obj, objtype=None): + """Return a class property.""" + if obj is None: + return self + return self.fget(obj) + + +class WebCase(unittest.TestCase): + """Helper web test suite base.""" + + HOST = '127.0.0.1' + PORT = 8000 + HTTP_CONN = http_client.HTTPConnection + PROTOCOL = 'HTTP/1.1' + + scheme = 'http' + url = None + + status = None + headers = None + body = None + + encoding = 'utf-8' + + time = None + + @property + def _Conn(self): + """Return HTTPConnection or HTTPSConnection based on self.scheme. + + * from http.client. + """ + cls_name = '{scheme}Connection'.format(scheme=self.scheme.upper()) + return getattr(http_client, cls_name) + + def get_conn(self, auto_open=False): + """Return a connection to our HTTP server.""" + conn = self._Conn(self.interface(), self.PORT) + # Automatically re-connect? + conn.auto_open = auto_open + conn.connect() + return conn + + def set_persistent(self, on=True, auto_open=False): + """Make our HTTP_CONN persistent (or not). + + If the 'on' argument is True (the default), then self.HTTP_CONN + will be set to an instance of HTTP(S)?Connection + to persist across requests. + As this class only allows for a single open connection, if + self already has an open connection, it will be closed. + """ + try: + self.HTTP_CONN.close() + except (TypeError, AttributeError): + pass + + self.HTTP_CONN = ( + self.get_conn(auto_open=auto_open) + if on + else self._Conn + ) + + @property + def persistent(self): # noqa: D401; irrelevant for properties + """Presense of the persistent HTTP connection.""" + return hasattr(self.HTTP_CONN, '__class__') + + @persistent.setter + def persistent(self, on): + self.set_persistent(on) + + def interface(self): + """Return an IP address for a client connection. + + If the server is listening on '0.0.0.0' (INADDR_ANY) + or '::' (IN6ADDR_ANY), this will return the proper localhost. + """ + return interface(self.HOST) + + def getPage(self, url, headers=None, method='GET', body=None, + protocol=None, raise_subcls=None): + """Open the url with debugging support. Return status, headers, body. + + url should be the identifier passed to the server, typically a + server-absolute path and query string (sent between method and + protocol), and should only be an absolute URI if proxy support is + enabled in the server. + + If the application under test generates absolute URIs, be sure + to wrap them first with strip_netloc:: + + class MyAppWebCase(WebCase): + def getPage(url, *args, **kwargs): + super(MyAppWebCase, self).getPage( + cheroot.test.webtest.strip_netloc(url), + *args, **kwargs + ) + + `raise_subcls` must be a tuple with the exceptions classes + or a single exception class that are not going to be considered + a socket.error regardless that they were are subclass of a + socket.error and therefore not considered for a connection retry. + """ + ServerError.on = False + + if isinstance(url, six.text_type): + url = url.encode('utf-8') + if isinstance(body, six.text_type): + body = body.encode('utf-8') + + self.url = url + self.time = None + start = time.time() + result = openURL(url, headers, method, body, self.HOST, self.PORT, + self.HTTP_CONN, protocol or self.PROTOCOL, + raise_subcls) + self.time = time.time() - start + self.status, self.headers, self.body = result + + # Build a list of request cookies from the previous response cookies. + self.cookies = [('Cookie', v) for k, v in self.headers + if k.lower() == 'set-cookie'] + + if ServerError.on: + raise ServerError() + return result + + @NonDataProperty + def interactive(self): + """Determine whether tests are run in interactive mode. + + Load interactivity setting from environment, where + the value can be numeric or a string like true or + False or 1 or 0. + """ + env_str = os.environ.get('WEBTEST_INTERACTIVE', 'True') + is_interactive = bool(json.loads(env_str.lower())) + if is_interactive: + warnings.warn( + 'Interactive test failure interceptor support via ' + 'WEBTEST_INTERACTIVE environment variable is deprecated.', + DeprecationWarning + ) + return is_interactive + + console_height = 30 + + def _handlewebError(self, msg): + print('') + print(' ERROR: %s' % msg) + + if not self.interactive: + raise self.failureException(msg) + + p = (' Show: ' + '[B]ody [H]eaders [S]tatus [U]RL; ' + '[I]gnore, [R]aise, or sys.e[X]it >> ') + sys.stdout.write(p) + sys.stdout.flush() + while True: + i = getchar().upper() + if not isinstance(i, type('')): + i = i.decode('ascii') + if i not in 'BHSUIRX': + continue + print(i.upper()) # Also prints new line + if i == 'B': + for x, line in enumerate(self.body.splitlines()): + if (x + 1) % self.console_height == 0: + # The \r and comma should make the next line overwrite + sys.stdout.write('<-- More -->\r') + m = getchar().lower() + # Erase our "More" prompt + sys.stdout.write(' \r') + if m == 'q': + break + print(line) + elif i == 'H': + pprint.pprint(self.headers) + elif i == 'S': + print(self.status) + elif i == 'U': + print(self.url) + elif i == 'I': + # return without raising the normal exception + return + elif i == 'R': + raise self.failureException(msg) + elif i == 'X': + sys.exit() + sys.stdout.write(p) + sys.stdout.flush() + + @property + def status_code(self): # noqa: D401; irrelevant for properties + """Integer HTTP status code.""" + return int(self.status[:3]) + + def status_matches(self, expected): + """Check whether actual status matches expected.""" + actual = ( + self.status_code + if isinstance(expected, int) else + self.status + ) + return expected == actual + + def assertStatus(self, status, msg=None): + """Fail if self.status != status. + + status may be integer code, exact string status, or + iterable of allowed possibilities. + """ + if any(map(self.status_matches, always_iterable(status))): + return + + tmpl = 'Status {self.status} does not match {status}' + msg = msg or tmpl.format(**locals()) + self._handlewebError(msg) + + def assertHeader(self, key, value=None, msg=None): + """Fail if (key, [value]) not in self.headers.""" + lowkey = key.lower() + for k, v in self.headers: + if k.lower() == lowkey: + if value is None or str(value) == v: + return v + + if msg is None: + if value is None: + msg = '%r not in headers' % key + else: + msg = '%r:%r not in headers' % (key, value) + self._handlewebError(msg) + + def assertHeaderIn(self, key, values, msg=None): + """Fail if header indicated by key doesn't have one of the values.""" + lowkey = key.lower() + for k, v in self.headers: + if k.lower() == lowkey: + matches = [value for value in values if str(value) == v] + if matches: + return matches + + if msg is None: + msg = '%(key)r not in %(values)r' % vars() + self._handlewebError(msg) + + def assertHeaderItemValue(self, key, value, msg=None): + """Fail if the header does not contain the specified value.""" + actual_value = self.assertHeader(key, msg=msg) + header_values = map(str.strip, actual_value.split(',')) + if value in header_values: + return value + + if msg is None: + msg = '%r not in %r' % (value, header_values) + self._handlewebError(msg) + + def assertNoHeader(self, key, msg=None): + """Fail if key in self.headers.""" + lowkey = key.lower() + matches = [k for k, v in self.headers if k.lower() == lowkey] + if matches: + if msg is None: + msg = '%r in headers' % key + self._handlewebError(msg) + + def assertNoHeaderItemValue(self, key, value, msg=None): + """Fail if the header contains the specified value.""" + lowkey = key.lower() + hdrs = self.headers + matches = [k for k, v in hdrs if k.lower() == lowkey and v == value] + if matches: + if msg is None: + msg = '%r:%r in %r' % (key, value, hdrs) + self._handlewebError(msg) + + def assertBody(self, value, msg=None): + """Fail if value != self.body.""" + if isinstance(value, six.text_type): + value = value.encode(self.encoding) + if value != self.body: + if msg is None: + msg = 'expected body:\n%r\n\nactual body:\n%r' % ( + value, self.body) + self._handlewebError(msg) + + def assertInBody(self, value, msg=None): + """Fail if value not in self.body.""" + if isinstance(value, six.text_type): + value = value.encode(self.encoding) + if value not in self.body: + if msg is None: + msg = '%r not in body: %s' % (value, self.body) + self._handlewebError(msg) + + def assertNotInBody(self, value, msg=None): + """Fail if value in self.body.""" + if isinstance(value, six.text_type): + value = value.encode(self.encoding) + if value in self.body: + if msg is None: + msg = '%r found in body' % value + self._handlewebError(msg) + + def assertMatchesBody(self, pattern, msg=None, flags=0): + """Fail if value (a regex pattern) is not in self.body.""" + if isinstance(pattern, six.text_type): + pattern = pattern.encode(self.encoding) + if re.search(pattern, self.body, flags) is None: + if msg is None: + msg = 'No match for %r in body' % pattern + self._handlewebError(msg) + + +methods_with_bodies = ('POST', 'PUT', 'PATCH') + + +def cleanHeaders(headers, method, body, host, port): + """Return request headers, with required headers added (if missing).""" + if headers is None: + headers = [] + + # Add the required Host request header if not present. + # [This specifies the host:port of the server, not the client.] + found = False + for k, v in headers: + if k.lower() == 'host': + found = True + break + if not found: + if port == 80: + headers.append(('Host', host)) + else: + headers.append(('Host', '%s:%s' % (host, port))) + + if method in methods_with_bodies: + # Stick in default type and length headers if not present + found = False + for k, v in headers: + if k.lower() == 'content-type': + found = True + break + if not found: + headers.append( + ('Content-Type', 'application/x-www-form-urlencoded')) + headers.append(('Content-Length', str(len(body or '')))) + + return headers + + +def shb(response): + """Return status, headers, body the way we like from a response.""" + if six.PY3: + h = response.getheaders() + else: + h = [] + key, value = None, None + for line in response.msg.headers: + if line: + if line[0] in ' \t': + value += line.strip() + else: + if key and value: + h.append((key, value)) + key, value = line.split(':', 1) + key = key.strip() + value = value.strip() + if key and value: + h.append((key, value)) + + return '%s %s' % (response.status, response.reason), h, response.read() + + +def openURL(url, headers=None, method='GET', body=None, + host='127.0.0.1', port=8000, http_conn=http_client.HTTPConnection, + protocol='HTTP/1.1', raise_subcls=None): + """ + Open the given HTTP resource and return status, headers, and body. + + `raise_subcls` must be a tuple with the exceptions classes + or a single exception class that are not going to be considered + a socket.error regardless that they were are subclass of a + socket.error and therefore not considered for a connection retry. + """ + headers = cleanHeaders(headers, method, body, host, port) + + # Trying 10 times is simply in case of socket errors. + # Normal case--it should run once. + for trial in range(10): + try: + # Allow http_conn to be a class or an instance + if hasattr(http_conn, 'host'): + conn = http_conn + else: + conn = http_conn(interface(host), port) + + conn._http_vsn_str = protocol + conn._http_vsn = int(''.join([x for x in protocol if x.isdigit()])) + + if six.PY3 and isinstance(url, bytes): + url = url.decode() + conn.putrequest(method.upper(), url, skip_host=True, + skip_accept_encoding=True) + + for key, value in headers: + conn.putheader(key, value.encode('Latin-1')) + conn.endheaders() + + if body is not None: + conn.send(body) + + # Handle response + response = conn.getresponse() + + s, h, b = shb(response) + + if not hasattr(http_conn, 'host'): + # We made our own conn instance. Close it. + conn.close() + + return s, h, b + except socket.error as e: + if raise_subcls is not None and isinstance(e, raise_subcls): + raise + else: + time.sleep(0.5) + if trial == 9: + raise + + +def strip_netloc(url): + """Return absolute-URI path from URL. + + Strip the scheme and host from the URL, returning the + server-absolute portion. + + Useful for wrapping an absolute-URI for which only the + path is expected (such as in calls to getPage). + + >>> strip_netloc('https://google.com/foo/bar?bing#baz') + '/foo/bar?bing' + + >>> strip_netloc('//google.com/foo/bar?bing#baz') + '/foo/bar?bing' + + >>> strip_netloc('/foo/bar?bing#baz') + '/foo/bar?bing' + """ + parsed = urllib_parse.urlparse(url) + scheme, netloc, path, params, query, fragment = parsed + stripped = '', '', path, params, query, '' + return urllib_parse.urlunparse(stripped) + + +# Add any exceptions which your web framework handles +# normally (that you don't want server_error to trap). +ignored_exceptions = [] + +# You'll want set this to True when you can't guarantee +# that each response will immediately follow each request; +# for example, when handling requests via multiple threads. +ignore_all = False + + +class ServerError(Exception): + """Exception for signalling server error.""" + + on = False + + +def server_error(exc=None): + """Server debug hook. + + Return True if exception handled, False if ignored. + You probably want to wrap this, so you can still handle an error using + your framework when it's ignored. + """ + if exc is None: + exc = sys.exc_info() + + if ignore_all or exc[0] in ignored_exceptions: + return False + else: + ServerError.on = True + print('') + print(''.join(traceback.format_exception(*exc))) + return True diff --git a/libraries/cheroot/testing.py b/libraries/cheroot/testing.py new file mode 100644 index 00000000..f01d0aa1 --- /dev/null +++ b/libraries/cheroot/testing.py @@ -0,0 +1,144 @@ +"""Pytest fixtures and other helpers for doing testing by end-users.""" + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +from contextlib import closing +import errno +import socket +import threading +import time + +import pytest +from six.moves import http_client + +import cheroot.server +from cheroot.test import webtest +import cheroot.wsgi + +EPHEMERAL_PORT = 0 +NO_INTERFACE = None # Using this or '' will cause an exception +ANY_INTERFACE_IPV4 = '0.0.0.0' +ANY_INTERFACE_IPV6 = '::' + +config = { + cheroot.wsgi.Server: { + 'bind_addr': (NO_INTERFACE, EPHEMERAL_PORT), + 'wsgi_app': None, + }, + cheroot.server.HTTPServer: { + 'bind_addr': (NO_INTERFACE, EPHEMERAL_PORT), + 'gateway': cheroot.server.Gateway, + }, +} + + +def cheroot_server(server_factory): + """Set up and tear down a Cheroot server instance.""" + conf = config[server_factory].copy() + bind_port = conf.pop('bind_addr')[-1] + + for interface in ANY_INTERFACE_IPV6, ANY_INTERFACE_IPV4: + try: + actual_bind_addr = (interface, bind_port) + httpserver = server_factory( # create it + bind_addr=actual_bind_addr, + **conf + ) + except OSError: + pass + else: + break + + threading.Thread(target=httpserver.safe_start).start() # spawn it + while not httpserver.ready: # wait until fully initialized and bound + time.sleep(0.1) + + yield httpserver + + httpserver.stop() # destroy it + + +@pytest.fixture(scope='module') +def wsgi_server(): + """Set up and tear down a Cheroot WSGI server instance.""" + for srv in cheroot_server(cheroot.wsgi.Server): + yield srv + + +@pytest.fixture(scope='module') +def native_server(): + """Set up and tear down a Cheroot HTTP server instance.""" + for srv in cheroot_server(cheroot.server.HTTPServer): + yield srv + + +class _TestClient: + def __init__(self, server): + self._interface, self._host, self._port = _get_conn_data(server) + self._http_connection = self.get_connection() + self.server_instance = server + + def get_connection(self): + name = '{interface}:{port}'.format( + interface=self._interface, + port=self._port, + ) + return http_client.HTTPConnection(name) + + def request( + self, uri, method='GET', headers=None, http_conn=None, + protocol='HTTP/1.1', + ): + return webtest.openURL( + uri, method=method, + headers=headers, + host=self._host, port=self._port, + http_conn=http_conn or self._http_connection, + protocol=protocol, + ) + + def __getattr__(self, attr_name): + def _wrapper(uri, **kwargs): + http_method = attr_name.upper() + return self.request(uri, method=http_method, **kwargs) + + return _wrapper + + +def _probe_ipv6_sock(interface): + # Alternate way is to check IPs on interfaces using glibc, like: + # github.com/Gautier/minifail/blob/master/minifail/getifaddrs.py + try: + with closing(socket.socket(family=socket.AF_INET6)) as sock: + sock.bind((interface, 0)) + except (OSError, socket.error) as sock_err: + # In Python 3 socket.error is an alias for OSError + # In Python 2 socket.error is a subclass of IOError + if sock_err.errno != errno.EADDRNOTAVAIL: + raise + else: + return True + + return False + + +def _get_conn_data(server): + if isinstance(server.bind_addr, tuple): + host, port = server.bind_addr + else: + host, port = server.bind_addr, 0 + + interface = webtest.interface(host) + + if ':' in interface and not _probe_ipv6_sock(interface): + interface = '127.0.0.1' + if ':' in host: + host = interface + + return interface, host, port + + +def get_server_client(server): + """Create and return a test client for the given server.""" + return _TestClient(server) diff --git a/libraries/cheroot/workers/__init__.py b/libraries/cheroot/workers/__init__.py new file mode 100644 index 00000000..098b8f25 --- /dev/null +++ b/libraries/cheroot/workers/__init__.py @@ -0,0 +1 @@ +"""HTTP workers pool.""" diff --git a/libraries/cheroot/workers/threadpool.py b/libraries/cheroot/workers/threadpool.py new file mode 100644 index 00000000..ff8fbcee --- /dev/null +++ b/libraries/cheroot/workers/threadpool.py @@ -0,0 +1,271 @@ +"""A thread-based worker pool.""" + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + + +import threading +import time +import socket + +from six.moves import queue + + +__all__ = ('WorkerThread', 'ThreadPool') + + +class TrueyZero: + """Object which equals and does math like the integer 0 but evals True.""" + + def __add__(self, other): + return other + + def __radd__(self, other): + return other + + +trueyzero = TrueyZero() + +_SHUTDOWNREQUEST = None + + +class WorkerThread(threading.Thread): + """Thread which continuously polls a Queue for Connection objects. + + Due to the timing issues of polling a Queue, a WorkerThread does not + check its own 'ready' flag after it has started. To stop the thread, + it is necessary to stick a _SHUTDOWNREQUEST object onto the Queue + (one for each running WorkerThread). + """ + + conn = None + """The current connection pulled off the Queue, or None.""" + + server = None + """The HTTP Server which spawned this thread, and which owns the + Queue and is placing active connections into it.""" + + ready = False + """A simple flag for the calling server to know when this thread + has begun polling the Queue.""" + + def __init__(self, server): + """Initialize WorkerThread instance. + + Args: + server (cheroot.server.HTTPServer): web server object + receiving this request + """ + self.ready = False + self.server = server + + self.requests_seen = 0 + self.bytes_read = 0 + self.bytes_written = 0 + self.start_time = None + self.work_time = 0 + self.stats = { + 'Requests': lambda s: self.requests_seen + ( + (self.start_time is None) and + trueyzero or + self.conn.requests_seen + ), + 'Bytes Read': lambda s: self.bytes_read + ( + (self.start_time is None) and + trueyzero or + self.conn.rfile.bytes_read + ), + 'Bytes Written': lambda s: self.bytes_written + ( + (self.start_time is None) and + trueyzero or + self.conn.wfile.bytes_written + ), + 'Work Time': lambda s: self.work_time + ( + (self.start_time is None) and + trueyzero or + time.time() - self.start_time + ), + 'Read Throughput': lambda s: s['Bytes Read'](s) / ( + s['Work Time'](s) or 1e-6), + 'Write Throughput': lambda s: s['Bytes Written'](s) / ( + s['Work Time'](s) or 1e-6), + } + threading.Thread.__init__(self) + + def run(self): + """Process incoming HTTP connections. + + Retrieves incoming connections from thread pool. + """ + self.server.stats['Worker Threads'][self.getName()] = self.stats + try: + self.ready = True + while True: + conn = self.server.requests.get() + if conn is _SHUTDOWNREQUEST: + return + + self.conn = conn + if self.server.stats['Enabled']: + self.start_time = time.time() + try: + conn.communicate() + finally: + conn.close() + if self.server.stats['Enabled']: + self.requests_seen += self.conn.requests_seen + self.bytes_read += self.conn.rfile.bytes_read + self.bytes_written += self.conn.wfile.bytes_written + self.work_time += time.time() - self.start_time + self.start_time = None + self.conn = None + except (KeyboardInterrupt, SystemExit) as ex: + self.server.interrupt = ex + + +class ThreadPool: + """A Request Queue for an HTTPServer which pools threads. + + ThreadPool objects must provide min, get(), put(obj), start() + and stop(timeout) attributes. + """ + + def __init__( + self, server, min=10, max=-1, accepted_queue_size=-1, + accepted_queue_timeout=10): + """Initialize HTTP requests queue instance. + + Args: + server (cheroot.server.HTTPServer): web server object + receiving this request + min (int): minimum number of worker threads + max (int): maximum number of worker threads + accepted_queue_size (int): maximum number of active + requests in queue + accepted_queue_timeout (int): timeout for putting request + into queue + """ + self.server = server + self.min = min + self.max = max + self._threads = [] + self._queue = queue.Queue(maxsize=accepted_queue_size) + self._queue_put_timeout = accepted_queue_timeout + self.get = self._queue.get + + def start(self): + """Start the pool of threads.""" + for i in range(self.min): + self._threads.append(WorkerThread(self.server)) + for worker in self._threads: + worker.setName('CP Server ' + worker.getName()) + worker.start() + for worker in self._threads: + while not worker.ready: + time.sleep(.1) + + @property + def idle(self): # noqa: D401; irrelevant for properties + """Number of worker threads which are idle. Read-only.""" + return len([t for t in self._threads if t.conn is None]) + + def put(self, obj): + """Put request into queue. + + Args: + obj (cheroot.server.HTTPConnection): HTTP connection + waiting to be processed + """ + self._queue.put(obj, block=True, timeout=self._queue_put_timeout) + if obj is _SHUTDOWNREQUEST: + return + + def grow(self, amount): + """Spawn new worker threads (not above self.max).""" + if self.max > 0: + budget = max(self.max - len(self._threads), 0) + else: + # self.max <= 0 indicates no maximum + budget = float('inf') + + n_new = min(amount, budget) + + workers = [self._spawn_worker() for i in range(n_new)] + while not all(worker.ready for worker in workers): + time.sleep(.1) + self._threads.extend(workers) + + def _spawn_worker(self): + worker = WorkerThread(self.server) + worker.setName('CP Server ' + worker.getName()) + worker.start() + return worker + + def shrink(self, amount): + """Kill off worker threads (not below self.min).""" + # Grow/shrink the pool if necessary. + # Remove any dead threads from our list + for t in self._threads: + if not t.isAlive(): + self._threads.remove(t) + amount -= 1 + + # calculate the number of threads above the minimum + n_extra = max(len(self._threads) - self.min, 0) + + # don't remove more than amount + n_to_remove = min(amount, n_extra) + + # put shutdown requests on the queue equal to the number of threads + # to remove. As each request is processed by a worker, that worker + # will terminate and be culled from the list. + for n in range(n_to_remove): + self._queue.put(_SHUTDOWNREQUEST) + + def stop(self, timeout=5): + """Terminate all worker threads. + + Args: + timeout (int): time to wait for threads to stop gracefully + """ + # Must shut down threads here so the code that calls + # this method can know when all threads are stopped. + for worker in self._threads: + self._queue.put(_SHUTDOWNREQUEST) + + # Don't join currentThread (when stop is called inside a request). + current = threading.currentThread() + if timeout is not None and timeout >= 0: + endtime = time.time() + timeout + while self._threads: + worker = self._threads.pop() + if worker is not current and worker.isAlive(): + try: + if timeout is None or timeout < 0: + worker.join() + else: + remaining_time = endtime - time.time() + if remaining_time > 0: + worker.join(remaining_time) + if worker.isAlive(): + # We exhausted the timeout. + # Forcibly shut down the socket. + c = worker.conn + if c and not c.rfile.closed: + try: + c.socket.shutdown(socket.SHUT_RD) + except TypeError: + # pyOpenSSL sockets don't take an arg + c.socket.shutdown() + worker.join() + except (AssertionError, + # Ignore repeated Ctrl-C. + # See + # https://github.com/cherrypy/cherrypy/issues/691. + KeyboardInterrupt): + pass + + @property + def qsize(self): + """Return the queue size.""" + return self._queue.qsize() diff --git a/libraries/cheroot/wsgi.py b/libraries/cheroot/wsgi.py new file mode 100644 index 00000000..a04c9438 --- /dev/null +++ b/libraries/cheroot/wsgi.py @@ -0,0 +1,423 @@ +"""This class holds Cheroot WSGI server implementation. + +Simplest example on how to use this server:: + + from cheroot import wsgi + + def my_crazy_app(environ, start_response): + status = '200 OK' + response_headers = [('Content-type','text/plain')] + start_response(status, response_headers) + return [b'Hello world!'] + + addr = '0.0.0.0', 8070 + server = wsgi.Server(addr, my_crazy_app) + server.start() + +The Cheroot WSGI server can serve as many WSGI applications +as you want in one instance by using a PathInfoDispatcher:: + + path_map = { + '/': my_crazy_app, + '/blog': my_blog_app, + } + d = wsgi.PathInfoDispatcher(path_map) + server = wsgi.Server(addr, d) +""" + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +import sys + +import six +from six.moves import filter + +from . import server +from .workers import threadpool +from ._compat import ntob, bton + + +class Server(server.HTTPServer): + """A subclass of HTTPServer which calls a WSGI application.""" + + wsgi_version = (1, 0) + """The version of WSGI to produce.""" + + def __init__( + self, bind_addr, wsgi_app, numthreads=10, server_name=None, + max=-1, request_queue_size=5, timeout=10, shutdown_timeout=5, + accepted_queue_size=-1, accepted_queue_timeout=10, + peercreds_enabled=False, peercreds_resolve_enabled=False, + ): + """Initialize WSGI Server instance. + + Args: + bind_addr (tuple): network interface to listen to + wsgi_app (callable): WSGI application callable + numthreads (int): number of threads for WSGI thread pool + server_name (str): web server name to be advertised via + Server HTTP header + max (int): maximum number of worker threads + request_queue_size (int): the 'backlog' arg to + socket.listen(); max queued connections + timeout (int): the timeout in seconds for accepted connections + shutdown_timeout (int): the total time, in seconds, to + wait for worker threads to cleanly exit + accepted_queue_size (int): maximum number of active + requests in queue + accepted_queue_timeout (int): timeout for putting request + into queue + """ + super(Server, self).__init__( + bind_addr, + gateway=wsgi_gateways[self.wsgi_version], + server_name=server_name, + peercreds_enabled=peercreds_enabled, + peercreds_resolve_enabled=peercreds_resolve_enabled, + ) + self.wsgi_app = wsgi_app + self.request_queue_size = request_queue_size + self.timeout = timeout + self.shutdown_timeout = shutdown_timeout + self.requests = threadpool.ThreadPool( + self, min=numthreads or 1, max=max, + accepted_queue_size=accepted_queue_size, + accepted_queue_timeout=accepted_queue_timeout) + + @property + def numthreads(self): + """Set minimum number of threads.""" + return self.requests.min + + @numthreads.setter + def numthreads(self, value): + self.requests.min = value + + +class Gateway(server.Gateway): + """A base class to interface HTTPServer with WSGI.""" + + def __init__(self, req): + """Initialize WSGI Gateway instance with request. + + Args: + req (HTTPRequest): current HTTP request + """ + super(Gateway, self).__init__(req) + self.started_response = False + self.env = self.get_environ() + self.remaining_bytes_out = None + + @classmethod + def gateway_map(cls): + """Create a mapping of gateways and their versions. + + Returns: + dict[tuple[int,int],class]: map of gateway version and + corresponding class + + """ + return dict( + (gw.version, gw) + for gw in cls.__subclasses__() + ) + + def get_environ(self): + """Return a new environ dict targeting the given wsgi.version.""" + raise NotImplementedError + + def respond(self): + """Process the current request. + + From :pep:`333`: + + The start_response callable must not actually transmit + the response headers. Instead, it must store them for the + server or gateway to transmit only after the first + iteration of the application return value that yields + a NON-EMPTY string, or upon the application's first + invocation of the write() callable. + """ + response = self.req.server.wsgi_app(self.env, self.start_response) + try: + for chunk in filter(None, response): + if not isinstance(chunk, six.binary_type): + raise ValueError('WSGI Applications must yield bytes') + self.write(chunk) + finally: + # Send headers if not already sent + self.req.ensure_headers_sent() + if hasattr(response, 'close'): + response.close() + + def start_response(self, status, headers, exc_info=None): + """WSGI callable to begin the HTTP response.""" + # "The application may call start_response more than once, + # if and only if the exc_info argument is provided." + if self.started_response and not exc_info: + raise AssertionError('WSGI start_response called a second ' + 'time with no exc_info.') + self.started_response = True + + # "if exc_info is provided, and the HTTP headers have already been + # sent, start_response must raise an error, and should raise the + # exc_info tuple." + if self.req.sent_headers: + try: + six.reraise(*exc_info) + finally: + exc_info = None + + self.req.status = self._encode_status(status) + + for k, v in headers: + if not isinstance(k, str): + raise TypeError( + 'WSGI response header key %r is not of type str.' % k) + if not isinstance(v, str): + raise TypeError( + 'WSGI response header value %r is not of type str.' % v) + if k.lower() == 'content-length': + self.remaining_bytes_out = int(v) + out_header = ntob(k), ntob(v) + self.req.outheaders.append(out_header) + + return self.write + + @staticmethod + def _encode_status(status): + """Cast status to bytes representation of current Python version. + + According to :pep:`3333`, when using Python 3, the response status + and headers must be bytes masquerading as unicode; that is, they + must be of type "str" but are restricted to code points in the + "latin-1" set. + """ + if six.PY2: + return status + if not isinstance(status, str): + raise TypeError('WSGI response status is not of type str.') + return status.encode('ISO-8859-1') + + def write(self, chunk): + """WSGI callable to write unbuffered data to the client. + + This method is also used internally by start_response (to write + data from the iterable returned by the WSGI application). + """ + if not self.started_response: + raise AssertionError('WSGI write called before start_response.') + + chunklen = len(chunk) + rbo = self.remaining_bytes_out + if rbo is not None and chunklen > rbo: + if not self.req.sent_headers: + # Whew. We can send a 500 to the client. + self.req.simple_response( + '500 Internal Server Error', + 'The requested resource returned more bytes than the ' + 'declared Content-Length.') + else: + # Dang. We have probably already sent data. Truncate the chunk + # to fit (so the client doesn't hang) and raise an error later. + chunk = chunk[:rbo] + + self.req.ensure_headers_sent() + + self.req.write(chunk) + + if rbo is not None: + rbo -= chunklen + if rbo < 0: + raise ValueError( + 'Response body exceeds the declared Content-Length.') + + +class Gateway_10(Gateway): + """A Gateway class to interface HTTPServer with WSGI 1.0.x.""" + + version = 1, 0 + + def get_environ(self): + """Return a new environ dict targeting the given wsgi.version.""" + req = self.req + req_conn = req.conn + env = { + # set a non-standard environ entry so the WSGI app can know what + # the *real* server protocol is (and what features to support). + # See http://www.faqs.org/rfcs/rfc2145.html. + 'ACTUAL_SERVER_PROTOCOL': req.server.protocol, + 'PATH_INFO': bton(req.path), + 'QUERY_STRING': bton(req.qs), + 'REMOTE_ADDR': req_conn.remote_addr or '', + 'REMOTE_PORT': str(req_conn.remote_port or ''), + 'REQUEST_METHOD': bton(req.method), + 'REQUEST_URI': bton(req.uri), + 'SCRIPT_NAME': '', + 'SERVER_NAME': req.server.server_name, + # Bah. "SERVER_PROTOCOL" is actually the REQUEST protocol. + 'SERVER_PROTOCOL': bton(req.request_protocol), + 'SERVER_SOFTWARE': req.server.software, + 'wsgi.errors': sys.stderr, + 'wsgi.input': req.rfile, + 'wsgi.input_terminated': bool(req.chunked_read), + 'wsgi.multiprocess': False, + 'wsgi.multithread': True, + 'wsgi.run_once': False, + 'wsgi.url_scheme': bton(req.scheme), + 'wsgi.version': self.version, + } + + if isinstance(req.server.bind_addr, six.string_types): + # AF_UNIX. This isn't really allowed by WSGI, which doesn't + # address unix domain sockets. But it's better than nothing. + env['SERVER_PORT'] = '' + try: + env['X_REMOTE_PID'] = str(req_conn.peer_pid) + env['X_REMOTE_UID'] = str(req_conn.peer_uid) + env['X_REMOTE_GID'] = str(req_conn.peer_gid) + + env['X_REMOTE_USER'] = str(req_conn.peer_user) + env['X_REMOTE_GROUP'] = str(req_conn.peer_group) + + env['REMOTE_USER'] = env['X_REMOTE_USER'] + except RuntimeError: + """Unable to retrieve peer creds data. + + Unsupported by current kernel or socket error happened, or + unsupported socket type, or disabled. + """ + else: + env['SERVER_PORT'] = str(req.server.bind_addr[1]) + + # Request headers + env.update( + ('HTTP_' + bton(k).upper().replace('-', '_'), bton(v)) + for k, v in req.inheaders.items() + ) + + # CONTENT_TYPE/CONTENT_LENGTH + ct = env.pop('HTTP_CONTENT_TYPE', None) + if ct is not None: + env['CONTENT_TYPE'] = ct + cl = env.pop('HTTP_CONTENT_LENGTH', None) + if cl is not None: + env['CONTENT_LENGTH'] = cl + + if req.conn.ssl_env: + env.update(req.conn.ssl_env) + + return env + + +class Gateway_u0(Gateway_10): + """A Gateway class to interface HTTPServer with WSGI u.0. + + WSGI u.0 is an experimental protocol, which uses unicode for keys + and values in both Python 2 and Python 3. + """ + + version = 'u', 0 + + def get_environ(self): + """Return a new environ dict targeting the given wsgi.version.""" + req = self.req + env_10 = super(Gateway_u0, self).get_environ() + env = dict(map(self._decode_key, env_10.items())) + + # Request-URI + enc = env.setdefault(six.u('wsgi.url_encoding'), six.u('utf-8')) + try: + env['PATH_INFO'] = req.path.decode(enc) + env['QUERY_STRING'] = req.qs.decode(enc) + except UnicodeDecodeError: + # Fall back to latin 1 so apps can transcode if needed. + env['wsgi.url_encoding'] = 'ISO-8859-1' + env['PATH_INFO'] = env_10['PATH_INFO'] + env['QUERY_STRING'] = env_10['QUERY_STRING'] + + env.update(map(self._decode_value, env.items())) + + return env + + @staticmethod + def _decode_key(item): + k, v = item + if six.PY2: + k = k.decode('ISO-8859-1') + return k, v + + @staticmethod + def _decode_value(item): + k, v = item + skip_keys = 'REQUEST_URI', 'wsgi.input' + if six.PY3 or not isinstance(v, bytes) or k in skip_keys: + return k, v + return k, v.decode('ISO-8859-1') + + +wsgi_gateways = Gateway.gateway_map() + + +class PathInfoDispatcher: + """A WSGI dispatcher for dispatch based on the PATH_INFO.""" + + def __init__(self, apps): + """Initialize path info WSGI app dispatcher. + + Args: + apps (dict[str,object]|list[tuple[str,object]]): URI prefix + and WSGI app pairs + """ + try: + apps = list(apps.items()) + except AttributeError: + pass + + # Sort the apps by len(path), descending + def by_path_len(app): + return len(app[0]) + apps.sort(key=by_path_len, reverse=True) + + # The path_prefix strings must start, but not end, with a slash. + # Use "" instead of "/". + self.apps = [(p.rstrip('/'), a) for p, a in apps] + + def __call__(self, environ, start_response): + """Process incoming WSGI request. + + Ref: :pep:`3333` + + Args: + environ (Mapping): a dict containing WSGI environment variables + start_response (callable): function, which sets response + status and headers + + Returns: + list[bytes]: iterable containing bytes to be returned in + HTTP response body + + """ + path = environ['PATH_INFO'] or '/' + for p, app in self.apps: + # The apps list should be sorted by length, descending. + if path.startswith(p + '/') or path == p: + environ = environ.copy() + environ['SCRIPT_NAME'] = environ['SCRIPT_NAME'] + p + environ['PATH_INFO'] = path[len(p):] + return app(environ, start_response) + + start_response('404 Not Found', [('Content-Type', 'text/plain'), + ('Content-Length', '0')]) + return [''] + + +# compatibility aliases +globals().update( + WSGIServer=Server, + WSGIGateway=Gateway, + WSGIGateway_u0=Gateway_u0, + WSGIGateway_10=Gateway_10, + WSGIPathInfoDispatcher=PathInfoDispatcher, +) diff --git a/libraries/cherrypy/__init__.py b/libraries/cherrypy/__init__.py new file mode 100644 index 00000000..c5925980 --- /dev/null +++ b/libraries/cherrypy/__init__.py @@ -0,0 +1,362 @@ +"""CherryPy is a pythonic, object-oriented HTTP framework. + +CherryPy consists of not one, but four separate API layers. + +The APPLICATION LAYER is the simplest. CherryPy applications are written as +a tree of classes and methods, where each branch in the tree corresponds to +a branch in the URL path. Each method is a 'page handler', which receives +GET and POST params as keyword arguments, and returns or yields the (HTML) +body of the response. The special method name 'index' is used for paths +that end in a slash, and the special method name 'default' is used to +handle multiple paths via a single handler. This layer also includes: + + * the 'exposed' attribute (and cherrypy.expose) + * cherrypy.quickstart() + * _cp_config attributes + * cherrypy.tools (including cherrypy.session) + * cherrypy.url() + +The ENVIRONMENT LAYER is used by developers at all levels. It provides +information about the current request and response, plus the application +and server environment, via a (default) set of top-level objects: + + * cherrypy.request + * cherrypy.response + * cherrypy.engine + * cherrypy.server + * cherrypy.tree + * cherrypy.config + * cherrypy.thread_data + * cherrypy.log + * cherrypy.HTTPError, NotFound, and HTTPRedirect + * cherrypy.lib + +The EXTENSION LAYER allows advanced users to construct and share their own +plugins. It consists of: + + * Hook API + * Tool API + * Toolbox API + * Dispatch API + * Config Namespace API + +Finally, there is the CORE LAYER, which uses the core API's to construct +the default components which are available at higher layers. You can think +of the default components as the 'reference implementation' for CherryPy. +Megaframeworks (and advanced users) may replace the default components +with customized or extended components. The core API's are: + + * Application API + * Engine API + * Request API + * Server API + * WSGI API + +These API's are described in the `CherryPy specification +<https://github.com/cherrypy/cherrypy/wiki/CherryPySpec>`_. +""" + +from threading import local as _local + +from ._cperror import ( + HTTPError, HTTPRedirect, InternalRedirect, + NotFound, CherryPyException, +) + +from . import _cpdispatch as dispatch + +from ._cptools import default_toolbox as tools, Tool +from ._helper import expose, popargs, url + +from . import _cprequest, _cpserver, _cptree, _cplogging, _cpconfig + +import cherrypy.lib.httputil as _httputil + +from ._cptree import Application +from . import _cpwsgi as wsgi + +from . import process +try: + from .process import win32 + engine = win32.Win32Bus() + engine.console_control_handler = win32.ConsoleCtrlHandler(engine) + del win32 +except ImportError: + engine = process.bus + +from . import _cpchecker + +__all__ = ( + 'HTTPError', 'HTTPRedirect', 'InternalRedirect', + 'NotFound', 'CherryPyException', + 'dispatch', 'tools', 'Tool', 'Application', + 'wsgi', 'process', 'tree', 'engine', + 'quickstart', 'serving', 'request', 'response', 'thread_data', + 'log', 'expose', 'popargs', 'url', 'config', +) + + +__import__('cherrypy._cptools') +__import__('cherrypy._cprequest') + + +tree = _cptree.Tree() + + +__version__ = '17.4.0' + + +engine.listeners['before_request'] = set() +engine.listeners['after_request'] = set() + + +engine.autoreload = process.plugins.Autoreloader(engine) +engine.autoreload.subscribe() + +engine.thread_manager = process.plugins.ThreadManager(engine) +engine.thread_manager.subscribe() + +engine.signal_handler = process.plugins.SignalHandler(engine) + + +class _HandleSignalsPlugin(object): + """Handle signals from other processes. + + Based on the configured platform handlers above. + """ + + def __init__(self, bus): + self.bus = bus + + def subscribe(self): + """Add the handlers based on the platform.""" + if hasattr(self.bus, 'signal_handler'): + self.bus.signal_handler.subscribe() + if hasattr(self.bus, 'console_control_handler'): + self.bus.console_control_handler.subscribe() + + +engine.signals = _HandleSignalsPlugin(engine) + + +server = _cpserver.Server() +server.subscribe() + + +def quickstart(root=None, script_name='', config=None): + """Mount the given root, start the builtin server (and engine), then block. + + root: an instance of a "controller class" (a collection of page handler + methods) which represents the root of the application. + script_name: a string containing the "mount point" of the application. + This should start with a slash, and be the path portion of the URL + at which to mount the given root. For example, if root.index() will + handle requests to "http://www.example.com:8080/dept/app1/", then + the script_name argument would be "/dept/app1". + + It MUST NOT end in a slash. If the script_name refers to the root + of the URI, it MUST be an empty string (not "/"). + config: a file or dict containing application config. If this contains + a [global] section, those entries will be used in the global + (site-wide) config. + """ + if config: + _global_conf_alias.update(config) + + tree.mount(root, script_name, config) + + engine.signals.subscribe() + engine.start() + engine.block() + + +class _Serving(_local): + """An interface for registering request and response objects. + + Rather than have a separate "thread local" object for the request and + the response, this class works as a single threadlocal container for + both objects (and any others which developers wish to define). In this + way, we can easily dump those objects when we stop/start a new HTTP + conversation, yet still refer to them as module-level globals in a + thread-safe way. + """ + + request = _cprequest.Request(_httputil.Host('127.0.0.1', 80), + _httputil.Host('127.0.0.1', 1111)) + """ + The request object for the current thread. In the main thread, + and any threads which are not receiving HTTP requests, this is None.""" + + response = _cprequest.Response() + """ + The response object for the current thread. In the main thread, + and any threads which are not receiving HTTP requests, this is None.""" + + def load(self, request, response): + self.request = request + self.response = response + + def clear(self): + """Remove all attributes of self.""" + self.__dict__.clear() + + +serving = _Serving() + + +class _ThreadLocalProxy(object): + + __slots__ = ['__attrname__', '__dict__'] + + def __init__(self, attrname): + self.__attrname__ = attrname + + def __getattr__(self, name): + child = getattr(serving, self.__attrname__) + return getattr(child, name) + + def __setattr__(self, name, value): + if name in ('__attrname__', ): + object.__setattr__(self, name, value) + else: + child = getattr(serving, self.__attrname__) + setattr(child, name, value) + + def __delattr__(self, name): + child = getattr(serving, self.__attrname__) + delattr(child, name) + + @property + def __dict__(self): + child = getattr(serving, self.__attrname__) + d = child.__class__.__dict__.copy() + d.update(child.__dict__) + return d + + def __getitem__(self, key): + child = getattr(serving, self.__attrname__) + return child[key] + + def __setitem__(self, key, value): + child = getattr(serving, self.__attrname__) + child[key] = value + + def __delitem__(self, key): + child = getattr(serving, self.__attrname__) + del child[key] + + def __contains__(self, key): + child = getattr(serving, self.__attrname__) + return key in child + + def __len__(self): + child = getattr(serving, self.__attrname__) + return len(child) + + def __nonzero__(self): + child = getattr(serving, self.__attrname__) + return bool(child) + # Python 3 + __bool__ = __nonzero__ + + +# Create request and response object (the same objects will be used +# throughout the entire life of the webserver, but will redirect +# to the "serving" object) +request = _ThreadLocalProxy('request') +response = _ThreadLocalProxy('response') + +# Create thread_data object as a thread-specific all-purpose storage + + +class _ThreadData(_local): + """A container for thread-specific data.""" + + +thread_data = _ThreadData() + + +# Monkeypatch pydoc to allow help() to go through the threadlocal proxy. +# Jan 2007: no Googleable examples of anyone else replacing pydoc.resolve. +# The only other way would be to change what is returned from type(request) +# and that's not possible in pure Python (you'd have to fake ob_type). +def _cherrypy_pydoc_resolve(thing, forceload=0): + """Given an object or a path to an object, get the object and its name.""" + if isinstance(thing, _ThreadLocalProxy): + thing = getattr(serving, thing.__attrname__) + return _pydoc._builtin_resolve(thing, forceload) + + +try: + import pydoc as _pydoc + _pydoc._builtin_resolve = _pydoc.resolve + _pydoc.resolve = _cherrypy_pydoc_resolve +except ImportError: + pass + + +class _GlobalLogManager(_cplogging.LogManager): + """A site-wide LogManager; routes to app.log or global log as appropriate. + + This :class:`LogManager<cherrypy._cplogging.LogManager>` implements + cherrypy.log() and cherrypy.log.access(). If either + function is called during a request, the message will be sent to the + logger for the current Application. If they are called outside of a + request, the message will be sent to the site-wide logger. + """ + + def __call__(self, *args, **kwargs): + """Log the given message to the app.log or global log. + + Log the given message to the app.log or global + log as appropriate. + """ + # Do NOT use try/except here. See + # https://github.com/cherrypy/cherrypy/issues/945 + if hasattr(request, 'app') and hasattr(request.app, 'log'): + log = request.app.log + else: + log = self + return log.error(*args, **kwargs) + + def access(self): + """Log an access message to the app.log or global log. + + Log the given message to the app.log or global + log as appropriate. + """ + try: + return request.app.log.access() + except AttributeError: + return _cplogging.LogManager.access(self) + + +log = _GlobalLogManager() +# Set a default screen handler on the global log. +log.screen = True +log.error_file = '' +# Using an access file makes CP about 10% slower. Leave off by default. +log.access_file = '' + + +@engine.subscribe('log') +def _buslog(msg, level): + log.error(msg, 'ENGINE', severity=level) + + +# Use _global_conf_alias so quickstart can use 'config' as an arg +# without shadowing cherrypy.config. +config = _global_conf_alias = _cpconfig.Config() +config.defaults = { + 'tools.log_tracebacks.on': True, + 'tools.log_headers.on': True, + 'tools.trailing_slash.on': True, + 'tools.encode.on': True +} +config.namespaces['log'] = lambda k, v: setattr(log, k, v) +config.namespaces['checker'] = lambda k, v: setattr(checker, k, v) +# Must reset to get our defaults applied. +config.reset() + +checker = _cpchecker.Checker() +engine.subscribe('start', checker) diff --git a/libraries/cherrypy/__main__.py b/libraries/cherrypy/__main__.py new file mode 100644 index 00000000..6674f7cb --- /dev/null +++ b/libraries/cherrypy/__main__.py @@ -0,0 +1,5 @@ +"""CherryPy'd cherryd daemon runner.""" +from cherrypy.daemon import run + + +__name__ == '__main__' and run() diff --git a/libraries/cherrypy/_cpchecker.py b/libraries/cherrypy/_cpchecker.py new file mode 100644 index 00000000..39b7c972 --- /dev/null +++ b/libraries/cherrypy/_cpchecker.py @@ -0,0 +1,325 @@ +"""Checker for CherryPy sites and mounted apps.""" +import os +import warnings + +import six +from six.moves import builtins + +import cherrypy + + +class Checker(object): + """A checker for CherryPy sites and their mounted applications. + + When this object is called at engine startup, it executes each + of its own methods whose names start with ``check_``. If you wish + to disable selected checks, simply add a line in your global + config which sets the appropriate method to False:: + + [global] + checker.check_skipped_app_config = False + + You may also dynamically add or replace ``check_*`` methods in this way. + """ + + on = True + """If True (the default), run all checks; if False, turn off all checks.""" + + def __init__(self): + """Initialize Checker instance.""" + self._populate_known_types() + + def __call__(self): + """Run all check_* methods.""" + if self.on: + oldformatwarning = warnings.formatwarning + warnings.formatwarning = self.formatwarning + try: + for name in dir(self): + if name.startswith('check_'): + method = getattr(self, name) + if method and hasattr(method, '__call__'): + method() + finally: + warnings.formatwarning = oldformatwarning + + def formatwarning(self, message, category, filename, lineno, line=None): + """Format a warning.""" + return 'CherryPy Checker:\n%s\n\n' % message + + # This value should be set inside _cpconfig. + global_config_contained_paths = False + + def check_app_config_entries_dont_start_with_script_name(self): + """Check for App config with sections that repeat script_name.""" + for sn, app in cherrypy.tree.apps.items(): + if not isinstance(app, cherrypy.Application): + continue + if not app.config: + continue + if sn == '': + continue + sn_atoms = sn.strip('/').split('/') + for key in app.config.keys(): + key_atoms = key.strip('/').split('/') + if key_atoms[:len(sn_atoms)] == sn_atoms: + warnings.warn( + 'The application mounted at %r has config ' + 'entries that start with its script name: %r' % (sn, + key)) + + def check_site_config_entries_in_app_config(self): + """Check for mounted Applications that have site-scoped config.""" + for sn, app in six.iteritems(cherrypy.tree.apps): + if not isinstance(app, cherrypy.Application): + continue + + msg = [] + for section, entries in six.iteritems(app.config): + if section.startswith('/'): + for key, value in six.iteritems(entries): + for n in ('engine.', 'server.', 'tree.', 'checker.'): + if key.startswith(n): + msg.append('[%s] %s = %s' % + (section, key, value)) + if msg: + msg.insert(0, + 'The application mounted at %r contains the ' + 'following config entries, which are only allowed ' + 'in site-wide config. Move them to a [global] ' + 'section and pass them to cherrypy.config.update() ' + 'instead of tree.mount().' % sn) + warnings.warn(os.linesep.join(msg)) + + def check_skipped_app_config(self): + """Check for mounted Applications that have no config.""" + for sn, app in cherrypy.tree.apps.items(): + if not isinstance(app, cherrypy.Application): + continue + if not app.config: + msg = 'The Application mounted at %r has an empty config.' % sn + if self.global_config_contained_paths: + msg += (' It looks like the config you passed to ' + 'cherrypy.config.update() contains application-' + 'specific sections. You must explicitly pass ' + 'application config via ' + 'cherrypy.tree.mount(..., config=app_config)') + warnings.warn(msg) + return + + def check_app_config_brackets(self): + """Check for App config with extraneous brackets in section names.""" + for sn, app in cherrypy.tree.apps.items(): + if not isinstance(app, cherrypy.Application): + continue + if not app.config: + continue + for key in app.config.keys(): + if key.startswith('[') or key.endswith(']'): + warnings.warn( + 'The application mounted at %r has config ' + 'section names with extraneous brackets: %r. ' + 'Config *files* need brackets; config *dicts* ' + '(e.g. passed to tree.mount) do not.' % (sn, key)) + + def check_static_paths(self): + """Check Application config for incorrect static paths.""" + # Use the dummy Request object in the main thread. + request = cherrypy.request + for sn, app in cherrypy.tree.apps.items(): + if not isinstance(app, cherrypy.Application): + continue + request.app = app + for section in app.config: + # get_resource will populate request.config + request.get_resource(section + '/dummy.html') + conf = request.config.get + + if conf('tools.staticdir.on', False): + msg = '' + root = conf('tools.staticdir.root') + dir = conf('tools.staticdir.dir') + if dir is None: + msg = 'tools.staticdir.dir is not set.' + else: + fulldir = '' + if os.path.isabs(dir): + fulldir = dir + if root: + msg = ('dir is an absolute path, even ' + 'though a root is provided.') + testdir = os.path.join(root, dir[1:]) + if os.path.exists(testdir): + msg += ( + '\nIf you meant to serve the ' + 'filesystem folder at %r, remove the ' + 'leading slash from dir.' % (testdir,)) + else: + if not root: + msg = ( + 'dir is a relative path and ' + 'no root provided.') + else: + fulldir = os.path.join(root, dir) + if not os.path.isabs(fulldir): + msg = ('%r is not an absolute path.' % ( + fulldir,)) + + if fulldir and not os.path.exists(fulldir): + if msg: + msg += '\n' + msg += ('%r (root + dir) is not an existing ' + 'filesystem path.' % fulldir) + + if msg: + warnings.warn('%s\nsection: [%s]\nroot: %r\ndir: %r' + % (msg, section, root, dir)) + + # -------------------------- Compatibility -------------------------- # + obsolete = { + 'server.default_content_type': 'tools.response_headers.headers', + 'log_access_file': 'log.access_file', + 'log_config_options': None, + 'log_file': 'log.error_file', + 'log_file_not_found': None, + 'log_request_headers': 'tools.log_headers.on', + 'log_to_screen': 'log.screen', + 'show_tracebacks': 'request.show_tracebacks', + 'throw_errors': 'request.throw_errors', + 'profiler.on': ('cherrypy.tree.mount(profiler.make_app(' + 'cherrypy.Application(Root())))'), + } + + deprecated = {} + + def _compat(self, config): + """Process config and warn on each obsolete or deprecated entry.""" + for section, conf in config.items(): + if isinstance(conf, dict): + for k in conf: + if k in self.obsolete: + warnings.warn('%r is obsolete. Use %r instead.\n' + 'section: [%s]' % + (k, self.obsolete[k], section)) + elif k in self.deprecated: + warnings.warn('%r is deprecated. Use %r instead.\n' + 'section: [%s]' % + (k, self.deprecated[k], section)) + else: + if section in self.obsolete: + warnings.warn('%r is obsolete. Use %r instead.' + % (section, self.obsolete[section])) + elif section in self.deprecated: + warnings.warn('%r is deprecated. Use %r instead.' + % (section, self.deprecated[section])) + + def check_compatibility(self): + """Process config and warn on each obsolete or deprecated entry.""" + self._compat(cherrypy.config) + for sn, app in cherrypy.tree.apps.items(): + if not isinstance(app, cherrypy.Application): + continue + self._compat(app.config) + + # ------------------------ Known Namespaces ------------------------ # + extra_config_namespaces = [] + + def _known_ns(self, app): + ns = ['wsgi'] + ns.extend(app.toolboxes) + ns.extend(app.namespaces) + ns.extend(app.request_class.namespaces) + ns.extend(cherrypy.config.namespaces) + ns += self.extra_config_namespaces + + for section, conf in app.config.items(): + is_path_section = section.startswith('/') + if is_path_section and isinstance(conf, dict): + for k in conf: + atoms = k.split('.') + if len(atoms) > 1: + if atoms[0] not in ns: + # Spit out a special warning if a known + # namespace is preceded by "cherrypy." + if atoms[0] == 'cherrypy' and atoms[1] in ns: + msg = ( + 'The config entry %r is invalid; ' + 'try %r instead.\nsection: [%s]' + % (k, '.'.join(atoms[1:]), section)) + else: + msg = ( + 'The config entry %r is invalid, ' + 'because the %r config namespace ' + 'is unknown.\n' + 'section: [%s]' % (k, atoms[0], section)) + warnings.warn(msg) + elif atoms[0] == 'tools': + if atoms[1] not in dir(cherrypy.tools): + msg = ( + 'The config entry %r may be invalid, ' + 'because the %r tool was not found.\n' + 'section: [%s]' % (k, atoms[1], section)) + warnings.warn(msg) + + def check_config_namespaces(self): + """Process config and warn on each unknown config namespace.""" + for sn, app in cherrypy.tree.apps.items(): + if not isinstance(app, cherrypy.Application): + continue + self._known_ns(app) + + # -------------------------- Config Types -------------------------- # + known_config_types = {} + + def _populate_known_types(self): + b = [x for x in vars(builtins).values() + if type(x) is type(str)] + + def traverse(obj, namespace): + for name in dir(obj): + # Hack for 3.2's warning about body_params + if name == 'body_params': + continue + vtype = type(getattr(obj, name, None)) + if vtype in b: + self.known_config_types[namespace + '.' + name] = vtype + + traverse(cherrypy.request, 'request') + traverse(cherrypy.response, 'response') + traverse(cherrypy.server, 'server') + traverse(cherrypy.engine, 'engine') + traverse(cherrypy.log, 'log') + + def _known_types(self, config): + msg = ('The config entry %r in section %r is of type %r, ' + 'which does not match the expected type %r.') + + for section, conf in config.items(): + if not isinstance(conf, dict): + conf = {section: conf} + for k, v in conf.items(): + if v is not None: + expected_type = self.known_config_types.get(k, None) + vtype = type(v) + if expected_type and vtype != expected_type: + warnings.warn(msg % (k, section, vtype.__name__, + expected_type.__name__)) + + def check_config_types(self): + """Assert that config values are of the same type as default values.""" + self._known_types(cherrypy.config) + for sn, app in cherrypy.tree.apps.items(): + if not isinstance(app, cherrypy.Application): + continue + self._known_types(app.config) + + # -------------------- Specific config warnings -------------------- # + def check_localhost(self): + """Warn if any socket_host is 'localhost'. See #711.""" + for k, v in cherrypy.config.items(): + if k == 'server.socket_host' and v == 'localhost': + warnings.warn("The use of 'localhost' as a socket host can " + 'cause problems on newer systems, since ' + "'localhost' can map to either an IPv4 or an " + "IPv6 address. You should use '127.0.0.1' " + "or '[::1]' instead.") diff --git a/libraries/cherrypy/_cpcompat.py b/libraries/cherrypy/_cpcompat.py new file mode 100644 index 00000000..f454505c --- /dev/null +++ b/libraries/cherrypy/_cpcompat.py @@ -0,0 +1,162 @@ +"""Compatibility code for using CherryPy with various versions of Python. + +To retain compatibility with older Python versions, this module provides a +useful abstraction over the differences between Python versions, sometimes by +preferring a newer idiom, sometimes an older one, and sometimes a custom one. + +In particular, Python 2 uses str and '' for byte strings, while Python 3 +uses str and '' for unicode strings. We will call each of these the 'native +string' type for each version. Because of this major difference, this module +provides +two functions: 'ntob', which translates native strings (of type 'str') into +byte strings regardless of Python version, and 'ntou', which translates native +strings to unicode strings. + +Try not to use the compatibility functions 'ntob', 'ntou', 'tonative'. +They were created with Python 2.3-2.5 compatibility in mind. +Instead, use unicode literals (from __future__) and bytes literals +and their .encode/.decode methods as needed. +""" + +import re +import sys +import threading + +import six +from six.moves import urllib + + +if six.PY3: + def ntob(n, encoding='ISO-8859-1'): + """Return the given native string as a byte string in the given + encoding. + """ + assert_native(n) + # In Python 3, the native string type is unicode + return n.encode(encoding) + + def ntou(n, encoding='ISO-8859-1'): + """Return the given native string as a unicode string with the given + encoding. + """ + assert_native(n) + # In Python 3, the native string type is unicode + return n + + def tonative(n, encoding='ISO-8859-1'): + """Return the given string as a native string in the given encoding.""" + # In Python 3, the native string type is unicode + if isinstance(n, bytes): + return n.decode(encoding) + return n +else: + # Python 2 + def ntob(n, encoding='ISO-8859-1'): + """Return the given native string as a byte string in the given + encoding. + """ + assert_native(n) + # In Python 2, the native string type is bytes. Assume it's already + # in the given encoding, which for ISO-8859-1 is almost always what + # was intended. + return n + + def ntou(n, encoding='ISO-8859-1'): + """Return the given native string as a unicode string with the given + encoding. + """ + assert_native(n) + # In Python 2, the native string type is bytes. + # First, check for the special encoding 'escape'. The test suite uses + # this to signal that it wants to pass a string with embedded \uXXXX + # escapes, but without having to prefix it with u'' for Python 2, + # but no prefix for Python 3. + if encoding == 'escape': + return six.text_type( # unicode for Python 2 + re.sub(r'\\u([0-9a-zA-Z]{4})', + lambda m: six.unichr(int(m.group(1), 16)), + n.decode('ISO-8859-1'))) + # Assume it's already in the given encoding, which for ISO-8859-1 + # is almost always what was intended. + return n.decode(encoding) + + def tonative(n, encoding='ISO-8859-1'): + """Return the given string as a native string in the given encoding.""" + # In Python 2, the native string type is bytes. + if isinstance(n, six.text_type): # unicode for Python 2 + return n.encode(encoding) + return n + + +def assert_native(n): + if not isinstance(n, str): + raise TypeError('n must be a native str (got %s)' % type(n).__name__) + + +# Some platforms don't expose HTTPSConnection, so handle it separately +HTTPSConnection = getattr(six.moves.http_client, 'HTTPSConnection', None) + + +def _unquote_plus_compat(string, encoding='utf-8', errors='replace'): + return urllib.parse.unquote_plus(string).decode(encoding, errors) + + +def _unquote_compat(string, encoding='utf-8', errors='replace'): + return urllib.parse.unquote(string).decode(encoding, errors) + + +def _quote_compat(string, encoding='utf-8', errors='replace'): + return urllib.parse.quote(string.encode(encoding, errors)) + + +unquote_plus = urllib.parse.unquote_plus if six.PY3 else _unquote_plus_compat +unquote = urllib.parse.unquote if six.PY3 else _unquote_compat +quote = urllib.parse.quote if six.PY3 else _quote_compat + +try: + # Prefer simplejson + import simplejson as json +except ImportError: + import json + + +json_decode = json.JSONDecoder().decode +_json_encode = json.JSONEncoder().iterencode + + +if six.PY3: + # Encode to bytes on Python 3 + def json_encode(value): + for chunk in _json_encode(value): + yield chunk.encode('utf-8') +else: + json_encode = _json_encode + + +text_or_bytes = six.text_type, bytes + + +if sys.version_info >= (3, 3): + Timer = threading.Timer + Event = threading.Event +else: + # Python 3.2 and earlier + Timer = threading._Timer + Event = threading._Event + +# html module come in 3.2 version +try: + from html import escape +except ImportError: + from cgi import escape + + +# html module needed the argument quote=False because in cgi the default +# is False. With quote=True the results differ. + +def escape_html(s, escape_quote=False): + """Replace special characters "&", "<" and ">" to HTML-safe sequences. + + When escape_quote=True, escape (') and (") chars. + """ + return escape(s, quote=escape_quote) diff --git a/libraries/cherrypy/_cpconfig.py b/libraries/cherrypy/_cpconfig.py new file mode 100644 index 00000000..79d9d911 --- /dev/null +++ b/libraries/cherrypy/_cpconfig.py @@ -0,0 +1,300 @@ +""" +Configuration system for CherryPy. + +Configuration in CherryPy is implemented via dictionaries. Keys are strings +which name the mapped value, which may be of any type. + + +Architecture +------------ + +CherryPy Requests are part of an Application, which runs in a global context, +and configuration data may apply to any of those three scopes: + +Global + Configuration entries which apply everywhere are stored in + cherrypy.config. + +Application + Entries which apply to each mounted application are stored + on the Application object itself, as 'app.config'. This is a two-level + dict where each key is a path, or "relative URL" (for example, "/" or + "/path/to/my/page"), and each value is a config dict. Usually, this + data is provided in the call to tree.mount(root(), config=conf), + although you may also use app.merge(conf). + +Request + Each Request object possesses a single 'Request.config' dict. + Early in the request process, this dict is populated by merging global + config entries, Application entries (whose path equals or is a parent + of Request.path_info), and any config acquired while looking up the + page handler (see next). + + +Declaration +----------- + +Configuration data may be supplied as a Python dictionary, as a filename, +or as an open file object. When you supply a filename or file, CherryPy +uses Python's builtin ConfigParser; you declare Application config by +writing each path as a section header:: + + [/path/to/my/page] + request.stream = True + +To declare global configuration entries, place them in a [global] section. + +You may also declare config entries directly on the classes and methods +(page handlers) that make up your CherryPy application via the ``_cp_config`` +attribute, set with the ``cherrypy.config`` decorator. For example:: + + @cherrypy.config(**{'tools.gzip.on': True}) + class Demo: + + @cherrypy.expose + @cherrypy.config(**{'request.show_tracebacks': False}) + def index(self): + return "Hello world" + +.. note:: + + This behavior is only guaranteed for the default dispatcher. + Other dispatchers may have different restrictions on where + you can attach config attributes. + + +Namespaces +---------- + +Configuration keys are separated into namespaces by the first "." in the key. +Current namespaces: + +engine + Controls the 'application engine', including autoreload. + These can only be declared in the global config. + +tree + Grafts cherrypy.Application objects onto cherrypy.tree. + These can only be declared in the global config. + +hooks + Declares additional request-processing functions. + +log + Configures the logging for each application. + These can only be declared in the global or / config. + +request + Adds attributes to each Request. + +response + Adds attributes to each Response. + +server + Controls the default HTTP server via cherrypy.server. + These can only be declared in the global config. + +tools + Runs and configures additional request-processing packages. + +wsgi + Adds WSGI middleware to an Application's "pipeline". + These can only be declared in the app's root config ("/"). + +checker + Controls the 'checker', which looks for common errors in + app state (including config) when the engine starts. + Global config only. + +The only key that does not exist in a namespace is the "environment" entry. +This special entry 'imports' other config entries from a template stored in +cherrypy._cpconfig.environments[environment]. It only applies to the global +config, and only when you use cherrypy.config.update. + +You can define your own namespaces to be called at the Global, Application, +or Request level, by adding a named handler to cherrypy.config.namespaces, +app.namespaces, or app.request_class.namespaces. The name can +be any string, and the handler must be either a callable or a (Python 2.5 +style) context manager. +""" + +import cherrypy +from cherrypy._cpcompat import text_or_bytes +from cherrypy.lib import reprconf + + +def _if_filename_register_autoreload(ob): + """Register for autoreload if ob is a string (presumed filename).""" + is_filename = isinstance(ob, text_or_bytes) + is_filename and cherrypy.engine.autoreload.files.add(ob) + + +def merge(base, other): + """Merge one app config (from a dict, file, or filename) into another. + + If the given config is a filename, it will be appended to + the list of files to monitor for "autoreload" changes. + """ + _if_filename_register_autoreload(other) + + # Load other into base + for section, value_map in reprconf.Parser.load(other).items(): + if not isinstance(value_map, dict): + raise ValueError( + 'Application config must include section headers, but the ' + "config you tried to merge doesn't have any sections. " + 'Wrap your config in another dict with paths as section ' + "headers, for example: {'/': config}.") + base.setdefault(section, {}).update(value_map) + + +class Config(reprconf.Config): + """The 'global' configuration data for the entire CherryPy process.""" + + def update(self, config): + """Update self from a dict, file or filename.""" + _if_filename_register_autoreload(config) + super(Config, self).update(config) + + def _apply(self, config): + """Update self from a dict.""" + if isinstance(config.get('global'), dict): + if len(config) > 1: + cherrypy.checker.global_config_contained_paths = True + config = config['global'] + if 'tools.staticdir.dir' in config: + config['tools.staticdir.section'] = 'global' + super(Config, self)._apply(config) + + @staticmethod + def __call__(**kwargs): + """Decorate for page handlers to set _cp_config.""" + def tool_decorator(f): + _Vars(f).setdefault('_cp_config', {}).update(kwargs) + return f + return tool_decorator + + +class _Vars(object): + """Adapter allowing setting a default attribute on a function or class.""" + + def __init__(self, target): + self.target = target + + def setdefault(self, key, default): + if not hasattr(self.target, key): + setattr(self.target, key, default) + return getattr(self.target, key) + + +# Sphinx begin config.environments +Config.environments = environments = { + 'staging': { + 'engine.autoreload.on': False, + 'checker.on': False, + 'tools.log_headers.on': False, + 'request.show_tracebacks': False, + 'request.show_mismatched_params': False, + }, + 'production': { + 'engine.autoreload.on': False, + 'checker.on': False, + 'tools.log_headers.on': False, + 'request.show_tracebacks': False, + 'request.show_mismatched_params': False, + 'log.screen': False, + }, + 'embedded': { + # For use with CherryPy embedded in another deployment stack. + 'engine.autoreload.on': False, + 'checker.on': False, + 'tools.log_headers.on': False, + 'request.show_tracebacks': False, + 'request.show_mismatched_params': False, + 'log.screen': False, + 'engine.SIGHUP': None, + 'engine.SIGTERM': None, + }, + 'test_suite': { + 'engine.autoreload.on': False, + 'checker.on': False, + 'tools.log_headers.on': False, + 'request.show_tracebacks': True, + 'request.show_mismatched_params': True, + 'log.screen': False, + }, +} +# Sphinx end config.environments + + +def _server_namespace_handler(k, v): + """Config handler for the "server" namespace.""" + atoms = k.split('.', 1) + if len(atoms) > 1: + # Special-case config keys of the form 'server.servername.socket_port' + # to configure additional HTTP servers. + if not hasattr(cherrypy, 'servers'): + cherrypy.servers = {} + + servername, k = atoms + if servername not in cherrypy.servers: + from cherrypy import _cpserver + cherrypy.servers[servername] = _cpserver.Server() + # On by default, but 'on = False' can unsubscribe it (see below). + cherrypy.servers[servername].subscribe() + + if k == 'on': + if v: + cherrypy.servers[servername].subscribe() + else: + cherrypy.servers[servername].unsubscribe() + else: + setattr(cherrypy.servers[servername], k, v) + else: + setattr(cherrypy.server, k, v) + + +Config.namespaces['server'] = _server_namespace_handler + + +def _engine_namespace_handler(k, v): + """Config handler for the "engine" namespace.""" + engine = cherrypy.engine + + if k in {'SIGHUP', 'SIGTERM'}: + engine.subscribe(k, v) + return + + if '.' in k: + plugin, attrname = k.split('.', 1) + try: + plugin = getattr(engine, plugin) + except Exception as error: + setattr(engine, k, v) + else: + op = 'subscribe' if v else 'unsubscribe' + sub_unsub = getattr(plugin, op, None) + if attrname == 'on' and callable(sub_unsub): + sub_unsub() + return + setattr(plugin, attrname, v) + else: + setattr(engine, k, v) + + +Config.namespaces['engine'] = _engine_namespace_handler + + +def _tree_namespace_handler(k, v): + """Namespace handler for the 'tree' config namespace.""" + if isinstance(v, dict): + for script_name, app in v.items(): + cherrypy.tree.graft(app, script_name) + msg = 'Mounted: %s on %s' % (app, script_name or '/') + cherrypy.engine.log(msg) + else: + cherrypy.tree.graft(v, v.script_name) + cherrypy.engine.log('Mounted: %s on %s' % (v, v.script_name or '/')) + + +Config.namespaces['tree'] = _tree_namespace_handler diff --git a/libraries/cherrypy/_cpdispatch.py b/libraries/cherrypy/_cpdispatch.py new file mode 100644 index 00000000..83eb79cb --- /dev/null +++ b/libraries/cherrypy/_cpdispatch.py @@ -0,0 +1,686 @@ +"""CherryPy dispatchers. + +A 'dispatcher' is the object which looks up the 'page handler' callable +and collects config for the current request based on the path_info, other +request attributes, and the application architecture. The core calls the +dispatcher as early as possible, passing it a 'path_info' argument. + +The default dispatcher discovers the page handler by matching path_info +to a hierarchical arrangement of objects, starting at request.app.root. +""" + +import string +import sys +import types +try: + classtype = (type, types.ClassType) +except AttributeError: + classtype = type + +import cherrypy + + +class PageHandler(object): + + """Callable which sets response.body.""" + + def __init__(self, callable, *args, **kwargs): + self.callable = callable + self.args = args + self.kwargs = kwargs + + @property + def args(self): + """The ordered args should be accessible from post dispatch hooks.""" + return cherrypy.serving.request.args + + @args.setter + def args(self, args): + cherrypy.serving.request.args = args + return cherrypy.serving.request.args + + @property + def kwargs(self): + """The named kwargs should be accessible from post dispatch hooks.""" + return cherrypy.serving.request.kwargs + + @kwargs.setter + def kwargs(self, kwargs): + cherrypy.serving.request.kwargs = kwargs + return cherrypy.serving.request.kwargs + + def __call__(self): + try: + return self.callable(*self.args, **self.kwargs) + except TypeError: + x = sys.exc_info()[1] + try: + test_callable_spec(self.callable, self.args, self.kwargs) + except cherrypy.HTTPError: + raise sys.exc_info()[1] + except Exception: + raise x + raise + + +def test_callable_spec(callable, callable_args, callable_kwargs): + """ + Inspect callable and test to see if the given args are suitable for it. + + When an error occurs during the handler's invoking stage there are 2 + erroneous cases: + 1. Too many parameters passed to a function which doesn't define + one of *args or **kwargs. + 2. Too little parameters are passed to the function. + + There are 3 sources of parameters to a cherrypy handler. + 1. query string parameters are passed as keyword parameters to the + handler. + 2. body parameters are also passed as keyword parameters. + 3. when partial matching occurs, the final path atoms are passed as + positional args. + Both the query string and path atoms are part of the URI. If they are + incorrect, then a 404 Not Found should be raised. Conversely the body + parameters are part of the request; if they are invalid a 400 Bad Request. + """ + show_mismatched_params = getattr( + cherrypy.serving.request, 'show_mismatched_params', False) + try: + (args, varargs, varkw, defaults) = getargspec(callable) + except TypeError: + if isinstance(callable, object) and hasattr(callable, '__call__'): + (args, varargs, varkw, + defaults) = getargspec(callable.__call__) + else: + # If it wasn't one of our own types, re-raise + # the original error + raise + + if args and ( + # For callable objects, which have a __call__(self) method + hasattr(callable, '__call__') or + # For normal methods + inspect.ismethod(callable) + ): + # Strip 'self' + args = args[1:] + + arg_usage = dict([(arg, 0,) for arg in args]) + vararg_usage = 0 + varkw_usage = 0 + extra_kwargs = set() + + for i, value in enumerate(callable_args): + try: + arg_usage[args[i]] += 1 + except IndexError: + vararg_usage += 1 + + for key in callable_kwargs.keys(): + try: + arg_usage[key] += 1 + except KeyError: + varkw_usage += 1 + extra_kwargs.add(key) + + # figure out which args have defaults. + args_with_defaults = args[-len(defaults or []):] + for i, val in enumerate(defaults or []): + # Defaults take effect only when the arg hasn't been used yet. + if arg_usage[args_with_defaults[i]] == 0: + arg_usage[args_with_defaults[i]] += 1 + + missing_args = [] + multiple_args = [] + for key, usage in arg_usage.items(): + if usage == 0: + missing_args.append(key) + elif usage > 1: + multiple_args.append(key) + + if missing_args: + # In the case where the method allows body arguments + # there are 3 potential errors: + # 1. not enough query string parameters -> 404 + # 2. not enough body parameters -> 400 + # 3. not enough path parts (partial matches) -> 404 + # + # We can't actually tell which case it is, + # so I'm raising a 404 because that covers 2/3 of the + # possibilities + # + # In the case where the method does not allow body + # arguments it's definitely a 404. + message = None + if show_mismatched_params: + message = 'Missing parameters: %s' % ','.join(missing_args) + raise cherrypy.HTTPError(404, message=message) + + # the extra positional arguments come from the path - 404 Not Found + if not varargs and vararg_usage > 0: + raise cherrypy.HTTPError(404) + + body_params = cherrypy.serving.request.body.params or {} + body_params = set(body_params.keys()) + qs_params = set(callable_kwargs.keys()) - body_params + + if multiple_args: + if qs_params.intersection(set(multiple_args)): + # If any of the multiple parameters came from the query string then + # it's a 404 Not Found + error = 404 + else: + # Otherwise it's a 400 Bad Request + error = 400 + + message = None + if show_mismatched_params: + message = 'Multiple values for parameters: '\ + '%s' % ','.join(multiple_args) + raise cherrypy.HTTPError(error, message=message) + + if not varkw and varkw_usage > 0: + + # If there were extra query string parameters, it's a 404 Not Found + extra_qs_params = set(qs_params).intersection(extra_kwargs) + if extra_qs_params: + message = None + if show_mismatched_params: + message = 'Unexpected query string '\ + 'parameters: %s' % ', '.join(extra_qs_params) + raise cherrypy.HTTPError(404, message=message) + + # If there were any extra body parameters, it's a 400 Not Found + extra_body_params = set(body_params).intersection(extra_kwargs) + if extra_body_params: + message = None + if show_mismatched_params: + message = 'Unexpected body parameters: '\ + '%s' % ', '.join(extra_body_params) + raise cherrypy.HTTPError(400, message=message) + + +try: + import inspect +except ImportError: + def test_callable_spec(callable, args, kwargs): # noqa: F811 + return None +else: + getargspec = inspect.getargspec + # Python 3 requires using getfullargspec if + # keyword-only arguments are present + if hasattr(inspect, 'getfullargspec'): + def getargspec(callable): + return inspect.getfullargspec(callable)[:4] + + +class LateParamPageHandler(PageHandler): + + """When passing cherrypy.request.params to the page handler, we do not + want to capture that dict too early; we want to give tools like the + decoding tool a chance to modify the params dict in-between the lookup + of the handler and the actual calling of the handler. This subclass + takes that into account, and allows request.params to be 'bound late' + (it's more complicated than that, but that's the effect). + """ + + @property + def kwargs(self): + """Page handler kwargs (with cherrypy.request.params copied in).""" + kwargs = cherrypy.serving.request.params.copy() + if self._kwargs: + kwargs.update(self._kwargs) + return kwargs + + @kwargs.setter + def kwargs(self, kwargs): + cherrypy.serving.request.kwargs = kwargs + self._kwargs = kwargs + + +if sys.version_info < (3, 0): + punctuation_to_underscores = string.maketrans( + string.punctuation, '_' * len(string.punctuation)) + + def validate_translator(t): + if not isinstance(t, str) or len(t) != 256: + raise ValueError( + 'The translate argument must be a str of len 256.') +else: + punctuation_to_underscores = str.maketrans( + string.punctuation, '_' * len(string.punctuation)) + + def validate_translator(t): + if not isinstance(t, dict): + raise ValueError('The translate argument must be a dict.') + + +class Dispatcher(object): + + """CherryPy Dispatcher which walks a tree of objects to find a handler. + + The tree is rooted at cherrypy.request.app.root, and each hierarchical + component in the path_info argument is matched to a corresponding nested + attribute of the root object. Matching handlers must have an 'exposed' + attribute which evaluates to True. The special method name "index" + matches a URI which ends in a slash ("/"). The special method name + "default" may match a portion of the path_info (but only when no longer + substring of the path_info matches some other object). + + This is the default, built-in dispatcher for CherryPy. + """ + + dispatch_method_name = '_cp_dispatch' + """ + The name of the dispatch method that nodes may optionally implement + to provide their own dynamic dispatch algorithm. + """ + + def __init__(self, dispatch_method_name=None, + translate=punctuation_to_underscores): + validate_translator(translate) + self.translate = translate + if dispatch_method_name: + self.dispatch_method_name = dispatch_method_name + + def __call__(self, path_info): + """Set handler and config for the current request.""" + request = cherrypy.serving.request + func, vpath = self.find_handler(path_info) + + if func: + # Decode any leftover %2F in the virtual_path atoms. + vpath = [x.replace('%2F', '/') for x in vpath] + request.handler = LateParamPageHandler(func, *vpath) + else: + request.handler = cherrypy.NotFound() + + def find_handler(self, path): + """Return the appropriate page handler, plus any virtual path. + + This will return two objects. The first will be a callable, + which can be used to generate page output. Any parameters from + the query string or request body will be sent to that callable + as keyword arguments. + + The callable is found by traversing the application's tree, + starting from cherrypy.request.app.root, and matching path + components to successive objects in the tree. For example, the + URL "/path/to/handler" might return root.path.to.handler. + + The second object returned will be a list of names which are + 'virtual path' components: parts of the URL which are dynamic, + and were not used when looking up the handler. + These virtual path components are passed to the handler as + positional arguments. + """ + request = cherrypy.serving.request + app = request.app + root = app.root + dispatch_name = self.dispatch_method_name + + # Get config for the root object/path. + fullpath = [x for x in path.strip('/').split('/') if x] + ['index'] + fullpath_len = len(fullpath) + segleft = fullpath_len + nodeconf = {} + if hasattr(root, '_cp_config'): + nodeconf.update(root._cp_config) + if '/' in app.config: + nodeconf.update(app.config['/']) + object_trail = [['root', root, nodeconf, segleft]] + + node = root + iternames = fullpath[:] + while iternames: + name = iternames[0] + # map to legal Python identifiers (e.g. replace '.' with '_') + objname = name.translate(self.translate) + + nodeconf = {} + subnode = getattr(node, objname, None) + pre_len = len(iternames) + if subnode is None: + dispatch = getattr(node, dispatch_name, None) + if dispatch and hasattr(dispatch, '__call__') and not \ + getattr(dispatch, 'exposed', False) and \ + pre_len > 1: + # Don't expose the hidden 'index' token to _cp_dispatch + # We skip this if pre_len == 1 since it makes no sense + # to call a dispatcher when we have no tokens left. + index_name = iternames.pop() + subnode = dispatch(vpath=iternames) + iternames.append(index_name) + else: + # We didn't find a path, but keep processing in case there + # is a default() handler. + iternames.pop(0) + else: + # We found the path, remove the vpath entry + iternames.pop(0) + segleft = len(iternames) + if segleft > pre_len: + # No path segment was removed. Raise an error. + raise cherrypy.CherryPyException( + 'A vpath segment was added. Custom dispatchers may only ' + 'remove elements. While trying to process ' + '{0} in {1}'.format(name, fullpath) + ) + elif segleft == pre_len: + # Assume that the handler used the current path segment, but + # did not pop it. This allows things like + # return getattr(self, vpath[0], None) + iternames.pop(0) + segleft -= 1 + node = subnode + + if node is not None: + # Get _cp_config attached to this node. + if hasattr(node, '_cp_config'): + nodeconf.update(node._cp_config) + + # Mix in values from app.config for this path. + existing_len = fullpath_len - pre_len + if existing_len != 0: + curpath = '/' + '/'.join(fullpath[0:existing_len]) + else: + curpath = '' + new_segs = fullpath[fullpath_len - pre_len:fullpath_len - segleft] + for seg in new_segs: + curpath += '/' + seg + if curpath in app.config: + nodeconf.update(app.config[curpath]) + + object_trail.append([name, node, nodeconf, segleft]) + + def set_conf(): + """Collapse all object_trail config into cherrypy.request.config. + """ + base = cherrypy.config.copy() + # Note that we merge the config from each node + # even if that node was None. + for name, obj, conf, segleft in object_trail: + base.update(conf) + if 'tools.staticdir.dir' in conf: + base['tools.staticdir.section'] = '/' + \ + '/'.join(fullpath[0:fullpath_len - segleft]) + return base + + # Try successive objects (reverse order) + num_candidates = len(object_trail) - 1 + for i in range(num_candidates, -1, -1): + + name, candidate, nodeconf, segleft = object_trail[i] + if candidate is None: + continue + + # Try a "default" method on the current leaf. + if hasattr(candidate, 'default'): + defhandler = candidate.default + if getattr(defhandler, 'exposed', False): + # Insert any extra _cp_config from the default handler. + conf = getattr(defhandler, '_cp_config', {}) + object_trail.insert( + i + 1, ['default', defhandler, conf, segleft]) + request.config = set_conf() + # See https://github.com/cherrypy/cherrypy/issues/613 + request.is_index = path.endswith('/') + return defhandler, fullpath[fullpath_len - segleft:-1] + + # Uncomment the next line to restrict positional params to + # "default". + # if i < num_candidates - 2: continue + + # Try the current leaf. + if getattr(candidate, 'exposed', False): + request.config = set_conf() + if i == num_candidates: + # We found the extra ".index". Mark request so tools + # can redirect if path_info has no trailing slash. + request.is_index = True + else: + # We're not at an 'index' handler. Mark request so tools + # can redirect if path_info has NO trailing slash. + # Note that this also includes handlers which take + # positional parameters (virtual paths). + request.is_index = False + return candidate, fullpath[fullpath_len - segleft:-1] + + # We didn't find anything + request.config = set_conf() + return None, [] + + +class MethodDispatcher(Dispatcher): + + """Additional dispatch based on cherrypy.request.method.upper(). + + Methods named GET, POST, etc will be called on an exposed class. + The method names must be all caps; the appropriate Allow header + will be output showing all capitalized method names as allowable + HTTP verbs. + + Note that the containing class must be exposed, not the methods. + """ + + def __call__(self, path_info): + """Set handler and config for the current request.""" + request = cherrypy.serving.request + resource, vpath = self.find_handler(path_info) + + if resource: + # Set Allow header + avail = [m for m in dir(resource) if m.isupper()] + if 'GET' in avail and 'HEAD' not in avail: + avail.append('HEAD') + avail.sort() + cherrypy.serving.response.headers['Allow'] = ', '.join(avail) + + # Find the subhandler + meth = request.method.upper() + func = getattr(resource, meth, None) + if func is None and meth == 'HEAD': + func = getattr(resource, 'GET', None) + if func: + # Grab any _cp_config on the subhandler. + if hasattr(func, '_cp_config'): + request.config.update(func._cp_config) + + # Decode any leftover %2F in the virtual_path atoms. + vpath = [x.replace('%2F', '/') for x in vpath] + request.handler = LateParamPageHandler(func, *vpath) + else: + request.handler = cherrypy.HTTPError(405) + else: + request.handler = cherrypy.NotFound() + + +class RoutesDispatcher(object): + + """A Routes based dispatcher for CherryPy.""" + + def __init__(self, full_result=False, **mapper_options): + """ + Routes dispatcher + + Set full_result to True if you wish the controller + and the action to be passed on to the page handler + parameters. By default they won't be. + """ + import routes + self.full_result = full_result + self.controllers = {} + self.mapper = routes.Mapper(**mapper_options) + self.mapper.controller_scan = self.controllers.keys + + def connect(self, name, route, controller, **kwargs): + self.controllers[name] = controller + self.mapper.connect(name, route, controller=name, **kwargs) + + def redirect(self, url): + raise cherrypy.HTTPRedirect(url) + + def __call__(self, path_info): + """Set handler and config for the current request.""" + func = self.find_handler(path_info) + if func: + cherrypy.serving.request.handler = LateParamPageHandler(func) + else: + cherrypy.serving.request.handler = cherrypy.NotFound() + + def find_handler(self, path_info): + """Find the right page handler, and set request.config.""" + import routes + + request = cherrypy.serving.request + + config = routes.request_config() + config.mapper = self.mapper + if hasattr(request, 'wsgi_environ'): + config.environ = request.wsgi_environ + config.host = request.headers.get('Host', None) + config.protocol = request.scheme + config.redirect = self.redirect + + result = self.mapper.match(path_info) + + config.mapper_dict = result + params = {} + if result: + params = result.copy() + if not self.full_result: + params.pop('controller', None) + params.pop('action', None) + request.params.update(params) + + # Get config for the root object/path. + request.config = base = cherrypy.config.copy() + curpath = '' + + def merge(nodeconf): + if 'tools.staticdir.dir' in nodeconf: + nodeconf['tools.staticdir.section'] = curpath or '/' + base.update(nodeconf) + + app = request.app + root = app.root + if hasattr(root, '_cp_config'): + merge(root._cp_config) + if '/' in app.config: + merge(app.config['/']) + + # Mix in values from app.config. + atoms = [x for x in path_info.split('/') if x] + if atoms: + last = atoms.pop() + else: + last = None + for atom in atoms: + curpath = '/'.join((curpath, atom)) + if curpath in app.config: + merge(app.config[curpath]) + + handler = None + if result: + controller = result.get('controller') + controller = self.controllers.get(controller, controller) + if controller: + if isinstance(controller, classtype): + controller = controller() + # Get config from the controller. + if hasattr(controller, '_cp_config'): + merge(controller._cp_config) + + action = result.get('action') + if action is not None: + handler = getattr(controller, action, None) + # Get config from the handler + if hasattr(handler, '_cp_config'): + merge(handler._cp_config) + else: + handler = controller + + # Do the last path atom here so it can + # override the controller's _cp_config. + if last: + curpath = '/'.join((curpath, last)) + if curpath in app.config: + merge(app.config[curpath]) + + return handler + + +def XMLRPCDispatcher(next_dispatcher=Dispatcher()): + from cherrypy.lib import xmlrpcutil + + def xmlrpc_dispatch(path_info): + path_info = xmlrpcutil.patched_path(path_info) + return next_dispatcher(path_info) + return xmlrpc_dispatch + + +def VirtualHost(next_dispatcher=Dispatcher(), use_x_forwarded_host=True, + **domains): + """ + Select a different handler based on the Host header. + + This can be useful when running multiple sites within one CP server. + It allows several domains to point to different parts of a single + website structure. For example:: + + http://www.domain.example -> root + http://www.domain2.example -> root/domain2/ + http://www.domain2.example:443 -> root/secure + + can be accomplished via the following config:: + + [/] + request.dispatch = cherrypy.dispatch.VirtualHost( + **{'www.domain2.example': '/domain2', + 'www.domain2.example:443': '/secure', + }) + + next_dispatcher + The next dispatcher object in the dispatch chain. + The VirtualHost dispatcher adds a prefix to the URL and calls + another dispatcher. Defaults to cherrypy.dispatch.Dispatcher(). + + use_x_forwarded_host + If True (the default), any "X-Forwarded-Host" + request header will be used instead of the "Host" header. This + is commonly added by HTTP servers (such as Apache) when proxying. + + ``**domains`` + A dict of {host header value: virtual prefix} pairs. + The incoming "Host" request header is looked up in this dict, + and, if a match is found, the corresponding "virtual prefix" + value will be prepended to the URL path before calling the + next dispatcher. Note that you often need separate entries + for "example.com" and "www.example.com". In addition, "Host" + headers may contain the port number. + """ + from cherrypy.lib import httputil + + def vhost_dispatch(path_info): + request = cherrypy.serving.request + header = request.headers.get + + domain = header('Host', '') + if use_x_forwarded_host: + domain = header('X-Forwarded-Host', domain) + + prefix = domains.get(domain, '') + if prefix: + path_info = httputil.urljoin(prefix, path_info) + + result = next_dispatcher(path_info) + + # Touch up staticdir config. See + # https://github.com/cherrypy/cherrypy/issues/614. + section = request.config.get('tools.staticdir.section') + if section: + section = section[len(prefix):] + request.config['tools.staticdir.section'] = section + + return result + return vhost_dispatch diff --git a/libraries/cherrypy/_cperror.py b/libraries/cherrypy/_cperror.py new file mode 100644 index 00000000..e2a8fad8 --- /dev/null +++ b/libraries/cherrypy/_cperror.py @@ -0,0 +1,619 @@ +"""Exception classes for CherryPy. + +CherryPy provides (and uses) exceptions for declaring that the HTTP response +should be a status other than the default "200 OK". You can ``raise`` them like +normal Python exceptions. You can also call them and they will raise +themselves; this means you can set an +:class:`HTTPError<cherrypy._cperror.HTTPError>` +or :class:`HTTPRedirect<cherrypy._cperror.HTTPRedirect>` as the +:attr:`request.handler<cherrypy._cprequest.Request.handler>`. + +.. _redirectingpost: + +Redirecting POST +================ + +When you GET a resource and are redirected by the server to another Location, +there's generally no problem since GET is both a "safe method" (there should +be no side-effects) and an "idempotent method" (multiple calls are no different +than a single call). + +POST, however, is neither safe nor idempotent--if you +charge a credit card, you don't want to be charged twice by a redirect! + +For this reason, *none* of the 3xx responses permit a user-agent (browser) to +resubmit a POST on redirection without first confirming the action with the +user: + +===== ================================= =========== +300 Multiple Choices Confirm with the user +301 Moved Permanently Confirm with the user +302 Found (Object moved temporarily) Confirm with the user +303 See Other GET the new URI; no confirmation +304 Not modified for conditional GET only; + POST should not raise this error +305 Use Proxy Confirm with the user +307 Temporary Redirect Confirm with the user +===== ================================= =========== + +However, browsers have historically implemented these restrictions poorly; +in particular, many browsers do not force the user to confirm 301, 302 +or 307 when redirecting POST. For this reason, CherryPy defaults to 303, +which most user-agents appear to have implemented correctly. Therefore, if +you raise HTTPRedirect for a POST request, the user-agent will most likely +attempt to GET the new URI (without asking for confirmation from the user). +We realize this is confusing for developers, but it's the safest thing we +could do. You are of course free to raise ``HTTPRedirect(uri, status=302)`` +or any other 3xx status if you know what you're doing, but given the +environment, we couldn't let any of those be the default. + +Custom Error Handling +===================== + +.. image:: /refman/cperrors.gif + +Anticipated HTTP responses +-------------------------- + +The 'error_page' config namespace can be used to provide custom HTML output for +expected responses (like 404 Not Found). Supply a filename from which the +output will be read. The contents will be interpolated with the values +%(status)s, %(message)s, %(traceback)s, and %(version)s using plain old Python +`string formatting +<http://docs.python.org/2/library/stdtypes.html#string-formatting-operations>`_. + +:: + + _cp_config = { + 'error_page.404': os.path.join(localDir, "static/index.html") + } + + +Beginning in version 3.1, you may also provide a function or other callable as +an error_page entry. It will be passed the same status, message, traceback and +version arguments that are interpolated into templates:: + + def error_page_402(status, message, traceback, version): + return "Error %s - Well, I'm very sorry but you haven't paid!" % status + cherrypy.config.update({'error_page.402': error_page_402}) + +Also in 3.1, in addition to the numbered error codes, you may also supply +"error_page.default" to handle all codes which do not have their own error_page +entry. + + + +Unanticipated errors +-------------------- + +CherryPy also has a generic error handling mechanism: whenever an unanticipated +error occurs in your code, it will call +:func:`Request.error_response<cherrypy._cprequest.Request.error_response>` to +set the response status, headers, and body. By default, this is the same +output as +:class:`HTTPError(500) <cherrypy._cperror.HTTPError>`. If you want to provide +some other behavior, you generally replace "request.error_response". + +Here is some sample code that shows how to display a custom error message and +send an e-mail containing the error:: + + from cherrypy import _cperror + + def handle_error(): + cherrypy.response.status = 500 + cherrypy.response.body = [ + "<html><body>Sorry, an error occurred</body></html>" + ] + sendMail('error@domain.com', + 'Error in your web app', + _cperror.format_exc()) + + @cherrypy.config(**{'request.error_response': handle_error}) + class Root: + pass + +Note that you have to explicitly set +:attr:`response.body <cherrypy._cprequest.Response.body>` +and not simply return an error message as a result. +""" + +import io +import contextlib +from sys import exc_info as _exc_info +from traceback import format_exception as _format_exception +from xml.sax import saxutils + +import six +from six.moves import urllib + +from more_itertools import always_iterable + +import cherrypy +from cherrypy._cpcompat import escape_html +from cherrypy._cpcompat import ntob +from cherrypy._cpcompat import tonative +from cherrypy._helper import classproperty +from cherrypy.lib import httputil as _httputil + + +class CherryPyException(Exception): + + """A base class for CherryPy exceptions.""" + pass + + +class InternalRedirect(CherryPyException): + + """Exception raised to switch to the handler for a different URL. + + This exception will redirect processing to another path within the site + (without informing the client). Provide the new path as an argument when + raising the exception. Provide any params in the querystring for the new + URL. + """ + + def __init__(self, path, query_string=''): + self.request = cherrypy.serving.request + + self.query_string = query_string + if '?' in path: + # Separate any params included in the path + path, self.query_string = path.split('?', 1) + + # Note that urljoin will "do the right thing" whether url is: + # 1. a URL relative to root (e.g. "/dummy") + # 2. a URL relative to the current path + # Note that any query string will be discarded. + path = urllib.parse.urljoin(self.request.path_info, path) + + # Set a 'path' member attribute so that code which traps this + # error can have access to it. + self.path = path + + CherryPyException.__init__(self, path, self.query_string) + + +class HTTPRedirect(CherryPyException): + + """Exception raised when the request should be redirected. + + This exception will force a HTTP redirect to the URL or URL's you give it. + The new URL must be passed as the first argument to the Exception, + e.g., HTTPRedirect(newUrl). Multiple URLs are allowed in a list. + If a URL is absolute, it will be used as-is. If it is relative, it is + assumed to be relative to the current cherrypy.request.path_info. + + If one of the provided URL is a unicode object, it will be encoded + using the default encoding or the one passed in parameter. + + There are multiple types of redirect, from which you can select via the + ``status`` argument. If you do not provide a ``status`` arg, it defaults to + 303 (or 302 if responding with HTTP/1.0). + + Examples:: + + raise cherrypy.HTTPRedirect("") + raise cherrypy.HTTPRedirect("/abs/path", 307) + raise cherrypy.HTTPRedirect(["path1", "path2?a=1&b=2"], 301) + + See :ref:`redirectingpost` for additional caveats. + """ + + urls = None + """The list of URL's to emit.""" + + encoding = 'utf-8' + """The encoding when passed urls are not native strings""" + + def __init__(self, urls, status=None, encoding=None): + self.urls = abs_urls = [ + # Note that urljoin will "do the right thing" whether url is: + # 1. a complete URL with host (e.g. "http://www.example.com/test") + # 2. a URL relative to root (e.g. "/dummy") + # 3. a URL relative to the current path + # Note that any query string in cherrypy.request is discarded. + urllib.parse.urljoin( + cherrypy.url(), + tonative(url, encoding or self.encoding), + ) + for url in always_iterable(urls) + ] + + status = ( + int(status) + if status is not None + else self.default_status + ) + if not 300 <= status <= 399: + raise ValueError('status must be between 300 and 399.') + + CherryPyException.__init__(self, abs_urls, status) + + @classproperty + def default_status(cls): + """ + The default redirect status for the request. + + RFC 2616 indicates a 301 response code fits our goal; however, + browser support for 301 is quite messy. Use 302/303 instead. See + http://www.alanflavell.org.uk/www/post-redirect.html + """ + return 303 if cherrypy.serving.request.protocol >= (1, 1) else 302 + + @property + def status(self): + """The integer HTTP status code to emit.""" + _, status = self.args[:2] + return status + + def set_response(self): + """Modify cherrypy.response status, headers, and body to represent + self. + + CherryPy uses this internally, but you can also use it to create an + HTTPRedirect object and set its output without *raising* the exception. + """ + response = cherrypy.serving.response + response.status = status = self.status + + if status in (300, 301, 302, 303, 307): + response.headers['Content-Type'] = 'text/html;charset=utf-8' + # "The ... URI SHOULD be given by the Location field + # in the response." + response.headers['Location'] = self.urls[0] + + # "Unless the request method was HEAD, the entity of the response + # SHOULD contain a short hypertext note with a hyperlink to the + # new URI(s)." + msg = { + 300: 'This resource can be found at ', + 301: 'This resource has permanently moved to ', + 302: 'This resource resides temporarily at ', + 303: 'This resource can be found at ', + 307: 'This resource has moved temporarily to ', + }[status] + msg += '<a href=%s>%s</a>.' + msgs = [ + msg % (saxutils.quoteattr(u), escape_html(u)) + for u in self.urls + ] + response.body = ntob('<br />\n'.join(msgs), 'utf-8') + # Previous code may have set C-L, so we have to reset it + # (allow finalize to set it). + response.headers.pop('Content-Length', None) + elif status == 304: + # Not Modified. + # "The response MUST include the following header fields: + # Date, unless its omission is required by section 14.18.1" + # The "Date" header should have been set in Response.__init__ + + # "...the response SHOULD NOT include other entity-headers." + for key in ('Allow', 'Content-Encoding', 'Content-Language', + 'Content-Length', 'Content-Location', 'Content-MD5', + 'Content-Range', 'Content-Type', 'Expires', + 'Last-Modified'): + if key in response.headers: + del response.headers[key] + + # "The 304 response MUST NOT contain a message-body." + response.body = None + # Previous code may have set C-L, so we have to reset it. + response.headers.pop('Content-Length', None) + elif status == 305: + # Use Proxy. + # self.urls[0] should be the URI of the proxy. + response.headers['Location'] = ntob(self.urls[0], 'utf-8') + response.body = None + # Previous code may have set C-L, so we have to reset it. + response.headers.pop('Content-Length', None) + else: + raise ValueError('The %s status code is unknown.' % status) + + def __call__(self): + """Use this exception as a request.handler (raise self).""" + raise self + + +def clean_headers(status): + """Remove any headers which should not apply to an error response.""" + response = cherrypy.serving.response + + # Remove headers which applied to the original content, + # but do not apply to the error page. + respheaders = response.headers + for key in ['Accept-Ranges', 'Age', 'ETag', 'Location', 'Retry-After', + 'Vary', 'Content-Encoding', 'Content-Length', 'Expires', + 'Content-Location', 'Content-MD5', 'Last-Modified']: + if key in respheaders: + del respheaders[key] + + if status != 416: + # A server sending a response with status code 416 (Requested + # range not satisfiable) SHOULD include a Content-Range field + # with a byte-range-resp-spec of "*". The instance-length + # specifies the current length of the selected resource. + # A response with status code 206 (Partial Content) MUST NOT + # include a Content-Range field with a byte-range- resp-spec of "*". + if 'Content-Range' in respheaders: + del respheaders['Content-Range'] + + +class HTTPError(CherryPyException): + + """Exception used to return an HTTP error code (4xx-5xx) to the client. + + This exception can be used to automatically send a response using a + http status code, with an appropriate error page. It takes an optional + ``status`` argument (which must be between 400 and 599); it defaults to 500 + ("Internal Server Error"). It also takes an optional ``message`` argument, + which will be returned in the response body. See + `RFC2616 <http://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.4>`_ + for a complete list of available error codes and when to use them. + + Examples:: + + raise cherrypy.HTTPError(403) + raise cherrypy.HTTPError( + "403 Forbidden", "You are not allowed to access this resource.") + """ + + status = None + """The HTTP status code. May be of type int or str (with a Reason-Phrase). + """ + + code = None + """The integer HTTP status code.""" + + reason = None + """The HTTP Reason-Phrase string.""" + + def __init__(self, status=500, message=None): + self.status = status + try: + self.code, self.reason, defaultmsg = _httputil.valid_status(status) + except ValueError: + raise self.__class__(500, _exc_info()[1].args[0]) + + if self.code < 400 or self.code > 599: + raise ValueError('status must be between 400 and 599.') + + # See http://www.python.org/dev/peps/pep-0352/ + # self.message = message + self._message = message or defaultmsg + CherryPyException.__init__(self, status, message) + + def set_response(self): + """Modify cherrypy.response status, headers, and body to represent + self. + + CherryPy uses this internally, but you can also use it to create an + HTTPError object and set its output without *raising* the exception. + """ + response = cherrypy.serving.response + + clean_headers(self.code) + + # In all cases, finalize will be called after this method, + # so don't bother cleaning up response values here. + response.status = self.status + tb = None + if cherrypy.serving.request.show_tracebacks: + tb = format_exc() + + response.headers.pop('Content-Length', None) + + content = self.get_error_page(self.status, traceback=tb, + message=self._message) + response.body = content + + _be_ie_unfriendly(self.code) + + def get_error_page(self, *args, **kwargs): + return get_error_page(*args, **kwargs) + + def __call__(self): + """Use this exception as a request.handler (raise self).""" + raise self + + @classmethod + @contextlib.contextmanager + def handle(cls, exception, status=500, message=''): + """Translate exception into an HTTPError.""" + try: + yield + except exception as exc: + raise cls(status, message or str(exc)) + + +class NotFound(HTTPError): + + """Exception raised when a URL could not be mapped to any handler (404). + + This is equivalent to raising + :class:`HTTPError("404 Not Found") <cherrypy._cperror.HTTPError>`. + """ + + def __init__(self, path=None): + if path is None: + request = cherrypy.serving.request + path = request.script_name + request.path_info + self.args = (path,) + HTTPError.__init__(self, 404, "The path '%s' was not found." % path) + + +_HTTPErrorTemplate = '''<!DOCTYPE html PUBLIC +"-//W3C//DTD XHTML 1.0 Transitional//EN" +"http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd"> +<html> +<head> + <meta http-equiv="Content-Type" content="text/html; charset=utf-8"></meta> + <title>%(status)s</title> + <style type="text/css"> + #powered_by { + margin-top: 20px; + border-top: 2px solid black; + font-style: italic; + } + + #traceback { + color: red; + } + </style> +</head> + <body> + <h2>%(status)s</h2> + <p>%(message)s</p> + <pre id="traceback">%(traceback)s</pre> + <div id="powered_by"> + <span> + Powered by <a href="http://www.cherrypy.org">CherryPy %(version)s</a> + </span> + </div> + </body> +</html> +''' + + +def get_error_page(status, **kwargs): + """Return an HTML page, containing a pretty error response. + + status should be an int or a str. + kwargs will be interpolated into the page template. + """ + try: + code, reason, message = _httputil.valid_status(status) + except ValueError: + raise cherrypy.HTTPError(500, _exc_info()[1].args[0]) + + # We can't use setdefault here, because some + # callers send None for kwarg values. + if kwargs.get('status') is None: + kwargs['status'] = '%s %s' % (code, reason) + if kwargs.get('message') is None: + kwargs['message'] = message + if kwargs.get('traceback') is None: + kwargs['traceback'] = '' + if kwargs.get('version') is None: + kwargs['version'] = cherrypy.__version__ + + for k, v in six.iteritems(kwargs): + if v is None: + kwargs[k] = '' + else: + kwargs[k] = escape_html(kwargs[k]) + + # Use a custom template or callable for the error page? + pages = cherrypy.serving.request.error_page + error_page = pages.get(code) or pages.get('default') + + # Default template, can be overridden below. + template = _HTTPErrorTemplate + if error_page: + try: + if hasattr(error_page, '__call__'): + # The caller function may be setting headers manually, + # so we delegate to it completely. We may be returning + # an iterator as well as a string here. + # + # We *must* make sure any content is not unicode. + result = error_page(**kwargs) + if cherrypy.lib.is_iterator(result): + from cherrypy.lib.encoding import UTF8StreamEncoder + return UTF8StreamEncoder(result) + elif isinstance(result, six.text_type): + return result.encode('utf-8') + else: + if not isinstance(result, bytes): + raise ValueError( + 'error page function did not ' + 'return a bytestring, six.text_type or an ' + 'iterator - returned object of type %s.' + % (type(result).__name__)) + return result + else: + # Load the template from this path. + template = io.open(error_page, newline='').read() + except Exception: + e = _format_exception(*_exc_info())[-1] + m = kwargs['message'] + if m: + m += '<br />' + m += 'In addition, the custom error page failed:\n<br />%s' % e + kwargs['message'] = m + + response = cherrypy.serving.response + response.headers['Content-Type'] = 'text/html;charset=utf-8' + result = template % kwargs + return result.encode('utf-8') + + +_ie_friendly_error_sizes = { + 400: 512, 403: 256, 404: 512, 405: 256, + 406: 512, 408: 512, 409: 512, 410: 256, + 500: 512, 501: 512, 505: 512, +} + + +def _be_ie_unfriendly(status): + response = cherrypy.serving.response + + # For some statuses, Internet Explorer 5+ shows "friendly error + # messages" instead of our response.body if the body is smaller + # than a given size. Fix this by returning a body over that size + # (by adding whitespace). + # See http://support.microsoft.com/kb/q218155/ + s = _ie_friendly_error_sizes.get(status, 0) + if s: + s += 1 + # Since we are issuing an HTTP error status, we assume that + # the entity is short, and we should just collapse it. + content = response.collapse_body() + content_length = len(content) + if content_length and content_length < s: + # IN ADDITION: the response must be written to IE + # in one chunk or it will still get replaced! Bah. + content = content + (b' ' * (s - content_length)) + response.body = content + response.headers['Content-Length'] = str(len(content)) + + +def format_exc(exc=None): + """Return exc (or sys.exc_info if None), formatted.""" + try: + if exc is None: + exc = _exc_info() + if exc == (None, None, None): + return '' + import traceback + return ''.join(traceback.format_exception(*exc)) + finally: + del exc + + +def bare_error(extrabody=None): + """Produce status, headers, body for a critical error. + + Returns a triple without calling any other questionable functions, + so it should be as error-free as possible. Call it from an HTTP server + if you get errors outside of the request. + + If extrabody is None, a friendly but rather unhelpful error message + is set in the body. If extrabody is a string, it will be appended + as-is to the body. + """ + + # The whole point of this function is to be a last line-of-defense + # in handling errors. That is, it must not raise any errors itself; + # it cannot be allowed to fail. Therefore, don't add to it! + # In particular, don't call any other CP functions. + + body = b'Unrecoverable error in the server.' + if extrabody is not None: + if not isinstance(extrabody, bytes): + extrabody = extrabody.encode('utf-8') + body += b'\n' + extrabody + + return (b'500 Internal Server Error', + [(b'Content-Type', b'text/plain'), + (b'Content-Length', ntob(str(len(body)), 'ISO-8859-1'))], + [body]) diff --git a/libraries/cherrypy/_cplogging.py b/libraries/cherrypy/_cplogging.py new file mode 100644 index 00000000..53b9addb --- /dev/null +++ b/libraries/cherrypy/_cplogging.py @@ -0,0 +1,482 @@ +""" +Simple config +============= + +Although CherryPy uses the :mod:`Python logging module <logging>`, it does so +behind the scenes so that simple logging is simple, but complicated logging +is still possible. "Simple" logging means that you can log to the screen +(i.e. console/stdout) or to a file, and that you can easily have separate +error and access log files. + +Here are the simplified logging settings. You use these by adding lines to +your config file or dict. You should set these at either the global level or +per application (see next), but generally not both. + + * ``log.screen``: Set this to True to have both "error" and "access" messages + printed to stdout. + * ``log.access_file``: Set this to an absolute filename where you want + "access" messages written. + * ``log.error_file``: Set this to an absolute filename where you want "error" + messages written. + +Many events are automatically logged; to log your own application events, call +:func:`cherrypy.log`. + +Architecture +============ + +Separate scopes +--------------- + +CherryPy provides log managers at both the global and application layers. +This means you can have one set of logging rules for your entire site, +and another set of rules specific to each application. The global log +manager is found at :func:`cherrypy.log`, and the log manager for each +application is found at :attr:`app.log<cherrypy._cptree.Application.log>`. +If you're inside a request, the latter is reachable from +``cherrypy.request.app.log``; if you're outside a request, you'll have to +obtain a reference to the ``app``: either the return value of +:func:`tree.mount()<cherrypy._cptree.Tree.mount>` or, if you used +:func:`quickstart()<cherrypy.quickstart>` instead, via +``cherrypy.tree.apps['/']``. + +By default, the global logs are named "cherrypy.error" and "cherrypy.access", +and the application logs are named "cherrypy.error.2378745" and +"cherrypy.access.2378745" (the number is the id of the Application object). +This means that the application logs "bubble up" to the site logs, so if your +application has no log handlers, the site-level handlers will still log the +messages. + +Errors vs. Access +----------------- + +Each log manager handles both "access" messages (one per HTTP request) and +"error" messages (everything else). Note that the "error" log is not just for +errors! The format of access messages is highly formalized, but the error log +isn't--it receives messages from a variety of sources (including full error +tracebacks, if enabled). + +If you are logging the access log and error log to the same source, then there +is a possibility that a specially crafted error message may replicate an access +log message as described in CWE-117. In this case it is the application +developer's responsibility to manually escape data before +using CherryPy's log() +functionality, or they may create an application that is vulnerable to CWE-117. +This would be achieved by using a custom handler escape any special characters, +and attached as described below. + +Custom Handlers +=============== + +The simple settings above work by manipulating Python's standard :mod:`logging` +module. So when you need something more complex, the full power of the standard +module is yours to exploit. You can borrow or create custom handlers, formats, +filters, and much more. Here's an example that skips the standard FileHandler +and uses a RotatingFileHandler instead: + +:: + + #python + log = app.log + + # Remove the default FileHandlers if present. + log.error_file = "" + log.access_file = "" + + maxBytes = getattr(log, "rot_maxBytes", 10000000) + backupCount = getattr(log, "rot_backupCount", 1000) + + # Make a new RotatingFileHandler for the error log. + fname = getattr(log, "rot_error_file", "error.log") + h = handlers.RotatingFileHandler(fname, 'a', maxBytes, backupCount) + h.setLevel(DEBUG) + h.setFormatter(_cplogging.logfmt) + log.error_log.addHandler(h) + + # Make a new RotatingFileHandler for the access log. + fname = getattr(log, "rot_access_file", "access.log") + h = handlers.RotatingFileHandler(fname, 'a', maxBytes, backupCount) + h.setLevel(DEBUG) + h.setFormatter(_cplogging.logfmt) + log.access_log.addHandler(h) + + +The ``rot_*`` attributes are pulled straight from the application log object. +Since "log.*" config entries simply set attributes on the log object, you can +add custom attributes to your heart's content. Note that these handlers are +used ''instead'' of the default, simple handlers outlined above (so don't set +the "log.error_file" config entry, for example). +""" + +import datetime +import logging +import os +import sys + +import six + +import cherrypy +from cherrypy import _cperror + + +# Silence the no-handlers "warning" (stderr write!) in stdlib logging +logging.Logger.manager.emittedNoHandlerWarning = 1 +logfmt = logging.Formatter('%(message)s') + + +class NullHandler(logging.Handler): + + """A no-op logging handler to silence the logging.lastResort handler.""" + + def handle(self, record): + pass + + def emit(self, record): + pass + + def createLock(self): + self.lock = None + + +class LogManager(object): + + """An object to assist both simple and advanced logging. + + ``cherrypy.log`` is an instance of this class. + """ + + appid = None + """The id() of the Application object which owns this log manager. If this + is a global log manager, appid is None.""" + + error_log = None + """The actual :class:`logging.Logger` instance for error messages.""" + + access_log = None + """The actual :class:`logging.Logger` instance for access messages.""" + + access_log_format = ( + '{h} {l} {u} {t} "{r}" {s} {b} "{f}" "{a}"' + if six.PY3 else + '%(h)s %(l)s %(u)s %(t)s "%(r)s" %(s)s %(b)s "%(f)s" "%(a)s"' + ) + + logger_root = None + """The "top-level" logger name. + + This string will be used as the first segment in the Logger names. + The default is "cherrypy", for example, in which case the Logger names + will be of the form:: + + cherrypy.error.<appid> + cherrypy.access.<appid> + """ + + def __init__(self, appid=None, logger_root='cherrypy'): + self.logger_root = logger_root + self.appid = appid + if appid is None: + self.error_log = logging.getLogger('%s.error' % logger_root) + self.access_log = logging.getLogger('%s.access' % logger_root) + else: + self.error_log = logging.getLogger( + '%s.error.%s' % (logger_root, appid)) + self.access_log = logging.getLogger( + '%s.access.%s' % (logger_root, appid)) + self.error_log.setLevel(logging.INFO) + self.access_log.setLevel(logging.INFO) + + # Silence the no-handlers "warning" (stderr write!) in stdlib logging + self.error_log.addHandler(NullHandler()) + self.access_log.addHandler(NullHandler()) + + cherrypy.engine.subscribe('graceful', self.reopen_files) + + def reopen_files(self): + """Close and reopen all file handlers.""" + for log in (self.error_log, self.access_log): + for h in log.handlers: + if isinstance(h, logging.FileHandler): + h.acquire() + h.stream.close() + h.stream = open(h.baseFilename, h.mode) + h.release() + + def error(self, msg='', context='', severity=logging.INFO, + traceback=False): + """Write the given ``msg`` to the error log. + + This is not just for errors! Applications may call this at any time + to log application-specific information. + + If ``traceback`` is True, the traceback of the current exception + (if any) will be appended to ``msg``. + """ + exc_info = None + if traceback: + exc_info = _cperror._exc_info() + + self.error_log.log( + severity, + ' '.join((self.time(), context, msg)), + exc_info=exc_info, + ) + + def __call__(self, *args, **kwargs): + """An alias for ``error``.""" + return self.error(*args, **kwargs) + + def access(self): + """Write to the access log (in Apache/NCSA Combined Log format). + + See the + `apache documentation + <http://httpd.apache.org/docs/current/logs.html#combined>`_ + for format details. + + CherryPy calls this automatically for you. Note there are no arguments; + it collects the data itself from + :class:`cherrypy.request<cherrypy._cprequest.Request>`. + + Like Apache started doing in 2.0.46, non-printable and other special + characters in %r (and we expand that to all parts) are escaped using + \\xhh sequences, where hh stands for the hexadecimal representation + of the raw byte. Exceptions from this rule are " and \\, which are + escaped by prepending a backslash, and all whitespace characters, + which are written in their C-style notation (\\n, \\t, etc). + """ + request = cherrypy.serving.request + remote = request.remote + response = cherrypy.serving.response + outheaders = response.headers + inheaders = request.headers + if response.output_status is None: + status = '-' + else: + status = response.output_status.split(b' ', 1)[0] + if six.PY3: + status = status.decode('ISO-8859-1') + + atoms = {'h': remote.name or remote.ip, + 'l': '-', + 'u': getattr(request, 'login', None) or '-', + 't': self.time(), + 'r': request.request_line, + 's': status, + 'b': dict.get(outheaders, 'Content-Length', '') or '-', + 'f': dict.get(inheaders, 'Referer', ''), + 'a': dict.get(inheaders, 'User-Agent', ''), + 'o': dict.get(inheaders, 'Host', '-'), + 'i': request.unique_id, + 'z': LazyRfc3339UtcTime(), + } + if six.PY3: + for k, v in atoms.items(): + if not isinstance(v, str): + v = str(v) + v = v.replace('"', '\\"').encode('utf8') + # Fortunately, repr(str) escapes unprintable chars, \n, \t, etc + # and backslash for us. All we have to do is strip the quotes. + v = repr(v)[2:-1] + + # in python 3.0 the repr of bytes (as returned by encode) + # uses double \'s. But then the logger escapes them yet, again + # resulting in quadruple slashes. Remove the extra one here. + v = v.replace('\\\\', '\\') + + # Escape double-quote. + atoms[k] = v + + try: + self.access_log.log( + logging.INFO, self.access_log_format.format(**atoms)) + except Exception: + self(traceback=True) + else: + for k, v in atoms.items(): + if isinstance(v, six.text_type): + v = v.encode('utf8') + elif not isinstance(v, str): + v = str(v) + # Fortunately, repr(str) escapes unprintable chars, \n, \t, etc + # and backslash for us. All we have to do is strip the quotes. + v = repr(v)[1:-1] + # Escape double-quote. + atoms[k] = v.replace('"', '\\"') + + try: + self.access_log.log( + logging.INFO, self.access_log_format % atoms) + except Exception: + self(traceback=True) + + def time(self): + """Return now() in Apache Common Log Format (no timezone).""" + now = datetime.datetime.now() + monthnames = ['jan', 'feb', 'mar', 'apr', 'may', 'jun', + 'jul', 'aug', 'sep', 'oct', 'nov', 'dec'] + month = monthnames[now.month - 1].capitalize() + return ('[%02d/%s/%04d:%02d:%02d:%02d]' % + (now.day, month, now.year, now.hour, now.minute, now.second)) + + def _get_builtin_handler(self, log, key): + for h in log.handlers: + if getattr(h, '_cpbuiltin', None) == key: + return h + + # ------------------------- Screen handlers ------------------------- # + def _set_screen_handler(self, log, enable, stream=None): + h = self._get_builtin_handler(log, 'screen') + if enable: + if not h: + if stream is None: + stream = sys.stderr + h = logging.StreamHandler(stream) + h.setFormatter(logfmt) + h._cpbuiltin = 'screen' + log.addHandler(h) + elif h: + log.handlers.remove(h) + + @property + def screen(self): + """Turn stderr/stdout logging on or off. + + If you set this to True, it'll add the appropriate StreamHandler for + you. If you set it to False, it will remove the handler. + """ + h = self._get_builtin_handler + has_h = h(self.error_log, 'screen') or h(self.access_log, 'screen') + return bool(has_h) + + @screen.setter + def screen(self, newvalue): + self._set_screen_handler(self.error_log, newvalue, stream=sys.stderr) + self._set_screen_handler(self.access_log, newvalue, stream=sys.stdout) + + # -------------------------- File handlers -------------------------- # + + def _add_builtin_file_handler(self, log, fname): + h = logging.FileHandler(fname) + h.setFormatter(logfmt) + h._cpbuiltin = 'file' + log.addHandler(h) + + def _set_file_handler(self, log, filename): + h = self._get_builtin_handler(log, 'file') + if filename: + if h: + if h.baseFilename != os.path.abspath(filename): + h.close() + log.handlers.remove(h) + self._add_builtin_file_handler(log, filename) + else: + self._add_builtin_file_handler(log, filename) + else: + if h: + h.close() + log.handlers.remove(h) + + @property + def error_file(self): + """The filename for self.error_log. + + If you set this to a string, it'll add the appropriate FileHandler for + you. If you set it to ``None`` or ``''``, it will remove the handler. + """ + h = self._get_builtin_handler(self.error_log, 'file') + if h: + return h.baseFilename + return '' + + @error_file.setter + def error_file(self, newvalue): + self._set_file_handler(self.error_log, newvalue) + + @property + def access_file(self): + """The filename for self.access_log. + + If you set this to a string, it'll add the appropriate FileHandler for + you. If you set it to ``None`` or ``''``, it will remove the handler. + """ + h = self._get_builtin_handler(self.access_log, 'file') + if h: + return h.baseFilename + return '' + + @access_file.setter + def access_file(self, newvalue): + self._set_file_handler(self.access_log, newvalue) + + # ------------------------- WSGI handlers ------------------------- # + + def _set_wsgi_handler(self, log, enable): + h = self._get_builtin_handler(log, 'wsgi') + if enable: + if not h: + h = WSGIErrorHandler() + h.setFormatter(logfmt) + h._cpbuiltin = 'wsgi' + log.addHandler(h) + elif h: + log.handlers.remove(h) + + @property + def wsgi(self): + """Write errors to wsgi.errors. + + If you set this to True, it'll add the appropriate + :class:`WSGIErrorHandler<cherrypy._cplogging.WSGIErrorHandler>` for you + (which writes errors to ``wsgi.errors``). + If you set it to False, it will remove the handler. + """ + return bool(self._get_builtin_handler(self.error_log, 'wsgi')) + + @wsgi.setter + def wsgi(self, newvalue): + self._set_wsgi_handler(self.error_log, newvalue) + + +class WSGIErrorHandler(logging.Handler): + + "A handler class which writes logging records to environ['wsgi.errors']." + + def flush(self): + """Flushes the stream.""" + try: + stream = cherrypy.serving.request.wsgi_environ.get('wsgi.errors') + except (AttributeError, KeyError): + pass + else: + stream.flush() + + def emit(self, record): + """Emit a record.""" + try: + stream = cherrypy.serving.request.wsgi_environ.get('wsgi.errors') + except (AttributeError, KeyError): + pass + else: + try: + msg = self.format(record) + fs = '%s\n' + import types + # if no unicode support... + if not hasattr(types, 'UnicodeType'): + stream.write(fs % msg) + else: + try: + stream.write(fs % msg) + except UnicodeError: + stream.write(fs % msg.encode('UTF-8')) + self.flush() + except Exception: + self.handleError(record) + + +class LazyRfc3339UtcTime(object): + def __str__(self): + """Return now() in RFC3339 UTC Format.""" + now = datetime.datetime.now() + return now.isoformat('T') + 'Z' diff --git a/libraries/cherrypy/_cpmodpy.py b/libraries/cherrypy/_cpmodpy.py new file mode 100644 index 00000000..ac91e625 --- /dev/null +++ b/libraries/cherrypy/_cpmodpy.py @@ -0,0 +1,356 @@ +"""Native adapter for serving CherryPy via mod_python + +Basic usage: + +########################################## +# Application in a module called myapp.py +########################################## + +import cherrypy + +class Root: + @cherrypy.expose + def index(self): + return 'Hi there, Ho there, Hey there' + + +# We will use this method from the mod_python configuration +# as the entry point to our application +def setup_server(): + cherrypy.tree.mount(Root()) + cherrypy.config.update({'environment': 'production', + 'log.screen': False, + 'show_tracebacks': False}) + +########################################## +# mod_python settings for apache2 +# This should reside in your httpd.conf +# or a file that will be loaded at +# apache startup +########################################## + +# Start +DocumentRoot "/" +Listen 8080 +LoadModule python_module /usr/lib/apache2/modules/mod_python.so + +<Location "/"> + PythonPath "sys.path+['/path/to/my/application']" + SetHandler python-program + PythonHandler cherrypy._cpmodpy::handler + PythonOption cherrypy.setup myapp::setup_server + PythonDebug On +</Location> +# End + +The actual path to your mod_python.so is dependent on your +environment. In this case we suppose a global mod_python +installation on a Linux distribution such as Ubuntu. + +We do set the PythonPath configuration setting so that +your application can be found by from the user running +the apache2 instance. Of course if your application +resides in the global site-package this won't be needed. + +Then restart apache2 and access http://127.0.0.1:8080 +""" + +import io +import logging +import os +import re +import sys + +import six + +from more_itertools import always_iterable + +import cherrypy +from cherrypy._cperror import format_exc, bare_error +from cherrypy.lib import httputil + + +# ------------------------------ Request-handling + + +def setup(req): + from mod_python import apache + + # Run any setup functions defined by a "PythonOption cherrypy.setup" + # directive. + options = req.get_options() + if 'cherrypy.setup' in options: + for function in options['cherrypy.setup'].split(): + atoms = function.split('::', 1) + if len(atoms) == 1: + mod = __import__(atoms[0], globals(), locals()) + else: + modname, fname = atoms + mod = __import__(modname, globals(), locals(), [fname]) + func = getattr(mod, fname) + func() + + cherrypy.config.update({'log.screen': False, + 'tools.ignore_headers.on': True, + 'tools.ignore_headers.headers': ['Range'], + }) + + engine = cherrypy.engine + if hasattr(engine, 'signal_handler'): + engine.signal_handler.unsubscribe() + if hasattr(engine, 'console_control_handler'): + engine.console_control_handler.unsubscribe() + engine.autoreload.unsubscribe() + cherrypy.server.unsubscribe() + + @engine.subscribe('log') + def _log(msg, level): + newlevel = apache.APLOG_ERR + if logging.DEBUG >= level: + newlevel = apache.APLOG_DEBUG + elif logging.INFO >= level: + newlevel = apache.APLOG_INFO + elif logging.WARNING >= level: + newlevel = apache.APLOG_WARNING + # On Windows, req.server is required or the msg will vanish. See + # http://www.modpython.org/pipermail/mod_python/2003-October/014291.html + # Also, "When server is not specified...LogLevel does not apply..." + apache.log_error(msg, newlevel, req.server) + + engine.start() + + def cherrypy_cleanup(data): + engine.exit() + try: + # apache.register_cleanup wasn't available until 3.1.4. + apache.register_cleanup(cherrypy_cleanup) + except AttributeError: + req.server.register_cleanup(req, cherrypy_cleanup) + + +class _ReadOnlyRequest: + expose = ('read', 'readline', 'readlines') + + def __init__(self, req): + for method in self.expose: + self.__dict__[method] = getattr(req, method) + + +recursive = False + +_isSetUp = False + + +def handler(req): + from mod_python import apache + try: + global _isSetUp + if not _isSetUp: + setup(req) + _isSetUp = True + + # Obtain a Request object from CherryPy + local = req.connection.local_addr + local = httputil.Host( + local[0], local[1], req.connection.local_host or '') + remote = req.connection.remote_addr + remote = httputil.Host( + remote[0], remote[1], req.connection.remote_host or '') + + scheme = req.parsed_uri[0] or 'http' + req.get_basic_auth_pw() + + try: + # apache.mpm_query only became available in mod_python 3.1 + q = apache.mpm_query + threaded = q(apache.AP_MPMQ_IS_THREADED) + forked = q(apache.AP_MPMQ_IS_FORKED) + except AttributeError: + bad_value = ("You must provide a PythonOption '%s', " + "either 'on' or 'off', when running a version " + 'of mod_python < 3.1') + + options = req.get_options() + + threaded = options.get('multithread', '').lower() + if threaded == 'on': + threaded = True + elif threaded == 'off': + threaded = False + else: + raise ValueError(bad_value % 'multithread') + + forked = options.get('multiprocess', '').lower() + if forked == 'on': + forked = True + elif forked == 'off': + forked = False + else: + raise ValueError(bad_value % 'multiprocess') + + sn = cherrypy.tree.script_name(req.uri or '/') + if sn is None: + send_response(req, '404 Not Found', [], '') + else: + app = cherrypy.tree.apps[sn] + method = req.method + path = req.uri + qs = req.args or '' + reqproto = req.protocol + headers = list(six.iteritems(req.headers_in)) + rfile = _ReadOnlyRequest(req) + prev = None + + try: + redirections = [] + while True: + request, response = app.get_serving(local, remote, scheme, + 'HTTP/1.1') + request.login = req.user + request.multithread = bool(threaded) + request.multiprocess = bool(forked) + request.app = app + request.prev = prev + + # Run the CherryPy Request object and obtain the response + try: + request.run(method, path, qs, reqproto, headers, rfile) + break + except cherrypy.InternalRedirect: + ir = sys.exc_info()[1] + app.release_serving() + prev = request + + if not recursive: + if ir.path in redirections: + raise RuntimeError( + 'InternalRedirector visited the same URL ' + 'twice: %r' % ir.path) + else: + # Add the *previous* path_info + qs to + # redirections. + if qs: + qs = '?' + qs + redirections.append(sn + path + qs) + + # Munge environment and try again. + method = 'GET' + path = ir.path + qs = ir.query_string + rfile = io.BytesIO() + + send_response( + req, response.output_status, response.header_list, + response.body, response.stream) + finally: + app.release_serving() + except Exception: + tb = format_exc() + cherrypy.log(tb, 'MOD_PYTHON', severity=logging.ERROR) + s, h, b = bare_error() + send_response(req, s, h, b) + return apache.OK + + +def send_response(req, status, headers, body, stream=False): + # Set response status + req.status = int(status[:3]) + + # Set response headers + req.content_type = 'text/plain' + for header, value in headers: + if header.lower() == 'content-type': + req.content_type = value + continue + req.headers_out.add(header, value) + + if stream: + # Flush now so the status and headers are sent immediately. + req.flush() + + # Set response body + for seg in always_iterable(body): + req.write(seg) + + +# --------------- Startup tools for CherryPy + mod_python --------------- # +try: + import subprocess + + def popen(fullcmd): + p = subprocess.Popen(fullcmd, shell=True, + stdout=subprocess.PIPE, stderr=subprocess.STDOUT, + close_fds=True) + return p.stdout +except ImportError: + def popen(fullcmd): + pipein, pipeout = os.popen4(fullcmd) + return pipeout + + +def read_process(cmd, args=''): + fullcmd = '%s %s' % (cmd, args) + pipeout = popen(fullcmd) + try: + firstline = pipeout.readline() + cmd_not_found = re.search( + b'(not recognized|No such file|not found)', + firstline, + re.IGNORECASE + ) + if cmd_not_found: + raise IOError('%s must be on your system path.' % cmd) + output = firstline + pipeout.read() + finally: + pipeout.close() + return output + + +class ModPythonServer(object): + + template = """ +# Apache2 server configuration file for running CherryPy with mod_python. + +DocumentRoot "/" +Listen %(port)s +LoadModule python_module modules/mod_python.so + +<Location %(loc)s> + SetHandler python-program + PythonHandler %(handler)s + PythonDebug On +%(opts)s +</Location> +""" + + def __init__(self, loc='/', port=80, opts=None, apache_path='apache', + handler='cherrypy._cpmodpy::handler'): + self.loc = loc + self.port = port + self.opts = opts + self.apache_path = apache_path + self.handler = handler + + def start(self): + opts = ''.join([' PythonOption %s %s\n' % (k, v) + for k, v in self.opts]) + conf_data = self.template % {'port': self.port, + 'loc': self.loc, + 'opts': opts, + 'handler': self.handler, + } + + mpconf = os.path.join(os.path.dirname(__file__), 'cpmodpy.conf') + f = open(mpconf, 'wb') + try: + f.write(conf_data) + finally: + f.close() + + response = read_process(self.apache_path, '-k start -f %s' % mpconf) + self.ready = True + return response + + def stop(self): + os.popen('apache -k stop') + self.ready = False diff --git a/libraries/cherrypy/_cpnative_server.py b/libraries/cherrypy/_cpnative_server.py new file mode 100644 index 00000000..55653c35 --- /dev/null +++ b/libraries/cherrypy/_cpnative_server.py @@ -0,0 +1,160 @@ +"""Native adapter for serving CherryPy via its builtin server.""" + +import logging +import sys +import io + +import cheroot.server + +import cherrypy +from cherrypy._cperror import format_exc, bare_error +from cherrypy.lib import httputil + + +class NativeGateway(cheroot.server.Gateway): + """Native gateway implementation allowing to bypass WSGI.""" + + recursive = False + + def respond(self): + """Obtain response from CherryPy machinery and then send it.""" + req = self.req + try: + # Obtain a Request object from CherryPy + local = req.server.bind_addr + local = httputil.Host(local[0], local[1], '') + remote = req.conn.remote_addr, req.conn.remote_port + remote = httputil.Host(remote[0], remote[1], '') + + scheme = req.scheme + sn = cherrypy.tree.script_name(req.uri or '/') + if sn is None: + self.send_response('404 Not Found', [], ['']) + else: + app = cherrypy.tree.apps[sn] + method = req.method + path = req.path + qs = req.qs or '' + headers = req.inheaders.items() + rfile = req.rfile + prev = None + + try: + redirections = [] + while True: + request, response = app.get_serving( + local, remote, scheme, 'HTTP/1.1') + request.multithread = True + request.multiprocess = False + request.app = app + request.prev = prev + + # Run the CherryPy Request object and obtain the + # response + try: + request.run(method, path, qs, + req.request_protocol, headers, rfile) + break + except cherrypy.InternalRedirect: + ir = sys.exc_info()[1] + app.release_serving() + prev = request + + if not self.recursive: + if ir.path in redirections: + raise RuntimeError( + 'InternalRedirector visited the same ' + 'URL twice: %r' % ir.path) + else: + # Add the *previous* path_info + qs to + # redirections. + if qs: + qs = '?' + qs + redirections.append(sn + path + qs) + + # Munge environment and try again. + method = 'GET' + path = ir.path + qs = ir.query_string + rfile = io.BytesIO() + + self.send_response( + response.output_status, response.header_list, + response.body) + finally: + app.release_serving() + except Exception: + tb = format_exc() + # print tb + cherrypy.log(tb, 'NATIVE_ADAPTER', severity=logging.ERROR) + s, h, b = bare_error() + self.send_response(s, h, b) + + def send_response(self, status, headers, body): + """Send response to HTTP request.""" + req = self.req + + # Set response status + req.status = status or b'500 Server Error' + + # Set response headers + for header, value in headers: + req.outheaders.append((header, value)) + if (req.ready and not req.sent_headers): + req.sent_headers = True + req.send_headers() + + # Set response body + for seg in body: + req.write(seg) + + +class CPHTTPServer(cheroot.server.HTTPServer): + """Wrapper for cheroot.server.HTTPServer. + + cheroot has been designed to not reference CherryPy in any way, + so that it can be used in other frameworks and applications. + Therefore, we wrap it here, so we can apply some attributes + from config -> cherrypy.server -> HTTPServer. + """ + + def __init__(self, server_adapter=cherrypy.server): + """Initialize CPHTTPServer.""" + self.server_adapter = server_adapter + + server_name = (self.server_adapter.socket_host or + self.server_adapter.socket_file or + None) + + cheroot.server.HTTPServer.__init__( + self, server_adapter.bind_addr, NativeGateway, + minthreads=server_adapter.thread_pool, + maxthreads=server_adapter.thread_pool_max, + server_name=server_name) + + self.max_request_header_size = ( + self.server_adapter.max_request_header_size or 0) + self.max_request_body_size = ( + self.server_adapter.max_request_body_size or 0) + self.request_queue_size = self.server_adapter.socket_queue_size + self.timeout = self.server_adapter.socket_timeout + self.shutdown_timeout = self.server_adapter.shutdown_timeout + self.protocol = self.server_adapter.protocol_version + self.nodelay = self.server_adapter.nodelay + + ssl_module = self.server_adapter.ssl_module or 'pyopenssl' + if self.server_adapter.ssl_context: + adapter_class = cheroot.server.get_ssl_adapter_class(ssl_module) + self.ssl_adapter = adapter_class( + self.server_adapter.ssl_certificate, + self.server_adapter.ssl_private_key, + self.server_adapter.ssl_certificate_chain, + self.server_adapter.ssl_ciphers) + self.ssl_adapter.context = self.server_adapter.ssl_context + elif self.server_adapter.ssl_certificate: + adapter_class = cheroot.server.get_ssl_adapter_class(ssl_module) + self.ssl_adapter = adapter_class( + self.server_adapter.ssl_certificate, + self.server_adapter.ssl_private_key, + self.server_adapter.ssl_certificate_chain, + self.server_adapter.ssl_ciphers) diff --git a/libraries/cherrypy/_cpreqbody.py b/libraries/cherrypy/_cpreqbody.py new file mode 100644 index 00000000..893fe5f5 --- /dev/null +++ b/libraries/cherrypy/_cpreqbody.py @@ -0,0 +1,1000 @@ +"""Request body processing for CherryPy. + +.. versionadded:: 3.2 + +Application authors have complete control over the parsing of HTTP request +entities. In short, +:attr:`cherrypy.request.body<cherrypy._cprequest.Request.body>` +is now always set to an instance of +:class:`RequestBody<cherrypy._cpreqbody.RequestBody>`, +and *that* class is a subclass of :class:`Entity<cherrypy._cpreqbody.Entity>`. + +When an HTTP request includes an entity body, it is often desirable to +provide that information to applications in a form other than the raw bytes. +Different content types demand different approaches. Examples: + + * For a GIF file, we want the raw bytes in a stream. + * An HTML form is better parsed into its component fields, and each text field + decoded from bytes to unicode. + * A JSON body should be deserialized into a Python dict or list. + +When the request contains a Content-Type header, the media type is used as a +key to look up a value in the +:attr:`request.body.processors<cherrypy._cpreqbody.Entity.processors>` dict. +If the full media +type is not found, then the major type is tried; for example, if no processor +is found for the 'image/jpeg' type, then we look for a processor for the +'image' types altogether. If neither the full type nor the major type has a +matching processor, then a default processor is used +(:func:`default_proc<cherrypy._cpreqbody.Entity.default_proc>`). For most +types, this means no processing is done, and the body is left unread as a +raw byte stream. Processors are configurable in an 'on_start_resource' hook. + +Some processors, especially those for the 'text' types, attempt to decode bytes +to unicode. If the Content-Type request header includes a 'charset' parameter, +this is used to decode the entity. Otherwise, one or more default charsets may +be attempted, although this decision is up to each processor. If a processor +successfully decodes an Entity or Part, it should set the +:attr:`charset<cherrypy._cpreqbody.Entity.charset>` attribute +on the Entity or Part to the name of the successful charset, so that +applications can easily re-encode or transcode the value if they wish. + +If the Content-Type of the request entity is of major type 'multipart', then +the above parsing process, and possibly a decoding process, is performed for +each part. + +For both the full entity and multipart parts, a Content-Disposition header may +be used to fill :attr:`name<cherrypy._cpreqbody.Entity.name>` and +:attr:`filename<cherrypy._cpreqbody.Entity.filename>` attributes on the +request.body or the Part. + +.. _custombodyprocessors: + +Custom Processors +================= + +You can add your own processors for any specific or major MIME type. Simply add +it to the :attr:`processors<cherrypy._cprequest.Entity.processors>` dict in a +hook/tool that runs at ``on_start_resource`` or ``before_request_body``. +Here's the built-in JSON tool for an example:: + + def json_in(force=True, debug=False): + request = cherrypy.serving.request + def json_processor(entity): + '''Read application/json data into request.json.''' + if not entity.headers.get("Content-Length", ""): + raise cherrypy.HTTPError(411) + + body = entity.fp.read() + try: + request.json = json_decode(body) + except ValueError: + raise cherrypy.HTTPError(400, 'Invalid JSON document') + if force: + request.body.processors.clear() + request.body.default_proc = cherrypy.HTTPError( + 415, 'Expected an application/json content type') + request.body.processors['application/json'] = json_processor + +We begin by defining a new ``json_processor`` function to stick in the +``processors`` dictionary. All processor functions take a single argument, +the ``Entity`` instance they are to process. It will be called whenever a +request is received (for those URI's where the tool is turned on) which +has a ``Content-Type`` of "application/json". + +First, it checks for a valid ``Content-Length`` (raising 411 if not valid), +then reads the remaining bytes on the socket. The ``fp`` object knows its +own length, so it won't hang waiting for data that never arrives. It will +return when all data has been read. Then, we decode those bytes using +Python's built-in ``json`` module, and stick the decoded result onto +``request.json`` . If it cannot be decoded, we raise 400. + +If the "force" argument is True (the default), the ``Tool`` clears the +``processors`` dict so that request entities of other ``Content-Types`` +aren't parsed at all. Since there's no entry for those invalid MIME +types, the ``default_proc`` method of ``cherrypy.request.body`` is +called. But this does nothing by default (usually to provide the page +handler an opportunity to handle it.) +But in our case, we want to raise 415, so we replace +``request.body.default_proc`` +with the error (``HTTPError`` instances, when called, raise themselves). + +If we were defining a custom processor, we can do so without making a ``Tool``. +Just add the config entry:: + + request.body.processors = {'application/json': json_processor} + +Note that you can only replace the ``processors`` dict wholesale this way, +not update the existing one. +""" + +try: + from io import DEFAULT_BUFFER_SIZE +except ImportError: + DEFAULT_BUFFER_SIZE = 8192 +import re +import sys +import tempfile +try: + from urllib import unquote_plus +except ImportError: + def unquote_plus(bs): + """Bytes version of urllib.parse.unquote_plus.""" + bs = bs.replace(b'+', b' ') + atoms = bs.split(b'%') + for i in range(1, len(atoms)): + item = atoms[i] + try: + pct = int(item[:2], 16) + atoms[i] = bytes([pct]) + item[2:] + except ValueError: + pass + return b''.join(atoms) + +import six +import cheroot.server + +import cherrypy +from cherrypy._cpcompat import ntou, unquote +from cherrypy.lib import httputil + + +# ------------------------------- Processors -------------------------------- # + +def process_urlencoded(entity): + """Read application/x-www-form-urlencoded data into entity.params.""" + qs = entity.fp.read() + for charset in entity.attempt_charsets: + try: + params = {} + for aparam in qs.split(b'&'): + for pair in aparam.split(b';'): + if not pair: + continue + + atoms = pair.split(b'=', 1) + if len(atoms) == 1: + atoms.append(b'') + + key = unquote_plus(atoms[0]).decode(charset) + value = unquote_plus(atoms[1]).decode(charset) + + if key in params: + if not isinstance(params[key], list): + params[key] = [params[key]] + params[key].append(value) + else: + params[key] = value + except UnicodeDecodeError: + pass + else: + entity.charset = charset + break + else: + raise cherrypy.HTTPError( + 400, 'The request entity could not be decoded. The following ' + 'charsets were attempted: %s' % repr(entity.attempt_charsets)) + + # Now that all values have been successfully parsed and decoded, + # apply them to the entity.params dict. + for key, value in params.items(): + if key in entity.params: + if not isinstance(entity.params[key], list): + entity.params[key] = [entity.params[key]] + entity.params[key].append(value) + else: + entity.params[key] = value + + +def process_multipart(entity): + """Read all multipart parts into entity.parts.""" + ib = '' + if 'boundary' in entity.content_type.params: + # http://tools.ietf.org/html/rfc2046#section-5.1.1 + # "The grammar for parameters on the Content-type field is such that it + # is often necessary to enclose the boundary parameter values in quotes + # on the Content-type line" + ib = entity.content_type.params['boundary'].strip('"') + + if not re.match('^[ -~]{0,200}[!-~]$', ib): + raise ValueError('Invalid boundary in multipart form: %r' % (ib,)) + + ib = ('--' + ib).encode('ascii') + + # Find the first marker + while True: + b = entity.readline() + if not b: + return + + b = b.strip() + if b == ib: + break + + # Read all parts + while True: + part = entity.part_class.from_fp(entity.fp, ib) + entity.parts.append(part) + part.process() + if part.fp.done: + break + + +def process_multipart_form_data(entity): + """Read all multipart/form-data parts into entity.parts or entity.params. + """ + process_multipart(entity) + + kept_parts = [] + for part in entity.parts: + if part.name is None: + kept_parts.append(part) + else: + if part.filename is None: + # It's a regular field + value = part.fullvalue() + else: + # It's a file upload. Retain the whole part so consumer code + # has access to its .file and .filename attributes. + value = part + + if part.name in entity.params: + if not isinstance(entity.params[part.name], list): + entity.params[part.name] = [entity.params[part.name]] + entity.params[part.name].append(value) + else: + entity.params[part.name] = value + + entity.parts = kept_parts + + +def _old_process_multipart(entity): + """The behavior of 3.2 and lower. Deprecated and will be changed in 3.3.""" + process_multipart(entity) + + params = entity.params + + for part in entity.parts: + if part.name is None: + key = ntou('parts') + else: + key = part.name + + if part.filename is None: + # It's a regular field + value = part.fullvalue() + else: + # It's a file upload. Retain the whole part so consumer code + # has access to its .file and .filename attributes. + value = part + + if key in params: + if not isinstance(params[key], list): + params[key] = [params[key]] + params[key].append(value) + else: + params[key] = value + + +# -------------------------------- Entities --------------------------------- # +class Entity(object): + + """An HTTP request body, or MIME multipart body. + + This class collects information about the HTTP request entity. When a + given entity is of MIME type "multipart", each part is parsed into its own + Entity instance, and the set of parts stored in + :attr:`entity.parts<cherrypy._cpreqbody.Entity.parts>`. + + Between the ``before_request_body`` and ``before_handler`` tools, CherryPy + tries to process the request body (if any) by calling + :func:`request.body.process<cherrypy._cpreqbody.RequestBody.process>`. + This uses the ``content_type`` of the Entity to look up a suitable + processor in + :attr:`Entity.processors<cherrypy._cpreqbody.Entity.processors>`, + a dict. + If a matching processor cannot be found for the complete Content-Type, + it tries again using the major type. For example, if a request with an + entity of type "image/jpeg" arrives, but no processor can be found for + that complete type, then one is sought for the major type "image". If a + processor is still not found, then the + :func:`default_proc<cherrypy._cpreqbody.Entity.default_proc>` method + of the Entity is called (which does nothing by default; you can + override this too). + + CherryPy includes processors for the "application/x-www-form-urlencoded" + type, the "multipart/form-data" type, and the "multipart" major type. + CherryPy 3.2 processes these types almost exactly as older versions. + Parts are passed as arguments to the page handler using their + ``Content-Disposition.name`` if given, otherwise in a generic "parts" + argument. Each such part is either a string, or the + :class:`Part<cherrypy._cpreqbody.Part>` itself if it's a file. (In this + case it will have ``file`` and ``filename`` attributes, or possibly a + ``value`` attribute). Each Part is itself a subclass of + Entity, and has its own ``process`` method and ``processors`` dict. + + There is a separate processor for the "multipart" major type which is more + flexible, and simply stores all multipart parts in + :attr:`request.body.parts<cherrypy._cpreqbody.Entity.parts>`. You can + enable it with:: + + cherrypy.request.body.processors['multipart'] = \ + _cpreqbody.process_multipart + + in an ``on_start_resource`` tool. + """ + + # http://tools.ietf.org/html/rfc2046#section-4.1.2: + # "The default character set, which must be assumed in the + # absence of a charset parameter, is US-ASCII." + # However, many browsers send data in utf-8 with no charset. + attempt_charsets = ['utf-8'] + r"""A list of strings, each of which should be a known encoding. + + When the Content-Type of the request body warrants it, each of the given + encodings will be tried in order. The first one to successfully decode the + entity without raising an error is stored as + :attr:`entity.charset<cherrypy._cpreqbody.Entity.charset>`. This defaults + to ``['utf-8']`` (plus 'ISO-8859-1' for "text/\*" types, as required by + `HTTP/1.1 + <http://www.w3.org/Protocols/rfc2616/rfc2616-sec3.html#sec3.7.1>`_), + but ``['us-ascii', 'utf-8']`` for multipart parts. + """ + + charset = None + """The successful decoding; see "attempt_charsets" above.""" + + content_type = None + """The value of the Content-Type request header. + + If the Entity is part of a multipart payload, this will be the Content-Type + given in the MIME headers for this part. + """ + + default_content_type = 'application/x-www-form-urlencoded' + """This defines a default ``Content-Type`` to use if no Content-Type header + is given. The empty string is used for RequestBody, which results in the + request body not being read or parsed at all. This is by design; a missing + ``Content-Type`` header in the HTTP request entity is an error at best, + and a security hole at worst. For multipart parts, however, the MIME spec + declares that a part with no Content-Type defaults to "text/plain" + (see :class:`Part<cherrypy._cpreqbody.Part>`). + """ + + filename = None + """The ``Content-Disposition.filename`` header, if available.""" + + fp = None + """The readable socket file object.""" + + headers = None + """A dict of request/multipart header names and values. + + This is a copy of the ``request.headers`` for the ``request.body``; + for multipart parts, it is the set of headers for that part. + """ + + length = None + """The value of the ``Content-Length`` header, if provided.""" + + name = None + """The "name" parameter of the ``Content-Disposition`` header, if any.""" + + params = None + """ + If the request Content-Type is 'application/x-www-form-urlencoded' or + multipart, this will be a dict of the params pulled from the entity + body; that is, it will be the portion of request.params that come + from the message body (sometimes called "POST params", although they + can be sent with various HTTP method verbs). This value is set between + the 'before_request_body' and 'before_handler' hooks (assuming that + process_request_body is True).""" + + processors = {'application/x-www-form-urlencoded': process_urlencoded, + 'multipart/form-data': process_multipart_form_data, + 'multipart': process_multipart, + } + """A dict of Content-Type names to processor methods.""" + + parts = None + """A list of Part instances if ``Content-Type`` is of major type + "multipart".""" + + part_class = None + """The class used for multipart parts. + + You can replace this with custom subclasses to alter the processing of + multipart parts. + """ + + def __init__(self, fp, headers, params=None, parts=None): + # Make an instance-specific copy of the class processors + # so Tools, etc. can replace them per-request. + self.processors = self.processors.copy() + + self.fp = fp + self.headers = headers + + if params is None: + params = {} + self.params = params + + if parts is None: + parts = [] + self.parts = parts + + # Content-Type + self.content_type = headers.elements('Content-Type') + if self.content_type: + self.content_type = self.content_type[0] + else: + self.content_type = httputil.HeaderElement.from_str( + self.default_content_type) + + # Copy the class 'attempt_charsets', prepending any Content-Type + # charset + dec = self.content_type.params.get('charset', None) + if dec: + self.attempt_charsets = [dec] + [c for c in self.attempt_charsets + if c != dec] + else: + self.attempt_charsets = self.attempt_charsets[:] + + # Length + self.length = None + clen = headers.get('Content-Length', None) + # If Transfer-Encoding is 'chunked', ignore any Content-Length. + if ( + clen is not None and + 'chunked' not in headers.get('Transfer-Encoding', '') + ): + try: + self.length = int(clen) + except ValueError: + pass + + # Content-Disposition + self.name = None + self.filename = None + disp = headers.elements('Content-Disposition') + if disp: + disp = disp[0] + if 'name' in disp.params: + self.name = disp.params['name'] + if self.name.startswith('"') and self.name.endswith('"'): + self.name = self.name[1:-1] + if 'filename' in disp.params: + self.filename = disp.params['filename'] + if ( + self.filename.startswith('"') and + self.filename.endswith('"') + ): + self.filename = self.filename[1:-1] + if 'filename*' in disp.params: + # @see https://tools.ietf.org/html/rfc5987 + encoding, lang, filename = disp.params['filename*'].split("'") + self.filename = unquote(str(filename), encoding) + + def read(self, size=None, fp_out=None): + return self.fp.read(size, fp_out) + + def readline(self, size=None): + return self.fp.readline(size) + + def readlines(self, sizehint=None): + return self.fp.readlines(sizehint) + + def __iter__(self): + return self + + def __next__(self): + line = self.readline() + if not line: + raise StopIteration + return line + + def next(self): + return self.__next__() + + def read_into_file(self, fp_out=None): + """Read the request body into fp_out (or make_file() if None). + + Return fp_out. + """ + if fp_out is None: + fp_out = self.make_file() + self.read(fp_out=fp_out) + return fp_out + + def make_file(self): + """Return a file-like object into which the request body will be read. + + By default, this will return a TemporaryFile. Override as needed. + See also :attr:`cherrypy._cpreqbody.Part.maxrambytes`.""" + return tempfile.TemporaryFile() + + def fullvalue(self): + """Return this entity as a string, whether stored in a file or not.""" + if self.file: + # It was stored in a tempfile. Read it. + self.file.seek(0) + value = self.file.read() + self.file.seek(0) + else: + value = self.value + value = self.decode_entity(value) + return value + + def decode_entity(self, value): + """Return a given byte encoded value as a string""" + for charset in self.attempt_charsets: + try: + value = value.decode(charset) + except UnicodeDecodeError: + pass + else: + self.charset = charset + return value + else: + raise cherrypy.HTTPError( + 400, + 'The request entity could not be decoded. The following ' + 'charsets were attempted: %s' % repr(self.attempt_charsets) + ) + + def process(self): + """Execute the best-match processor for the given media type.""" + proc = None + ct = self.content_type.value + try: + proc = self.processors[ct] + except KeyError: + toptype = ct.split('/', 1)[0] + try: + proc = self.processors[toptype] + except KeyError: + pass + if proc is None: + self.default_proc() + else: + proc(self) + + def default_proc(self): + """Called if a more-specific processor is not found for the + ``Content-Type``. + """ + # Leave the fp alone for someone else to read. This works fine + # for request.body, but the Part subclasses need to override this + # so they can move on to the next part. + pass + + +class Part(Entity): + + """A MIME part entity, part of a multipart entity.""" + + # "The default character set, which must be assumed in the absence of a + # charset parameter, is US-ASCII." + attempt_charsets = ['us-ascii', 'utf-8'] + r"""A list of strings, each of which should be a known encoding. + + When the Content-Type of the request body warrants it, each of the given + encodings will be tried in order. The first one to successfully decode the + entity without raising an error is stored as + :attr:`entity.charset<cherrypy._cpreqbody.Entity.charset>`. This defaults + to ``['utf-8']`` (plus 'ISO-8859-1' for "text/\*" types, as required by + `HTTP/1.1 + <http://www.w3.org/Protocols/rfc2616/rfc2616-sec3.html#sec3.7.1>`_), + but ``['us-ascii', 'utf-8']`` for multipart parts. + """ + + boundary = None + """The MIME multipart boundary.""" + + default_content_type = 'text/plain' + """This defines a default ``Content-Type`` to use if no Content-Type header + is given. The empty string is used for RequestBody, which results in the + request body not being read or parsed at all. This is by design; a missing + ``Content-Type`` header in the HTTP request entity is an error at best, + and a security hole at worst. For multipart parts, however (this class), + the MIME spec declares that a part with no Content-Type defaults to + "text/plain". + """ + + # This is the default in stdlib cgi. We may want to increase it. + maxrambytes = 1000 + """The threshold of bytes after which point the ``Part`` will store + its data in a file (generated by + :func:`make_file<cherrypy._cprequest.Entity.make_file>`) + instead of a string. Defaults to 1000, just like the :mod:`cgi` + module in Python's standard library. + """ + + def __init__(self, fp, headers, boundary): + Entity.__init__(self, fp, headers) + self.boundary = boundary + self.file = None + self.value = None + + @classmethod + def from_fp(cls, fp, boundary): + headers = cls.read_headers(fp) + return cls(fp, headers, boundary) + + @classmethod + def read_headers(cls, fp): + headers = httputil.HeaderMap() + while True: + line = fp.readline() + if not line: + # No more data--illegal end of headers + raise EOFError('Illegal end of headers.') + + if line == b'\r\n': + # Normal end of headers + break + if not line.endswith(b'\r\n'): + raise ValueError('MIME requires CRLF terminators: %r' % line) + + if line[0] in b' \t': + # It's a continuation line. + v = line.strip().decode('ISO-8859-1') + else: + k, v = line.split(b':', 1) + k = k.strip().decode('ISO-8859-1') + v = v.strip().decode('ISO-8859-1') + + existing = headers.get(k) + if existing: + v = ', '.join((existing, v)) + headers[k] = v + + return headers + + def read_lines_to_boundary(self, fp_out=None): + """Read bytes from self.fp and return or write them to a file. + + If the 'fp_out' argument is None (the default), all bytes read are + returned in a single byte string. + + If the 'fp_out' argument is not None, it must be a file-like + object that supports the 'write' method; all bytes read will be + written to the fp, and that fp is returned. + """ + endmarker = self.boundary + b'--' + delim = b'' + prev_lf = True + lines = [] + seen = 0 + while True: + line = self.fp.readline(1 << 16) + if not line: + raise EOFError('Illegal end of multipart body.') + if line.startswith(b'--') and prev_lf: + strippedline = line.strip() + if strippedline == self.boundary: + break + if strippedline == endmarker: + self.fp.finish() + break + + line = delim + line + + if line.endswith(b'\r\n'): + delim = b'\r\n' + line = line[:-2] + prev_lf = True + elif line.endswith(b'\n'): + delim = b'\n' + line = line[:-1] + prev_lf = True + else: + delim = b'' + prev_lf = False + + if fp_out is None: + lines.append(line) + seen += len(line) + if seen > self.maxrambytes: + fp_out = self.make_file() + for line in lines: + fp_out.write(line) + else: + fp_out.write(line) + + if fp_out is None: + result = b''.join(lines) + return result + else: + fp_out.seek(0) + return fp_out + + def default_proc(self): + """Called if a more-specific processor is not found for the + ``Content-Type``. + """ + if self.filename: + # Always read into a file if a .filename was given. + self.file = self.read_into_file() + else: + result = self.read_lines_to_boundary() + if isinstance(result, bytes): + self.value = result + else: + self.file = result + + def read_into_file(self, fp_out=None): + """Read the request body into fp_out (or make_file() if None). + + Return fp_out. + """ + if fp_out is None: + fp_out = self.make_file() + self.read_lines_to_boundary(fp_out=fp_out) + return fp_out + + +Entity.part_class = Part + +inf = float('inf') + + +class SizedReader: + + def __init__(self, fp, length, maxbytes, bufsize=DEFAULT_BUFFER_SIZE, + has_trailers=False): + # Wrap our fp in a buffer so peek() works + self.fp = fp + self.length = length + self.maxbytes = maxbytes + self.buffer = b'' + self.bufsize = bufsize + self.bytes_read = 0 + self.done = False + self.has_trailers = has_trailers + + def read(self, size=None, fp_out=None): + """Read bytes from the request body and return or write them to a file. + + A number of bytes less than or equal to the 'size' argument are read + off the socket. The actual number of bytes read are tracked in + self.bytes_read. The number may be smaller than 'size' when 1) the + client sends fewer bytes, 2) the 'Content-Length' request header + specifies fewer bytes than requested, or 3) the number of bytes read + exceeds self.maxbytes (in which case, 413 is raised). + + If the 'fp_out' argument is None (the default), all bytes read are + returned in a single byte string. + + If the 'fp_out' argument is not None, it must be a file-like + object that supports the 'write' method; all bytes read will be + written to the fp, and None is returned. + """ + + if self.length is None: + if size is None: + remaining = inf + else: + remaining = size + else: + remaining = self.length - self.bytes_read + if size and size < remaining: + remaining = size + if remaining == 0: + self.finish() + if fp_out is None: + return b'' + else: + return None + + chunks = [] + + # Read bytes from the buffer. + if self.buffer: + if remaining is inf: + data = self.buffer + self.buffer = b'' + else: + data = self.buffer[:remaining] + self.buffer = self.buffer[remaining:] + datalen = len(data) + remaining -= datalen + + # Check lengths. + self.bytes_read += datalen + if self.maxbytes and self.bytes_read > self.maxbytes: + raise cherrypy.HTTPError(413) + + # Store the data. + if fp_out is None: + chunks.append(data) + else: + fp_out.write(data) + + # Read bytes from the socket. + while remaining > 0: + chunksize = min(remaining, self.bufsize) + try: + data = self.fp.read(chunksize) + except Exception: + e = sys.exc_info()[1] + if e.__class__.__name__ == 'MaxSizeExceeded': + # Post data is too big + raise cherrypy.HTTPError( + 413, 'Maximum request length: %r' % e.args[1]) + else: + raise + if not data: + self.finish() + break + datalen = len(data) + remaining -= datalen + + # Check lengths. + self.bytes_read += datalen + if self.maxbytes and self.bytes_read > self.maxbytes: + raise cherrypy.HTTPError(413) + + # Store the data. + if fp_out is None: + chunks.append(data) + else: + fp_out.write(data) + + if fp_out is None: + return b''.join(chunks) + + def readline(self, size=None): + """Read a line from the request body and return it.""" + chunks = [] + while size is None or size > 0: + chunksize = self.bufsize + if size is not None and size < self.bufsize: + chunksize = size + data = self.read(chunksize) + if not data: + break + pos = data.find(b'\n') + 1 + if pos: + chunks.append(data[:pos]) + remainder = data[pos:] + self.buffer += remainder + self.bytes_read -= len(remainder) + break + else: + chunks.append(data) + return b''.join(chunks) + + def readlines(self, sizehint=None): + """Read lines from the request body and return them.""" + if self.length is not None: + if sizehint is None: + sizehint = self.length - self.bytes_read + else: + sizehint = min(sizehint, self.length - self.bytes_read) + + lines = [] + seen = 0 + while True: + line = self.readline() + if not line: + break + lines.append(line) + seen += len(line) + if seen >= sizehint: + break + return lines + + def finish(self): + self.done = True + if self.has_trailers and hasattr(self.fp, 'read_trailer_lines'): + self.trailers = {} + + try: + for line in self.fp.read_trailer_lines(): + if line[0] in b' \t': + # It's a continuation line. + v = line.strip() + else: + try: + k, v = line.split(b':', 1) + except ValueError: + raise ValueError('Illegal header line.') + k = k.strip().title() + v = v.strip() + + if k in cheroot.server.comma_separated_headers: + existing = self.trailers.get(k) + if existing: + v = b', '.join((existing, v)) + self.trailers[k] = v + except Exception: + e = sys.exc_info()[1] + if e.__class__.__name__ == 'MaxSizeExceeded': + # Post data is too big + raise cherrypy.HTTPError( + 413, 'Maximum request length: %r' % e.args[1]) + else: + raise + + +class RequestBody(Entity): + + """The entity of the HTTP request.""" + + bufsize = 8 * 1024 + """The buffer size used when reading the socket.""" + + # Don't parse the request body at all if the client didn't provide + # a Content-Type header. See + # https://github.com/cherrypy/cherrypy/issues/790 + default_content_type = '' + """This defines a default ``Content-Type`` to use if no Content-Type header + is given. The empty string is used for RequestBody, which results in the + request body not being read or parsed at all. This is by design; a missing + ``Content-Type`` header in the HTTP request entity is an error at best, + and a security hole at worst. For multipart parts, however, the MIME spec + declares that a part with no Content-Type defaults to "text/plain" + (see :class:`Part<cherrypy._cpreqbody.Part>`). + """ + + maxbytes = None + """Raise ``MaxSizeExceeded`` if more bytes than this are read from + the socket. + """ + + def __init__(self, fp, headers, params=None, request_params=None): + Entity.__init__(self, fp, headers, params) + + # http://www.w3.org/Protocols/rfc2616/rfc2616-sec3.html#sec3.7.1 + # When no explicit charset parameter is provided by the + # sender, media subtypes of the "text" type are defined + # to have a default charset value of "ISO-8859-1" when + # received via HTTP. + if self.content_type.value.startswith('text/'): + for c in ('ISO-8859-1', 'iso-8859-1', 'Latin-1', 'latin-1'): + if c in self.attempt_charsets: + break + else: + self.attempt_charsets.append('ISO-8859-1') + + # Temporary fix while deprecating passing .parts as .params. + self.processors['multipart'] = _old_process_multipart + + if request_params is None: + request_params = {} + self.request_params = request_params + + def process(self): + """Process the request entity based on its Content-Type.""" + # "The presence of a message-body in a request is signaled by the + # inclusion of a Content-Length or Transfer-Encoding header field in + # the request's message-headers." + # It is possible to send a POST request with no body, for example; + # however, app developers are responsible in that case to set + # cherrypy.request.process_body to False so this method isn't called. + h = cherrypy.serving.request.headers + if 'Content-Length' not in h and 'Transfer-Encoding' not in h: + raise cherrypy.HTTPError(411) + + self.fp = SizedReader(self.fp, self.length, + self.maxbytes, bufsize=self.bufsize, + has_trailers='Trailer' in h) + super(RequestBody, self).process() + + # Body params should also be a part of the request_params + # add them in here. + request_params = self.request_params + for key, value in self.params.items(): + # Python 2 only: keyword arguments must be byte strings (type + # 'str'). + if sys.version_info < (3, 0): + if isinstance(key, six.text_type): + key = key.encode('ISO-8859-1') + + if key in request_params: + if not isinstance(request_params[key], list): + request_params[key] = [request_params[key]] + request_params[key].append(value) + else: + request_params[key] = value diff --git a/libraries/cherrypy/_cprequest.py b/libraries/cherrypy/_cprequest.py new file mode 100644 index 00000000..3cc0c811 --- /dev/null +++ b/libraries/cherrypy/_cprequest.py @@ -0,0 +1,930 @@ +import sys +import time + +import uuid + +import six +from six.moves.http_cookies import SimpleCookie, CookieError + +from more_itertools import consume + +import cherrypy +from cherrypy._cpcompat import ntob +from cherrypy import _cpreqbody +from cherrypy._cperror import format_exc, bare_error +from cherrypy.lib import httputil, reprconf, encoding + + +class Hook(object): + + """A callback and its metadata: failsafe, priority, and kwargs.""" + + callback = None + """ + The bare callable that this Hook object is wrapping, which will + be called when the Hook is called.""" + + failsafe = False + """ + If True, the callback is guaranteed to run even if other callbacks + from the same call point raise exceptions.""" + + priority = 50 + """ + Defines the order of execution for a list of Hooks. Priority numbers + should be limited to the closed interval [0, 100], but values outside + this range are acceptable, as are fractional values.""" + + kwargs = {} + """ + A set of keyword arguments that will be passed to the + callable on each call.""" + + def __init__(self, callback, failsafe=None, priority=None, **kwargs): + self.callback = callback + + if failsafe is None: + failsafe = getattr(callback, 'failsafe', False) + self.failsafe = failsafe + + if priority is None: + priority = getattr(callback, 'priority', 50) + self.priority = priority + + self.kwargs = kwargs + + def __lt__(self, other): + """ + Hooks sort by priority, ascending, such that + hooks of lower priority are run first. + """ + return self.priority < other.priority + + def __call__(self): + """Run self.callback(**self.kwargs).""" + return self.callback(**self.kwargs) + + def __repr__(self): + cls = self.__class__ + return ('%s.%s(callback=%r, failsafe=%r, priority=%r, %s)' + % (cls.__module__, cls.__name__, self.callback, + self.failsafe, self.priority, + ', '.join(['%s=%r' % (k, v) + for k, v in self.kwargs.items()]))) + + +class HookMap(dict): + + """A map of call points to lists of callbacks (Hook objects).""" + + def __new__(cls, points=None): + d = dict.__new__(cls) + for p in points or []: + d[p] = [] + return d + + def __init__(self, *a, **kw): + pass + + def attach(self, point, callback, failsafe=None, priority=None, **kwargs): + """Append a new Hook made from the supplied arguments.""" + self[point].append(Hook(callback, failsafe, priority, **kwargs)) + + def run(self, point): + """Execute all registered Hooks (callbacks) for the given point.""" + exc = None + hooks = self[point] + hooks.sort() + for hook in hooks: + # Some hooks are guaranteed to run even if others at + # the same hookpoint fail. We will still log the failure, + # but proceed on to the next hook. The only way + # to stop all processing from one of these hooks is + # to raise SystemExit and stop the whole server. + if exc is None or hook.failsafe: + try: + hook() + except (KeyboardInterrupt, SystemExit): + raise + except (cherrypy.HTTPError, cherrypy.HTTPRedirect, + cherrypy.InternalRedirect): + exc = sys.exc_info()[1] + except Exception: + exc = sys.exc_info()[1] + cherrypy.log(traceback=True, severity=40) + if exc: + raise exc + + def __copy__(self): + newmap = self.__class__() + # We can't just use 'update' because we want copies of the + # mutable values (each is a list) as well. + for k, v in self.items(): + newmap[k] = v[:] + return newmap + copy = __copy__ + + def __repr__(self): + cls = self.__class__ + return '%s.%s(points=%r)' % ( + cls.__module__, + cls.__name__, + list(self) + ) + + +# Config namespace handlers + +def hooks_namespace(k, v): + """Attach bare hooks declared in config.""" + # Use split again to allow multiple hooks for a single + # hookpoint per path (e.g. "hooks.before_handler.1"). + # Little-known fact you only get from reading source ;) + hookpoint = k.split('.', 1)[0] + if isinstance(v, six.string_types): + v = cherrypy.lib.reprconf.attributes(v) + if not isinstance(v, Hook): + v = Hook(v) + cherrypy.serving.request.hooks[hookpoint].append(v) + + +def request_namespace(k, v): + """Attach request attributes declared in config.""" + # Provides config entries to set request.body attrs (like + # attempt_charsets). + if k[:5] == 'body.': + setattr(cherrypy.serving.request.body, k[5:], v) + else: + setattr(cherrypy.serving.request, k, v) + + +def response_namespace(k, v): + """Attach response attributes declared in config.""" + # Provides config entries to set default response headers + # http://cherrypy.org/ticket/889 + if k[:8] == 'headers.': + cherrypy.serving.response.headers[k.split('.', 1)[1]] = v + else: + setattr(cherrypy.serving.response, k, v) + + +def error_page_namespace(k, v): + """Attach error pages declared in config.""" + if k != 'default': + k = int(k) + cherrypy.serving.request.error_page[k] = v + + +hookpoints = ['on_start_resource', 'before_request_body', + 'before_handler', 'before_finalize', + 'on_end_resource', 'on_end_request', + 'before_error_response', 'after_error_response'] + + +class Request(object): + + """An HTTP request. + + This object represents the metadata of an HTTP request message; + that is, it contains attributes which describe the environment + in which the request URL, headers, and body were sent (if you + want tools to interpret the headers and body, those are elsewhere, + mostly in Tools). This 'metadata' consists of socket data, + transport characteristics, and the Request-Line. This object + also contains data regarding the configuration in effect for + the given URL, and the execution plan for generating a response. + """ + + prev = None + """ + The previous Request object (if any). This should be None + unless we are processing an InternalRedirect.""" + + # Conversation/connection attributes + local = httputil.Host('127.0.0.1', 80) + 'An httputil.Host(ip, port, hostname) object for the server socket.' + + remote = httputil.Host('127.0.0.1', 1111) + 'An httputil.Host(ip, port, hostname) object for the client socket.' + + scheme = 'http' + """ + The protocol used between client and server. In most cases, + this will be either 'http' or 'https'.""" + + server_protocol = 'HTTP/1.1' + """ + The HTTP version for which the HTTP server is at least + conditionally compliant.""" + + base = '' + """The (scheme://host) portion of the requested URL. + In some cases (e.g. when proxying via mod_rewrite), this may contain + path segments which cherrypy.url uses when constructing url's, but + which otherwise are ignored by CherryPy. Regardless, this value + MUST NOT end in a slash.""" + + # Request-Line attributes + request_line = '' + """ + The complete Request-Line received from the client. This is a + single string consisting of the request method, URI, and protocol + version (joined by spaces). Any final CRLF is removed.""" + + method = 'GET' + """ + Indicates the HTTP method to be performed on the resource identified + by the Request-URI. Common methods include GET, HEAD, POST, PUT, and + DELETE. CherryPy allows any extension method; however, various HTTP + servers and gateways may restrict the set of allowable methods. + CherryPy applications SHOULD restrict the set (on a per-URI basis).""" + + query_string = '' + """ + The query component of the Request-URI, a string of information to be + interpreted by the resource. The query portion of a URI follows the + path component, and is separated by a '?'. For example, the URI + 'http://www.cherrypy.org/wiki?a=3&b=4' has the query component, + 'a=3&b=4'.""" + + query_string_encoding = 'utf8' + """ + The encoding expected for query string arguments after % HEX HEX decoding). + If a query string is provided that cannot be decoded with this encoding, + 404 is raised (since technically it's a different URI). If you want + arbitrary encodings to not error, set this to 'Latin-1'; you can then + encode back to bytes and re-decode to whatever encoding you like later. + """ + + protocol = (1, 1) + """The HTTP protocol version corresponding to the set + of features which should be allowed in the response. If BOTH + the client's request message AND the server's level of HTTP + compliance is HTTP/1.1, this attribute will be the tuple (1, 1). + If either is 1.0, this attribute will be the tuple (1, 0). + Lower HTTP protocol versions are not explicitly supported.""" + + params = {} + """ + A dict which combines query string (GET) and request entity (POST) + variables. This is populated in two stages: GET params are added + before the 'on_start_resource' hook, and POST params are added + between the 'before_request_body' and 'before_handler' hooks.""" + + # Message attributes + header_list = [] + """ + A list of the HTTP request headers as (name, value) tuples. + In general, you should use request.headers (a dict) instead.""" + + headers = httputil.HeaderMap() + """ + A dict-like object containing the request headers. Keys are header + names (in Title-Case format); however, you may get and set them in + a case-insensitive manner. That is, headers['Content-Type'] and + headers['content-type'] refer to the same value. Values are header + values (decoded according to :rfc:`2047` if necessary). See also: + httputil.HeaderMap, httputil.HeaderElement.""" + + cookie = SimpleCookie() + """See help(Cookie).""" + + rfile = None + """ + If the request included an entity (body), it will be available + as a stream in this attribute. However, the rfile will normally + be read for you between the 'before_request_body' hook and the + 'before_handler' hook, and the resulting string is placed into + either request.params or the request.body attribute. + + You may disable the automatic consumption of the rfile by setting + request.process_request_body to False, either in config for the desired + path, or in an 'on_start_resource' or 'before_request_body' hook. + + WARNING: In almost every case, you should not attempt to read from the + rfile stream after CherryPy's automatic mechanism has read it. If you + turn off the automatic parsing of rfile, you should read exactly the + number of bytes specified in request.headers['Content-Length']. + Ignoring either of these warnings may result in a hung request thread + or in corruption of the next (pipelined) request. + """ + + process_request_body = True + """ + If True, the rfile (if any) is automatically read and parsed, + and the result placed into request.params or request.body.""" + + methods_with_bodies = ('POST', 'PUT', 'PATCH') + """ + A sequence of HTTP methods for which CherryPy will automatically + attempt to read a body from the rfile. If you are going to change + this property, modify it on the configuration (recommended) + or on the "hook point" `on_start_resource`. + """ + + body = None + """ + If the request Content-Type is 'application/x-www-form-urlencoded' + or multipart, this will be None. Otherwise, this will be an instance + of :class:`RequestBody<cherrypy._cpreqbody.RequestBody>` (which you + can .read()); this value is set between the 'before_request_body' and + 'before_handler' hooks (assuming that process_request_body is True).""" + + # Dispatch attributes + dispatch = cherrypy.dispatch.Dispatcher() + """ + The object which looks up the 'page handler' callable and collects + config for the current request based on the path_info, other + request attributes, and the application architecture. The core + calls the dispatcher as early as possible, passing it a 'path_info' + argument. + + The default dispatcher discovers the page handler by matching path_info + to a hierarchical arrangement of objects, starting at request.app.root. + See help(cherrypy.dispatch) for more information.""" + + script_name = '' + """ + The 'mount point' of the application which is handling this request. + + This attribute MUST NOT end in a slash. If the script_name refers to + the root of the URI, it MUST be an empty string (not "/"). + """ + + path_info = '/' + """ + The 'relative path' portion of the Request-URI. This is relative + to the script_name ('mount point') of the application which is + handling this request.""" + + login = None + """ + When authentication is used during the request processing this is + set to 'False' if it failed and to the 'username' value if it succeeded. + The default 'None' implies that no authentication happened.""" + + # Note that cherrypy.url uses "if request.app:" to determine whether + # the call is during a real HTTP request or not. So leave this None. + app = None + """The cherrypy.Application object which is handling this request.""" + + handler = None + """ + The function, method, or other callable which CherryPy will call to + produce the response. The discovery of the handler and the arguments + it will receive are determined by the request.dispatch object. + By default, the handler is discovered by walking a tree of objects + starting at request.app.root, and is then passed all HTTP params + (from the query string and POST body) as keyword arguments.""" + + toolmaps = {} + """ + A nested dict of all Toolboxes and Tools in effect for this request, + of the form: {Toolbox.namespace: {Tool.name: config dict}}.""" + + config = None + """ + A flat dict of all configuration entries which apply to the + current request. These entries are collected from global config, + application config (based on request.path_info), and from handler + config (exactly how is governed by the request.dispatch object in + effect for this request; by default, handler config can be attached + anywhere in the tree between request.app.root and the final handler, + and inherits downward).""" + + is_index = None + """ + This will be True if the current request is mapped to an 'index' + resource handler (also, a 'default' handler if path_info ends with + a slash). The value may be used to automatically redirect the + user-agent to a 'more canonical' URL which either adds or removes + the trailing slash. See cherrypy.tools.trailing_slash.""" + + hooks = HookMap(hookpoints) + """ + A HookMap (dict-like object) of the form: {hookpoint: [hook, ...]}. + Each key is a str naming the hook point, and each value is a list + of hooks which will be called at that hook point during this request. + The list of hooks is generally populated as early as possible (mostly + from Tools specified in config), but may be extended at any time. + See also: _cprequest.Hook, _cprequest.HookMap, and cherrypy.tools.""" + + error_response = cherrypy.HTTPError(500).set_response + """ + The no-arg callable which will handle unexpected, untrapped errors + during request processing. This is not used for expected exceptions + (like NotFound, HTTPError, or HTTPRedirect) which are raised in + response to expected conditions (those should be customized either + via request.error_page or by overriding HTTPError.set_response). + By default, error_response uses HTTPError(500) to return a generic + error response to the user-agent.""" + + error_page = {} + """ + A dict of {error code: response filename or callable} pairs. + + The error code must be an int representing a given HTTP error code, + or the string 'default', which will be used if no matching entry + is found for a given numeric code. + + If a filename is provided, the file should contain a Python string- + formatting template, and can expect by default to receive format + values with the mapping keys %(status)s, %(message)s, %(traceback)s, + and %(version)s. The set of format mappings can be extended by + overriding HTTPError.set_response. + + If a callable is provided, it will be called by default with keyword + arguments 'status', 'message', 'traceback', and 'version', as for a + string-formatting template. The callable must return a string or + iterable of strings which will be set to response.body. It may also + override headers or perform any other processing. + + If no entry is given for an error code, and no 'default' entry exists, + a default template will be used. + """ + + show_tracebacks = True + """ + If True, unexpected errors encountered during request processing will + include a traceback in the response body.""" + + show_mismatched_params = True + """ + If True, mismatched parameters encountered during PageHandler invocation + processing will be included in the response body.""" + + throws = (KeyboardInterrupt, SystemExit, cherrypy.InternalRedirect) + """The sequence of exceptions which Request.run does not trap.""" + + throw_errors = False + """ + If True, Request.run will not trap any errors (except HTTPRedirect and + HTTPError, which are more properly called 'exceptions', not errors).""" + + closed = False + """True once the close method has been called, False otherwise.""" + + stage = None + """ + A string containing the stage reached in the request-handling process. + This is useful when debugging a live server with hung requests.""" + + unique_id = None + """A lazy object generating and memorizing UUID4 on ``str()`` render.""" + + namespaces = reprconf.NamespaceSet( + **{'hooks': hooks_namespace, + 'request': request_namespace, + 'response': response_namespace, + 'error_page': error_page_namespace, + 'tools': cherrypy.tools, + }) + + def __init__(self, local_host, remote_host, scheme='http', + server_protocol='HTTP/1.1'): + """Populate a new Request object. + + local_host should be an httputil.Host object with the server info. + remote_host should be an httputil.Host object with the client info. + scheme should be a string, either "http" or "https". + """ + self.local = local_host + self.remote = remote_host + self.scheme = scheme + self.server_protocol = server_protocol + + self.closed = False + + # Put a *copy* of the class error_page into self. + self.error_page = self.error_page.copy() + + # Put a *copy* of the class namespaces into self. + self.namespaces = self.namespaces.copy() + + self.stage = None + + self.unique_id = LazyUUID4() + + def close(self): + """Run cleanup code. (Core)""" + if not self.closed: + self.closed = True + self.stage = 'on_end_request' + self.hooks.run('on_end_request') + self.stage = 'close' + + def run(self, method, path, query_string, req_protocol, headers, rfile): + r"""Process the Request. (Core) + + method, path, query_string, and req_protocol should be pulled directly + from the Request-Line (e.g. "GET /path?key=val HTTP/1.0"). + + path + This should be %XX-unquoted, but query_string should not be. + + When using Python 2, they both MUST be byte strings, + not unicode strings. + + When using Python 3, they both MUST be unicode strings, + not byte strings, and preferably not bytes \x00-\xFF + disguised as unicode. + + headers + A list of (name, value) tuples. + + rfile + A file-like object containing the HTTP request entity. + + When run() is done, the returned object should have 3 attributes: + + * status, e.g. "200 OK" + * header_list, a list of (name, value) tuples + * body, an iterable yielding strings + + Consumer code (HTTP servers) should then access these response + attributes to build the outbound stream. + + """ + response = cherrypy.serving.response + self.stage = 'run' + try: + self.error_response = cherrypy.HTTPError(500).set_response + + self.method = method + path = path or '/' + self.query_string = query_string or '' + self.params = {} + + # Compare request and server HTTP protocol versions, in case our + # server does not support the requested protocol. Limit our output + # to min(req, server). We want the following output: + # request server actual written supported response + # protocol protocol response protocol feature set + # a 1.0 1.0 1.0 1.0 + # b 1.0 1.1 1.1 1.0 + # c 1.1 1.0 1.0 1.0 + # d 1.1 1.1 1.1 1.1 + # Notice that, in (b), the response will be "HTTP/1.1" even though + # the client only understands 1.0. RFC 2616 10.5.6 says we should + # only return 505 if the _major_ version is different. + rp = int(req_protocol[5]), int(req_protocol[7]) + sp = int(self.server_protocol[5]), int(self.server_protocol[7]) + self.protocol = min(rp, sp) + response.headers.protocol = self.protocol + + # Rebuild first line of the request (e.g. "GET /path HTTP/1.0"). + url = path + if query_string: + url += '?' + query_string + self.request_line = '%s %s %s' % (method, url, req_protocol) + + self.header_list = list(headers) + self.headers = httputil.HeaderMap() + + self.rfile = rfile + self.body = None + + self.cookie = SimpleCookie() + self.handler = None + + # path_info should be the path from the + # app root (script_name) to the handler. + self.script_name = self.app.script_name + self.path_info = pi = path[len(self.script_name):] + + self.stage = 'respond' + self.respond(pi) + + except self.throws: + raise + except Exception: + if self.throw_errors: + raise + else: + # Failure in setup, error handler or finalize. Bypass them. + # Can't use handle_error because we may not have hooks yet. + cherrypy.log(traceback=True, severity=40) + if self.show_tracebacks: + body = format_exc() + else: + body = '' + r = bare_error(body) + response.output_status, response.header_list, response.body = r + + if self.method == 'HEAD': + # HEAD requests MUST NOT return a message-body in the response. + response.body = [] + + try: + cherrypy.log.access() + except Exception: + cherrypy.log.error(traceback=True) + + return response + + def respond(self, path_info): + """Generate a response for the resource at self.path_info. (Core)""" + try: + try: + try: + self._do_respond(path_info) + except (cherrypy.HTTPRedirect, cherrypy.HTTPError): + inst = sys.exc_info()[1] + inst.set_response() + self.stage = 'before_finalize (HTTPError)' + self.hooks.run('before_finalize') + cherrypy.serving.response.finalize() + finally: + self.stage = 'on_end_resource' + self.hooks.run('on_end_resource') + except self.throws: + raise + except Exception: + if self.throw_errors: + raise + self.handle_error() + + def _do_respond(self, path_info): + response = cherrypy.serving.response + + if self.app is None: + raise cherrypy.NotFound() + + self.hooks = self.__class__.hooks.copy() + self.toolmaps = {} + + # Get the 'Host' header, so we can HTTPRedirect properly. + self.stage = 'process_headers' + self.process_headers() + + self.stage = 'get_resource' + self.get_resource(path_info) + + self.body = _cpreqbody.RequestBody( + self.rfile, self.headers, request_params=self.params) + + self.namespaces(self.config) + + self.stage = 'on_start_resource' + self.hooks.run('on_start_resource') + + # Parse the querystring + self.stage = 'process_query_string' + self.process_query_string() + + # Process the body + if self.process_request_body: + if self.method not in self.methods_with_bodies: + self.process_request_body = False + self.stage = 'before_request_body' + self.hooks.run('before_request_body') + if self.process_request_body: + self.body.process() + + # Run the handler + self.stage = 'before_handler' + self.hooks.run('before_handler') + if self.handler: + self.stage = 'handler' + response.body = self.handler() + + # Finalize + self.stage = 'before_finalize' + self.hooks.run('before_finalize') + response.finalize() + + def process_query_string(self): + """Parse the query string into Python structures. (Core)""" + try: + p = httputil.parse_query_string( + self.query_string, encoding=self.query_string_encoding) + except UnicodeDecodeError: + raise cherrypy.HTTPError( + 404, 'The given query string could not be processed. Query ' + 'strings for this resource must be encoded with %r.' % + self.query_string_encoding) + + # Python 2 only: keyword arguments must be byte strings (type 'str'). + if six.PY2: + for key, value in p.items(): + if isinstance(key, six.text_type): + del p[key] + p[key.encode(self.query_string_encoding)] = value + self.params.update(p) + + def process_headers(self): + """Parse HTTP header data into Python structures. (Core)""" + # Process the headers into self.headers + headers = self.headers + for name, value in self.header_list: + # Call title() now (and use dict.__method__(headers)) + # so title doesn't have to be called twice. + name = name.title() + value = value.strip() + + headers[name] = httputil.decode_TEXT_maybe(value) + + # Some clients, notably Konquoror, supply multiple + # cookies on different lines with the same key. To + # handle this case, store all cookies in self.cookie. + if name == 'Cookie': + try: + self.cookie.load(value) + except CookieError as exc: + raise cherrypy.HTTPError(400, str(exc)) + + if not dict.__contains__(headers, 'Host'): + # All Internet-based HTTP/1.1 servers MUST respond with a 400 + # (Bad Request) status code to any HTTP/1.1 request message + # which lacks a Host header field. + if self.protocol >= (1, 1): + msg = "HTTP/1.1 requires a 'Host' request header." + raise cherrypy.HTTPError(400, msg) + host = dict.get(headers, 'Host') + if not host: + host = self.local.name or self.local.ip + self.base = '%s://%s' % (self.scheme, host) + + def get_resource(self, path): + """Call a dispatcher (which sets self.handler and .config). (Core)""" + # First, see if there is a custom dispatch at this URI. Custom + # dispatchers can only be specified in app.config, not in _cp_config + # (since custom dispatchers may not even have an app.root). + dispatch = self.app.find_config( + path, 'request.dispatch', self.dispatch) + + # dispatch() should set self.handler and self.config + dispatch(path) + + def handle_error(self): + """Handle the last unanticipated exception. (Core)""" + try: + self.hooks.run('before_error_response') + if self.error_response: + self.error_response() + self.hooks.run('after_error_response') + cherrypy.serving.response.finalize() + except cherrypy.HTTPRedirect: + inst = sys.exc_info()[1] + inst.set_response() + cherrypy.serving.response.finalize() + + +class ResponseBody(object): + + """The body of the HTTP response (the response entity).""" + + unicode_err = ('Page handlers MUST return bytes. Use tools.encode ' + 'if you wish to return unicode.') + + def __get__(self, obj, objclass=None): + if obj is None: + # When calling on the class instead of an instance... + return self + else: + return obj._body + + def __set__(self, obj, value): + # Convert the given value to an iterable object. + if isinstance(value, six.text_type): + raise ValueError(self.unicode_err) + elif isinstance(value, list): + # every item in a list must be bytes... + if any(isinstance(item, six.text_type) for item in value): + raise ValueError(self.unicode_err) + + obj._body = encoding.prepare_iter(value) + + +class Response(object): + + """An HTTP Response, including status, headers, and body.""" + + status = '' + """The HTTP Status-Code and Reason-Phrase.""" + + header_list = [] + """ + A list of the HTTP response headers as (name, value) tuples. + In general, you should use response.headers (a dict) instead. This + attribute is generated from response.headers and is not valid until + after the finalize phase.""" + + headers = httputil.HeaderMap() + """ + A dict-like object containing the response headers. Keys are header + names (in Title-Case format); however, you may get and set them in + a case-insensitive manner. That is, headers['Content-Type'] and + headers['content-type'] refer to the same value. Values are header + values (decoded according to :rfc:`2047` if necessary). + + .. seealso:: classes :class:`HeaderMap`, :class:`HeaderElement` + """ + + cookie = SimpleCookie() + """See help(Cookie).""" + + body = ResponseBody() + """The body (entity) of the HTTP response.""" + + time = None + """The value of time.time() when created. Use in HTTP dates.""" + + stream = False + """If False, buffer the response body.""" + + def __init__(self): + self.status = None + self.header_list = None + self._body = [] + self.time = time.time() + + self.headers = httputil.HeaderMap() + # Since we know all our keys are titled strings, we can + # bypass HeaderMap.update and get a big speed boost. + dict.update(self.headers, { + 'Content-Type': 'text/html', + 'Server': 'CherryPy/' + cherrypy.__version__, + 'Date': httputil.HTTPDate(self.time), + }) + self.cookie = SimpleCookie() + + def collapse_body(self): + """Collapse self.body to a single string; replace it and return it.""" + new_body = b''.join(self.body) + self.body = new_body + return new_body + + def _flush_body(self): + """ + Discard self.body but consume any generator such that + any finalization can occur, such as is required by + caching.tee_output(). + """ + consume(iter(self.body)) + + def finalize(self): + """Transform headers (and cookies) into self.header_list. (Core)""" + try: + code, reason, _ = httputil.valid_status(self.status) + except ValueError: + raise cherrypy.HTTPError(500, sys.exc_info()[1].args[0]) + + headers = self.headers + + self.status = '%s %s' % (code, reason) + self.output_status = ntob(str(code), 'ascii') + \ + b' ' + headers.encode(reason) + + if self.stream: + # The upshot: wsgiserver will chunk the response if + # you pop Content-Length (or set it explicitly to None). + # Note that lib.static sets C-L to the file's st_size. + if dict.get(headers, 'Content-Length') is None: + dict.pop(headers, 'Content-Length', None) + elif code < 200 or code in (204, 205, 304): + # "All 1xx (informational), 204 (no content), + # and 304 (not modified) responses MUST NOT + # include a message-body." + dict.pop(headers, 'Content-Length', None) + self._flush_body() + self.body = b'' + else: + # Responses which are not streamed should have a Content-Length, + # but allow user code to set Content-Length if desired. + if dict.get(headers, 'Content-Length') is None: + content = self.collapse_body() + dict.__setitem__(headers, 'Content-Length', len(content)) + + # Transform our header dict into a list of tuples. + self.header_list = h = headers.output() + + cookie = self.cookie.output() + if cookie: + for line in cookie.split('\r\n'): + name, value = line.split(': ', 1) + if isinstance(name, six.text_type): + name = name.encode('ISO-8859-1') + if isinstance(value, six.text_type): + value = headers.encode(value) + h.append((name, value)) + + +class LazyUUID4(object): + def __str__(self): + """Return UUID4 and keep it for future calls.""" + return str(self.uuid4) + + @property + def uuid4(self): + """Provide unique id on per-request basis using UUID4. + + It's evaluated lazily on render. + """ + try: + self._uuid4 + except AttributeError: + # evaluate on first access + self._uuid4 = uuid.uuid4() + + return self._uuid4 diff --git a/libraries/cherrypy/_cpserver.py b/libraries/cherrypy/_cpserver.py new file mode 100644 index 00000000..0f60e2c8 --- /dev/null +++ b/libraries/cherrypy/_cpserver.py @@ -0,0 +1,252 @@ +"""Manage HTTP servers with CherryPy.""" + +import six + +import cherrypy +from cherrypy.lib.reprconf import attributes +from cherrypy._cpcompat import text_or_bytes +from cherrypy.process.servers import ServerAdapter + + +__all__ = ('Server', ) + + +class Server(ServerAdapter): + """An adapter for an HTTP server. + + You can set attributes (like socket_host and socket_port) + on *this* object (which is probably cherrypy.server), and call + quickstart. For example:: + + cherrypy.server.socket_port = 80 + cherrypy.quickstart() + """ + + socket_port = 8080 + """The TCP port on which to listen for connections.""" + + _socket_host = '127.0.0.1' + + @property + def socket_host(self): # noqa: D401; irrelevant for properties + """The hostname or IP address on which to listen for connections. + + Host values may be any IPv4 or IPv6 address, or any valid hostname. + The string 'localhost' is a synonym for '127.0.0.1' (or '::1', if + your hosts file prefers IPv6). The string '0.0.0.0' is a special + IPv4 entry meaning "any active interface" (INADDR_ANY), and '::' + is the similar IN6ADDR_ANY for IPv6. The empty string or None are + not allowed. + """ + return self._socket_host + + @socket_host.setter + def socket_host(self, value): + if value == '': + raise ValueError("The empty string ('') is not an allowed value. " + "Use '0.0.0.0' instead to listen on all active " + 'interfaces (INADDR_ANY).') + self._socket_host = value + + socket_file = None + """If given, the name of the UNIX socket to use instead of TCP/IP. + + When this option is not None, the `socket_host` and `socket_port` options + are ignored.""" + + socket_queue_size = 5 + """The 'backlog' argument to socket.listen(); specifies the maximum number + of queued connections (default 5).""" + + socket_timeout = 10 + """The timeout in seconds for accepted connections (default 10).""" + + accepted_queue_size = -1 + """The maximum number of requests which will be queued up before + the server refuses to accept it (default -1, meaning no limit).""" + + accepted_queue_timeout = 10 + """The timeout in seconds for attempting to add a request to the + queue when the queue is full (default 10).""" + + shutdown_timeout = 5 + """The time to wait for HTTP worker threads to clean up.""" + + protocol_version = 'HTTP/1.1' + """The version string to write in the Status-Line of all HTTP responses, + for example, "HTTP/1.1" (the default). Depending on the HTTP server used, + this should also limit the supported features used in the response.""" + + thread_pool = 10 + """The number of worker threads to start up in the pool.""" + + thread_pool_max = -1 + """The maximum size of the worker-thread pool. Use -1 to indicate no limit. + """ + + max_request_header_size = 500 * 1024 + """The maximum number of bytes allowable in the request headers. + If exceeded, the HTTP server should return "413 Request Entity Too Large". + """ + + max_request_body_size = 100 * 1024 * 1024 + """The maximum number of bytes allowable in the request body. If exceeded, + the HTTP server should return "413 Request Entity Too Large".""" + + instance = None + """If not None, this should be an HTTP server instance (such as + cheroot.wsgi.Server) which cherrypy.server will control. + Use this when you need + more control over object instantiation than is available in the various + configuration options.""" + + ssl_context = None + """When using PyOpenSSL, an instance of SSL.Context.""" + + ssl_certificate = None + """The filename of the SSL certificate to use.""" + + ssl_certificate_chain = None + """When using PyOpenSSL, the certificate chain to pass to + Context.load_verify_locations.""" + + ssl_private_key = None + """The filename of the private key to use with SSL.""" + + ssl_ciphers = None + """The ciphers list of SSL.""" + + if six.PY3: + ssl_module = 'builtin' + """The name of a registered SSL adaptation module to use with + the builtin WSGI server. Builtin options are: 'builtin' (to + use the SSL library built into recent versions of Python). + You may also register your own classes in the + cheroot.server.ssl_adapters dict.""" + else: + ssl_module = 'pyopenssl' + """The name of a registered SSL adaptation module to use with the + builtin WSGI server. Builtin options are 'builtin' (to use the SSL + library built into recent versions of Python) and 'pyopenssl' (to + use the PyOpenSSL project, which you must install separately). You + may also register your own classes in the cheroot.server.ssl_adapters + dict.""" + + statistics = False + """Turns statistics-gathering on or off for aware HTTP servers.""" + + nodelay = True + """If True (the default since 3.1), sets the TCP_NODELAY socket option.""" + + wsgi_version = (1, 0) + """The WSGI version tuple to use with the builtin WSGI server. + The provided options are (1, 0) [which includes support for PEP 3333, + which declares it covers WSGI version 1.0.1 but still mandates the + wsgi.version (1, 0)] and ('u', 0), an experimental unicode version. + You may create and register your own experimental versions of the WSGI + protocol by adding custom classes to the cheroot.server.wsgi_gateways dict. + """ + + peercreds = False + """If True, peer cred lookup for UNIX domain socket will put to WSGI env. + + This information will then be available through WSGI env vars: + * X_REMOTE_PID + * X_REMOTE_UID + * X_REMOTE_GID + """ + + peercreds_resolve = False + """If True, username/group will be looked up in the OS from peercreds. + + This information will then be available through WSGI env vars: + * REMOTE_USER + * X_REMOTE_USER + * X_REMOTE_GROUP + """ + + def __init__(self): + """Initialize Server instance.""" + self.bus = cherrypy.engine + self.httpserver = None + self.interrupt = None + self.running = False + + def httpserver_from_self(self, httpserver=None): + """Return a (httpserver, bind_addr) pair based on self attributes.""" + if httpserver is None: + httpserver = self.instance + if httpserver is None: + from cherrypy import _cpwsgi_server + httpserver = _cpwsgi_server.CPWSGIServer(self) + if isinstance(httpserver, text_or_bytes): + # Is anyone using this? Can I add an arg? + httpserver = attributes(httpserver)(self) + return httpserver, self.bind_addr + + def start(self): + """Start the HTTP server.""" + if not self.httpserver: + self.httpserver, self.bind_addr = self.httpserver_from_self() + super(Server, self).start() + start.priority = 75 + + @property + def bind_addr(self): + """Return bind address. + + A (host, port) tuple for TCP sockets or a str for Unix domain sockts. + """ + if self.socket_file: + return self.socket_file + if self.socket_host is None and self.socket_port is None: + return None + return (self.socket_host, self.socket_port) + + @bind_addr.setter + def bind_addr(self, value): + if value is None: + self.socket_file = None + self.socket_host = None + self.socket_port = None + elif isinstance(value, text_or_bytes): + self.socket_file = value + self.socket_host = None + self.socket_port = None + else: + try: + self.socket_host, self.socket_port = value + self.socket_file = None + except ValueError: + raise ValueError('bind_addr must be a (host, port) tuple ' + '(for TCP sockets) or a string (for Unix ' + 'domain sockets), not %r' % value) + + def base(self): + """Return the base for this server. + + e.i. scheme://host[:port] or sock file + """ + if self.socket_file: + return self.socket_file + + host = self.socket_host + if host in ('0.0.0.0', '::'): + # 0.0.0.0 is INADDR_ANY and :: is IN6ADDR_ANY. + # Look up the host name, which should be the + # safest thing to spit out in a URL. + import socket + host = socket.gethostname() + + port = self.socket_port + + if self.ssl_certificate: + scheme = 'https' + if port != 443: + host += ':%s' % port + else: + scheme = 'http' + if port != 80: + host += ':%s' % port + + return '%s://%s' % (scheme, host) diff --git a/libraries/cherrypy/_cptools.py b/libraries/cherrypy/_cptools.py new file mode 100644 index 00000000..57460285 --- /dev/null +++ b/libraries/cherrypy/_cptools.py @@ -0,0 +1,509 @@ +"""CherryPy tools. A "tool" is any helper, adapted to CP. + +Tools are usually designed to be used in a variety of ways (although some +may only offer one if they choose): + + Library calls + All tools are callables that can be used wherever needed. + The arguments are straightforward and should be detailed within the + docstring. + + Function decorators + All tools, when called, may be used as decorators which configure + individual CherryPy page handlers (methods on the CherryPy tree). + That is, "@tools.anytool()" should "turn on" the tool via the + decorated function's _cp_config attribute. + + CherryPy config + If a tool exposes a "_setup" callable, it will be called + once per Request (if the feature is "turned on" via config). + +Tools may be implemented as any object with a namespace. The builtins +are generally either modules or instances of the tools.Tool class. +""" + +import six + +import cherrypy +from cherrypy._helper import expose + +from cherrypy.lib import cptools, encoding, static, jsontools +from cherrypy.lib import sessions as _sessions, xmlrpcutil as _xmlrpc +from cherrypy.lib import caching as _caching +from cherrypy.lib import auth_basic, auth_digest + + +def _getargs(func): + """Return the names of all static arguments to the given function.""" + # Use this instead of importing inspect for less mem overhead. + import types + if six.PY3: + if isinstance(func, types.MethodType): + func = func.__func__ + co = func.__code__ + else: + if isinstance(func, types.MethodType): + func = func.im_func + co = func.func_code + return co.co_varnames[:co.co_argcount] + + +_attr_error = ( + 'CherryPy Tools cannot be turned on directly. Instead, turn them ' + 'on via config, or use them as decorators on your page handlers.' +) + + +class Tool(object): + + """A registered function for use with CherryPy request-processing hooks. + + help(tool.callable) should give you more information about this Tool. + """ + + namespace = 'tools' + + def __init__(self, point, callable, name=None, priority=50): + self._point = point + self.callable = callable + self._name = name + self._priority = priority + self.__doc__ = self.callable.__doc__ + self._setargs() + + @property + def on(self): + raise AttributeError(_attr_error) + + @on.setter + def on(self, value): + raise AttributeError(_attr_error) + + def _setargs(self): + """Copy func parameter names to obj attributes.""" + try: + for arg in _getargs(self.callable): + setattr(self, arg, None) + except (TypeError, AttributeError): + if hasattr(self.callable, '__call__'): + for arg in _getargs(self.callable.__call__): + setattr(self, arg, None) + # IronPython 1.0 raises NotImplementedError because + # inspect.getargspec tries to access Python bytecode + # in co_code attribute. + except NotImplementedError: + pass + # IronPython 1B1 may raise IndexError in some cases, + # but if we trap it here it doesn't prevent CP from working. + except IndexError: + pass + + def _merged_args(self, d=None): + """Return a dict of configuration entries for this Tool.""" + if d: + conf = d.copy() + else: + conf = {} + + tm = cherrypy.serving.request.toolmaps[self.namespace] + if self._name in tm: + conf.update(tm[self._name]) + + if 'on' in conf: + del conf['on'] + + return conf + + def __call__(self, *args, **kwargs): + """Compile-time decorator (turn on the tool in config). + + For example:: + + @expose + @tools.proxy() + def whats_my_base(self): + return cherrypy.request.base + """ + if args: + raise TypeError('The %r Tool does not accept positional ' + 'arguments; you must use keyword arguments.' + % self._name) + + def tool_decorator(f): + if not hasattr(f, '_cp_config'): + f._cp_config = {} + subspace = self.namespace + '.' + self._name + '.' + f._cp_config[subspace + 'on'] = True + for k, v in kwargs.items(): + f._cp_config[subspace + k] = v + return f + return tool_decorator + + def _setup(self): + """Hook this tool into cherrypy.request. + + The standard CherryPy request object will automatically call this + method when the tool is "turned on" in config. + """ + conf = self._merged_args() + p = conf.pop('priority', None) + if p is None: + p = getattr(self.callable, 'priority', self._priority) + cherrypy.serving.request.hooks.attach(self._point, self.callable, + priority=p, **conf) + + +class HandlerTool(Tool): + + """Tool which is called 'before main', that may skip normal handlers. + + If the tool successfully handles the request (by setting response.body), + if should return True. This will cause CherryPy to skip any 'normal' page + handler. If the tool did not handle the request, it should return False + to tell CherryPy to continue on and call the normal page handler. If the + tool is declared AS a page handler (see the 'handler' method), returning + False will raise NotFound. + """ + + def __init__(self, callable, name=None): + Tool.__init__(self, 'before_handler', callable, name) + + def handler(self, *args, **kwargs): + """Use this tool as a CherryPy page handler. + + For example:: + + class Root: + nav = tools.staticdir.handler(section="/nav", dir="nav", + root=absDir) + """ + @expose + def handle_func(*a, **kw): + handled = self.callable(*args, **self._merged_args(kwargs)) + if not handled: + raise cherrypy.NotFound() + return cherrypy.serving.response.body + return handle_func + + def _wrapper(self, **kwargs): + if self.callable(**kwargs): + cherrypy.serving.request.handler = None + + def _setup(self): + """Hook this tool into cherrypy.request. + + The standard CherryPy request object will automatically call this + method when the tool is "turned on" in config. + """ + conf = self._merged_args() + p = conf.pop('priority', None) + if p is None: + p = getattr(self.callable, 'priority', self._priority) + cherrypy.serving.request.hooks.attach(self._point, self._wrapper, + priority=p, **conf) + + +class HandlerWrapperTool(Tool): + + """Tool which wraps request.handler in a provided wrapper function. + + The 'newhandler' arg must be a handler wrapper function that takes a + 'next_handler' argument, plus ``*args`` and ``**kwargs``. Like all + page handler + functions, it must return an iterable for use as cherrypy.response.body. + + For example, to allow your 'inner' page handlers to return dicts + which then get interpolated into a template:: + + def interpolator(next_handler, *args, **kwargs): + filename = cherrypy.request.config.get('template') + cherrypy.response.template = env.get_template(filename) + response_dict = next_handler(*args, **kwargs) + return cherrypy.response.template.render(**response_dict) + cherrypy.tools.jinja = HandlerWrapperTool(interpolator) + """ + + def __init__(self, newhandler, point='before_handler', name=None, + priority=50): + self.newhandler = newhandler + self._point = point + self._name = name + self._priority = priority + + def callable(self, *args, **kwargs): + innerfunc = cherrypy.serving.request.handler + + def wrap(*args, **kwargs): + return self.newhandler(innerfunc, *args, **kwargs) + cherrypy.serving.request.handler = wrap + + +class ErrorTool(Tool): + + """Tool which is used to replace the default request.error_response.""" + + def __init__(self, callable, name=None): + Tool.__init__(self, None, callable, name) + + def _wrapper(self): + self.callable(**self._merged_args()) + + def _setup(self): + """Hook this tool into cherrypy.request. + + The standard CherryPy request object will automatically call this + method when the tool is "turned on" in config. + """ + cherrypy.serving.request.error_response = self._wrapper + + +# Builtin tools # + + +class SessionTool(Tool): + + """Session Tool for CherryPy. + + sessions.locking + When 'implicit' (the default), the session will be locked for you, + just before running the page handler. + + When 'early', the session will be locked before reading the request + body. This is off by default for safety reasons; for example, + a large upload would block the session, denying an AJAX + progress meter + (`issue <https://github.com/cherrypy/cherrypy/issues/630>`_). + + When 'explicit' (or any other value), you need to call + cherrypy.session.acquire_lock() yourself before using + session data. + """ + + def __init__(self): + # _sessions.init must be bound after headers are read + Tool.__init__(self, 'before_request_body', _sessions.init) + + def _lock_session(self): + cherrypy.serving.session.acquire_lock() + + def _setup(self): + """Hook this tool into cherrypy.request. + + The standard CherryPy request object will automatically call this + method when the tool is "turned on" in config. + """ + hooks = cherrypy.serving.request.hooks + + conf = self._merged_args() + + p = conf.pop('priority', None) + if p is None: + p = getattr(self.callable, 'priority', self._priority) + + hooks.attach(self._point, self.callable, priority=p, **conf) + + locking = conf.pop('locking', 'implicit') + if locking == 'implicit': + hooks.attach('before_handler', self._lock_session) + elif locking == 'early': + # Lock before the request body (but after _sessions.init runs!) + hooks.attach('before_request_body', self._lock_session, + priority=60) + else: + # Don't lock + pass + + hooks.attach('before_finalize', _sessions.save) + hooks.attach('on_end_request', _sessions.close) + + def regenerate(self): + """Drop the current session and make a new one (with a new id).""" + sess = cherrypy.serving.session + sess.regenerate() + + # Grab cookie-relevant tool args + relevant = 'path', 'path_header', 'name', 'timeout', 'domain', 'secure' + conf = dict( + (k, v) + for k, v in self._merged_args().items() + if k in relevant + ) + _sessions.set_response_cookie(**conf) + + +class XMLRPCController(object): + + """A Controller (page handler collection) for XML-RPC. + + To use it, have your controllers subclass this base class (it will + turn on the tool for you). + + You can also supply the following optional config entries:: + + tools.xmlrpc.encoding: 'utf-8' + tools.xmlrpc.allow_none: 0 + + XML-RPC is a rather discontinuous layer over HTTP; dispatching to the + appropriate handler must first be performed according to the URL, and + then a second dispatch step must take place according to the RPC method + specified in the request body. It also allows a superfluous "/RPC2" + prefix in the URL, supplies its own handler args in the body, and + requires a 200 OK "Fault" response instead of 404 when the desired + method is not found. + + Therefore, XML-RPC cannot be implemented for CherryPy via a Tool alone. + This Controller acts as the dispatch target for the first half (based + on the URL); it then reads the RPC method from the request body and + does its own second dispatch step based on that method. It also reads + body params, and returns a Fault on error. + + The XMLRPCDispatcher strips any /RPC2 prefix; if you aren't using /RPC2 + in your URL's, you can safely skip turning on the XMLRPCDispatcher. + Otherwise, you need to use declare it in config:: + + request.dispatch: cherrypy.dispatch.XMLRPCDispatcher() + """ + + # Note we're hard-coding this into the 'tools' namespace. We could do + # a huge amount of work to make it relocatable, but the only reason why + # would be if someone actually disabled the default_toolbox. Meh. + _cp_config = {'tools.xmlrpc.on': True} + + @expose + def default(self, *vpath, **params): + rpcparams, rpcmethod = _xmlrpc.process_body() + + subhandler = self + for attr in str(rpcmethod).split('.'): + subhandler = getattr(subhandler, attr, None) + + if subhandler and getattr(subhandler, 'exposed', False): + body = subhandler(*(vpath + rpcparams), **params) + + else: + # https://github.com/cherrypy/cherrypy/issues/533 + # if a method is not found, an xmlrpclib.Fault should be returned + # raising an exception here will do that; see + # cherrypy.lib.xmlrpcutil.on_error + raise Exception('method "%s" is not supported' % attr) + + conf = cherrypy.serving.request.toolmaps['tools'].get('xmlrpc', {}) + _xmlrpc.respond(body, + conf.get('encoding', 'utf-8'), + conf.get('allow_none', 0)) + return cherrypy.serving.response.body + + +class SessionAuthTool(HandlerTool): + pass + + +class CachingTool(Tool): + + """Caching Tool for CherryPy.""" + + def _wrapper(self, **kwargs): + request = cherrypy.serving.request + if _caching.get(**kwargs): + request.handler = None + else: + if request.cacheable: + # Note the devious technique here of adding hooks on the fly + request.hooks.attach('before_finalize', _caching.tee_output, + priority=100) + _wrapper.priority = 90 + + def _setup(self): + """Hook caching into cherrypy.request.""" + conf = self._merged_args() + + p = conf.pop('priority', None) + cherrypy.serving.request.hooks.attach('before_handler', self._wrapper, + priority=p, **conf) + + +class Toolbox(object): + + """A collection of Tools. + + This object also functions as a config namespace handler for itself. + Custom toolboxes should be added to each Application's toolboxes dict. + """ + + def __init__(self, namespace): + self.namespace = namespace + + def __setattr__(self, name, value): + # If the Tool._name is None, supply it from the attribute name. + if isinstance(value, Tool): + if value._name is None: + value._name = name + value.namespace = self.namespace + object.__setattr__(self, name, value) + + def __enter__(self): + """Populate request.toolmaps from tools specified in config.""" + cherrypy.serving.request.toolmaps[self.namespace] = map = {} + + def populate(k, v): + toolname, arg = k.split('.', 1) + bucket = map.setdefault(toolname, {}) + bucket[arg] = v + return populate + + def __exit__(self, exc_type, exc_val, exc_tb): + """Run tool._setup() for each tool in our toolmap.""" + map = cherrypy.serving.request.toolmaps.get(self.namespace) + if map: + for name, settings in map.items(): + if settings.get('on', False): + tool = getattr(self, name) + tool._setup() + + def register(self, point, **kwargs): + """ + Return a decorator which registers the function + at the given hook point. + """ + def decorator(func): + attr_name = kwargs.get('name', func.__name__) + tool = Tool(point, func, **kwargs) + setattr(self, attr_name, tool) + return func + return decorator + + +default_toolbox = _d = Toolbox('tools') +_d.session_auth = SessionAuthTool(cptools.session_auth) +_d.allow = Tool('on_start_resource', cptools.allow) +_d.proxy = Tool('before_request_body', cptools.proxy, priority=30) +_d.response_headers = Tool('on_start_resource', cptools.response_headers) +_d.log_tracebacks = Tool('before_error_response', cptools.log_traceback) +_d.log_headers = Tool('before_error_response', cptools.log_request_headers) +_d.log_hooks = Tool('on_end_request', cptools.log_hooks, priority=100) +_d.err_redirect = ErrorTool(cptools.redirect) +_d.etags = Tool('before_finalize', cptools.validate_etags, priority=75) +_d.decode = Tool('before_request_body', encoding.decode) +# the order of encoding, gzip, caching is important +_d.encode = Tool('before_handler', encoding.ResponseEncoder, priority=70) +_d.gzip = Tool('before_finalize', encoding.gzip, priority=80) +_d.staticdir = HandlerTool(static.staticdir) +_d.staticfile = HandlerTool(static.staticfile) +_d.sessions = SessionTool() +_d.xmlrpc = ErrorTool(_xmlrpc.on_error) +_d.caching = CachingTool('before_handler', _caching.get, 'caching') +_d.expires = Tool('before_finalize', _caching.expires) +_d.ignore_headers = Tool('before_request_body', cptools.ignore_headers) +_d.referer = Tool('before_request_body', cptools.referer) +_d.trailing_slash = Tool('before_handler', cptools.trailing_slash, priority=60) +_d.flatten = Tool('before_finalize', cptools.flatten) +_d.accept = Tool('on_start_resource', cptools.accept) +_d.redirect = Tool('on_start_resource', cptools.redirect) +_d.autovary = Tool('on_start_resource', cptools.autovary, priority=0) +_d.json_in = Tool('before_request_body', jsontools.json_in, priority=30) +_d.json_out = Tool('before_handler', jsontools.json_out, priority=30) +_d.auth_basic = Tool('before_handler', auth_basic.basic_auth, priority=1) +_d.auth_digest = Tool('before_handler', auth_digest.digest_auth, priority=1) +_d.params = Tool('before_handler', cptools.convert_params, priority=15) + +del _d, cptools, encoding, static diff --git a/libraries/cherrypy/_cptree.py b/libraries/cherrypy/_cptree.py new file mode 100644 index 00000000..ceb54379 --- /dev/null +++ b/libraries/cherrypy/_cptree.py @@ -0,0 +1,313 @@ +"""CherryPy Application and Tree objects.""" + +import os + +import six + +import cherrypy +from cherrypy._cpcompat import ntou +from cherrypy import _cpconfig, _cplogging, _cprequest, _cpwsgi, tools +from cherrypy.lib import httputil, reprconf + + +class Application(object): + """A CherryPy Application. + + Servers and gateways should not instantiate Request objects directly. + Instead, they should ask an Application object for a request object. + + An instance of this class may also be used as a WSGI callable + (WSGI application object) for itself. + """ + + root = None + """The top-most container of page handlers for this app. Handlers should + be arranged in a hierarchy of attributes, matching the expected URI + hierarchy; the default dispatcher then searches this hierarchy for a + matching handler. When using a dispatcher other than the default, + this value may be None.""" + + config = {} + """A dict of {path: pathconf} pairs, where 'pathconf' is itself a dict + of {key: value} pairs.""" + + namespaces = reprconf.NamespaceSet() + toolboxes = {'tools': cherrypy.tools} + + log = None + """A LogManager instance. See _cplogging.""" + + wsgiapp = None + """A CPWSGIApp instance. See _cpwsgi.""" + + request_class = _cprequest.Request + response_class = _cprequest.Response + + relative_urls = False + + def __init__(self, root, script_name='', config=None): + """Initialize Application with given root.""" + self.log = _cplogging.LogManager(id(self), cherrypy.log.logger_root) + self.root = root + self.script_name = script_name + self.wsgiapp = _cpwsgi.CPWSGIApp(self) + + self.namespaces = self.namespaces.copy() + self.namespaces['log'] = lambda k, v: setattr(self.log, k, v) + self.namespaces['wsgi'] = self.wsgiapp.namespace_handler + + self.config = self.__class__.config.copy() + if config: + self.merge(config) + + def __repr__(self): + """Generate a representation of the Application instance.""" + return '%s.%s(%r, %r)' % (self.__module__, self.__class__.__name__, + self.root, self.script_name) + + script_name_doc = """The URI "mount point" for this app. A mount point + is that portion of the URI which is constant for all URIs that are + serviced by this application; it does not include scheme, host, or proxy + ("virtual host") portions of the URI. + + For example, if script_name is "/my/cool/app", then the URL + "http://www.example.com/my/cool/app/page1" might be handled by a + "page1" method on the root object. + + The value of script_name MUST NOT end in a slash. If the script_name + refers to the root of the URI, it MUST be an empty string (not "/"). + + If script_name is explicitly set to None, then the script_name will be + provided for each call from request.wsgi_environ['SCRIPT_NAME']. + """ + + @property + def script_name(self): # noqa: D401; irrelevant for properties + """The URI "mount point" for this app. + + A mount point is that portion of the URI which is constant for all URIs + that are serviced by this application; it does not include scheme, + host, or proxy ("virtual host") portions of the URI. + + For example, if script_name is "/my/cool/app", then the URL + "http://www.example.com/my/cool/app/page1" might be handled by a + "page1" method on the root object. + + The value of script_name MUST NOT end in a slash. If the script_name + refers to the root of the URI, it MUST be an empty string (not "/"). + + If script_name is explicitly set to None, then the script_name will be + provided for each call from request.wsgi_environ['SCRIPT_NAME']. + """ + if self._script_name is not None: + return self._script_name + + # A `_script_name` with a value of None signals that the script name + # should be pulled from WSGI environ. + return cherrypy.serving.request.wsgi_environ['SCRIPT_NAME'].rstrip('/') + + @script_name.setter + def script_name(self, value): + if value: + value = value.rstrip('/') + self._script_name = value + + def merge(self, config): + """Merge the given config into self.config.""" + _cpconfig.merge(self.config, config) + + # Handle namespaces specified in config. + self.namespaces(self.config.get('/', {})) + + def find_config(self, path, key, default=None): + """Return the most-specific value for key along path, or default.""" + trail = path or '/' + while trail: + nodeconf = self.config.get(trail, {}) + + if key in nodeconf: + return nodeconf[key] + + lastslash = trail.rfind('/') + if lastslash == -1: + break + elif lastslash == 0 and trail != '/': + trail = '/' + else: + trail = trail[:lastslash] + + return default + + def get_serving(self, local, remote, scheme, sproto): + """Create and return a Request and Response object.""" + req = self.request_class(local, remote, scheme, sproto) + req.app = self + + for name, toolbox in self.toolboxes.items(): + req.namespaces[name] = toolbox + + resp = self.response_class() + cherrypy.serving.load(req, resp) + cherrypy.engine.publish('acquire_thread') + cherrypy.engine.publish('before_request') + + return req, resp + + def release_serving(self): + """Release the current serving (request and response).""" + req = cherrypy.serving.request + + cherrypy.engine.publish('after_request') + + try: + req.close() + except Exception: + cherrypy.log(traceback=True, severity=40) + + cherrypy.serving.clear() + + def __call__(self, environ, start_response): + """Call a WSGI-callable.""" + return self.wsgiapp(environ, start_response) + + +class Tree(object): + """A registry of CherryPy applications, mounted at diverse points. + + An instance of this class may also be used as a WSGI callable + (WSGI application object), in which case it dispatches to all + mounted apps. + """ + + apps = {} + """ + A dict of the form {script name: application}, where "script name" + is a string declaring the URI mount point (no trailing slash), and + "application" is an instance of cherrypy.Application (or an arbitrary + WSGI callable if you happen to be using a WSGI server).""" + + def __init__(self): + """Initialize registry Tree.""" + self.apps = {} + + def mount(self, root, script_name='', config=None): + """Mount a new app from a root object, script_name, and config. + + root + An instance of a "controller class" (a collection of page + handler methods) which represents the root of the application. + This may also be an Application instance, or None if using + a dispatcher other than the default. + + script_name + A string containing the "mount point" of the application. + This should start with a slash, and be the path portion of the + URL at which to mount the given root. For example, if root.index() + will handle requests to "http://www.example.com:8080/dept/app1/", + then the script_name argument would be "/dept/app1". + + It MUST NOT end in a slash. If the script_name refers to the + root of the URI, it MUST be an empty string (not "/"). + + config + A file or dict containing application config. + """ + if script_name is None: + raise TypeError( + "The 'script_name' argument may not be None. Application " + 'objects may, however, possess a script_name of None (in ' + 'order to inpect the WSGI environ for SCRIPT_NAME upon each ' + 'request). You cannot mount such Applications on this Tree; ' + 'you must pass them to a WSGI server interface directly.') + + # Next line both 1) strips trailing slash and 2) maps "/" -> "". + script_name = script_name.rstrip('/') + + if isinstance(root, Application): + app = root + if script_name != '' and script_name != app.script_name: + raise ValueError( + 'Cannot specify a different script name and pass an ' + 'Application instance to cherrypy.mount') + script_name = app.script_name + else: + app = Application(root, script_name) + + # If mounted at "", add favicon.ico + needs_favicon = ( + script_name == '' + and root is not None + and not hasattr(root, 'favicon_ico') + ) + if needs_favicon: + favicon = os.path.join( + os.getcwd(), + os.path.dirname(__file__), + 'favicon.ico', + ) + root.favicon_ico = tools.staticfile.handler(favicon) + + if config: + app.merge(config) + + self.apps[script_name] = app + + return app + + def graft(self, wsgi_callable, script_name=''): + """Mount a wsgi callable at the given script_name.""" + # Next line both 1) strips trailing slash and 2) maps "/" -> "". + script_name = script_name.rstrip('/') + self.apps[script_name] = wsgi_callable + + def script_name(self, path=None): + """Return the script_name of the app at the given path, or None. + + If path is None, cherrypy.request is used. + """ + if path is None: + try: + request = cherrypy.serving.request + path = httputil.urljoin(request.script_name, + request.path_info) + except AttributeError: + return None + + while True: + if path in self.apps: + return path + + if path == '': + return None + + # Move one node up the tree and try again. + path = path[:path.rfind('/')] + + def __call__(self, environ, start_response): + """Pre-initialize WSGI env and call WSGI-callable.""" + # If you're calling this, then you're probably setting SCRIPT_NAME + # to '' (some WSGI servers always set SCRIPT_NAME to ''). + # Try to look up the app using the full path. + env1x = environ + if six.PY2 and environ.get(ntou('wsgi.version')) == (ntou('u'), 0): + env1x = _cpwsgi.downgrade_wsgi_ux_to_1x(environ) + path = httputil.urljoin(env1x.get('SCRIPT_NAME', ''), + env1x.get('PATH_INFO', '')) + sn = self.script_name(path or '/') + if sn is None: + start_response('404 Not Found', []) + return [] + + app = self.apps[sn] + + # Correct the SCRIPT_NAME and PATH_INFO environ entries. + environ = environ.copy() + if six.PY2 and environ.get(ntou('wsgi.version')) == (ntou('u'), 0): + # Python 2/WSGI u.0: all strings MUST be of type unicode + enc = environ[ntou('wsgi.url_encoding')] + environ[ntou('SCRIPT_NAME')] = sn.decode(enc) + environ[ntou('PATH_INFO')] = path[len(sn.rstrip('/')):].decode(enc) + else: + environ['SCRIPT_NAME'] = sn + environ['PATH_INFO'] = path[len(sn.rstrip('/')):] + return app(environ, start_response) diff --git a/libraries/cherrypy/_cpwsgi.py b/libraries/cherrypy/_cpwsgi.py new file mode 100644 index 00000000..0b4942ff --- /dev/null +++ b/libraries/cherrypy/_cpwsgi.py @@ -0,0 +1,467 @@ +"""WSGI interface (see PEP 333 and 3333). + +Note that WSGI environ keys and values are 'native strings'; that is, +whatever the type of "" is. For Python 2, that's a byte string; for Python 3, +it's a unicode string. But PEP 3333 says: "even if Python's str type is +actually Unicode "under the hood", the content of native strings must +still be translatable to bytes via the Latin-1 encoding!" +""" + +import sys as _sys +import io + +import six + +import cherrypy as _cherrypy +from cherrypy._cpcompat import ntou +from cherrypy import _cperror +from cherrypy.lib import httputil +from cherrypy.lib import is_closable_iterator + + +def downgrade_wsgi_ux_to_1x(environ): + """Return a new environ dict for WSGI 1.x from the given WSGI u.x environ. + """ + env1x = {} + + url_encoding = environ[ntou('wsgi.url_encoding')] + for k, v in list(environ.items()): + if k in [ntou('PATH_INFO'), ntou('SCRIPT_NAME'), ntou('QUERY_STRING')]: + v = v.encode(url_encoding) + elif isinstance(v, six.text_type): + v = v.encode('ISO-8859-1') + env1x[k.encode('ISO-8859-1')] = v + + return env1x + + +class VirtualHost(object): + + """Select a different WSGI application based on the Host header. + + This can be useful when running multiple sites within one CP server. + It allows several domains to point to different applications. For example:: + + root = Root() + RootApp = cherrypy.Application(root) + Domain2App = cherrypy.Application(root) + SecureApp = cherrypy.Application(Secure()) + + vhost = cherrypy._cpwsgi.VirtualHost( + RootApp, + domains={ + 'www.domain2.example': Domain2App, + 'www.domain2.example:443': SecureApp, + }, + ) + + cherrypy.tree.graft(vhost) + """ + default = None + """Required. The default WSGI application.""" + + use_x_forwarded_host = True + """If True (the default), any "X-Forwarded-Host" + request header will be used instead of the "Host" header. This + is commonly added by HTTP servers (such as Apache) when proxying.""" + + domains = {} + """A dict of {host header value: application} pairs. + The incoming "Host" request header is looked up in this dict, + and, if a match is found, the corresponding WSGI application + will be called instead of the default. Note that you often need + separate entries for "example.com" and "www.example.com". + In addition, "Host" headers may contain the port number. + """ + + def __init__(self, default, domains=None, use_x_forwarded_host=True): + self.default = default + self.domains = domains or {} + self.use_x_forwarded_host = use_x_forwarded_host + + def __call__(self, environ, start_response): + domain = environ.get('HTTP_HOST', '') + if self.use_x_forwarded_host: + domain = environ.get('HTTP_X_FORWARDED_HOST', domain) + + nextapp = self.domains.get(domain) + if nextapp is None: + nextapp = self.default + return nextapp(environ, start_response) + + +class InternalRedirector(object): + + """WSGI middleware that handles raised cherrypy.InternalRedirect.""" + + def __init__(self, nextapp, recursive=False): + self.nextapp = nextapp + self.recursive = recursive + + def __call__(self, environ, start_response): + redirections = [] + while True: + environ = environ.copy() + try: + return self.nextapp(environ, start_response) + except _cherrypy.InternalRedirect: + ir = _sys.exc_info()[1] + sn = environ.get('SCRIPT_NAME', '') + path = environ.get('PATH_INFO', '') + qs = environ.get('QUERY_STRING', '') + + # Add the *previous* path_info + qs to redirections. + old_uri = sn + path + if qs: + old_uri += '?' + qs + redirections.append(old_uri) + + if not self.recursive: + # Check to see if the new URI has been redirected to + # already + new_uri = sn + ir.path + if ir.query_string: + new_uri += '?' + ir.query_string + if new_uri in redirections: + ir.request.close() + tmpl = ( + 'InternalRedirector visited the same URL twice: %r' + ) + raise RuntimeError(tmpl % new_uri) + + # Munge the environment and try again. + environ['REQUEST_METHOD'] = 'GET' + environ['PATH_INFO'] = ir.path + environ['QUERY_STRING'] = ir.query_string + environ['wsgi.input'] = io.BytesIO() + environ['CONTENT_LENGTH'] = '0' + environ['cherrypy.previous_request'] = ir.request + + +class ExceptionTrapper(object): + + """WSGI middleware that traps exceptions.""" + + def __init__(self, nextapp, throws=(KeyboardInterrupt, SystemExit)): + self.nextapp = nextapp + self.throws = throws + + def __call__(self, environ, start_response): + return _TrappedResponse( + self.nextapp, + environ, + start_response, + self.throws + ) + + +class _TrappedResponse(object): + + response = iter([]) + + def __init__(self, nextapp, environ, start_response, throws): + self.nextapp = nextapp + self.environ = environ + self.start_response = start_response + self.throws = throws + self.started_response = False + self.response = self.trap( + self.nextapp, self.environ, self.start_response, + ) + self.iter_response = iter(self.response) + + def __iter__(self): + self.started_response = True + return self + + def __next__(self): + return self.trap(next, self.iter_response) + + # todo: https://pythonhosted.org/six/#six.Iterator + if six.PY2: + next = __next__ + + def close(self): + if hasattr(self.response, 'close'): + self.response.close() + + def trap(self, func, *args, **kwargs): + try: + return func(*args, **kwargs) + except self.throws: + raise + except StopIteration: + raise + except Exception: + tb = _cperror.format_exc() + _cherrypy.log(tb, severity=40) + if not _cherrypy.request.show_tracebacks: + tb = '' + s, h, b = _cperror.bare_error(tb) + if six.PY3: + # What fun. + s = s.decode('ISO-8859-1') + h = [ + (k.decode('ISO-8859-1'), v.decode('ISO-8859-1')) + for k, v in h + ] + if self.started_response: + # Empty our iterable (so future calls raise StopIteration) + self.iter_response = iter([]) + else: + self.iter_response = iter(b) + + try: + self.start_response(s, h, _sys.exc_info()) + except Exception: + # "The application must not trap any exceptions raised by + # start_response, if it called start_response with exc_info. + # Instead, it should allow such exceptions to propagate + # back to the server or gateway." + # But we still log and call close() to clean up ourselves. + _cherrypy.log(traceback=True, severity=40) + raise + + if self.started_response: + return b''.join(b) + else: + return b + + +# WSGI-to-CP Adapter # + + +class AppResponse(object): + + """WSGI response iterable for CherryPy applications.""" + + def __init__(self, environ, start_response, cpapp): + self.cpapp = cpapp + try: + if six.PY2: + if environ.get(ntou('wsgi.version')) == (ntou('u'), 0): + environ = downgrade_wsgi_ux_to_1x(environ) + self.environ = environ + self.run() + + r = _cherrypy.serving.response + + outstatus = r.output_status + if not isinstance(outstatus, bytes): + raise TypeError('response.output_status is not a byte string.') + + outheaders = [] + for k, v in r.header_list: + if not isinstance(k, bytes): + tmpl = 'response.header_list key %r is not a byte string.' + raise TypeError(tmpl % k) + if not isinstance(v, bytes): + tmpl = ( + 'response.header_list value %r is not a byte string.' + ) + raise TypeError(tmpl % v) + outheaders.append((k, v)) + + if six.PY3: + # According to PEP 3333, when using Python 3, the response + # status and headers must be bytes masquerading as unicode; + # that is, they must be of type "str" but are restricted to + # code points in the "latin-1" set. + outstatus = outstatus.decode('ISO-8859-1') + outheaders = [ + (k.decode('ISO-8859-1'), v.decode('ISO-8859-1')) + for k, v in outheaders + ] + + self.iter_response = iter(r.body) + self.write = start_response(outstatus, outheaders) + except BaseException: + self.close() + raise + + def __iter__(self): + return self + + def __next__(self): + return next(self.iter_response) + + # todo: https://pythonhosted.org/six/#six.Iterator + if six.PY2: + next = __next__ + + def close(self): + """Close and de-reference the current request and response. (Core)""" + streaming = _cherrypy.serving.response.stream + self.cpapp.release_serving() + + # We avoid the expense of examining the iterator to see if it's + # closable unless we are streaming the response, as that's the + # only situation where we are going to have an iterator which + # may not have been exhausted yet. + if streaming and is_closable_iterator(self.iter_response): + iter_close = self.iter_response.close + try: + iter_close() + except Exception: + _cherrypy.log(traceback=True, severity=40) + + def run(self): + """Create a Request object using environ.""" + env = self.environ.get + + local = httputil.Host( + '', + int(env('SERVER_PORT', 80) or -1), + env('SERVER_NAME', ''), + ) + remote = httputil.Host( + env('REMOTE_ADDR', ''), + int(env('REMOTE_PORT', -1) or -1), + env('REMOTE_HOST', ''), + ) + scheme = env('wsgi.url_scheme') + sproto = env('ACTUAL_SERVER_PROTOCOL', 'HTTP/1.1') + request, resp = self.cpapp.get_serving(local, remote, scheme, sproto) + + # LOGON_USER is served by IIS, and is the name of the + # user after having been mapped to a local account. + # Both IIS and Apache set REMOTE_USER, when possible. + request.login = env('LOGON_USER') or env('REMOTE_USER') or None + request.multithread = self.environ['wsgi.multithread'] + request.multiprocess = self.environ['wsgi.multiprocess'] + request.wsgi_environ = self.environ + request.prev = env('cherrypy.previous_request', None) + + meth = self.environ['REQUEST_METHOD'] + + path = httputil.urljoin( + self.environ.get('SCRIPT_NAME', ''), + self.environ.get('PATH_INFO', ''), + ) + qs = self.environ.get('QUERY_STRING', '') + + path, qs = self.recode_path_qs(path, qs) or (path, qs) + + rproto = self.environ.get('SERVER_PROTOCOL') + headers = self.translate_headers(self.environ) + rfile = self.environ['wsgi.input'] + request.run(meth, path, qs, rproto, headers, rfile) + + headerNames = { + 'HTTP_CGI_AUTHORIZATION': 'Authorization', + 'CONTENT_LENGTH': 'Content-Length', + 'CONTENT_TYPE': 'Content-Type', + 'REMOTE_HOST': 'Remote-Host', + 'REMOTE_ADDR': 'Remote-Addr', + } + + def recode_path_qs(self, path, qs): + if not six.PY3: + return + + # This isn't perfect; if the given PATH_INFO is in the + # wrong encoding, it may fail to match the appropriate config + # section URI. But meh. + old_enc = self.environ.get('wsgi.url_encoding', 'ISO-8859-1') + new_enc = self.cpapp.find_config( + self.environ.get('PATH_INFO', ''), + 'request.uri_encoding', 'utf-8', + ) + if new_enc.lower() == old_enc.lower(): + return + + # Even though the path and qs are unicode, the WSGI server + # is required by PEP 3333 to coerce them to ISO-8859-1 + # masquerading as unicode. So we have to encode back to + # bytes and then decode again using the "correct" encoding. + try: + return ( + path.encode(old_enc).decode(new_enc), + qs.encode(old_enc).decode(new_enc), + ) + except (UnicodeEncodeError, UnicodeDecodeError): + # Just pass them through without transcoding and hope. + pass + + def translate_headers(self, environ): + """Translate CGI-environ header names to HTTP header names.""" + for cgiName in environ: + # We assume all incoming header keys are uppercase already. + if cgiName in self.headerNames: + yield self.headerNames[cgiName], environ[cgiName] + elif cgiName[:5] == 'HTTP_': + # Hackish attempt at recovering original header names. + translatedHeader = cgiName[5:].replace('_', '-') + yield translatedHeader, environ[cgiName] + + +class CPWSGIApp(object): + + """A WSGI application object for a CherryPy Application.""" + + pipeline = [ + ('ExceptionTrapper', ExceptionTrapper), + ('InternalRedirector', InternalRedirector), + ] + """A list of (name, wsgiapp) pairs. Each 'wsgiapp' MUST be a + constructor that takes an initial, positional 'nextapp' argument, + plus optional keyword arguments, and returns a WSGI application + (that takes environ and start_response arguments). The 'name' can + be any you choose, and will correspond to keys in self.config.""" + + head = None + """Rather than nest all apps in the pipeline on each call, it's only + done the first time, and the result is memoized into self.head. Set + this to None again if you change self.pipeline after calling self.""" + + config = {} + """A dict whose keys match names listed in the pipeline. Each + value is a further dict which will be passed to the corresponding + named WSGI callable (from the pipeline) as keyword arguments.""" + + response_class = AppResponse + """The class to instantiate and return as the next app in the WSGI chain. + """ + + def __init__(self, cpapp, pipeline=None): + self.cpapp = cpapp + self.pipeline = self.pipeline[:] + if pipeline: + self.pipeline.extend(pipeline) + self.config = self.config.copy() + + def tail(self, environ, start_response): + """WSGI application callable for the actual CherryPy application. + + You probably shouldn't call this; call self.__call__ instead, + so that any WSGI middleware in self.pipeline can run first. + """ + return self.response_class(environ, start_response, self.cpapp) + + def __call__(self, environ, start_response): + head = self.head + if head is None: + # Create and nest the WSGI apps in our pipeline (in reverse order). + # Then memoize the result in self.head. + head = self.tail + for name, callable in self.pipeline[::-1]: + conf = self.config.get(name, {}) + head = callable(head, **conf) + self.head = head + return head(environ, start_response) + + def namespace_handler(self, k, v): + """Config handler for the 'wsgi' namespace.""" + if k == 'pipeline': + # Note this allows multiple 'wsgi.pipeline' config entries + # (but each entry will be processed in a 'random' order). + # It should also allow developers to set default middleware + # in code (passed to self.__init__) that deployers can add to + # (but not remove) via config. + self.pipeline.extend(v) + elif k == 'response_class': + self.response_class = v + else: + name, arg = k.split('.', 1) + bucket = self.config.setdefault(name, {}) + bucket[arg] = v diff --git a/libraries/cherrypy/_cpwsgi_server.py b/libraries/cherrypy/_cpwsgi_server.py new file mode 100644 index 00000000..11dd846a --- /dev/null +++ b/libraries/cherrypy/_cpwsgi_server.py @@ -0,0 +1,110 @@ +""" +WSGI server interface (see PEP 333). + +This adds some CP-specific bits to the framework-agnostic cheroot package. +""" +import sys + +import cheroot.wsgi +import cheroot.server + +import cherrypy + + +class CPWSGIHTTPRequest(cheroot.server.HTTPRequest): + """Wrapper for cheroot.server.HTTPRequest. + + This is a layer, which preserves URI parsing mode like it which was + before Cheroot v5.8.0. + """ + + def __init__(self, server, conn): + """Initialize HTTP request container instance. + + Args: + server (cheroot.server.HTTPServer): + web server object receiving this request + conn (cheroot.server.HTTPConnection): + HTTP connection object for this request + """ + super(CPWSGIHTTPRequest, self).__init__( + server, conn, proxy_mode=True + ) + + +class CPWSGIServer(cheroot.wsgi.Server): + """Wrapper for cheroot.wsgi.Server. + + cheroot has been designed to not reference CherryPy in any way, + so that it can be used in other frameworks and applications. Therefore, + we wrap it here, so we can set our own mount points from cherrypy.tree + and apply some attributes from config -> cherrypy.server -> wsgi.Server. + """ + + fmt = 'CherryPy/{cherrypy.__version__} {cheroot.wsgi.Server.version}' + version = fmt.format(**globals()) + + def __init__(self, server_adapter=cherrypy.server): + """Initialize CPWSGIServer instance. + + Args: + server_adapter (cherrypy._cpserver.Server): ... + """ + self.server_adapter = server_adapter + self.max_request_header_size = ( + self.server_adapter.max_request_header_size or 0 + ) + self.max_request_body_size = ( + self.server_adapter.max_request_body_size or 0 + ) + + server_name = (self.server_adapter.socket_host or + self.server_adapter.socket_file or + None) + + self.wsgi_version = self.server_adapter.wsgi_version + + super(CPWSGIServer, self).__init__( + server_adapter.bind_addr, cherrypy.tree, + self.server_adapter.thread_pool, + server_name, + max=self.server_adapter.thread_pool_max, + request_queue_size=self.server_adapter.socket_queue_size, + timeout=self.server_adapter.socket_timeout, + shutdown_timeout=self.server_adapter.shutdown_timeout, + accepted_queue_size=self.server_adapter.accepted_queue_size, + accepted_queue_timeout=self.server_adapter.accepted_queue_timeout, + peercreds_enabled=self.server_adapter.peercreds, + peercreds_resolve_enabled=self.server_adapter.peercreds_resolve, + ) + self.ConnectionClass.RequestHandlerClass = CPWSGIHTTPRequest + + self.protocol = self.server_adapter.protocol_version + self.nodelay = self.server_adapter.nodelay + + if sys.version_info >= (3, 0): + ssl_module = self.server_adapter.ssl_module or 'builtin' + else: + ssl_module = self.server_adapter.ssl_module or 'pyopenssl' + if self.server_adapter.ssl_context: + adapter_class = cheroot.server.get_ssl_adapter_class(ssl_module) + self.ssl_adapter = adapter_class( + self.server_adapter.ssl_certificate, + self.server_adapter.ssl_private_key, + self.server_adapter.ssl_certificate_chain, + self.server_adapter.ssl_ciphers) + self.ssl_adapter.context = self.server_adapter.ssl_context + elif self.server_adapter.ssl_certificate: + adapter_class = cheroot.server.get_ssl_adapter_class(ssl_module) + self.ssl_adapter = adapter_class( + self.server_adapter.ssl_certificate, + self.server_adapter.ssl_private_key, + self.server_adapter.ssl_certificate_chain, + self.server_adapter.ssl_ciphers) + + self.stats['Enabled'] = getattr( + self.server_adapter, 'statistics', False) + + def error_log(self, msg='', level=20, traceback=False): + """Write given message to the error log.""" + cherrypy.engine.log(msg, level, traceback) diff --git a/libraries/cherrypy/_helper.py b/libraries/cherrypy/_helper.py new file mode 100644 index 00000000..314550cb --- /dev/null +++ b/libraries/cherrypy/_helper.py @@ -0,0 +1,344 @@ +"""Helper functions for CP apps.""" + +import six +from six.moves import urllib + +from cherrypy._cpcompat import text_or_bytes + +import cherrypy + + +def expose(func=None, alias=None): + """Expose the function or class. + + Optionally provide an alias or set of aliases. + """ + def expose_(func): + func.exposed = True + if alias is not None: + if isinstance(alias, text_or_bytes): + parents[alias.replace('.', '_')] = func + else: + for a in alias: + parents[a.replace('.', '_')] = func + return func + + import sys + import types + decoratable_types = types.FunctionType, types.MethodType, type, + if six.PY2: + # Old-style classes are type types.ClassType. + decoratable_types += types.ClassType, + if isinstance(func, decoratable_types): + if alias is None: + # @expose + func.exposed = True + return func + else: + # func = expose(func, alias) + parents = sys._getframe(1).f_locals + return expose_(func) + elif func is None: + if alias is None: + # @expose() + parents = sys._getframe(1).f_locals + return expose_ + else: + # @expose(alias="alias") or + # @expose(alias=["alias1", "alias2"]) + parents = sys._getframe(1).f_locals + return expose_ + else: + # @expose("alias") or + # @expose(["alias1", "alias2"]) + parents = sys._getframe(1).f_locals + alias = func + return expose_ + + +def popargs(*args, **kwargs): + """Decorate _cp_dispatch. + + (cherrypy.dispatch.Dispatcher.dispatch_method_name) + + Optional keyword argument: handler=(Object or Function) + + Provides a _cp_dispatch function that pops off path segments into + cherrypy.request.params under the names specified. The dispatch + is then forwarded on to the next vpath element. + + Note that any existing (and exposed) member function of the class that + popargs is applied to will override that value of the argument. For + instance, if you have a method named "list" on the class decorated with + popargs, then accessing "/list" will call that function instead of popping + it off as the requested parameter. This restriction applies to all + _cp_dispatch functions. The only way around this restriction is to create + a "blank class" whose only function is to provide _cp_dispatch. + + If there are path elements after the arguments, or more arguments + are requested than are available in the vpath, then the 'handler' + keyword argument specifies the next object to handle the parameterized + request. If handler is not specified or is None, then self is used. + If handler is a function rather than an instance, then that function + will be called with the args specified and the return value from that + function used as the next object INSTEAD of adding the parameters to + cherrypy.request.args. + + This decorator may be used in one of two ways: + + As a class decorator: + @cherrypy.popargs('year', 'month', 'day') + class Blog: + def index(self, year=None, month=None, day=None): + #Process the parameters here; any url like + #/, /2009, /2009/12, or /2009/12/31 + #will fill in the appropriate parameters. + + def create(self): + #This link will still be available at /create. Defined functions + #take precedence over arguments. + + Or as a member of a class: + class Blog: + _cp_dispatch = cherrypy.popargs('year', 'month', 'day') + #... + + The handler argument may be used to mix arguments with built in functions. + For instance, the following setup allows different activities at the + day, month, and year level: + + class DayHandler: + def index(self, year, month, day): + #Do something with this day; probably list entries + + def delete(self, year, month, day): + #Delete all entries for this day + + @cherrypy.popargs('day', handler=DayHandler()) + class MonthHandler: + def index(self, year, month): + #Do something with this month; probably list entries + + def delete(self, year, month): + #Delete all entries for this month + + @cherrypy.popargs('month', handler=MonthHandler()) + class YearHandler: + def index(self, year): + #Do something with this year + + #... + + @cherrypy.popargs('year', handler=YearHandler()) + class Root: + def index(self): + #... + + """ + # Since keyword arg comes after *args, we have to process it ourselves + # for lower versions of python. + + handler = None + handler_call = False + for k, v in kwargs.items(): + if k == 'handler': + handler = v + else: + tm = "cherrypy.popargs() got an unexpected keyword argument '{0}'" + raise TypeError(tm.format(k)) + + import inspect + + if handler is not None \ + and (hasattr(handler, '__call__') or inspect.isclass(handler)): + handler_call = True + + def decorated(cls_or_self=None, vpath=None): + if inspect.isclass(cls_or_self): + # cherrypy.popargs is a class decorator + cls = cls_or_self + name = cherrypy.dispatch.Dispatcher.dispatch_method_name + setattr(cls, name, decorated) + return cls + + # We're in the actual function + self = cls_or_self + parms = {} + for arg in args: + if not vpath: + break + parms[arg] = vpath.pop(0) + + if handler is not None: + if handler_call: + return handler(**parms) + else: + cherrypy.request.params.update(parms) + return handler + + cherrypy.request.params.update(parms) + + # If we are the ultimate handler, then to prevent our _cp_dispatch + # from being called again, we will resolve remaining elements through + # getattr() directly. + if vpath: + return getattr(self, vpath.pop(0), None) + else: + return self + + return decorated + + +def url(path='', qs='', script_name=None, base=None, relative=None): + """Create an absolute URL for the given path. + + If 'path' starts with a slash ('/'), this will return + (base + script_name + path + qs). + If it does not start with a slash, this returns + (base + script_name [+ request.path_info] + path + qs). + + If script_name is None, cherrypy.request will be used + to find a script_name, if available. + + If base is None, cherrypy.request.base will be used (if available). + Note that you can use cherrypy.tools.proxy to change this. + + Finally, note that this function can be used to obtain an absolute URL + for the current request path (minus the querystring) by passing no args. + If you call url(qs=cherrypy.request.query_string), you should get the + original browser URL (assuming no internal redirections). + + If relative is None or not provided, request.app.relative_urls will + be used (if available, else False). If False, the output will be an + absolute URL (including the scheme, host, vhost, and script_name). + If True, the output will instead be a URL that is relative to the + current request path, perhaps including '..' atoms. If relative is + the string 'server', the output will instead be a URL that is + relative to the server root; i.e., it will start with a slash. + """ + if isinstance(qs, (tuple, list, dict)): + qs = urllib.parse.urlencode(qs) + if qs: + qs = '?' + qs + + if cherrypy.request.app: + if not path.startswith('/'): + # Append/remove trailing slash from path_info as needed + # (this is to support mistyped URL's without redirecting; + # if you want to redirect, use tools.trailing_slash). + pi = cherrypy.request.path_info + if cherrypy.request.is_index is True: + if not pi.endswith('/'): + pi = pi + '/' + elif cherrypy.request.is_index is False: + if pi.endswith('/') and pi != '/': + pi = pi[:-1] + + if path == '': + path = pi + else: + path = urllib.parse.urljoin(pi, path) + + if script_name is None: + script_name = cherrypy.request.script_name + if base is None: + base = cherrypy.request.base + + newurl = base + script_name + normalize_path(path) + qs + else: + # No request.app (we're being called outside a request). + # We'll have to guess the base from server.* attributes. + # This will produce very different results from the above + # if you're using vhosts or tools.proxy. + if base is None: + base = cherrypy.server.base() + + path = (script_name or '') + path + newurl = base + normalize_path(path) + qs + + # At this point, we should have a fully-qualified absolute URL. + + if relative is None: + relative = getattr(cherrypy.request.app, 'relative_urls', False) + + # See http://www.ietf.org/rfc/rfc2396.txt + if relative == 'server': + # "A relative reference beginning with a single slash character is + # termed an absolute-path reference, as defined by <abs_path>..." + # This is also sometimes called "server-relative". + newurl = '/' + '/'.join(newurl.split('/', 3)[3:]) + elif relative: + # "A relative reference that does not begin with a scheme name + # or a slash character is termed a relative-path reference." + old = url(relative=False).split('/')[:-1] + new = newurl.split('/') + while old and new: + a, b = old[0], new[0] + if a != b: + break + old.pop(0) + new.pop(0) + new = (['..'] * len(old)) + new + newurl = '/'.join(new) + + return newurl + + +def normalize_path(path): + """Resolve given path from relative into absolute form.""" + if './' not in path: + return path + + # Normalize the URL by removing ./ and ../ + atoms = [] + for atom in path.split('/'): + if atom == '.': + pass + elif atom == '..': + # Don't pop from empty list + # (i.e. ignore redundant '..') + if atoms: + atoms.pop() + elif atom: + atoms.append(atom) + + newpath = '/'.join(atoms) + # Preserve leading '/' + if path.startswith('/'): + newpath = '/' + newpath + + return newpath + + +#### +# Inlined from jaraco.classes 1.4.3 +# Ref #1673 +class _ClassPropertyDescriptor(object): + """Descript for read-only class-based property. + + Turns a classmethod-decorated func into a read-only property of that class + type (means the value cannot be set). + """ + + def __init__(self, fget, fset=None): + """Initialize a class property descriptor. + + Instantiated by ``_helper.classproperty``. + """ + self.fget = fget + self.fset = fset + + def __get__(self, obj, klass=None): + """Return property value.""" + if klass is None: + klass = type(obj) + return self.fget.__get__(obj, klass)() + + +def classproperty(func): # noqa: D401; irrelevant for properties + """Decorator like classmethod to implement a static class property.""" + if not isinstance(func, (classmethod, staticmethod)): + func = classmethod(func) + + return _ClassPropertyDescriptor(func) +#### diff --git a/libraries/cherrypy/daemon.py b/libraries/cherrypy/daemon.py new file mode 100644 index 00000000..74488c06 --- /dev/null +++ b/libraries/cherrypy/daemon.py @@ -0,0 +1,107 @@ +"""The CherryPy daemon.""" + +import sys + +import cherrypy +from cherrypy.process import plugins, servers +from cherrypy import Application + + +def start(configfiles=None, daemonize=False, environment=None, + fastcgi=False, scgi=False, pidfile=None, imports=None, + cgi=False): + """Subscribe all engine plugins and start the engine.""" + sys.path = [''] + sys.path + for i in imports or []: + exec('import %s' % i) + + for c in configfiles or []: + cherrypy.config.update(c) + # If there's only one app mounted, merge config into it. + if len(cherrypy.tree.apps) == 1: + for app in cherrypy.tree.apps.values(): + if isinstance(app, Application): + app.merge(c) + + engine = cherrypy.engine + + if environment is not None: + cherrypy.config.update({'environment': environment}) + + # Only daemonize if asked to. + if daemonize: + # Don't print anything to stdout/sterr. + cherrypy.config.update({'log.screen': False}) + plugins.Daemonizer(engine).subscribe() + + if pidfile: + plugins.PIDFile(engine, pidfile).subscribe() + + if hasattr(engine, 'signal_handler'): + engine.signal_handler.subscribe() + if hasattr(engine, 'console_control_handler'): + engine.console_control_handler.subscribe() + + if (fastcgi and (scgi or cgi)) or (scgi and cgi): + cherrypy.log.error('You may only specify one of the cgi, fastcgi, and ' + 'scgi options.', 'ENGINE') + sys.exit(1) + elif fastcgi or scgi or cgi: + # Turn off autoreload when using *cgi. + cherrypy.config.update({'engine.autoreload.on': False}) + # Turn off the default HTTP server (which is subscribed by default). + cherrypy.server.unsubscribe() + + addr = cherrypy.server.bind_addr + cls = ( + servers.FlupFCGIServer if fastcgi else + servers.FlupSCGIServer if scgi else + servers.FlupCGIServer + ) + f = cls(application=cherrypy.tree, bindAddress=addr) + s = servers.ServerAdapter(engine, httpserver=f, bind_addr=addr) + s.subscribe() + + # Always start the engine; this will start all other services + try: + engine.start() + except Exception: + # Assume the error has been logged already via bus.log. + sys.exit(1) + else: + engine.block() + + +def run(): + """Run cherryd CLI.""" + from optparse import OptionParser + + p = OptionParser() + p.add_option('-c', '--config', action='append', dest='config', + help='specify config file(s)') + p.add_option('-d', action='store_true', dest='daemonize', + help='run the server as a daemon') + p.add_option('-e', '--environment', dest='environment', default=None, + help='apply the given config environment') + p.add_option('-f', action='store_true', dest='fastcgi', + help='start a fastcgi server instead of the default HTTP ' + 'server') + p.add_option('-s', action='store_true', dest='scgi', + help='start a scgi server instead of the default HTTP server') + p.add_option('-x', action='store_true', dest='cgi', + help='start a cgi server instead of the default HTTP server') + p.add_option('-i', '--import', action='append', dest='imports', + help='specify modules to import') + p.add_option('-p', '--pidfile', dest='pidfile', default=None, + help='store the process id in the given file') + p.add_option('-P', '--Path', action='append', dest='Path', + help='add the given paths to sys.path') + options, args = p.parse_args() + + if options.Path: + for p in options.Path: + sys.path.insert(0, p) + + start(options.config, options.daemonize, + options.environment, options.fastcgi, options.scgi, + options.pidfile, options.imports, options.cgi) diff --git a/libraries/cherrypy/favicon.ico b/libraries/cherrypy/favicon.ico new file mode 100644 index 00000000..f0d7e61b Binary files /dev/null and b/libraries/cherrypy/favicon.ico differ diff --git a/libraries/cherrypy/lib/__init__.py b/libraries/cherrypy/lib/__init__.py new file mode 100644 index 00000000..f815f76a --- /dev/null +++ b/libraries/cherrypy/lib/__init__.py @@ -0,0 +1,96 @@ +"""CherryPy Library.""" + + +def is_iterator(obj): + """Detect if the object provided implements the iterator protocol. + + (i.e. like a generator). + + This will return False for objects which are iterable, + but not iterators themselves. + """ + from types import GeneratorType + if isinstance(obj, GeneratorType): + return True + elif not hasattr(obj, '__iter__'): + return False + else: + # Types which implement the protocol must return themselves when + # invoking 'iter' upon them. + return iter(obj) is obj + + +def is_closable_iterator(obj): + """Detect if the given object is both closable and iterator.""" + # Not an iterator. + if not is_iterator(obj): + return False + + # A generator - the easiest thing to deal with. + import inspect + if inspect.isgenerator(obj): + return True + + # A custom iterator. Look for a close method... + if not (hasattr(obj, 'close') and callable(obj.close)): + return False + + # ... which doesn't require any arguments. + try: + inspect.getcallargs(obj.close) + except TypeError: + return False + else: + return True + + +class file_generator(object): + """Yield the given input (a file object) in chunks (default 64k). + + (Core) + """ + + def __init__(self, input, chunkSize=65536): + """Initialize file_generator with file ``input`` for chunked access.""" + self.input = input + self.chunkSize = chunkSize + + def __iter__(self): + """Return iterator.""" + return self + + def __next__(self): + """Return next chunk of file.""" + chunk = self.input.read(self.chunkSize) + if chunk: + return chunk + else: + if hasattr(self.input, 'close'): + self.input.close() + raise StopIteration() + next = __next__ + + +def file_generator_limited(fileobj, count, chunk_size=65536): + """Yield the given file object in chunks. + + Stopps after `count` bytes has been emitted. + Default chunk size is 64kB. (Core) + """ + remaining = count + while remaining > 0: + chunk = fileobj.read(min(chunk_size, remaining)) + chunklen = len(chunk) + if chunklen == 0: + return + remaining -= chunklen + yield chunk + + +def set_vary_header(response, header_name): + """Add a Vary header to a response.""" + varies = response.headers.get('Vary', '') + varies = [x.strip() for x in varies.split(',') if x.strip()] + if header_name not in varies: + varies.append(header_name) + response.headers['Vary'] = ', '.join(varies) diff --git a/libraries/cherrypy/lib/auth_basic.py b/libraries/cherrypy/lib/auth_basic.py new file mode 100644 index 00000000..ad379a26 --- /dev/null +++ b/libraries/cherrypy/lib/auth_basic.py @@ -0,0 +1,120 @@ +# This file is part of CherryPy <http://www.cherrypy.org/> +# -*- coding: utf-8 -*- +# vim:ts=4:sw=4:expandtab:fileencoding=utf-8 +"""HTTP Basic Authentication tool. + +This module provides a CherryPy 3.x tool which implements +the server-side of HTTP Basic Access Authentication, as described in +:rfc:`2617`. + +Example usage, using the built-in checkpassword_dict function which uses a dict +as the credentials store:: + + userpassdict = {'bird' : 'bebop', 'ornette' : 'wayout'} + checkpassword = cherrypy.lib.auth_basic.checkpassword_dict(userpassdict) + basic_auth = {'tools.auth_basic.on': True, + 'tools.auth_basic.realm': 'earth', + 'tools.auth_basic.checkpassword': checkpassword, + 'tools.auth_basic.accept_charset': 'UTF-8', + } + app_config = { '/' : basic_auth } + +""" + +import binascii +import unicodedata +import base64 + +import cherrypy +from cherrypy._cpcompat import ntou, tonative + + +__author__ = 'visteya' +__date__ = 'April 2009' + + +def checkpassword_dict(user_password_dict): + """Returns a checkpassword function which checks credentials + against a dictionary of the form: {username : password}. + + If you want a simple dictionary-based authentication scheme, use + checkpassword_dict(my_credentials_dict) as the value for the + checkpassword argument to basic_auth(). + """ + def checkpassword(realm, user, password): + p = user_password_dict.get(user) + return p and p == password or False + + return checkpassword + + +def basic_auth(realm, checkpassword, debug=False, accept_charset='utf-8'): + """A CherryPy tool which hooks at before_handler to perform + HTTP Basic Access Authentication, as specified in :rfc:`2617` + and :rfc:`7617`. + + If the request has an 'authorization' header with a 'Basic' scheme, this + tool attempts to authenticate the credentials supplied in that header. If + the request has no 'authorization' header, or if it does but the scheme is + not 'Basic', or if authentication fails, the tool sends a 401 response with + a 'WWW-Authenticate' Basic header. + + realm + A string containing the authentication realm. + + checkpassword + A callable which checks the authentication credentials. + Its signature is checkpassword(realm, username, password). where + username and password are the values obtained from the request's + 'authorization' header. If authentication succeeds, checkpassword + returns True, else it returns False. + + """ + + fallback_charset = 'ISO-8859-1' + + if '"' in realm: + raise ValueError('Realm cannot contain the " (quote) character.') + request = cherrypy.serving.request + + auth_header = request.headers.get('authorization') + if auth_header is not None: + # split() error, base64.decodestring() error + msg = 'Bad Request' + with cherrypy.HTTPError.handle((ValueError, binascii.Error), 400, msg): + scheme, params = auth_header.split(' ', 1) + if scheme.lower() == 'basic': + charsets = accept_charset, fallback_charset + decoded_params = base64.b64decode(params.encode('ascii')) + decoded_params = _try_decode(decoded_params, charsets) + decoded_params = ntou(decoded_params) + decoded_params = unicodedata.normalize('NFC', decoded_params) + decoded_params = tonative(decoded_params) + username, password = decoded_params.split(':', 1) + if checkpassword(realm, username, password): + if debug: + cherrypy.log('Auth succeeded', 'TOOLS.AUTH_BASIC') + request.login = username + return # successful authentication + + charset = accept_charset.upper() + charset_declaration = ( + (', charset="%s"' % charset) + if charset != fallback_charset + else '' + ) + # Respond with 401 status and a WWW-Authenticate header + cherrypy.serving.response.headers['www-authenticate'] = ( + 'Basic realm="%s"%s' % (realm, charset_declaration) + ) + raise cherrypy.HTTPError( + 401, 'You are not authorized to access that resource') + + +def _try_decode(subject, charsets): + for charset in charsets[:-1]: + try: + return tonative(subject, charset) + except ValueError: + pass + return tonative(subject, charsets[-1]) diff --git a/libraries/cherrypy/lib/auth_digest.py b/libraries/cherrypy/lib/auth_digest.py new file mode 100644 index 00000000..9b4f55c8 --- /dev/null +++ b/libraries/cherrypy/lib/auth_digest.py @@ -0,0 +1,464 @@ +# This file is part of CherryPy <http://www.cherrypy.org/> +# -*- coding: utf-8 -*- +# vim:ts=4:sw=4:expandtab:fileencoding=utf-8 +"""HTTP Digest Authentication tool. + +An implementation of the server-side of HTTP Digest Access +Authentication, which is described in :rfc:`2617`. + +Example usage, using the built-in get_ha1_dict_plain function which uses a dict +of plaintext passwords as the credentials store:: + + userpassdict = {'alice' : '4x5istwelve'} + get_ha1 = cherrypy.lib.auth_digest.get_ha1_dict_plain(userpassdict) + digest_auth = {'tools.auth_digest.on': True, + 'tools.auth_digest.realm': 'wonderland', + 'tools.auth_digest.get_ha1': get_ha1, + 'tools.auth_digest.key': 'a565c27146791cfb', + 'tools.auth_digest.accept_charset': 'UTF-8', + } + app_config = { '/' : digest_auth } +""" + +import time +import functools +from hashlib import md5 + +from six.moves.urllib.request import parse_http_list, parse_keqv_list + +import cherrypy +from cherrypy._cpcompat import ntob, tonative + + +__author__ = 'visteya' +__date__ = 'April 2009' + + +def md5_hex(s): + return md5(ntob(s, 'utf-8')).hexdigest() + + +qop_auth = 'auth' +qop_auth_int = 'auth-int' +valid_qops = (qop_auth, qop_auth_int) + +valid_algorithms = ('MD5', 'MD5-sess') + +FALLBACK_CHARSET = 'ISO-8859-1' +DEFAULT_CHARSET = 'UTF-8' + + +def TRACE(msg): + cherrypy.log(msg, context='TOOLS.AUTH_DIGEST') + +# Three helper functions for users of the tool, providing three variants +# of get_ha1() functions for three different kinds of credential stores. + + +def get_ha1_dict_plain(user_password_dict): + """Returns a get_ha1 function which obtains a plaintext password from a + dictionary of the form: {username : password}. + + If you want a simple dictionary-based authentication scheme, with plaintext + passwords, use get_ha1_dict_plain(my_userpass_dict) as the value for the + get_ha1 argument to digest_auth(). + """ + def get_ha1(realm, username): + password = user_password_dict.get(username) + if password: + return md5_hex('%s:%s:%s' % (username, realm, password)) + return None + + return get_ha1 + + +def get_ha1_dict(user_ha1_dict): + """Returns a get_ha1 function which obtains a HA1 password hash from a + dictionary of the form: {username : HA1}. + + If you want a dictionary-based authentication scheme, but with + pre-computed HA1 hashes instead of plain-text passwords, use + get_ha1_dict(my_userha1_dict) as the value for the get_ha1 + argument to digest_auth(). + """ + def get_ha1(realm, username): + return user_ha1_dict.get(username) + + return get_ha1 + + +def get_ha1_file_htdigest(filename): + """Returns a get_ha1 function which obtains a HA1 password hash from a + flat file with lines of the same format as that produced by the Apache + htdigest utility. For example, for realm 'wonderland', username 'alice', + and password '4x5istwelve', the htdigest line would be:: + + alice:wonderland:3238cdfe91a8b2ed8e39646921a02d4c + + If you want to use an Apache htdigest file as the credentials store, + then use get_ha1_file_htdigest(my_htdigest_file) as the value for the + get_ha1 argument to digest_auth(). It is recommended that the filename + argument be an absolute path, to avoid problems. + """ + def get_ha1(realm, username): + result = None + f = open(filename, 'r') + for line in f: + u, r, ha1 = line.rstrip().split(':') + if u == username and r == realm: + result = ha1 + break + f.close() + return result + + return get_ha1 + + +def synthesize_nonce(s, key, timestamp=None): + """Synthesize a nonce value which resists spoofing and can be checked + for staleness. Returns a string suitable as the value for 'nonce' in + the www-authenticate header. + + s + A string related to the resource, such as the hostname of the server. + + key + A secret string known only to the server. + + timestamp + An integer seconds-since-the-epoch timestamp + + """ + if timestamp is None: + timestamp = int(time.time()) + h = md5_hex('%s:%s:%s' % (timestamp, s, key)) + nonce = '%s:%s' % (timestamp, h) + return nonce + + +def H(s): + """The hash function H""" + return md5_hex(s) + + +def _try_decode_header(header, charset): + global FALLBACK_CHARSET + + for enc in (charset, FALLBACK_CHARSET): + try: + return tonative(ntob(tonative(header, 'latin1'), 'latin1'), enc) + except ValueError as ve: + last_err = ve + else: + raise last_err + + +class HttpDigestAuthorization(object): + """ + Parses a Digest Authorization header and performs + re-calculation of the digest. + """ + + scheme = 'digest' + + def errmsg(self, s): + return 'Digest Authorization header: %s' % s + + @classmethod + def matches(cls, header): + scheme, _, _ = header.partition(' ') + return scheme.lower() == cls.scheme + + def __init__( + self, auth_header, http_method, + debug=False, accept_charset=DEFAULT_CHARSET[:], + ): + self.http_method = http_method + self.debug = debug + + if not self.matches(auth_header): + raise ValueError('Authorization scheme is not "Digest"') + + self.auth_header = _try_decode_header(auth_header, accept_charset) + + scheme, params = self.auth_header.split(' ', 1) + + # make a dict of the params + items = parse_http_list(params) + paramsd = parse_keqv_list(items) + + self.realm = paramsd.get('realm') + self.username = paramsd.get('username') + self.nonce = paramsd.get('nonce') + self.uri = paramsd.get('uri') + self.method = paramsd.get('method') + self.response = paramsd.get('response') # the response digest + self.algorithm = paramsd.get('algorithm', 'MD5').upper() + self.cnonce = paramsd.get('cnonce') + self.opaque = paramsd.get('opaque') + self.qop = paramsd.get('qop') # qop + self.nc = paramsd.get('nc') # nonce count + + # perform some correctness checks + if self.algorithm not in valid_algorithms: + raise ValueError( + self.errmsg("Unsupported value for algorithm: '%s'" % + self.algorithm)) + + has_reqd = ( + self.username and + self.realm and + self.nonce and + self.uri and + self.response + ) + if not has_reqd: + raise ValueError( + self.errmsg('Not all required parameters are present.')) + + if self.qop: + if self.qop not in valid_qops: + raise ValueError( + self.errmsg("Unsupported value for qop: '%s'" % self.qop)) + if not (self.cnonce and self.nc): + raise ValueError( + self.errmsg('If qop is sent then ' + 'cnonce and nc MUST be present')) + else: + if self.cnonce or self.nc: + raise ValueError( + self.errmsg('If qop is not sent, ' + 'neither cnonce nor nc can be present')) + + def __str__(self): + return 'authorization : %s' % self.auth_header + + def validate_nonce(self, s, key): + """Validate the nonce. + Returns True if nonce was generated by synthesize_nonce() and the + timestamp is not spoofed, else returns False. + + s + A string related to the resource, such as the hostname of + the server. + + key + A secret string known only to the server. + + Both s and key must be the same values which were used to synthesize + the nonce we are trying to validate. + """ + try: + timestamp, hashpart = self.nonce.split(':', 1) + s_timestamp, s_hashpart = synthesize_nonce( + s, key, timestamp).split(':', 1) + is_valid = s_hashpart == hashpart + if self.debug: + TRACE('validate_nonce: %s' % is_valid) + return is_valid + except ValueError: # split() error + pass + return False + + def is_nonce_stale(self, max_age_seconds=600): + """Returns True if a validated nonce is stale. The nonce contains a + timestamp in plaintext and also a secure hash of the timestamp. + You should first validate the nonce to ensure the plaintext + timestamp is not spoofed. + """ + try: + timestamp, hashpart = self.nonce.split(':', 1) + if int(timestamp) + max_age_seconds > int(time.time()): + return False + except ValueError: # int() error + pass + if self.debug: + TRACE('nonce is stale') + return True + + def HA2(self, entity_body=''): + """Returns the H(A2) string. See :rfc:`2617` section 3.2.2.3.""" + # RFC 2617 3.2.2.3 + # If the "qop" directive's value is "auth" or is unspecified, + # then A2 is: + # A2 = method ":" digest-uri-value + # + # If the "qop" value is "auth-int", then A2 is: + # A2 = method ":" digest-uri-value ":" H(entity-body) + if self.qop is None or self.qop == 'auth': + a2 = '%s:%s' % (self.http_method, self.uri) + elif self.qop == 'auth-int': + a2 = '%s:%s:%s' % (self.http_method, self.uri, H(entity_body)) + else: + # in theory, this should never happen, since I validate qop in + # __init__() + raise ValueError(self.errmsg('Unrecognized value for qop!')) + return H(a2) + + def request_digest(self, ha1, entity_body=''): + """Calculates the Request-Digest. See :rfc:`2617` section 3.2.2.1. + + ha1 + The HA1 string obtained from the credentials store. + + entity_body + If 'qop' is set to 'auth-int', then A2 includes a hash + of the "entity body". The entity body is the part of the + message which follows the HTTP headers. See :rfc:`2617` section + 4.3. This refers to the entity the user agent sent in the + request which has the Authorization header. Typically GET + requests don't have an entity, and POST requests do. + + """ + ha2 = self.HA2(entity_body) + # Request-Digest -- RFC 2617 3.2.2.1 + if self.qop: + req = '%s:%s:%s:%s:%s' % ( + self.nonce, self.nc, self.cnonce, self.qop, ha2) + else: + req = '%s:%s' % (self.nonce, ha2) + + # RFC 2617 3.2.2.2 + # + # If the "algorithm" directive's value is "MD5" or is unspecified, + # then A1 is: + # A1 = unq(username-value) ":" unq(realm-value) ":" passwd + # + # If the "algorithm" directive's value is "MD5-sess", then A1 is + # calculated only once - on the first request by the client following + # receipt of a WWW-Authenticate challenge from the server. + # A1 = H( unq(username-value) ":" unq(realm-value) ":" passwd ) + # ":" unq(nonce-value) ":" unq(cnonce-value) + if self.algorithm == 'MD5-sess': + ha1 = H('%s:%s:%s' % (ha1, self.nonce, self.cnonce)) + + digest = H('%s:%s' % (ha1, req)) + return digest + + +def _get_charset_declaration(charset): + global FALLBACK_CHARSET + charset = charset.upper() + return ( + (', charset="%s"' % charset) + if charset != FALLBACK_CHARSET + else '' + ) + + +def www_authenticate( + realm, key, algorithm='MD5', nonce=None, qop=qop_auth, + stale=False, accept_charset=DEFAULT_CHARSET[:], +): + """Constructs a WWW-Authenticate header for Digest authentication.""" + if qop not in valid_qops: + raise ValueError("Unsupported value for qop: '%s'" % qop) + if algorithm not in valid_algorithms: + raise ValueError("Unsupported value for algorithm: '%s'" % algorithm) + + HEADER_PATTERN = ( + 'Digest realm="%s", nonce="%s", algorithm="%s", qop="%s"%s%s' + ) + + if nonce is None: + nonce = synthesize_nonce(realm, key) + + stale_param = ', stale="true"' if stale else '' + + charset_declaration = _get_charset_declaration(accept_charset) + + return HEADER_PATTERN % ( + realm, nonce, algorithm, qop, stale_param, charset_declaration, + ) + + +def digest_auth(realm, get_ha1, key, debug=False, accept_charset='utf-8'): + """A CherryPy tool that hooks at before_handler to perform + HTTP Digest Access Authentication, as specified in :rfc:`2617`. + + If the request has an 'authorization' header with a 'Digest' scheme, + this tool authenticates the credentials supplied in that header. + If the request has no 'authorization' header, or if it does but the + scheme is not "Digest", or if authentication fails, the tool sends + a 401 response with a 'WWW-Authenticate' Digest header. + + realm + A string containing the authentication realm. + + get_ha1 + A callable that looks up a username in a credentials store + and returns the HA1 string, which is defined in the RFC to be + MD5(username : realm : password). The function's signature is: + ``get_ha1(realm, username)`` + where username is obtained from the request's 'authorization' header. + If username is not found in the credentials store, get_ha1() returns + None. + + key + A secret string known only to the server, used in the synthesis + of nonces. + + """ + request = cherrypy.serving.request + + auth_header = request.headers.get('authorization') + + respond_401 = functools.partial( + _respond_401, realm, key, accept_charset, debug) + + if not HttpDigestAuthorization.matches(auth_header or ''): + respond_401() + + msg = 'The Authorization header could not be parsed.' + with cherrypy.HTTPError.handle(ValueError, 400, msg): + auth = HttpDigestAuthorization( + auth_header, request.method, + debug=debug, accept_charset=accept_charset, + ) + + if debug: + TRACE(str(auth)) + + if not auth.validate_nonce(realm, key): + respond_401() + + ha1 = get_ha1(realm, auth.username) + + if ha1 is None: + respond_401() + + # note that for request.body to be available we need to + # hook in at before_handler, not on_start_resource like + # 3.1.x digest_auth does. + digest = auth.request_digest(ha1, entity_body=request.body) + if digest != auth.response: + respond_401() + + # authenticated + if debug: + TRACE('digest matches auth.response') + # Now check if nonce is stale. + # The choice of ten minutes' lifetime for nonce is somewhat + # arbitrary + if auth.is_nonce_stale(max_age_seconds=600): + respond_401(stale=True) + + request.login = auth.username + if debug: + TRACE('authentication of %s successful' % auth.username) + + +def _respond_401(realm, key, accept_charset, debug, **kwargs): + """ + Respond with 401 status and a WWW-Authenticate header + """ + header = www_authenticate( + realm, key, + accept_charset=accept_charset, + **kwargs + ) + if debug: + TRACE(header) + cherrypy.serving.response.headers['WWW-Authenticate'] = header + raise cherrypy.HTTPError( + 401, 'You are not authorized to access that resource') diff --git a/libraries/cherrypy/lib/caching.py b/libraries/cherrypy/lib/caching.py new file mode 100644 index 00000000..fed325a6 --- /dev/null +++ b/libraries/cherrypy/lib/caching.py @@ -0,0 +1,482 @@ +""" +CherryPy implements a simple caching system as a pluggable Tool. This tool +tries to be an (in-process) HTTP/1.1-compliant cache. It's not quite there +yet, but it's probably good enough for most sites. + +In general, GET responses are cached (along with selecting headers) and, if +another request arrives for the same resource, the caching Tool will return 304 +Not Modified if possible, or serve the cached response otherwise. It also sets +request.cached to True if serving a cached representation, and sets +request.cacheable to False (so it doesn't get cached again). + +If POST, PUT, or DELETE requests are made for a cached resource, they +invalidate (delete) any cached response. + +Usage +===== + +Configuration file example:: + + [/] + tools.caching.on = True + tools.caching.delay = 3600 + +You may use a class other than the default +:class:`MemoryCache<cherrypy.lib.caching.MemoryCache>` by supplying the config +entry ``cache_class``; supply the full dotted name of the replacement class +as the config value. It must implement the basic methods ``get``, ``put``, +``delete``, and ``clear``. + +You may set any attribute, including overriding methods, on the cache +instance by providing them in config. The above sets the +:attr:`delay<cherrypy.lib.caching.MemoryCache.delay>` attribute, for example. +""" + +import datetime +import sys +import threading +import time + +import six + +import cherrypy +from cherrypy.lib import cptools, httputil +from cherrypy._cpcompat import Event + + +class Cache(object): + + """Base class for Cache implementations.""" + + def get(self): + """Return the current variant if in the cache, else None.""" + raise NotImplemented + + def put(self, obj, size): + """Store the current variant in the cache.""" + raise NotImplemented + + def delete(self): + """Remove ALL cached variants of the current resource.""" + raise NotImplemented + + def clear(self): + """Reset the cache to its initial, empty state.""" + raise NotImplemented + + +# ------------------------------ Memory Cache ------------------------------- # +class AntiStampedeCache(dict): + + """A storage system for cached items which reduces stampede collisions.""" + + def wait(self, key, timeout=5, debug=False): + """Return the cached value for the given key, or None. + + If timeout is not None, and the value is already + being calculated by another thread, wait until the given timeout has + elapsed. If the value is available before the timeout expires, it is + returned. If not, None is returned, and a sentinel placed in the cache + to signal other threads to wait. + + If timeout is None, no waiting is performed nor sentinels used. + """ + value = self.get(key) + if isinstance(value, Event): + if timeout is None: + # Ignore the other thread and recalc it ourselves. + if debug: + cherrypy.log('No timeout', 'TOOLS.CACHING') + return None + + # Wait until it's done or times out. + if debug: + cherrypy.log('Waiting up to %s seconds' % + timeout, 'TOOLS.CACHING') + value.wait(timeout) + if value.result is not None: + # The other thread finished its calculation. Use it. + if debug: + cherrypy.log('Result!', 'TOOLS.CACHING') + return value.result + # Timed out. Stick an Event in the slot so other threads wait + # on this one to finish calculating the value. + if debug: + cherrypy.log('Timed out', 'TOOLS.CACHING') + e = threading.Event() + e.result = None + dict.__setitem__(self, key, e) + + return None + elif value is None: + # Stick an Event in the slot so other threads wait + # on this one to finish calculating the value. + if debug: + cherrypy.log('Timed out', 'TOOLS.CACHING') + e = threading.Event() + e.result = None + dict.__setitem__(self, key, e) + return value + + def __setitem__(self, key, value): + """Set the cached value for the given key.""" + existing = self.get(key) + dict.__setitem__(self, key, value) + if isinstance(existing, Event): + # Set Event.result so other threads waiting on it have + # immediate access without needing to poll the cache again. + existing.result = value + existing.set() + + +class MemoryCache(Cache): + + """An in-memory cache for varying response content. + + Each key in self.store is a URI, and each value is an AntiStampedeCache. + The response for any given URI may vary based on the values of + "selecting request headers"; that is, those named in the Vary + response header. We assume the list of header names to be constant + for each URI throughout the lifetime of the application, and store + that list in ``self.store[uri].selecting_headers``. + + The items contained in ``self.store[uri]`` have keys which are tuples of + request header values (in the same order as the names in its + selecting_headers), and values which are the actual responses. + """ + + maxobjects = 1000 + """The maximum number of cached objects; defaults to 1000.""" + + maxobj_size = 100000 + """The maximum size of each cached object in bytes; defaults to 100 KB.""" + + maxsize = 10000000 + """The maximum size of the entire cache in bytes; defaults to 10 MB.""" + + delay = 600 + """Seconds until the cached content expires; defaults to 600 (10 minutes). + """ + + antistampede_timeout = 5 + """Seconds to wait for other threads to release a cache lock.""" + + expire_freq = 0.1 + """Seconds to sleep between cache expiration sweeps.""" + + debug = False + + def __init__(self): + self.clear() + + # Run self.expire_cache in a separate daemon thread. + t = threading.Thread(target=self.expire_cache, name='expire_cache') + self.expiration_thread = t + t.daemon = True + t.start() + + def clear(self): + """Reset the cache to its initial, empty state.""" + self.store = {} + self.expirations = {} + self.tot_puts = 0 + self.tot_gets = 0 + self.tot_hist = 0 + self.tot_expires = 0 + self.tot_non_modified = 0 + self.cursize = 0 + + def expire_cache(self): + """Continuously examine cached objects, expiring stale ones. + + This function is designed to be run in its own daemon thread, + referenced at ``self.expiration_thread``. + """ + # It's possible that "time" will be set to None + # arbitrarily, so we check "while time" to avoid exceptions. + # See tickets #99 and #180 for more information. + while time: + now = time.time() + # Must make a copy of expirations so it doesn't change size + # during iteration + items = list(six.iteritems(self.expirations)) + for expiration_time, objects in items: + if expiration_time <= now: + for obj_size, uri, sel_header_values in objects: + try: + del self.store[uri][tuple(sel_header_values)] + self.tot_expires += 1 + self.cursize -= obj_size + except KeyError: + # the key may have been deleted elsewhere + pass + del self.expirations[expiration_time] + time.sleep(self.expire_freq) + + def get(self): + """Return the current variant if in the cache, else None.""" + request = cherrypy.serving.request + self.tot_gets += 1 + + uri = cherrypy.url(qs=request.query_string) + uricache = self.store.get(uri) + if uricache is None: + return None + + header_values = [request.headers.get(h, '') + for h in uricache.selecting_headers] + variant = uricache.wait(key=tuple(sorted(header_values)), + timeout=self.antistampede_timeout, + debug=self.debug) + if variant is not None: + self.tot_hist += 1 + return variant + + def put(self, variant, size): + """Store the current variant in the cache.""" + request = cherrypy.serving.request + response = cherrypy.serving.response + + uri = cherrypy.url(qs=request.query_string) + uricache = self.store.get(uri) + if uricache is None: + uricache = AntiStampedeCache() + uricache.selecting_headers = [ + e.value for e in response.headers.elements('Vary')] + self.store[uri] = uricache + + if len(self.store) < self.maxobjects: + total_size = self.cursize + size + + # checks if there's space for the object + if (size < self.maxobj_size and total_size < self.maxsize): + # add to the expirations list + expiration_time = response.time + self.delay + bucket = self.expirations.setdefault(expiration_time, []) + bucket.append((size, uri, uricache.selecting_headers)) + + # add to the cache + header_values = [request.headers.get(h, '') + for h in uricache.selecting_headers] + uricache[tuple(sorted(header_values))] = variant + self.tot_puts += 1 + self.cursize = total_size + + def delete(self): + """Remove ALL cached variants of the current resource.""" + uri = cherrypy.url(qs=cherrypy.serving.request.query_string) + self.store.pop(uri, None) + + +def get(invalid_methods=('POST', 'PUT', 'DELETE'), debug=False, **kwargs): + """Try to obtain cached output. If fresh enough, raise HTTPError(304). + + If POST, PUT, or DELETE: + * invalidates (deletes) any cached response for this resource + * sets request.cached = False + * sets request.cacheable = False + + else if a cached copy exists: + * sets request.cached = True + * sets request.cacheable = False + * sets response.headers to the cached values + * checks the cached Last-Modified response header against the + current If-(Un)Modified-Since request headers; raises 304 + if necessary. + * sets response.status and response.body to the cached values + * returns True + + otherwise: + * sets request.cached = False + * sets request.cacheable = True + * returns False + """ + request = cherrypy.serving.request + response = cherrypy.serving.response + + if not hasattr(cherrypy, '_cache'): + # Make a process-wide Cache object. + cherrypy._cache = kwargs.pop('cache_class', MemoryCache)() + + # Take all remaining kwargs and set them on the Cache object. + for k, v in kwargs.items(): + setattr(cherrypy._cache, k, v) + cherrypy._cache.debug = debug + + # POST, PUT, DELETE should invalidate (delete) the cached copy. + # See http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html#sec13.10. + if request.method in invalid_methods: + if debug: + cherrypy.log('request.method %r in invalid_methods %r' % + (request.method, invalid_methods), 'TOOLS.CACHING') + cherrypy._cache.delete() + request.cached = False + request.cacheable = False + return False + + if 'no-cache' in [e.value for e in request.headers.elements('Pragma')]: + request.cached = False + request.cacheable = True + return False + + cache_data = cherrypy._cache.get() + request.cached = bool(cache_data) + request.cacheable = not request.cached + if request.cached: + # Serve the cached copy. + max_age = cherrypy._cache.delay + for v in [e.value for e in request.headers.elements('Cache-Control')]: + atoms = v.split('=', 1) + directive = atoms.pop(0) + if directive == 'max-age': + if len(atoms) != 1 or not atoms[0].isdigit(): + raise cherrypy.HTTPError( + 400, 'Invalid Cache-Control header') + max_age = int(atoms[0]) + break + elif directive == 'no-cache': + if debug: + cherrypy.log( + 'Ignoring cache due to Cache-Control: no-cache', + 'TOOLS.CACHING') + request.cached = False + request.cacheable = True + return False + + if debug: + cherrypy.log('Reading response from cache', 'TOOLS.CACHING') + s, h, b, create_time = cache_data + age = int(response.time - create_time) + if (age > max_age): + if debug: + cherrypy.log('Ignoring cache due to age > %d' % max_age, + 'TOOLS.CACHING') + request.cached = False + request.cacheable = True + return False + + # Copy the response headers. See + # https://github.com/cherrypy/cherrypy/issues/721. + response.headers = rh = httputil.HeaderMap() + for k in h: + dict.__setitem__(rh, k, dict.__getitem__(h, k)) + + # Add the required Age header + response.headers['Age'] = str(age) + + try: + # Note that validate_since depends on a Last-Modified header; + # this was put into the cached copy, and should have been + # resurrected just above (response.headers = cache_data[1]). + cptools.validate_since() + except cherrypy.HTTPRedirect: + x = sys.exc_info()[1] + if x.status == 304: + cherrypy._cache.tot_non_modified += 1 + raise + + # serve it & get out from the request + response.status = s + response.body = b + else: + if debug: + cherrypy.log('request is not cached', 'TOOLS.CACHING') + return request.cached + + +def tee_output(): + """Tee response output to cache storage. Internal.""" + # Used by CachingTool by attaching to request.hooks + + request = cherrypy.serving.request + if 'no-store' in request.headers.values('Cache-Control'): + return + + def tee(body): + """Tee response.body into a list.""" + if ('no-cache' in response.headers.values('Pragma') or + 'no-store' in response.headers.values('Cache-Control')): + for chunk in body: + yield chunk + return + + output = [] + for chunk in body: + output.append(chunk) + yield chunk + + # Save the cache data, but only if the body isn't empty. + # e.g. a 304 Not Modified on a static file response will + # have an empty body. + # If the body is empty, delete the cache because it + # contains a stale Threading._Event object that will + # stall all consecutive requests until the _Event times + # out + body = b''.join(output) + if not body: + cherrypy._cache.delete() + else: + cherrypy._cache.put((response.status, response.headers or {}, + body, response.time), len(body)) + + response = cherrypy.serving.response + response.body = tee(response.body) + + +def expires(secs=0, force=False, debug=False): + """Tool for influencing cache mechanisms using the 'Expires' header. + + secs + Must be either an int or a datetime.timedelta, and indicates the + number of seconds between response.time and when the response should + expire. The 'Expires' header will be set to response.time + secs. + If secs is zero, the 'Expires' header is set one year in the past, and + the following "cache prevention" headers are also set: + + * Pragma: no-cache + * Cache-Control': no-cache, must-revalidate + + force + If False, the following headers are checked: + + * Etag + * Last-Modified + * Age + * Expires + + If any are already present, none of the above response headers are set. + + """ + + response = cherrypy.serving.response + headers = response.headers + + cacheable = False + if not force: + # some header names that indicate that the response can be cached + for indicator in ('Etag', 'Last-Modified', 'Age', 'Expires'): + if indicator in headers: + cacheable = True + break + + if not cacheable and not force: + if debug: + cherrypy.log('request is not cacheable', 'TOOLS.EXPIRES') + else: + if debug: + cherrypy.log('request is cacheable', 'TOOLS.EXPIRES') + if isinstance(secs, datetime.timedelta): + secs = (86400 * secs.days) + secs.seconds + + if secs == 0: + if force or ('Pragma' not in headers): + headers['Pragma'] = 'no-cache' + if cherrypy.serving.request.protocol >= (1, 1): + if force or 'Cache-Control' not in headers: + headers['Cache-Control'] = 'no-cache, must-revalidate' + # Set an explicit Expires date in the past. + expiry = httputil.HTTPDate(1169942400.0) + else: + expiry = httputil.HTTPDate(response.time + secs) + if force or 'Expires' not in headers: + headers['Expires'] = expiry diff --git a/libraries/cherrypy/lib/covercp.py b/libraries/cherrypy/lib/covercp.py new file mode 100644 index 00000000..0bafca13 --- /dev/null +++ b/libraries/cherrypy/lib/covercp.py @@ -0,0 +1,391 @@ +"""Code-coverage tools for CherryPy. + +To use this module, or the coverage tools in the test suite, +you need to download 'coverage.py', either Gareth Rees' `original +implementation <http://www.garethrees.org/2001/12/04/python-coverage/>`_ +or Ned Batchelder's `enhanced version: +<http://www.nedbatchelder.com/code/modules/coverage.html>`_ + +To turn on coverage tracing, use the following code:: + + cherrypy.engine.subscribe('start', covercp.start) + +DO NOT subscribe anything on the 'start_thread' channel, as previously +recommended. Calling start once in the main thread should be sufficient +to start coverage on all threads. Calling start again in each thread +effectively clears any coverage data gathered up to that point. + +Run your code, then use the ``covercp.serve()`` function to browse the +results in a web browser. If you run this module from the command line, +it will call ``serve()`` for you. +""" + +import re +import sys +import cgi +import os +import os.path + +from six.moves import urllib + +import cherrypy + + +localFile = os.path.join(os.path.dirname(__file__), 'coverage.cache') + +the_coverage = None +try: + from coverage import coverage + the_coverage = coverage(data_file=localFile) + + def start(): + the_coverage.start() +except ImportError: + # Setting the_coverage to None will raise errors + # that need to be trapped downstream. + the_coverage = None + + import warnings + warnings.warn( + 'No code coverage will be performed; ' + 'coverage.py could not be imported.') + + def start(): + pass +start.priority = 20 + +TEMPLATE_MENU = """<html> +<head> + <title>CherryPy Coverage Menu</title> + <style> + body {font: 9pt Arial, serif;} + #tree { + font-size: 8pt; + font-family: Andale Mono, monospace; + white-space: pre; + } + #tree a:active, a:focus { + background-color: black; + padding: 1px; + color: white; + border: 0px solid #9999FF; + -moz-outline-style: none; + } + .fail { color: red;} + .pass { color: #888;} + #pct { text-align: right;} + h3 { + font-size: small; + font-weight: bold; + font-style: italic; + margin-top: 5px; + } + input { border: 1px solid #ccc; padding: 2px; } + .directory { + color: #933; + font-style: italic; + font-weight: bold; + font-size: 10pt; + } + .file { + color: #400; + } + a { text-decoration: none; } + #crumbs { + color: white; + font-size: 8pt; + font-family: Andale Mono, monospace; + width: 100%; + background-color: black; + } + #crumbs a { + color: #f88; + } + #options { + line-height: 2.3em; + border: 1px solid black; + background-color: #eee; + padding: 4px; + } + #exclude { + width: 100%; + margin-bottom: 3px; + border: 1px solid #999; + } + #submit { + background-color: black; + color: white; + border: 0; + margin-bottom: -9px; + } + </style> +</head> +<body> +<h2>CherryPy Coverage</h2>""" + +TEMPLATE_FORM = """ +<div id="options"> +<form action='menu' method=GET> + <input type='hidden' name='base' value='%(base)s' /> + Show percentages + <input type='checkbox' %(showpct)s name='showpct' value='checked' /><br /> + Hide files over + <input type='text' id='pct' name='pct' value='%(pct)s' size='3' />%%<br /> + Exclude files matching<br /> + <input type='text' id='exclude' name='exclude' + value='%(exclude)s' size='20' /> + <br /> + + <input type='submit' value='Change view' id="submit"/> +</form> +</div>""" + +TEMPLATE_FRAMESET = """<html> +<head><title>CherryPy coverage data</title></head> +<frameset cols='250, 1*'> + <frame src='menu?base=%s' /> + <frame name='main' src='' /> +</frameset> +</html> +""" + +TEMPLATE_COVERAGE = """<html> +<head> + <title>Coverage for %(name)s</title> + <style> + h2 { margin-bottom: .25em; } + p { margin: .25em; } + .covered { color: #000; background-color: #fff; } + .notcovered { color: #fee; background-color: #500; } + .excluded { color: #00f; background-color: #fff; } + table .covered, table .notcovered, table .excluded + { font-family: Andale Mono, monospace; + font-size: 10pt; white-space: pre; } + + .lineno { background-color: #eee;} + .notcovered .lineno { background-color: #000;} + table { border-collapse: collapse; + </style> +</head> +<body> +<h2>%(name)s</h2> +<p>%(fullpath)s</p> +<p>Coverage: %(pc)s%%</p>""" + +TEMPLATE_LOC_COVERED = """<tr class="covered"> + <td class="lineno">%s </td> + <td>%s</td> +</tr>\n""" +TEMPLATE_LOC_NOT_COVERED = """<tr class="notcovered"> + <td class="lineno">%s </td> + <td>%s</td> +</tr>\n""" +TEMPLATE_LOC_EXCLUDED = """<tr class="excluded"> + <td class="lineno">%s </td> + <td>%s</td> +</tr>\n""" + +TEMPLATE_ITEM = ( + "%s%s<a class='file' href='report?name=%s' target='main'>%s</a>\n" +) + + +def _percent(statements, missing): + s = len(statements) + e = s - len(missing) + if s > 0: + return int(round(100.0 * e / s)) + return 0 + + +def _show_branch(root, base, path, pct=0, showpct=False, exclude='', + coverage=the_coverage): + + # Show the directory name and any of our children + dirs = [k for k, v in root.items() if v] + dirs.sort() + for name in dirs: + newpath = os.path.join(path, name) + + if newpath.lower().startswith(base): + relpath = newpath[len(base):] + yield '| ' * relpath.count(os.sep) + yield ( + "<a class='directory' " + "href='menu?base=%s&exclude=%s'>%s</a>\n" % + (newpath, urllib.parse.quote_plus(exclude), name) + ) + + for chunk in _show_branch( + root[name], base, newpath, pct, showpct, + exclude, coverage=coverage + ): + yield chunk + + # Now list the files + if path.lower().startswith(base): + relpath = path[len(base):] + files = [k for k, v in root.items() if not v] + files.sort() + for name in files: + newpath = os.path.join(path, name) + + pc_str = '' + if showpct: + try: + _, statements, _, missing, _ = coverage.analysis2(newpath) + except Exception: + # Yes, we really want to pass on all errors. + pass + else: + pc = _percent(statements, missing) + pc_str = ('%3d%% ' % pc).replace(' ', ' ') + if pc < float(pct) or pc == -1: + pc_str = "<span class='fail'>%s</span>" % pc_str + else: + pc_str = "<span class='pass'>%s</span>" % pc_str + + yield TEMPLATE_ITEM % ('| ' * (relpath.count(os.sep) + 1), + pc_str, newpath, name) + + +def _skip_file(path, exclude): + if exclude: + return bool(re.search(exclude, path)) + + +def _graft(path, tree): + d = tree + + p = path + atoms = [] + while True: + p, tail = os.path.split(p) + if not tail: + break + atoms.append(tail) + atoms.append(p) + if p != '/': + atoms.append('/') + + atoms.reverse() + for node in atoms: + if node: + d = d.setdefault(node, {}) + + +def get_tree(base, exclude, coverage=the_coverage): + """Return covered module names as a nested dict.""" + tree = {} + runs = coverage.data.executed_files() + for path in runs: + if not _skip_file(path, exclude) and not os.path.isdir(path): + _graft(path, tree) + return tree + + +class CoverStats(object): + + def __init__(self, coverage, root=None): + self.coverage = coverage + if root is None: + # Guess initial depth. Files outside this path will not be + # reachable from the web interface. + root = os.path.dirname(cherrypy.__file__) + self.root = root + + @cherrypy.expose + def index(self): + return TEMPLATE_FRAMESET % self.root.lower() + + @cherrypy.expose + def menu(self, base='/', pct='50', showpct='', + exclude=r'python\d\.\d|test|tut\d|tutorial'): + + # The coverage module uses all-lower-case names. + base = base.lower().rstrip(os.sep) + + yield TEMPLATE_MENU + yield TEMPLATE_FORM % locals() + + # Start by showing links for parent paths + yield "<div id='crumbs'>" + path = '' + atoms = base.split(os.sep) + atoms.pop() + for atom in atoms: + path += atom + os.sep + yield ("<a href='menu?base=%s&exclude=%s'>%s</a> %s" + % (path, urllib.parse.quote_plus(exclude), atom, os.sep)) + yield '</div>' + + yield "<div id='tree'>" + + # Then display the tree + tree = get_tree(base, exclude, self.coverage) + if not tree: + yield '<p>No modules covered.</p>' + else: + for chunk in _show_branch(tree, base, '/', pct, + showpct == 'checked', exclude, + coverage=self.coverage): + yield chunk + + yield '</div>' + yield '</body></html>' + + def annotated_file(self, filename, statements, excluded, missing): + source = open(filename, 'r') + buffer = [] + for lineno, line in enumerate(source.readlines()): + lineno += 1 + line = line.strip('\n\r') + empty_the_buffer = True + if lineno in excluded: + template = TEMPLATE_LOC_EXCLUDED + elif lineno in missing: + template = TEMPLATE_LOC_NOT_COVERED + elif lineno in statements: + template = TEMPLATE_LOC_COVERED + else: + empty_the_buffer = False + buffer.append((lineno, line)) + if empty_the_buffer: + for lno, pastline in buffer: + yield template % (lno, cgi.escape(pastline)) + buffer = [] + yield template % (lineno, cgi.escape(line)) + + @cherrypy.expose + def report(self, name): + filename, statements, excluded, missing, _ = self.coverage.analysis2( + name) + pc = _percent(statements, missing) + yield TEMPLATE_COVERAGE % dict(name=os.path.basename(name), + fullpath=name, + pc=pc) + yield '<table>\n' + for line in self.annotated_file(filename, statements, excluded, + missing): + yield line + yield '</table>' + yield '</body>' + yield '</html>' + + +def serve(path=localFile, port=8080, root=None): + if coverage is None: + raise ImportError('The coverage module could not be imported.') + from coverage import coverage + cov = coverage(data_file=path) + cov.load() + + cherrypy.config.update({'server.socket_port': int(port), + 'server.thread_pool': 10, + 'environment': 'production', + }) + cherrypy.quickstart(CoverStats(cov, root)) + + +if __name__ == '__main__': + serve(*tuple(sys.argv[1:])) diff --git a/libraries/cherrypy/lib/cpstats.py b/libraries/cherrypy/lib/cpstats.py new file mode 100644 index 00000000..ae9f7475 --- /dev/null +++ b/libraries/cherrypy/lib/cpstats.py @@ -0,0 +1,696 @@ +"""CPStats, a package for collecting and reporting on program statistics. + +Overview +======== + +Statistics about program operation are an invaluable monitoring and debugging +tool. Unfortunately, the gathering and reporting of these critical values is +usually ad-hoc. This package aims to add a centralized place for gathering +statistical performance data, a structure for recording that data which +provides for extrapolation of that data into more useful information, +and a method of serving that data to both human investigators and +monitoring software. Let's examine each of those in more detail. + +Data Gathering +-------------- + +Just as Python's `logging` module provides a common importable for gathering +and sending messages, performance statistics would benefit from a similar +common mechanism, and one that does *not* require each package which wishes +to collect stats to import a third-party module. Therefore, we choose to +re-use the `logging` module by adding a `statistics` object to it. + +That `logging.statistics` object is a nested dict. It is not a custom class, +because that would: + + 1. require libraries and applications to import a third-party module in + order to participate + 2. inhibit innovation in extrapolation approaches and in reporting tools, and + 3. be slow. + +There are, however, some specifications regarding the structure of the dict.:: + + { + +----"SQLAlchemy": { + | "Inserts": 4389745, + | "Inserts per Second": + | lambda s: s["Inserts"] / (time() - s["Start"]), + | C +---"Table Statistics": { + | o | "widgets": {-----------+ + N | l | "Rows": 1.3M, | Record + a | l | "Inserts": 400, | + m | e | },---------------------+ + e | c | "froobles": { + s | t | "Rows": 7845, + p | i | "Inserts": 0, + a | o | }, + c | n +---}, + e | "Slow Queries": + | [{"Query": "SELECT * FROM widgets;", + | "Processing Time": 47.840923343, + | }, + | ], + +----}, + } + +The `logging.statistics` dict has four levels. The topmost level is nothing +more than a set of names to introduce modularity, usually along the lines of +package names. If the SQLAlchemy project wanted to participate, for example, +it might populate the item `logging.statistics['SQLAlchemy']`, whose value +would be a second-layer dict we call a "namespace". Namespaces help multiple +packages to avoid collisions over key names, and make reports easier to read, +to boot. The maintainers of SQLAlchemy should feel free to use more than one +namespace if needed (such as 'SQLAlchemy ORM'). Note that there are no case +or other syntax constraints on the namespace names; they should be chosen +to be maximally readable by humans (neither too short nor too long). + +Each namespace, then, is a dict of named statistical values, such as +'Requests/sec' or 'Uptime'. You should choose names which will look +good on a report: spaces and capitalization are just fine. + +In addition to scalars, values in a namespace MAY be a (third-layer) +dict, or a list, called a "collection". For example, the CherryPy +:class:`StatsTool` keeps track of what each request is doing (or has most +recently done) in a 'Requests' collection, where each key is a thread ID; each +value in the subdict MUST be a fourth dict (whew!) of statistical data about +each thread. We call each subdict in the collection a "record". Similarly, +the :class:`StatsTool` also keeps a list of slow queries, where each record +contains data about each slow query, in order. + +Values in a namespace or record may also be functions, which brings us to: + +Extrapolation +------------- + +The collection of statistical data needs to be fast, as close to unnoticeable +as possible to the host program. That requires us to minimize I/O, for example, +but in Python it also means we need to minimize function calls. So when you +are designing your namespace and record values, try to insert the most basic +scalar values you already have on hand. + +When it comes time to report on the gathered data, however, we usually have +much more freedom in what we can calculate. Therefore, whenever reporting +tools (like the provided :class:`StatsPage` CherryPy class) fetch the contents +of `logging.statistics` for reporting, they first call +`extrapolate_statistics` (passing the whole `statistics` dict as the only +argument). This makes a deep copy of the statistics dict so that the +reporting tool can both iterate over it and even change it without harming +the original. But it also expands any functions in the dict by calling them. +For example, you might have a 'Current Time' entry in the namespace with the +value "lambda scope: time.time()". The "scope" parameter is the current +namespace dict (or record, if we're currently expanding one of those +instead), allowing you access to existing static entries. If you're truly +evil, you can even modify more than one entry at a time. + +However, don't try to calculate an entry and then use its value in further +extrapolations; the order in which the functions are called is not guaranteed. +This can lead to a certain amount of duplicated work (or a redesign of your +schema), but that's better than complicating the spec. + +After the whole thing has been extrapolated, it's time for: + +Reporting +--------- + +The :class:`StatsPage` class grabs the `logging.statistics` dict, extrapolates +it all, and then transforms it to HTML for easy viewing. Each namespace gets +its own header and attribute table, plus an extra table for each collection. +This is NOT part of the statistics specification; other tools can format how +they like. + +You can control which columns are output and how they are formatted by updating +StatsPage.formatting, which is a dict that mirrors the keys and nesting of +`logging.statistics`. The difference is that, instead of data values, it has +formatting values. Use None for a given key to indicate to the StatsPage that a +given column should not be output. Use a string with formatting +(such as '%.3f') to interpolate the value(s), or use a callable (such as +lambda v: v.isoformat()) for more advanced formatting. Any entry which is not +mentioned in the formatting dict is output unchanged. + +Monitoring +---------- + +Although the HTML output takes pains to assign unique id's to each <td> with +statistical data, you're probably better off fetching /cpstats/data, which +outputs the whole (extrapolated) `logging.statistics` dict in JSON format. +That is probably easier to parse, and doesn't have any formatting controls, +so you get the "original" data in a consistently-serialized format. +Note: there's no treatment yet for datetime objects. Try time.time() instead +for now if you can. Nagios will probably thank you. + +Turning Collection Off +---------------------- + +It is recommended each namespace have an "Enabled" item which, if False, +stops collection (but not reporting) of statistical data. Applications +SHOULD provide controls to pause and resume collection by setting these +entries to False or True, if present. + + +Usage +===== + +To collect statistics on CherryPy applications:: + + from cherrypy.lib import cpstats + appconfig['/']['tools.cpstats.on'] = True + +To collect statistics on your own code:: + + import logging + # Initialize the repository + if not hasattr(logging, 'statistics'): logging.statistics = {} + # Initialize my namespace + mystats = logging.statistics.setdefault('My Stuff', {}) + # Initialize my namespace's scalars and collections + mystats.update({ + 'Enabled': True, + 'Start Time': time.time(), + 'Important Events': 0, + 'Events/Second': lambda s: ( + (s['Important Events'] / (time.time() - s['Start Time']))), + }) + ... + for event in events: + ... + # Collect stats + if mystats.get('Enabled', False): + mystats['Important Events'] += 1 + +To report statistics:: + + root.cpstats = cpstats.StatsPage() + +To format statistics reports:: + + See 'Reporting', above. + +""" + +import logging +import os +import sys +import threading +import time + +import six + +import cherrypy +from cherrypy._cpcompat import json + +# ------------------------------- Statistics -------------------------------- # + +if not hasattr(logging, 'statistics'): + logging.statistics = {} + + +def extrapolate_statistics(scope): + """Return an extrapolated copy of the given scope.""" + c = {} + for k, v in list(scope.items()): + if isinstance(v, dict): + v = extrapolate_statistics(v) + elif isinstance(v, (list, tuple)): + v = [extrapolate_statistics(record) for record in v] + elif hasattr(v, '__call__'): + v = v(scope) + c[k] = v + return c + + +# -------------------- CherryPy Applications Statistics --------------------- # + +appstats = logging.statistics.setdefault('CherryPy Applications', {}) +appstats.update({ + 'Enabled': True, + 'Bytes Read/Request': lambda s: ( + s['Total Requests'] and + (s['Total Bytes Read'] / float(s['Total Requests'])) or + 0.0 + ), + 'Bytes Read/Second': lambda s: s['Total Bytes Read'] / s['Uptime'](s), + 'Bytes Written/Request': lambda s: ( + s['Total Requests'] and + (s['Total Bytes Written'] / float(s['Total Requests'])) or + 0.0 + ), + 'Bytes Written/Second': lambda s: ( + s['Total Bytes Written'] / s['Uptime'](s) + ), + 'Current Time': lambda s: time.time(), + 'Current Requests': 0, + 'Requests/Second': lambda s: float(s['Total Requests']) / s['Uptime'](s), + 'Server Version': cherrypy.__version__, + 'Start Time': time.time(), + 'Total Bytes Read': 0, + 'Total Bytes Written': 0, + 'Total Requests': 0, + 'Total Time': 0, + 'Uptime': lambda s: time.time() - s['Start Time'], + 'Requests': {}, +}) + + +def proc_time(s): + return time.time() - s['Start Time'] + + +class ByteCountWrapper(object): + + """Wraps a file-like object, counting the number of bytes read.""" + + def __init__(self, rfile): + self.rfile = rfile + self.bytes_read = 0 + + def read(self, size=-1): + data = self.rfile.read(size) + self.bytes_read += len(data) + return data + + def readline(self, size=-1): + data = self.rfile.readline(size) + self.bytes_read += len(data) + return data + + def readlines(self, sizehint=0): + # Shamelessly stolen from StringIO + total = 0 + lines = [] + line = self.readline() + while line: + lines.append(line) + total += len(line) + if 0 < sizehint <= total: + break + line = self.readline() + return lines + + def close(self): + self.rfile.close() + + def __iter__(self): + return self + + def next(self): + data = self.rfile.next() + self.bytes_read += len(data) + return data + + +def average_uriset_time(s): + return s['Count'] and (s['Sum'] / s['Count']) or 0 + + +def _get_threading_ident(): + if sys.version_info >= (3, 3): + return threading.get_ident() + return threading._get_ident() + + +class StatsTool(cherrypy.Tool): + + """Record various information about the current request.""" + + def __init__(self): + cherrypy.Tool.__init__(self, 'on_end_request', self.record_stop) + + def _setup(self): + """Hook this tool into cherrypy.request. + + The standard CherryPy request object will automatically call this + method when the tool is "turned on" in config. + """ + if appstats.get('Enabled', False): + cherrypy.Tool._setup(self) + self.record_start() + + def record_start(self): + """Record the beginning of a request.""" + request = cherrypy.serving.request + if not hasattr(request.rfile, 'bytes_read'): + request.rfile = ByteCountWrapper(request.rfile) + request.body.fp = request.rfile + + r = request.remote + + appstats['Current Requests'] += 1 + appstats['Total Requests'] += 1 + appstats['Requests'][_get_threading_ident()] = { + 'Bytes Read': None, + 'Bytes Written': None, + # Use a lambda so the ip gets updated by tools.proxy later + 'Client': lambda s: '%s:%s' % (r.ip, r.port), + 'End Time': None, + 'Processing Time': proc_time, + 'Request-Line': request.request_line, + 'Response Status': None, + 'Start Time': time.time(), + } + + def record_stop( + self, uriset=None, slow_queries=1.0, slow_queries_count=100, + debug=False, **kwargs): + """Record the end of a request.""" + resp = cherrypy.serving.response + w = appstats['Requests'][_get_threading_ident()] + + r = cherrypy.request.rfile.bytes_read + w['Bytes Read'] = r + appstats['Total Bytes Read'] += r + + if resp.stream: + w['Bytes Written'] = 'chunked' + else: + cl = int(resp.headers.get('Content-Length', 0)) + w['Bytes Written'] = cl + appstats['Total Bytes Written'] += cl + + w['Response Status'] = getattr( + resp, 'output_status', None) or resp.status + + w['End Time'] = time.time() + p = w['End Time'] - w['Start Time'] + w['Processing Time'] = p + appstats['Total Time'] += p + + appstats['Current Requests'] -= 1 + + if debug: + cherrypy.log('Stats recorded: %s' % repr(w), 'TOOLS.CPSTATS') + + if uriset: + rs = appstats.setdefault('URI Set Tracking', {}) + r = rs.setdefault(uriset, { + 'Min': None, 'Max': None, 'Count': 0, 'Sum': 0, + 'Avg': average_uriset_time}) + if r['Min'] is None or p < r['Min']: + r['Min'] = p + if r['Max'] is None or p > r['Max']: + r['Max'] = p + r['Count'] += 1 + r['Sum'] += p + + if slow_queries and p > slow_queries: + sq = appstats.setdefault('Slow Queries', []) + sq.append(w.copy()) + if len(sq) > slow_queries_count: + sq.pop(0) + + +cherrypy.tools.cpstats = StatsTool() + + +# ---------------------- CherryPy Statistics Reporting ---------------------- # + +thisdir = os.path.abspath(os.path.dirname(__file__)) + +missing = object() + + +def locale_date(v): + return time.strftime('%c', time.gmtime(v)) + + +def iso_format(v): + return time.strftime('%Y-%m-%d %H:%M:%S', time.gmtime(v)) + + +def pause_resume(ns): + def _pause_resume(enabled): + pause_disabled = '' + resume_disabled = '' + if enabled: + resume_disabled = 'disabled="disabled" ' + else: + pause_disabled = 'disabled="disabled" ' + return """ + <form action="pause" method="POST" style="display:inline"> + <input type="hidden" name="namespace" value="%s" /> + <input type="submit" value="Pause" %s/> + </form> + <form action="resume" method="POST" style="display:inline"> + <input type="hidden" name="namespace" value="%s" /> + <input type="submit" value="Resume" %s/> + </form> + """ % (ns, pause_disabled, ns, resume_disabled) + return _pause_resume + + +class StatsPage(object): + + formatting = { + 'CherryPy Applications': { + 'Enabled': pause_resume('CherryPy Applications'), + 'Bytes Read/Request': '%.3f', + 'Bytes Read/Second': '%.3f', + 'Bytes Written/Request': '%.3f', + 'Bytes Written/Second': '%.3f', + 'Current Time': iso_format, + 'Requests/Second': '%.3f', + 'Start Time': iso_format, + 'Total Time': '%.3f', + 'Uptime': '%.3f', + 'Slow Queries': { + 'End Time': None, + 'Processing Time': '%.3f', + 'Start Time': iso_format, + }, + 'URI Set Tracking': { + 'Avg': '%.3f', + 'Max': '%.3f', + 'Min': '%.3f', + 'Sum': '%.3f', + }, + 'Requests': { + 'Bytes Read': '%s', + 'Bytes Written': '%s', + 'End Time': None, + 'Processing Time': '%.3f', + 'Start Time': None, + }, + }, + 'CherryPy WSGIServer': { + 'Enabled': pause_resume('CherryPy WSGIServer'), + 'Connections/second': '%.3f', + 'Start time': iso_format, + }, + } + + @cherrypy.expose + def index(self): + # Transform the raw data into pretty output for HTML + yield """ +<html> +<head> + <title>Statistics</title> +<style> + +th, td { + padding: 0.25em 0.5em; + border: 1px solid #666699; +} + +table { + border-collapse: collapse; +} + +table.stats1 { + width: 100%; +} + +table.stats1 th { + font-weight: bold; + text-align: right; + background-color: #CCD5DD; +} + +table.stats2, h2 { + margin-left: 50px; +} + +table.stats2 th { + font-weight: bold; + text-align: center; + background-color: #CCD5DD; +} + +</style> +</head> +<body> +""" + for title, scalars, collections in self.get_namespaces(): + yield """ +<h1>%s</h1> + +<table class='stats1'> + <tbody> +""" % title + for i, (key, value) in enumerate(scalars): + colnum = i % 3 + if colnum == 0: + yield """ + <tr>""" + yield ( + """ + <th>%(key)s</th><td id='%(title)s-%(key)s'>%(value)s</td>""" % + vars() + ) + if colnum == 2: + yield """ + </tr>""" + + if colnum == 0: + yield """ + <th></th><td></td> + <th></th><td></td> + </tr>""" + elif colnum == 1: + yield """ + <th></th><td></td> + </tr>""" + yield """ + </tbody> +</table>""" + + for subtitle, headers, subrows in collections: + yield """ +<h2>%s</h2> +<table class='stats2'> + <thead> + <tr>""" % subtitle + for key in headers: + yield """ + <th>%s</th>""" % key + yield """ + </tr> + </thead> + <tbody>""" + for subrow in subrows: + yield """ + <tr>""" + for value in subrow: + yield """ + <td>%s</td>""" % value + yield """ + </tr>""" + yield """ + </tbody> +</table>""" + yield """ +</body> +</html> +""" + + def get_namespaces(self): + """Yield (title, scalars, collections) for each namespace.""" + s = extrapolate_statistics(logging.statistics) + for title, ns in sorted(s.items()): + scalars = [] + collections = [] + ns_fmt = self.formatting.get(title, {}) + for k, v in sorted(ns.items()): + fmt = ns_fmt.get(k, {}) + if isinstance(v, dict): + headers, subrows = self.get_dict_collection(v, fmt) + collections.append((k, ['ID'] + headers, subrows)) + elif isinstance(v, (list, tuple)): + headers, subrows = self.get_list_collection(v, fmt) + collections.append((k, headers, subrows)) + else: + format = ns_fmt.get(k, missing) + if format is None: + # Don't output this column. + continue + if hasattr(format, '__call__'): + v = format(v) + elif format is not missing: + v = format % v + scalars.append((k, v)) + yield title, scalars, collections + + def get_dict_collection(self, v, formatting): + """Return ([headers], [rows]) for the given collection.""" + # E.g., the 'Requests' dict. + headers = [] + vals = six.itervalues(v) + for record in vals: + for k3 in record: + format = formatting.get(k3, missing) + if format is None: + # Don't output this column. + continue + if k3 not in headers: + headers.append(k3) + headers.sort() + + subrows = [] + for k2, record in sorted(v.items()): + subrow = [k2] + for k3 in headers: + v3 = record.get(k3, '') + format = formatting.get(k3, missing) + if format is None: + # Don't output this column. + continue + if hasattr(format, '__call__'): + v3 = format(v3) + elif format is not missing: + v3 = format % v3 + subrow.append(v3) + subrows.append(subrow) + + return headers, subrows + + def get_list_collection(self, v, formatting): + """Return ([headers], [subrows]) for the given collection.""" + # E.g., the 'Slow Queries' list. + headers = [] + for record in v: + for k3 in record: + format = formatting.get(k3, missing) + if format is None: + # Don't output this column. + continue + if k3 not in headers: + headers.append(k3) + headers.sort() + + subrows = [] + for record in v: + subrow = [] + for k3 in headers: + v3 = record.get(k3, '') + format = formatting.get(k3, missing) + if format is None: + # Don't output this column. + continue + if hasattr(format, '__call__'): + v3 = format(v3) + elif format is not missing: + v3 = format % v3 + subrow.append(v3) + subrows.append(subrow) + + return headers, subrows + + if json is not None: + @cherrypy.expose + def data(self): + s = extrapolate_statistics(logging.statistics) + cherrypy.response.headers['Content-Type'] = 'application/json' + return json.dumps(s, sort_keys=True, indent=4) + + @cherrypy.expose + def pause(self, namespace): + logging.statistics.get(namespace, {})['Enabled'] = False + raise cherrypy.HTTPRedirect('./') + pause.cp_config = {'tools.allow.on': True, + 'tools.allow.methods': ['POST']} + + @cherrypy.expose + def resume(self, namespace): + logging.statistics.get(namespace, {})['Enabled'] = True + raise cherrypy.HTTPRedirect('./') + resume.cp_config = {'tools.allow.on': True, + 'tools.allow.methods': ['POST']} diff --git a/libraries/cherrypy/lib/cptools.py b/libraries/cherrypy/lib/cptools.py new file mode 100644 index 00000000..1c079634 --- /dev/null +++ b/libraries/cherrypy/lib/cptools.py @@ -0,0 +1,640 @@ +"""Functions for builtin CherryPy tools.""" + +import logging +import re +from hashlib import md5 + +import six +from six.moves import urllib + +import cherrypy +from cherrypy._cpcompat import text_or_bytes +from cherrypy.lib import httputil as _httputil +from cherrypy.lib import is_iterator + + +# Conditional HTTP request support # + +def validate_etags(autotags=False, debug=False): + """Validate the current ETag against If-Match, If-None-Match headers. + + If autotags is True, an ETag response-header value will be provided + from an MD5 hash of the response body (unless some other code has + already provided an ETag header). If False (the default), the ETag + will not be automatic. + + WARNING: the autotags feature is not designed for URL's which allow + methods other than GET. For example, if a POST to the same URL returns + no content, the automatic ETag will be incorrect, breaking a fundamental + use for entity tags in a possibly destructive fashion. Likewise, if you + raise 304 Not Modified, the response body will be empty, the ETag hash + will be incorrect, and your application will break. + See :rfc:`2616` Section 14.24. + """ + response = cherrypy.serving.response + + # Guard against being run twice. + if hasattr(response, 'ETag'): + return + + status, reason, msg = _httputil.valid_status(response.status) + + etag = response.headers.get('ETag') + + # Automatic ETag generation. See warning in docstring. + if etag: + if debug: + cherrypy.log('ETag already set: %s' % etag, 'TOOLS.ETAGS') + elif not autotags: + if debug: + cherrypy.log('Autotags off', 'TOOLS.ETAGS') + elif status != 200: + if debug: + cherrypy.log('Status not 200', 'TOOLS.ETAGS') + else: + etag = response.collapse_body() + etag = '"%s"' % md5(etag).hexdigest() + if debug: + cherrypy.log('Setting ETag: %s' % etag, 'TOOLS.ETAGS') + response.headers['ETag'] = etag + + response.ETag = etag + + # "If the request would, without the If-Match header field, result in + # anything other than a 2xx or 412 status, then the If-Match header + # MUST be ignored." + if debug: + cherrypy.log('Status: %s' % status, 'TOOLS.ETAGS') + if status >= 200 and status <= 299: + request = cherrypy.serving.request + + conditions = request.headers.elements('If-Match') or [] + conditions = [str(x) for x in conditions] + if debug: + cherrypy.log('If-Match conditions: %s' % repr(conditions), + 'TOOLS.ETAGS') + if conditions and not (conditions == ['*'] or etag in conditions): + raise cherrypy.HTTPError(412, 'If-Match failed: ETag %r did ' + 'not match %r' % (etag, conditions)) + + conditions = request.headers.elements('If-None-Match') or [] + conditions = [str(x) for x in conditions] + if debug: + cherrypy.log('If-None-Match conditions: %s' % repr(conditions), + 'TOOLS.ETAGS') + if conditions == ['*'] or etag in conditions: + if debug: + cherrypy.log('request.method: %s' % + request.method, 'TOOLS.ETAGS') + if request.method in ('GET', 'HEAD'): + raise cherrypy.HTTPRedirect([], 304) + else: + raise cherrypy.HTTPError(412, 'If-None-Match failed: ETag %r ' + 'matched %r' % (etag, conditions)) + + +def validate_since(): + """Validate the current Last-Modified against If-Modified-Since headers. + + If no code has set the Last-Modified response header, then no validation + will be performed. + """ + response = cherrypy.serving.response + lastmod = response.headers.get('Last-Modified') + if lastmod: + status, reason, msg = _httputil.valid_status(response.status) + + request = cherrypy.serving.request + + since = request.headers.get('If-Unmodified-Since') + if since and since != lastmod: + if (status >= 200 and status <= 299) or status == 412: + raise cherrypy.HTTPError(412) + + since = request.headers.get('If-Modified-Since') + if since and since == lastmod: + if (status >= 200 and status <= 299) or status == 304: + if request.method in ('GET', 'HEAD'): + raise cherrypy.HTTPRedirect([], 304) + else: + raise cherrypy.HTTPError(412) + + +# Tool code # + +def allow(methods=None, debug=False): + """Raise 405 if request.method not in methods (default ['GET', 'HEAD']). + + The given methods are case-insensitive, and may be in any order. + If only one method is allowed, you may supply a single string; + if more than one, supply a list of strings. + + Regardless of whether the current method is allowed or not, this + also emits an 'Allow' response header, containing the given methods. + """ + if not isinstance(methods, (tuple, list)): + methods = [methods] + methods = [m.upper() for m in methods if m] + if not methods: + methods = ['GET', 'HEAD'] + elif 'GET' in methods and 'HEAD' not in methods: + methods.append('HEAD') + + cherrypy.response.headers['Allow'] = ', '.join(methods) + if cherrypy.request.method not in methods: + if debug: + cherrypy.log('request.method %r not in methods %r' % + (cherrypy.request.method, methods), 'TOOLS.ALLOW') + raise cherrypy.HTTPError(405) + else: + if debug: + cherrypy.log('request.method %r in methods %r' % + (cherrypy.request.method, methods), 'TOOLS.ALLOW') + + +def proxy(base=None, local='X-Forwarded-Host', remote='X-Forwarded-For', + scheme='X-Forwarded-Proto', debug=False): + """Change the base URL (scheme://host[:port][/path]). + + For running a CP server behind Apache, lighttpd, or other HTTP server. + + For Apache and lighttpd, you should leave the 'local' argument at the + default value of 'X-Forwarded-Host'. For Squid, you probably want to set + tools.proxy.local = 'Origin'. + + If you want the new request.base to include path info (not just the host), + you must explicitly set base to the full base path, and ALSO set 'local' + to '', so that the X-Forwarded-Host request header (which never includes + path info) does not override it. Regardless, the value for 'base' MUST + NOT end in a slash. + + cherrypy.request.remote.ip (the IP address of the client) will be + rewritten if the header specified by the 'remote' arg is valid. + By default, 'remote' is set to 'X-Forwarded-For'. If you do not + want to rewrite remote.ip, set the 'remote' arg to an empty string. + """ + + request = cherrypy.serving.request + + if scheme: + s = request.headers.get(scheme, None) + if debug: + cherrypy.log('Testing scheme %r:%r' % (scheme, s), 'TOOLS.PROXY') + if s == 'on' and 'ssl' in scheme.lower(): + # This handles e.g. webfaction's 'X-Forwarded-Ssl: on' header + scheme = 'https' + else: + # This is for lighttpd/pound/Mongrel's 'X-Forwarded-Proto: https' + scheme = s + if not scheme: + scheme = request.base[:request.base.find('://')] + + if local: + lbase = request.headers.get(local, None) + if debug: + cherrypy.log('Testing local %r:%r' % (local, lbase), 'TOOLS.PROXY') + if lbase is not None: + base = lbase.split(',')[0] + if not base: + default = urllib.parse.urlparse(request.base).netloc + base = request.headers.get('Host', default) + + if base.find('://') == -1: + # add http:// or https:// if needed + base = scheme + '://' + base + + request.base = base + + if remote: + xff = request.headers.get(remote) + if debug: + cherrypy.log('Testing remote %r:%r' % (remote, xff), 'TOOLS.PROXY') + if xff: + if remote == 'X-Forwarded-For': + # Grab the first IP in a comma-separated list. Ref #1268. + xff = next(ip.strip() for ip in xff.split(',')) + request.remote.ip = xff + + +def ignore_headers(headers=('Range',), debug=False): + """Delete request headers whose field names are included in 'headers'. + + This is a useful tool for working behind certain HTTP servers; + for example, Apache duplicates the work that CP does for 'Range' + headers, and will doubly-truncate the response. + """ + request = cherrypy.serving.request + for name in headers: + if name in request.headers: + if debug: + cherrypy.log('Ignoring request header %r' % name, + 'TOOLS.IGNORE_HEADERS') + del request.headers[name] + + +def response_headers(headers=None, debug=False): + """Set headers on the response.""" + if debug: + cherrypy.log('Setting response headers: %s' % repr(headers), + 'TOOLS.RESPONSE_HEADERS') + for name, value in (headers or []): + cherrypy.serving.response.headers[name] = value + + +response_headers.failsafe = True + + +def referer(pattern, accept=True, accept_missing=False, error=403, + message='Forbidden Referer header.', debug=False): + """Raise HTTPError if Referer header does/does not match the given pattern. + + pattern + A regular expression pattern to test against the Referer. + + accept + If True, the Referer must match the pattern; if False, + the Referer must NOT match the pattern. + + accept_missing + If True, permit requests with no Referer header. + + error + The HTTP error code to return to the client on failure. + + message + A string to include in the response body on failure. + + """ + try: + ref = cherrypy.serving.request.headers['Referer'] + match = bool(re.match(pattern, ref)) + if debug: + cherrypy.log('Referer %r matches %r' % (ref, pattern), + 'TOOLS.REFERER') + if accept == match: + return + except KeyError: + if debug: + cherrypy.log('No Referer header', 'TOOLS.REFERER') + if accept_missing: + return + + raise cherrypy.HTTPError(error, message) + + +class SessionAuth(object): + + """Assert that the user is logged in.""" + + session_key = 'username' + debug = False + + def check_username_and_password(self, username, password): + pass + + def anonymous(self): + """Provide a temporary user name for anonymous users.""" + pass + + def on_login(self, username): + pass + + def on_logout(self, username): + pass + + def on_check(self, username): + pass + + def login_screen(self, from_page='..', username='', error_msg='', + **kwargs): + return (six.text_type("""<html><body> +Message: %(error_msg)s +<form method="post" action="do_login"> + Login: <input type="text" name="username" value="%(username)s" size="10" /> + <br /> + Password: <input type="password" name="password" size="10" /> + <br /> + <input type="hidden" name="from_page" value="%(from_page)s" /> + <br /> + <input type="submit" /> +</form> +</body></html>""") % vars()).encode('utf-8') + + def do_login(self, username, password, from_page='..', **kwargs): + """Login. May raise redirect, or return True if request handled.""" + response = cherrypy.serving.response + error_msg = self.check_username_and_password(username, password) + if error_msg: + body = self.login_screen(from_page, username, error_msg) + response.body = body + if 'Content-Length' in response.headers: + # Delete Content-Length header so finalize() recalcs it. + del response.headers['Content-Length'] + return True + else: + cherrypy.serving.request.login = username + cherrypy.session[self.session_key] = username + self.on_login(username) + raise cherrypy.HTTPRedirect(from_page or '/') + + def do_logout(self, from_page='..', **kwargs): + """Logout. May raise redirect, or return True if request handled.""" + sess = cherrypy.session + username = sess.get(self.session_key) + sess[self.session_key] = None + if username: + cherrypy.serving.request.login = None + self.on_logout(username) + raise cherrypy.HTTPRedirect(from_page) + + def do_check(self): + """Assert username. Raise redirect, or return True if request handled. + """ + sess = cherrypy.session + request = cherrypy.serving.request + response = cherrypy.serving.response + + username = sess.get(self.session_key) + if not username: + sess[self.session_key] = username = self.anonymous() + self._debug_message('No session[username], trying anonymous') + if not username: + url = cherrypy.url(qs=request.query_string) + self._debug_message( + 'No username, routing to login_screen with from_page %(url)r', + locals(), + ) + response.body = self.login_screen(url) + if 'Content-Length' in response.headers: + # Delete Content-Length header so finalize() recalcs it. + del response.headers['Content-Length'] + return True + self._debug_message('Setting request.login to %(username)r', locals()) + request.login = username + self.on_check(username) + + def _debug_message(self, template, context={}): + if not self.debug: + return + cherrypy.log(template % context, 'TOOLS.SESSAUTH') + + def run(self): + request = cherrypy.serving.request + response = cherrypy.serving.response + + path = request.path_info + if path.endswith('login_screen'): + self._debug_message('routing %(path)r to login_screen', locals()) + response.body = self.login_screen() + return True + elif path.endswith('do_login'): + if request.method != 'POST': + response.headers['Allow'] = 'POST' + self._debug_message('do_login requires POST') + raise cherrypy.HTTPError(405) + self._debug_message('routing %(path)r to do_login', locals()) + return self.do_login(**request.params) + elif path.endswith('do_logout'): + if request.method != 'POST': + response.headers['Allow'] = 'POST' + raise cherrypy.HTTPError(405) + self._debug_message('routing %(path)r to do_logout', locals()) + return self.do_logout(**request.params) + else: + self._debug_message('No special path, running do_check') + return self.do_check() + + +def session_auth(**kwargs): + sa = SessionAuth() + for k, v in kwargs.items(): + setattr(sa, k, v) + return sa.run() + + +session_auth.__doc__ = ( + """Session authentication hook. + + Any attribute of the SessionAuth class may be overridden via a keyword arg + to this function: + + """ + '\n'.join(['%s: %s' % (k, type(getattr(SessionAuth, k)).__name__) + for k in dir(SessionAuth) if not k.startswith('__')]) +) + + +def log_traceback(severity=logging.ERROR, debug=False): + """Write the last error's traceback to the cherrypy error log.""" + cherrypy.log('', 'HTTP', severity=severity, traceback=True) + + +def log_request_headers(debug=False): + """Write request headers to the cherrypy error log.""" + h = [' %s: %s' % (k, v) for k, v in cherrypy.serving.request.header_list] + cherrypy.log('\nRequest Headers:\n' + '\n'.join(h), 'HTTP') + + +def log_hooks(debug=False): + """Write request.hooks to the cherrypy error log.""" + request = cherrypy.serving.request + + msg = [] + # Sort by the standard points if possible. + from cherrypy import _cprequest + points = _cprequest.hookpoints + for k in request.hooks.keys(): + if k not in points: + points.append(k) + + for k in points: + msg.append(' %s:' % k) + v = request.hooks.get(k, []) + v.sort() + for h in v: + msg.append(' %r' % h) + cherrypy.log('\nRequest Hooks for ' + cherrypy.url() + + ':\n' + '\n'.join(msg), 'HTTP') + + +def redirect(url='', internal=True, debug=False): + """Raise InternalRedirect or HTTPRedirect to the given url.""" + if debug: + cherrypy.log('Redirecting %sto: %s' % + ({True: 'internal ', False: ''}[internal], url), + 'TOOLS.REDIRECT') + if internal: + raise cherrypy.InternalRedirect(url) + else: + raise cherrypy.HTTPRedirect(url) + + +def trailing_slash(missing=True, extra=False, status=None, debug=False): + """Redirect if path_info has (missing|extra) trailing slash.""" + request = cherrypy.serving.request + pi = request.path_info + + if debug: + cherrypy.log('is_index: %r, missing: %r, extra: %r, path_info: %r' % + (request.is_index, missing, extra, pi), + 'TOOLS.TRAILING_SLASH') + if request.is_index is True: + if missing: + if not pi.endswith('/'): + new_url = cherrypy.url(pi + '/', request.query_string) + raise cherrypy.HTTPRedirect(new_url, status=status or 301) + elif request.is_index is False: + if extra: + # If pi == '/', don't redirect to ''! + if pi.endswith('/') and pi != '/': + new_url = cherrypy.url(pi[:-1], request.query_string) + raise cherrypy.HTTPRedirect(new_url, status=status or 301) + + +def flatten(debug=False): + """Wrap response.body in a generator that recursively iterates over body. + + This allows cherrypy.response.body to consist of 'nested generators'; + that is, a set of generators that yield generators. + """ + def flattener(input): + numchunks = 0 + for x in input: + if not is_iterator(x): + numchunks += 1 + yield x + else: + for y in flattener(x): + numchunks += 1 + yield y + if debug: + cherrypy.log('Flattened %d chunks' % numchunks, 'TOOLS.FLATTEN') + response = cherrypy.serving.response + response.body = flattener(response.body) + + +def accept(media=None, debug=False): + """Return the client's preferred media-type (from the given Content-Types). + + If 'media' is None (the default), no test will be performed. + + If 'media' is provided, it should be the Content-Type value (as a string) + or values (as a list or tuple of strings) which the current resource + can emit. The client's acceptable media ranges (as declared in the + Accept request header) will be matched in order to these Content-Type + values; the first such string is returned. That is, the return value + will always be one of the strings provided in the 'media' arg (or None + if 'media' is None). + + If no match is found, then HTTPError 406 (Not Acceptable) is raised. + Note that most web browsers send */* as a (low-quality) acceptable + media range, which should match any Content-Type. In addition, "...if + no Accept header field is present, then it is assumed that the client + accepts all media types." + + Matching types are checked in order of client preference first, + and then in the order of the given 'media' values. + + Note that this function does not honor accept-params (other than "q"). + """ + if not media: + return + if isinstance(media, text_or_bytes): + media = [media] + request = cherrypy.serving.request + + # Parse the Accept request header, and try to match one + # of the requested media-ranges (in order of preference). + ranges = request.headers.elements('Accept') + if not ranges: + # Any media type is acceptable. + if debug: + cherrypy.log('No Accept header elements', 'TOOLS.ACCEPT') + return media[0] + else: + # Note that 'ranges' is sorted in order of preference + for element in ranges: + if element.qvalue > 0: + if element.value == '*/*': + # Matches any type or subtype + if debug: + cherrypy.log('Match due to */*', 'TOOLS.ACCEPT') + return media[0] + elif element.value.endswith('/*'): + # Matches any subtype + mtype = element.value[:-1] # Keep the slash + for m in media: + if m.startswith(mtype): + if debug: + cherrypy.log('Match due to %s' % element.value, + 'TOOLS.ACCEPT') + return m + else: + # Matches exact value + if element.value in media: + if debug: + cherrypy.log('Match due to %s' % element.value, + 'TOOLS.ACCEPT') + return element.value + + # No suitable media-range found. + ah = request.headers.get('Accept') + if ah is None: + msg = 'Your client did not send an Accept header.' + else: + msg = 'Your client sent this Accept header: %s.' % ah + msg += (' But this resource only emits these media types: %s.' % + ', '.join(media)) + raise cherrypy.HTTPError(406, msg) + + +class MonitoredHeaderMap(_httputil.HeaderMap): + + def transform_key(self, key): + self.accessed_headers.add(key) + return super(MonitoredHeaderMap, self).transform_key(key) + + def __init__(self): + self.accessed_headers = set() + super(MonitoredHeaderMap, self).__init__() + + +def autovary(ignore=None, debug=False): + """Auto-populate the Vary response header based on request.header access. + """ + request = cherrypy.serving.request + + req_h = request.headers + request.headers = MonitoredHeaderMap() + request.headers.update(req_h) + if ignore is None: + ignore = set(['Content-Disposition', 'Content-Length', 'Content-Type']) + + def set_response_header(): + resp_h = cherrypy.serving.response.headers + v = set([e.value for e in resp_h.elements('Vary')]) + if debug: + cherrypy.log( + 'Accessed headers: %s' % request.headers.accessed_headers, + 'TOOLS.AUTOVARY') + v = v.union(request.headers.accessed_headers) + v = v.difference(ignore) + v = list(v) + v.sort() + resp_h['Vary'] = ', '.join(v) + request.hooks.attach('before_finalize', set_response_header, 95) + + +def convert_params(exception=ValueError, error=400): + """Convert request params based on function annotations, with error handling. + + exception + Exception class to catch. + + status + The HTTP error code to return to the client on failure. + """ + request = cherrypy.serving.request + types = request.handler.callable.__annotations__ + with cherrypy.HTTPError.handle(exception, error): + for key in set(types).intersection(request.params): + request.params[key] = types[key](request.params[key]) diff --git a/libraries/cherrypy/lib/encoding.py b/libraries/cherrypy/lib/encoding.py new file mode 100644 index 00000000..3d001ca6 --- /dev/null +++ b/libraries/cherrypy/lib/encoding.py @@ -0,0 +1,436 @@ +import struct +import time +import io + +import six + +import cherrypy +from cherrypy._cpcompat import text_or_bytes +from cherrypy.lib import file_generator +from cherrypy.lib import is_closable_iterator +from cherrypy.lib import set_vary_header + + +def decode(encoding=None, default_encoding='utf-8'): + """Replace or extend the list of charsets used to decode a request entity. + + Either argument may be a single string or a list of strings. + + encoding + If not None, restricts the set of charsets attempted while decoding + a request entity to the given set (even if a different charset is + given in the Content-Type request header). + + default_encoding + Only in effect if the 'encoding' argument is not given. + If given, the set of charsets attempted while decoding a request + entity is *extended* with the given value(s). + + """ + body = cherrypy.request.body + if encoding is not None: + if not isinstance(encoding, list): + encoding = [encoding] + body.attempt_charsets = encoding + elif default_encoding: + if not isinstance(default_encoding, list): + default_encoding = [default_encoding] + body.attempt_charsets = body.attempt_charsets + default_encoding + + +class UTF8StreamEncoder: + def __init__(self, iterator): + self._iterator = iterator + + def __iter__(self): + return self + + def next(self): + return self.__next__() + + def __next__(self): + res = next(self._iterator) + if isinstance(res, six.text_type): + res = res.encode('utf-8') + return res + + def close(self): + if is_closable_iterator(self._iterator): + self._iterator.close() + + def __getattr__(self, attr): + if attr.startswith('__'): + raise AttributeError(self, attr) + return getattr(self._iterator, attr) + + +class ResponseEncoder: + + default_encoding = 'utf-8' + failmsg = 'Response body could not be encoded with %r.' + encoding = None + errors = 'strict' + text_only = True + add_charset = True + debug = False + + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + + self.attempted_charsets = set() + request = cherrypy.serving.request + if request.handler is not None: + # Replace request.handler with self + if self.debug: + cherrypy.log('Replacing request.handler', 'TOOLS.ENCODE') + self.oldhandler = request.handler + request.handler = self + + def encode_stream(self, encoding): + """Encode a streaming response body. + + Use a generator wrapper, and just pray it works as the stream is + being written out. + """ + if encoding in self.attempted_charsets: + return False + self.attempted_charsets.add(encoding) + + def encoder(body): + for chunk in body: + if isinstance(chunk, six.text_type): + chunk = chunk.encode(encoding, self.errors) + yield chunk + self.body = encoder(self.body) + return True + + def encode_string(self, encoding): + """Encode a buffered response body.""" + if encoding in self.attempted_charsets: + return False + self.attempted_charsets.add(encoding) + body = [] + for chunk in self.body: + if isinstance(chunk, six.text_type): + try: + chunk = chunk.encode(encoding, self.errors) + except (LookupError, UnicodeError): + return False + body.append(chunk) + self.body = body + return True + + def find_acceptable_charset(self): + request = cherrypy.serving.request + response = cherrypy.serving.response + + if self.debug: + cherrypy.log('response.stream %r' % + response.stream, 'TOOLS.ENCODE') + if response.stream: + encoder = self.encode_stream + else: + encoder = self.encode_string + if 'Content-Length' in response.headers: + # Delete Content-Length header so finalize() recalcs it. + # Encoded strings may be of different lengths from their + # unicode equivalents, and even from each other. For example: + # >>> t = u"\u7007\u3040" + # >>> len(t) + # 2 + # >>> len(t.encode("UTF-8")) + # 6 + # >>> len(t.encode("utf7")) + # 8 + del response.headers['Content-Length'] + + # Parse the Accept-Charset request header, and try to provide one + # of the requested charsets (in order of user preference). + encs = request.headers.elements('Accept-Charset') + charsets = [enc.value.lower() for enc in encs] + if self.debug: + cherrypy.log('charsets %s' % repr(charsets), 'TOOLS.ENCODE') + + if self.encoding is not None: + # If specified, force this encoding to be used, or fail. + encoding = self.encoding.lower() + if self.debug: + cherrypy.log('Specified encoding %r' % + encoding, 'TOOLS.ENCODE') + if (not charsets) or '*' in charsets or encoding in charsets: + if self.debug: + cherrypy.log('Attempting encoding %r' % + encoding, 'TOOLS.ENCODE') + if encoder(encoding): + return encoding + else: + if not encs: + if self.debug: + cherrypy.log('Attempting default encoding %r' % + self.default_encoding, 'TOOLS.ENCODE') + # Any character-set is acceptable. + if encoder(self.default_encoding): + return self.default_encoding + else: + raise cherrypy.HTTPError(500, self.failmsg % + self.default_encoding) + else: + for element in encs: + if element.qvalue > 0: + if element.value == '*': + # Matches any charset. Try our default. + if self.debug: + cherrypy.log('Attempting default encoding due ' + 'to %r' % element, 'TOOLS.ENCODE') + if encoder(self.default_encoding): + return self.default_encoding + else: + encoding = element.value + if self.debug: + cherrypy.log('Attempting encoding %s (qvalue >' + '0)' % element, 'TOOLS.ENCODE') + if encoder(encoding): + return encoding + + if '*' not in charsets: + # If no "*" is present in an Accept-Charset field, then all + # character sets not explicitly mentioned get a quality + # value of 0, except for ISO-8859-1, which gets a quality + # value of 1 if not explicitly mentioned. + iso = 'iso-8859-1' + if iso not in charsets: + if self.debug: + cherrypy.log('Attempting ISO-8859-1 encoding', + 'TOOLS.ENCODE') + if encoder(iso): + return iso + + # No suitable encoding found. + ac = request.headers.get('Accept-Charset') + if ac is None: + msg = 'Your client did not send an Accept-Charset header.' + else: + msg = 'Your client sent this Accept-Charset header: %s.' % ac + _charsets = ', '.join(sorted(self.attempted_charsets)) + msg += ' We tried these charsets: %s.' % (_charsets,) + raise cherrypy.HTTPError(406, msg) + + def __call__(self, *args, **kwargs): + response = cherrypy.serving.response + self.body = self.oldhandler(*args, **kwargs) + + self.body = prepare_iter(self.body) + + ct = response.headers.elements('Content-Type') + if self.debug: + cherrypy.log('Content-Type: %r' % [str(h) + for h in ct], 'TOOLS.ENCODE') + if ct and self.add_charset: + ct = ct[0] + if self.text_only: + if ct.value.lower().startswith('text/'): + if self.debug: + cherrypy.log( + 'Content-Type %s starts with "text/"' % ct, + 'TOOLS.ENCODE') + do_find = True + else: + if self.debug: + cherrypy.log('Not finding because Content-Type %s ' + 'does not start with "text/"' % ct, + 'TOOLS.ENCODE') + do_find = False + else: + if self.debug: + cherrypy.log('Finding because not text_only', + 'TOOLS.ENCODE') + do_find = True + + if do_find: + # Set "charset=..." param on response Content-Type header + ct.params['charset'] = self.find_acceptable_charset() + if self.debug: + cherrypy.log('Setting Content-Type %s' % ct, + 'TOOLS.ENCODE') + response.headers['Content-Type'] = str(ct) + + return self.body + + +def prepare_iter(value): + """ + Ensure response body is iterable and resolves to False when empty. + """ + if isinstance(value, text_or_bytes): + # strings get wrapped in a list because iterating over a single + # item list is much faster than iterating over every character + # in a long string. + if value: + value = [value] + else: + # [''] doesn't evaluate to False, so replace it with []. + value = [] + # Don't use isinstance here; io.IOBase which has an ABC takes + # 1000 times as long as, say, isinstance(value, str) + elif hasattr(value, 'read'): + value = file_generator(value) + elif value is None: + value = [] + return value + + +# GZIP + + +def compress(body, compress_level): + """Compress 'body' at the given compress_level.""" + import zlib + + # See http://www.gzip.org/zlib/rfc-gzip.html + yield b'\x1f\x8b' # ID1 and ID2: gzip marker + yield b'\x08' # CM: compression method + yield b'\x00' # FLG: none set + # MTIME: 4 bytes + yield struct.pack('<L', int(time.time()) & int('FFFFFFFF', 16)) + yield b'\x02' # XFL: max compression, slowest algo + yield b'\xff' # OS: unknown + + crc = zlib.crc32(b'') + size = 0 + zobj = zlib.compressobj(compress_level, + zlib.DEFLATED, -zlib.MAX_WBITS, + zlib.DEF_MEM_LEVEL, 0) + for line in body: + size += len(line) + crc = zlib.crc32(line, crc) + yield zobj.compress(line) + yield zobj.flush() + + # CRC32: 4 bytes + yield struct.pack('<L', crc & int('FFFFFFFF', 16)) + # ISIZE: 4 bytes + yield struct.pack('<L', size & int('FFFFFFFF', 16)) + + +def decompress(body): + import gzip + + zbuf = io.BytesIO() + zbuf.write(body) + zbuf.seek(0) + zfile = gzip.GzipFile(mode='rb', fileobj=zbuf) + data = zfile.read() + zfile.close() + return data + + +def gzip(compress_level=5, mime_types=['text/html', 'text/plain'], + debug=False): + """Try to gzip the response body if Content-Type in mime_types. + + cherrypy.response.headers['Content-Type'] must be set to one of the + values in the mime_types arg before calling this function. + + The provided list of mime-types must be of one of the following form: + * `type/subtype` + * `type/*` + * `type/*+subtype` + + No compression is performed if any of the following hold: + * The client sends no Accept-Encoding request header + * No 'gzip' or 'x-gzip' is present in the Accept-Encoding header + * No 'gzip' or 'x-gzip' with a qvalue > 0 is present + * The 'identity' value is given with a qvalue > 0. + + """ + request = cherrypy.serving.request + response = cherrypy.serving.response + + set_vary_header(response, 'Accept-Encoding') + + if not response.body: + # Response body is empty (might be a 304 for instance) + if debug: + cherrypy.log('No response body', context='TOOLS.GZIP') + return + + # If returning cached content (which should already have been gzipped), + # don't re-zip. + if getattr(request, 'cached', False): + if debug: + cherrypy.log('Not gzipping cached response', context='TOOLS.GZIP') + return + + acceptable = request.headers.elements('Accept-Encoding') + if not acceptable: + # If no Accept-Encoding field is present in a request, + # the server MAY assume that the client will accept any + # content coding. In this case, if "identity" is one of + # the available content-codings, then the server SHOULD use + # the "identity" content-coding, unless it has additional + # information that a different content-coding is meaningful + # to the client. + if debug: + cherrypy.log('No Accept-Encoding', context='TOOLS.GZIP') + return + + ct = response.headers.get('Content-Type', '').split(';')[0] + for coding in acceptable: + if coding.value == 'identity' and coding.qvalue != 0: + if debug: + cherrypy.log('Non-zero identity qvalue: %s' % coding, + context='TOOLS.GZIP') + return + if coding.value in ('gzip', 'x-gzip'): + if coding.qvalue == 0: + if debug: + cherrypy.log('Zero gzip qvalue: %s' % coding, + context='TOOLS.GZIP') + return + + if ct not in mime_types: + # If the list of provided mime-types contains tokens + # such as 'text/*' or 'application/*+xml', + # we go through them and find the most appropriate one + # based on the given content-type. + # The pattern matching is only caring about the most + # common cases, as stated above, and doesn't support + # for extra parameters. + found = False + if '/' in ct: + ct_media_type, ct_sub_type = ct.split('/') + for mime_type in mime_types: + if '/' in mime_type: + media_type, sub_type = mime_type.split('/') + if ct_media_type == media_type: + if sub_type == '*': + found = True + break + elif '+' in sub_type and '+' in ct_sub_type: + ct_left, ct_right = ct_sub_type.split('+') + left, right = sub_type.split('+') + if left == '*' and ct_right == right: + found = True + break + + if not found: + if debug: + cherrypy.log('Content-Type %s not in mime_types %r' % + (ct, mime_types), context='TOOLS.GZIP') + return + + if debug: + cherrypy.log('Gzipping', context='TOOLS.GZIP') + # Return a generator that compresses the page + response.headers['Content-Encoding'] = 'gzip' + response.body = compress(response.body, compress_level) + if 'Content-Length' in response.headers: + # Delete Content-Length header so finalize() recalcs it. + del response.headers['Content-Length'] + + return + + if debug: + cherrypy.log('No acceptable encoding found.', context='GZIP') + cherrypy.HTTPError(406, 'identity, gzip').set_response() diff --git a/libraries/cherrypy/lib/gctools.py b/libraries/cherrypy/lib/gctools.py new file mode 100644 index 00000000..26746d78 --- /dev/null +++ b/libraries/cherrypy/lib/gctools.py @@ -0,0 +1,218 @@ +import gc +import inspect +import sys +import time + +try: + import objgraph +except ImportError: + objgraph = None + +import cherrypy +from cherrypy import _cprequest, _cpwsgi +from cherrypy.process.plugins import SimplePlugin + + +class ReferrerTree(object): + + """An object which gathers all referrers of an object to a given depth.""" + + peek_length = 40 + + def __init__(self, ignore=None, maxdepth=2, maxparents=10): + self.ignore = ignore or [] + self.ignore.append(inspect.currentframe().f_back) + self.maxdepth = maxdepth + self.maxparents = maxparents + + def ascend(self, obj, depth=1): + """Return a nested list containing referrers of the given object.""" + depth += 1 + parents = [] + + # Gather all referrers in one step to minimize + # cascading references due to repr() logic. + refs = gc.get_referrers(obj) + self.ignore.append(refs) + if len(refs) > self.maxparents: + return [('[%s referrers]' % len(refs), [])] + + try: + ascendcode = self.ascend.__code__ + except AttributeError: + ascendcode = self.ascend.im_func.func_code + for parent in refs: + if inspect.isframe(parent) and parent.f_code is ascendcode: + continue + if parent in self.ignore: + continue + if depth <= self.maxdepth: + parents.append((parent, self.ascend(parent, depth))) + else: + parents.append((parent, [])) + + return parents + + def peek(self, s): + """Return s, restricted to a sane length.""" + if len(s) > (self.peek_length + 3): + half = self.peek_length // 2 + return s[:half] + '...' + s[-half:] + else: + return s + + def _format(self, obj, descend=True): + """Return a string representation of a single object.""" + if inspect.isframe(obj): + filename, lineno, func, context, index = inspect.getframeinfo(obj) + return "<frame of function '%s'>" % func + + if not descend: + return self.peek(repr(obj)) + + if isinstance(obj, dict): + return '{' + ', '.join(['%s: %s' % (self._format(k, descend=False), + self._format(v, descend=False)) + for k, v in obj.items()]) + '}' + elif isinstance(obj, list): + return '[' + ', '.join([self._format(item, descend=False) + for item in obj]) + ']' + elif isinstance(obj, tuple): + return '(' + ', '.join([self._format(item, descend=False) + for item in obj]) + ')' + + r = self.peek(repr(obj)) + if isinstance(obj, (str, int, float)): + return r + return '%s: %s' % (type(obj), r) + + def format(self, tree): + """Return a list of string reprs from a nested list of referrers.""" + output = [] + + def ascend(branch, depth=1): + for parent, grandparents in branch: + output.append((' ' * depth) + self._format(parent)) + if grandparents: + ascend(grandparents, depth + 1) + ascend(tree) + return output + + +def get_instances(cls): + return [x for x in gc.get_objects() if isinstance(x, cls)] + + +class RequestCounter(SimplePlugin): + + def start(self): + self.count = 0 + + def before_request(self): + self.count += 1 + + def after_request(self): + self.count -= 1 + + +request_counter = RequestCounter(cherrypy.engine) +request_counter.subscribe() + + +def get_context(obj): + if isinstance(obj, _cprequest.Request): + return 'path=%s;stage=%s' % (obj.path_info, obj.stage) + elif isinstance(obj, _cprequest.Response): + return 'status=%s' % obj.status + elif isinstance(obj, _cpwsgi.AppResponse): + return 'PATH_INFO=%s' % obj.environ.get('PATH_INFO', '') + elif hasattr(obj, 'tb_lineno'): + return 'tb_lineno=%s' % obj.tb_lineno + return '' + + +class GCRoot(object): + + """A CherryPy page handler for testing reference leaks.""" + + classes = [ + (_cprequest.Request, 2, 2, + 'Should be 1 in this request thread and 1 in the main thread.'), + (_cprequest.Response, 2, 2, + 'Should be 1 in this request thread and 1 in the main thread.'), + (_cpwsgi.AppResponse, 1, 1, + 'Should be 1 in this request thread only.'), + ] + + @cherrypy.expose + def index(self): + return 'Hello, world!' + + @cherrypy.expose + def stats(self): + output = ['Statistics:'] + + for trial in range(10): + if request_counter.count > 0: + break + time.sleep(0.5) + else: + output.append('\nNot all requests closed properly.') + + # gc_collect isn't perfectly synchronous, because it may + # break reference cycles that then take time to fully + # finalize. Call it thrice and hope for the best. + gc.collect() + gc.collect() + unreachable = gc.collect() + if unreachable: + if objgraph is not None: + final = objgraph.by_type('Nondestructible') + if final: + objgraph.show_backrefs(final, filename='finalizers.png') + + trash = {} + for x in gc.garbage: + trash[type(x)] = trash.get(type(x), 0) + 1 + if trash: + output.insert(0, '\n%s unreachable objects:' % unreachable) + trash = [(v, k) for k, v in trash.items()] + trash.sort() + for pair in trash: + output.append(' ' + repr(pair)) + + # Check declared classes to verify uncollected instances. + # These don't have to be part of a cycle; they can be + # any objects that have unanticipated referrers that keep + # them from being collected. + allobjs = {} + for cls, minobj, maxobj, msg in self.classes: + allobjs[cls] = get_instances(cls) + + for cls, minobj, maxobj, msg in self.classes: + objs = allobjs[cls] + lenobj = len(objs) + if lenobj < minobj or lenobj > maxobj: + if minobj == maxobj: + output.append( + '\nExpected %s %r references, got %s.' % + (minobj, cls, lenobj)) + else: + output.append( + '\nExpected %s to %s %r references, got %s.' % + (minobj, maxobj, cls, lenobj)) + + for obj in objs: + if objgraph is not None: + ig = [id(objs), id(inspect.currentframe())] + fname = 'graph_%s_%s.png' % (cls.__name__, id(obj)) + objgraph.show_backrefs( + obj, extra_ignore=ig, max_depth=4, too_many=20, + filename=fname, extra_info=get_context) + output.append('\nReferrers for %s (refcount=%s):' % + (repr(obj), sys.getrefcount(obj))) + t = ReferrerTree(ignore=[objs], maxdepth=3) + tree = t.ascend(obj) + output.extend(t.format(tree)) + + return '\n'.join(output) diff --git a/libraries/cherrypy/lib/httputil.py b/libraries/cherrypy/lib/httputil.py new file mode 100644 index 00000000..b68d8dd5 --- /dev/null +++ b/libraries/cherrypy/lib/httputil.py @@ -0,0 +1,581 @@ +"""HTTP library functions. + +This module contains functions for building an HTTP application +framework: any one, not just one whose name starts with "Ch". ;) If you +reference any modules from some popular framework inside *this* module, +FuManChu will personally hang you up by your thumbs and submit you +to a public caning. +""" + +import functools +import email.utils +import re +from binascii import b2a_base64 +from cgi import parse_header +from email.header import decode_header + +import six +from six.moves import range, builtins, map +from six.moves.BaseHTTPServer import BaseHTTPRequestHandler + +import cherrypy +from cherrypy._cpcompat import ntob, ntou +from cherrypy._cpcompat import unquote_plus + +response_codes = BaseHTTPRequestHandler.responses.copy() + +# From https://github.com/cherrypy/cherrypy/issues/361 +response_codes[500] = ('Internal Server Error', + 'The server encountered an unexpected condition ' + 'which prevented it from fulfilling the request.') +response_codes[503] = ('Service Unavailable', + 'The server is currently unable to handle the ' + 'request due to a temporary overloading or ' + 'maintenance of the server.') + + +HTTPDate = functools.partial(email.utils.formatdate, usegmt=True) + + +def urljoin(*atoms): + r"""Return the given path \*atoms, joined into a single URL. + + This will correctly join a SCRIPT_NAME and PATH_INFO into the + original URL, even if either atom is blank. + """ + url = '/'.join([x for x in atoms if x]) + while '//' in url: + url = url.replace('//', '/') + # Special-case the final url of "", and return "/" instead. + return url or '/' + + +def urljoin_bytes(*atoms): + """Return the given path `*atoms`, joined into a single URL. + + This will correctly join a SCRIPT_NAME and PATH_INFO into the + original URL, even if either atom is blank. + """ + url = b'/'.join([x for x in atoms if x]) + while b'//' in url: + url = url.replace(b'//', b'/') + # Special-case the final url of "", and return "/" instead. + return url or b'/' + + +def protocol_from_http(protocol_str): + """Return a protocol tuple from the given 'HTTP/x.y' string.""" + return int(protocol_str[5]), int(protocol_str[7]) + + +def get_ranges(headervalue, content_length): + """Return a list of (start, stop) indices from a Range header, or None. + + Each (start, stop) tuple will be composed of two ints, which are suitable + for use in a slicing operation. That is, the header "Range: bytes=3-6", + if applied against a Python string, is requesting resource[3:7]. This + function will return the list [(3, 7)]. + + If this function returns an empty list, you should return HTTP 416. + """ + + if not headervalue: + return None + + result = [] + bytesunit, byteranges = headervalue.split('=', 1) + for brange in byteranges.split(','): + start, stop = [x.strip() for x in brange.split('-', 1)] + if start: + if not stop: + stop = content_length - 1 + start, stop = int(start), int(stop) + if start >= content_length: + # From rfc 2616 sec 14.16: + # "If the server receives a request (other than one + # including an If-Range request-header field) with an + # unsatisfiable Range request-header field (that is, + # all of whose byte-range-spec values have a first-byte-pos + # value greater than the current length of the selected + # resource), it SHOULD return a response code of 416 + # (Requested range not satisfiable)." + continue + if stop < start: + # From rfc 2616 sec 14.16: + # "If the server ignores a byte-range-spec because it + # is syntactically invalid, the server SHOULD treat + # the request as if the invalid Range header field + # did not exist. (Normally, this means return a 200 + # response containing the full entity)." + return None + result.append((start, stop + 1)) + else: + if not stop: + # See rfc quote above. + return None + # Negative subscript (last N bytes) + # + # RFC 2616 Section 14.35.1: + # If the entity is shorter than the specified suffix-length, + # the entire entity-body is used. + if int(stop) > content_length: + result.append((0, content_length)) + else: + result.append((content_length - int(stop), content_length)) + + return result + + +class HeaderElement(object): + + """An element (with parameters) from an HTTP header's element list.""" + + def __init__(self, value, params=None): + self.value = value + if params is None: + params = {} + self.params = params + + def __cmp__(self, other): + return builtins.cmp(self.value, other.value) + + def __lt__(self, other): + return self.value < other.value + + def __str__(self): + p = [';%s=%s' % (k, v) for k, v in six.iteritems(self.params)] + return str('%s%s' % (self.value, ''.join(p))) + + def __bytes__(self): + return ntob(self.__str__()) + + def __unicode__(self): + return ntou(self.__str__()) + + @staticmethod + def parse(elementstr): + """Transform 'token;key=val' to ('token', {'key': 'val'}).""" + initial_value, params = parse_header(elementstr) + return initial_value, params + + @classmethod + def from_str(cls, elementstr): + """Construct an instance from a string of the form 'token;key=val'.""" + ival, params = cls.parse(elementstr) + return cls(ival, params) + + +q_separator = re.compile(r'; *q *=') + + +class AcceptElement(HeaderElement): + + """An element (with parameters) from an Accept* header's element list. + + AcceptElement objects are comparable; the more-preferred object will be + "less than" the less-preferred object. They are also therefore sortable; + if you sort a list of AcceptElement objects, they will be listed in + priority order; the most preferred value will be first. Yes, it should + have been the other way around, but it's too late to fix now. + """ + + @classmethod + def from_str(cls, elementstr): + qvalue = None + # The first "q" parameter (if any) separates the initial + # media-range parameter(s) (if any) from the accept-params. + atoms = q_separator.split(elementstr, 1) + media_range = atoms.pop(0).strip() + if atoms: + # The qvalue for an Accept header can have extensions. The other + # headers cannot, but it's easier to parse them as if they did. + qvalue = HeaderElement.from_str(atoms[0].strip()) + + media_type, params = cls.parse(media_range) + if qvalue is not None: + params['q'] = qvalue + return cls(media_type, params) + + @property + def qvalue(self): + 'The qvalue, or priority, of this value.' + val = self.params.get('q', '1') + if isinstance(val, HeaderElement): + val = val.value + try: + return float(val) + except ValueError as val_err: + """Fail client requests with invalid quality value. + + Ref: https://github.com/cherrypy/cherrypy/issues/1370 + """ + six.raise_from( + cherrypy.HTTPError( + 400, + 'Malformed HTTP header: `{}`'. + format(str(self)), + ), + val_err, + ) + + def __cmp__(self, other): + diff = builtins.cmp(self.qvalue, other.qvalue) + if diff == 0: + diff = builtins.cmp(str(self), str(other)) + return diff + + def __lt__(self, other): + if self.qvalue == other.qvalue: + return str(self) < str(other) + else: + return self.qvalue < other.qvalue + + +RE_HEADER_SPLIT = re.compile(',(?=(?:[^"]*"[^"]*")*[^"]*$)') + + +def header_elements(fieldname, fieldvalue): + """Return a sorted HeaderElement list from a comma-separated header string. + """ + if not fieldvalue: + return [] + + result = [] + for element in RE_HEADER_SPLIT.split(fieldvalue): + if fieldname.startswith('Accept') or fieldname == 'TE': + hv = AcceptElement.from_str(element) + else: + hv = HeaderElement.from_str(element) + result.append(hv) + + return list(reversed(sorted(result))) + + +def decode_TEXT(value): + r""" + Decode :rfc:`2047` TEXT + + >>> decode_TEXT("=?utf-8?q?f=C3=BCr?=") == b'f\xfcr'.decode('latin-1') + True + """ + atoms = decode_header(value) + decodedvalue = '' + for atom, charset in atoms: + if charset is not None: + atom = atom.decode(charset) + decodedvalue += atom + return decodedvalue + + +def decode_TEXT_maybe(value): + """ + Decode the text but only if '=?' appears in it. + """ + return decode_TEXT(value) if '=?' in value else value + + +def valid_status(status): + """Return legal HTTP status Code, Reason-phrase and Message. + + The status arg must be an int, a str that begins with an int + or the constant from ``http.client`` stdlib module. + + If status has no reason-phrase is supplied, a default reason- + phrase will be provided. + + >>> from six.moves import http_client + >>> from six.moves.BaseHTTPServer import BaseHTTPRequestHandler + >>> valid_status(http_client.ACCEPTED) == ( + ... int(http_client.ACCEPTED), + ... ) + BaseHTTPRequestHandler.responses[http_client.ACCEPTED] + True + """ + + if not status: + status = 200 + + code, reason = status, None + if isinstance(status, six.string_types): + code, _, reason = status.partition(' ') + reason = reason.strip() or None + + try: + code = int(code) + except (TypeError, ValueError): + raise ValueError('Illegal response status from server ' + '(%s is non-numeric).' % repr(code)) + + if code < 100 or code > 599: + raise ValueError('Illegal response status from server ' + '(%s is out of range).' % repr(code)) + + if code not in response_codes: + # code is unknown but not illegal + default_reason, message = '', '' + else: + default_reason, message = response_codes[code] + + if reason is None: + reason = default_reason + + return code, reason, message + + +# NOTE: the parse_qs functions that follow are modified version of those +# in the python3.0 source - we need to pass through an encoding to the unquote +# method, but the default parse_qs function doesn't allow us to. These do. + +def _parse_qs(qs, keep_blank_values=0, strict_parsing=0, encoding='utf-8'): + """Parse a query given as a string argument. + + Arguments: + + qs: URL-encoded query string to be parsed + + keep_blank_values: flag indicating whether blank values in + URL encoded queries should be treated as blank strings. A + true value indicates that blanks should be retained as blank + strings. The default false value indicates that blank values + are to be ignored and treated as if they were not included. + + strict_parsing: flag indicating what to do with parsing errors. If + false (the default), errors are silently ignored. If true, + errors raise a ValueError exception. + + Returns a dict, as G-d intended. + """ + pairs = [s2 for s1 in qs.split('&') for s2 in s1.split(';')] + d = {} + for name_value in pairs: + if not name_value and not strict_parsing: + continue + nv = name_value.split('=', 1) + if len(nv) != 2: + if strict_parsing: + raise ValueError('bad query field: %r' % (name_value,)) + # Handle case of a control-name with no equal sign + if keep_blank_values: + nv.append('') + else: + continue + if len(nv[1]) or keep_blank_values: + name = unquote_plus(nv[0], encoding, errors='strict') + value = unquote_plus(nv[1], encoding, errors='strict') + if name in d: + if not isinstance(d[name], list): + d[name] = [d[name]] + d[name].append(value) + else: + d[name] = value + return d + + +image_map_pattern = re.compile(r'[0-9]+,[0-9]+') + + +def parse_query_string(query_string, keep_blank_values=True, encoding='utf-8'): + """Build a params dictionary from a query_string. + + Duplicate key/value pairs in the provided query_string will be + returned as {'key': [val1, val2, ...]}. Single key/values will + be returned as strings: {'key': 'value'}. + """ + if image_map_pattern.match(query_string): + # Server-side image map. Map the coords to 'x' and 'y' + # (like CGI::Request does). + pm = query_string.split(',') + pm = {'x': int(pm[0]), 'y': int(pm[1])} + else: + pm = _parse_qs(query_string, keep_blank_values, encoding=encoding) + return pm + + +#### +# Inlined from jaraco.collections 1.5.2 +# Ref #1673 +class KeyTransformingDict(dict): + """ + A dict subclass that transforms the keys before they're used. + Subclasses may override the default transform_key to customize behavior. + """ + @staticmethod + def transform_key(key): + return key + + def __init__(self, *args, **kargs): + super(KeyTransformingDict, self).__init__() + # build a dictionary using the default constructs + d = dict(*args, **kargs) + # build this dictionary using transformed keys. + for item in d.items(): + self.__setitem__(*item) + + def __setitem__(self, key, val): + key = self.transform_key(key) + super(KeyTransformingDict, self).__setitem__(key, val) + + def __getitem__(self, key): + key = self.transform_key(key) + return super(KeyTransformingDict, self).__getitem__(key) + + def __contains__(self, key): + key = self.transform_key(key) + return super(KeyTransformingDict, self).__contains__(key) + + def __delitem__(self, key): + key = self.transform_key(key) + return super(KeyTransformingDict, self).__delitem__(key) + + def get(self, key, *args, **kwargs): + key = self.transform_key(key) + return super(KeyTransformingDict, self).get(key, *args, **kwargs) + + def setdefault(self, key, *args, **kwargs): + key = self.transform_key(key) + return super(KeyTransformingDict, self).setdefault( + key, *args, **kwargs) + + def pop(self, key, *args, **kwargs): + key = self.transform_key(key) + return super(KeyTransformingDict, self).pop(key, *args, **kwargs) + + def matching_key_for(self, key): + """ + Given a key, return the actual key stored in self that matches. + Raise KeyError if the key isn't found. + """ + try: + return next(e_key for e_key in self.keys() if e_key == key) + except StopIteration: + raise KeyError(key) +#### + + +class CaseInsensitiveDict(KeyTransformingDict): + + """A case-insensitive dict subclass. + + Each key is changed on entry to str(key).title(). + """ + + @staticmethod + def transform_key(key): + return str(key).title() + + +# TEXT = <any OCTET except CTLs, but including LWS> +# +# A CRLF is allowed in the definition of TEXT only as part of a header +# field continuation. It is expected that the folding LWS will be +# replaced with a single SP before interpretation of the TEXT value." +if str == bytes: + header_translate_table = ''.join([chr(i) for i in range(256)]) + header_translate_deletechars = ''.join( + [chr(i) for i in range(32)]) + chr(127) +else: + header_translate_table = None + header_translate_deletechars = bytes(range(32)) + bytes([127]) + + +class HeaderMap(CaseInsensitiveDict): + + """A dict subclass for HTTP request and response headers. + + Each key is changed on entry to str(key).title(). This allows headers + to be case-insensitive and avoid duplicates. + + Values are header values (decoded according to :rfc:`2047` if necessary). + """ + + protocol = (1, 1) + encodings = ['ISO-8859-1'] + + # Someday, when http-bis is done, this will probably get dropped + # since few servers, clients, or intermediaries do it. But until then, + # we're going to obey the spec as is. + # "Words of *TEXT MAY contain characters from character sets other than + # ISO-8859-1 only when encoded according to the rules of RFC 2047." + use_rfc_2047 = True + + def elements(self, key): + """Return a sorted list of HeaderElements for the given header.""" + key = str(key).title() + value = self.get(key) + return header_elements(key, value) + + def values(self, key): + """Return a sorted list of HeaderElement.value for the given header.""" + return [e.value for e in self.elements(key)] + + def output(self): + """Transform self into a list of (name, value) tuples.""" + return list(self.encode_header_items(self.items())) + + @classmethod + def encode_header_items(cls, header_items): + """ + Prepare the sequence of name, value tuples into a form suitable for + transmitting on the wire for HTTP. + """ + for k, v in header_items: + if not isinstance(v, six.string_types): + v = six.text_type(v) + + yield tuple(map(cls.encode_header_item, (k, v))) + + @classmethod + def encode_header_item(cls, item): + if isinstance(item, six.text_type): + item = cls.encode(item) + + # See header_translate_* constants above. + # Replace only if you really know what you're doing. + return item.translate( + header_translate_table, header_translate_deletechars) + + @classmethod + def encode(cls, v): + """Return the given header name or value, encoded for HTTP output.""" + for enc in cls.encodings: + try: + return v.encode(enc) + except UnicodeEncodeError: + continue + + if cls.protocol == (1, 1) and cls.use_rfc_2047: + # Encode RFC-2047 TEXT + # (e.g. u"\u8200" -> "=?utf-8?b?6IiA?="). + # We do our own here instead of using the email module + # because we never want to fold lines--folding has + # been deprecated by the HTTP working group. + v = b2a_base64(v.encode('utf-8')) + return (b'=?utf-8?b?' + v.strip(b'\n') + b'?=') + + raise ValueError('Could not encode header part %r using ' + 'any of the encodings %r.' % + (v, cls.encodings)) + + +class Host(object): + + """An internet address. + + name + Should be the client's host name. If not available (because no DNS + lookup is performed), the IP address should be used instead. + + """ + + ip = '0.0.0.0' + port = 80 + name = 'unknown.tld' + + def __init__(self, ip, port, name=None): + self.ip = ip + self.port = port + if name is None: + name = ip + self.name = name + + def __repr__(self): + return 'httputil.Host(%r, %r, %r)' % (self.ip, self.port, self.name) diff --git a/libraries/cherrypy/lib/jsontools.py b/libraries/cherrypy/lib/jsontools.py new file mode 100644 index 00000000..48683097 --- /dev/null +++ b/libraries/cherrypy/lib/jsontools.py @@ -0,0 +1,88 @@ +import cherrypy +from cherrypy._cpcompat import text_or_bytes, ntou, json_encode, json_decode + + +def json_processor(entity): + """Read application/json data into request.json.""" + if not entity.headers.get(ntou('Content-Length'), ntou('')): + raise cherrypy.HTTPError(411) + + body = entity.fp.read() + with cherrypy.HTTPError.handle(ValueError, 400, 'Invalid JSON document'): + cherrypy.serving.request.json = json_decode(body.decode('utf-8')) + + +def json_in(content_type=[ntou('application/json'), ntou('text/javascript')], + force=True, debug=False, processor=json_processor): + """Add a processor to parse JSON request entities: + The default processor places the parsed data into request.json. + + Incoming request entities which match the given content_type(s) will + be deserialized from JSON to the Python equivalent, and the result + stored at cherrypy.request.json. The 'content_type' argument may + be a Content-Type string or a list of allowable Content-Type strings. + + If the 'force' argument is True (the default), then entities of other + content types will not be allowed; "415 Unsupported Media Type" is + raised instead. + + Supply your own processor to use a custom decoder, or to handle the parsed + data differently. The processor can be configured via + tools.json_in.processor or via the decorator method. + + Note that the deserializer requires the client send a Content-Length + request header, or it will raise "411 Length Required". If for any + other reason the request entity cannot be deserialized from JSON, + it will raise "400 Bad Request: Invalid JSON document". + """ + request = cherrypy.serving.request + if isinstance(content_type, text_or_bytes): + content_type = [content_type] + + if force: + if debug: + cherrypy.log('Removing body processors %s' % + repr(request.body.processors.keys()), 'TOOLS.JSON_IN') + request.body.processors.clear() + request.body.default_proc = cherrypy.HTTPError( + 415, 'Expected an entity of content type %s' % + ', '.join(content_type)) + + for ct in content_type: + if debug: + cherrypy.log('Adding body processor for %s' % ct, 'TOOLS.JSON_IN') + request.body.processors[ct] = processor + + +def json_handler(*args, **kwargs): + value = cherrypy.serving.request._json_inner_handler(*args, **kwargs) + return json_encode(value) + + +def json_out(content_type='application/json', debug=False, + handler=json_handler): + """Wrap request.handler to serialize its output to JSON. Sets Content-Type. + + If the given content_type is None, the Content-Type response header + is not set. + + Provide your own handler to use a custom encoder. For example + cherrypy.config['tools.json_out.handler'] = <function>, or + @json_out(handler=function). + """ + request = cherrypy.serving.request + # request.handler may be set to None by e.g. the caching tool + # to signal to all components that a response body has already + # been attached, in which case we don't need to wrap anything. + if request.handler is None: + return + if debug: + cherrypy.log('Replacing %s with JSON handler' % request.handler, + 'TOOLS.JSON_OUT') + request._json_inner_handler = request.handler + request.handler = handler + if content_type is not None: + if debug: + cherrypy.log('Setting Content-Type to %s' % + content_type, 'TOOLS.JSON_OUT') + cherrypy.serving.response.headers['Content-Type'] = content_type diff --git a/libraries/cherrypy/lib/locking.py b/libraries/cherrypy/lib/locking.py new file mode 100644 index 00000000..317fb58c --- /dev/null +++ b/libraries/cherrypy/lib/locking.py @@ -0,0 +1,47 @@ +import datetime + + +class NeverExpires(object): + def expired(self): + return False + + +class Timer(object): + """ + A simple timer that will indicate when an expiration time has passed. + """ + def __init__(self, expiration): + 'Create a timer that expires at `expiration` (UTC datetime)' + self.expiration = expiration + + @classmethod + def after(cls, elapsed): + """ + Return a timer that will expire after `elapsed` passes. + """ + return cls(datetime.datetime.utcnow() + elapsed) + + def expired(self): + return datetime.datetime.utcnow() >= self.expiration + + +class LockTimeout(Exception): + 'An exception when a lock could not be acquired before a timeout period' + + +class LockChecker(object): + """ + Keep track of the time and detect if a timeout has expired + """ + def __init__(self, session_id, timeout): + self.session_id = session_id + if timeout: + self.timer = Timer.after(timeout) + else: + self.timer = NeverExpires() + + def expired(self): + if self.timer.expired(): + raise LockTimeout( + 'Timeout acquiring lock for %(session_id)s' % vars(self)) + return False diff --git a/libraries/cherrypy/lib/profiler.py b/libraries/cherrypy/lib/profiler.py new file mode 100644 index 00000000..fccf2eb8 --- /dev/null +++ b/libraries/cherrypy/lib/profiler.py @@ -0,0 +1,221 @@ +"""Profiler tools for CherryPy. + +CherryPy users +============== + +You can profile any of your pages as follows:: + + from cherrypy.lib import profiler + + class Root: + p = profiler.Profiler("/path/to/profile/dir") + + @cherrypy.expose + def index(self): + self.p.run(self._index) + + def _index(self): + return "Hello, world!" + + cherrypy.tree.mount(Root()) + +You can also turn on profiling for all requests +using the ``make_app`` function as WSGI middleware. + +CherryPy developers +=================== + +This module can be used whenever you make changes to CherryPy, +to get a quick sanity-check on overall CP performance. Use the +``--profile`` flag when running the test suite. Then, use the ``serve()`` +function to browse the results in a web browser. If you run this +module from the command line, it will call ``serve()`` for you. + +""" + +import io +import os +import os.path +import sys +import warnings + +import cherrypy + + +try: + import profile + import pstats + + def new_func_strip_path(func_name): + """Make profiler output more readable by adding `__init__` modules' parents + """ + filename, line, name = func_name + if filename.endswith('__init__.py'): + return ( + os.path.basename(filename[:-12]) + filename[-12:], + line, + name, + ) + return os.path.basename(filename), line, name + + pstats.func_strip_path = new_func_strip_path +except ImportError: + profile = None + pstats = None + + +_count = 0 + + +class Profiler(object): + + def __init__(self, path=None): + if not path: + path = os.path.join(os.path.dirname(__file__), 'profile') + self.path = path + if not os.path.exists(path): + os.makedirs(path) + + def run(self, func, *args, **params): + """Dump profile data into self.path.""" + global _count + c = _count = _count + 1 + path = os.path.join(self.path, 'cp_%04d.prof' % c) + prof = profile.Profile() + result = prof.runcall(func, *args, **params) + prof.dump_stats(path) + return result + + def statfiles(self): + """:rtype: list of available profiles. + """ + return [f for f in os.listdir(self.path) + if f.startswith('cp_') and f.endswith('.prof')] + + def stats(self, filename, sortby='cumulative'): + """:rtype stats(index): output of print_stats() for the given profile. + """ + sio = io.StringIO() + if sys.version_info >= (2, 5): + s = pstats.Stats(os.path.join(self.path, filename), stream=sio) + s.strip_dirs() + s.sort_stats(sortby) + s.print_stats() + else: + # pstats.Stats before Python 2.5 didn't take a 'stream' arg, + # but just printed to stdout. So re-route stdout. + s = pstats.Stats(os.path.join(self.path, filename)) + s.strip_dirs() + s.sort_stats(sortby) + oldout = sys.stdout + try: + sys.stdout = sio + s.print_stats() + finally: + sys.stdout = oldout + response = sio.getvalue() + sio.close() + return response + + @cherrypy.expose + def index(self): + return """<html> + <head><title>CherryPy profile data</title></head> + <frameset cols='200, 1*'> + <frame src='menu' /> + <frame name='main' src='' /> + </frameset> + </html> + """ + + @cherrypy.expose + def menu(self): + yield '<h2>Profiling runs</h2>' + yield '<p>Click on one of the runs below to see profiling data.</p>' + runs = self.statfiles() + runs.sort() + for i in runs: + yield "<a href='report?filename=%s' target='main'>%s</a><br />" % ( + i, i) + + @cherrypy.expose + def report(self, filename): + cherrypy.response.headers['Content-Type'] = 'text/plain' + return self.stats(filename) + + +class ProfileAggregator(Profiler): + + def __init__(self, path=None): + Profiler.__init__(self, path) + global _count + self.count = _count = _count + 1 + self.profiler = profile.Profile() + + def run(self, func, *args, **params): + path = os.path.join(self.path, 'cp_%04d.prof' % self.count) + result = self.profiler.runcall(func, *args, **params) + self.profiler.dump_stats(path) + return result + + +class make_app: + + def __init__(self, nextapp, path=None, aggregate=False): + """Make a WSGI middleware app which wraps 'nextapp' with profiling. + + nextapp + the WSGI application to wrap, usually an instance of + cherrypy.Application. + + path + where to dump the profiling output. + + aggregate + if True, profile data for all HTTP requests will go in + a single file. If False (the default), each HTTP request will + dump its profile data into a separate file. + + """ + if profile is None or pstats is None: + msg = ('Your installation of Python does not have a profile ' + "module. If you're on Debian, try " + '`sudo apt-get install python-profiler`. ' + 'See http://www.cherrypy.org/wiki/ProfilingOnDebian ' + 'for details.') + warnings.warn(msg) + + self.nextapp = nextapp + self.aggregate = aggregate + if aggregate: + self.profiler = ProfileAggregator(path) + else: + self.profiler = Profiler(path) + + def __call__(self, environ, start_response): + def gather(): + result = [] + for line in self.nextapp(environ, start_response): + result.append(line) + return result + return self.profiler.run(gather) + + +def serve(path=None, port=8080): + if profile is None or pstats is None: + msg = ('Your installation of Python does not have a profile module. ' + "If you're on Debian, try " + '`sudo apt-get install python-profiler`. ' + 'See http://www.cherrypy.org/wiki/ProfilingOnDebian ' + 'for details.') + warnings.warn(msg) + + cherrypy.config.update({'server.socket_port': int(port), + 'server.thread_pool': 10, + 'environment': 'production', + }) + cherrypy.quickstart(Profiler(path)) + + +if __name__ == '__main__': + serve(*tuple(sys.argv[1:])) diff --git a/libraries/cherrypy/lib/reprconf.py b/libraries/cherrypy/lib/reprconf.py new file mode 100644 index 00000000..291ab663 --- /dev/null +++ b/libraries/cherrypy/lib/reprconf.py @@ -0,0 +1,514 @@ +"""Generic configuration system using unrepr. + +Configuration data may be supplied as a Python dictionary, as a filename, +or as an open file object. When you supply a filename or file, Python's +builtin ConfigParser is used (with some extensions). + +Namespaces +---------- + +Configuration keys are separated into namespaces by the first "." in the key. + +The only key that cannot exist in a namespace is the "environment" entry. +This special entry 'imports' other config entries from a template stored in +the Config.environments dict. + +You can define your own namespaces to be called when new config is merged +by adding a named handler to Config.namespaces. The name can be any string, +and the handler must be either a callable or a context manager. +""" + +from cherrypy._cpcompat import text_or_bytes +from six.moves import configparser +from six.moves import builtins + +import operator +import sys + + +class NamespaceSet(dict): + + """A dict of config namespace names and handlers. + + Each config entry should begin with a namespace name; the corresponding + namespace handler will be called once for each config entry in that + namespace, and will be passed two arguments: the config key (with the + namespace removed) and the config value. + + Namespace handlers may be any Python callable; they may also be + Python 2.5-style 'context managers', in which case their __enter__ + method should return a callable to be used as the handler. + See cherrypy.tools (the Toolbox class) for an example. + """ + + def __call__(self, config): + """Iterate through config and pass it to each namespace handler. + + config + A flat dict, where keys use dots to separate + namespaces, and values are arbitrary. + + The first name in each config key is used to look up the corresponding + namespace handler. For example, a config entry of {'tools.gzip.on': v} + will call the 'tools' namespace handler with the args: ('gzip.on', v) + """ + # Separate the given config into namespaces + ns_confs = {} + for k in config: + if '.' in k: + ns, name = k.split('.', 1) + bucket = ns_confs.setdefault(ns, {}) + bucket[name] = config[k] + + # I chose __enter__ and __exit__ so someday this could be + # rewritten using Python 2.5's 'with' statement: + # for ns, handler in six.iteritems(self): + # with handler as callable: + # for k, v in six.iteritems(ns_confs.get(ns, {})): + # callable(k, v) + for ns, handler in self.items(): + exit = getattr(handler, '__exit__', None) + if exit: + callable = handler.__enter__() + no_exc = True + try: + try: + for k, v in ns_confs.get(ns, {}).items(): + callable(k, v) + except Exception: + # The exceptional case is handled here + no_exc = False + if exit is None: + raise + if not exit(*sys.exc_info()): + raise + # The exception is swallowed if exit() returns true + finally: + # The normal and non-local-goto cases are handled here + if no_exc and exit: + exit(None, None, None) + else: + for k, v in ns_confs.get(ns, {}).items(): + handler(k, v) + + def __repr__(self): + return '%s.%s(%s)' % (self.__module__, self.__class__.__name__, + dict.__repr__(self)) + + def __copy__(self): + newobj = self.__class__() + newobj.update(self) + return newobj + copy = __copy__ + + +class Config(dict): + + """A dict-like set of configuration data, with defaults and namespaces. + + May take a file, filename, or dict. + """ + + defaults = {} + environments = {} + namespaces = NamespaceSet() + + def __init__(self, file=None, **kwargs): + self.reset() + if file is not None: + self.update(file) + if kwargs: + self.update(kwargs) + + def reset(self): + """Reset self to default values.""" + self.clear() + dict.update(self, self.defaults) + + def update(self, config): + """Update self from a dict, file, or filename.""" + self._apply(Parser.load(config)) + + def _apply(self, config): + """Update self from a dict.""" + which_env = config.get('environment') + if which_env: + env = self.environments[which_env] + for k in env: + if k not in config: + config[k] = env[k] + + dict.update(self, config) + self.namespaces(config) + + def __setitem__(self, k, v): + dict.__setitem__(self, k, v) + self.namespaces({k: v}) + + +class Parser(configparser.ConfigParser): + + """Sub-class of ConfigParser that keeps the case of options and that + raises an exception if the file cannot be read. + """ + + def optionxform(self, optionstr): + return optionstr + + def read(self, filenames): + if isinstance(filenames, text_or_bytes): + filenames = [filenames] + for filename in filenames: + # try: + # fp = open(filename) + # except IOError: + # continue + fp = open(filename) + try: + self._read(fp, filename) + finally: + fp.close() + + def as_dict(self, raw=False, vars=None): + """Convert an INI file to a dictionary""" + # Load INI file into a dict + result = {} + for section in self.sections(): + if section not in result: + result[section] = {} + for option in self.options(section): + value = self.get(section, option, raw=raw, vars=vars) + try: + value = unrepr(value) + except Exception: + x = sys.exc_info()[1] + msg = ('Config error in section: %r, option: %r, ' + 'value: %r. Config values must be valid Python.' % + (section, option, value)) + raise ValueError(msg, x.__class__.__name__, x.args) + result[section][option] = value + return result + + def dict_from_file(self, file): + if hasattr(file, 'read'): + self.readfp(file) + else: + self.read(file) + return self.as_dict() + + @classmethod + def load(self, input): + """Resolve 'input' to dict from a dict, file, or filename.""" + is_file = ( + # Filename + isinstance(input, text_or_bytes) + # Open file object + or hasattr(input, 'read') + ) + return Parser().dict_from_file(input) if is_file else input.copy() + + +# public domain "unrepr" implementation, found on the web and then improved. + + +class _Builder2: + + def build(self, o): + m = getattr(self, 'build_' + o.__class__.__name__, None) + if m is None: + raise TypeError('unrepr does not recognize %s' % + repr(o.__class__.__name__)) + return m(o) + + def astnode(self, s): + """Return a Python2 ast Node compiled from a string.""" + try: + import compiler + except ImportError: + # Fallback to eval when compiler package is not available, + # e.g. IronPython 1.0. + return eval(s) + + p = compiler.parse('__tempvalue__ = ' + s) + return p.getChildren()[1].getChildren()[0].getChildren()[1] + + def build_Subscript(self, o): + expr, flags, subs = o.getChildren() + expr = self.build(expr) + subs = self.build(subs) + return expr[subs] + + def build_CallFunc(self, o): + children = o.getChildren() + # Build callee from first child + callee = self.build(children[0]) + # Build args and kwargs from remaining children + args = [] + kwargs = {} + for child in children[1:]: + class_name = child.__class__.__name__ + # None is ignored + if class_name == 'NoneType': + continue + # Keywords become kwargs + if class_name == 'Keyword': + kwargs.update(self.build(child)) + # Everything else becomes args + else: + args.append(self.build(child)) + + return callee(*args, **kwargs) + + def build_Keyword(self, o): + key, value_obj = o.getChildren() + value = self.build(value_obj) + kw_dict = {key: value} + return kw_dict + + def build_List(self, o): + return map(self.build, o.getChildren()) + + def build_Const(self, o): + return o.value + + def build_Dict(self, o): + d = {} + i = iter(map(self.build, o.getChildren())) + for el in i: + d[el] = i.next() + return d + + def build_Tuple(self, o): + return tuple(self.build_List(o)) + + def build_Name(self, o): + name = o.name + if name == 'None': + return None + if name == 'True': + return True + if name == 'False': + return False + + # See if the Name is a package or module. If it is, import it. + try: + return modules(name) + except ImportError: + pass + + # See if the Name is in builtins. + try: + return getattr(builtins, name) + except AttributeError: + pass + + raise TypeError('unrepr could not resolve the name %s' % repr(name)) + + def build_Add(self, o): + left, right = map(self.build, o.getChildren()) + return left + right + + def build_Mul(self, o): + left, right = map(self.build, o.getChildren()) + return left * right + + def build_Getattr(self, o): + parent = self.build(o.expr) + return getattr(parent, o.attrname) + + def build_NoneType(self, o): + return None + + def build_UnarySub(self, o): + return -self.build(o.getChildren()[0]) + + def build_UnaryAdd(self, o): + return self.build(o.getChildren()[0]) + + +class _Builder3: + + def build(self, o): + m = getattr(self, 'build_' + o.__class__.__name__, None) + if m is None: + raise TypeError('unrepr does not recognize %s' % + repr(o.__class__.__name__)) + return m(o) + + def astnode(self, s): + """Return a Python3 ast Node compiled from a string.""" + try: + import ast + except ImportError: + # Fallback to eval when ast package is not available, + # e.g. IronPython 1.0. + return eval(s) + + p = ast.parse('__tempvalue__ = ' + s) + return p.body[0].value + + def build_Subscript(self, o): + return self.build(o.value)[self.build(o.slice)] + + def build_Index(self, o): + return self.build(o.value) + + def _build_call35(self, o): + """ + Workaround for python 3.5 _ast.Call signature, docs found here + https://greentreesnakes.readthedocs.org/en/latest/nodes.html + """ + import ast + callee = self.build(o.func) + args = [] + if o.args is not None: + for a in o.args: + if isinstance(a, ast.Starred): + args.append(self.build(a.value)) + else: + args.append(self.build(a)) + kwargs = {} + for kw in o.keywords: + if kw.arg is None: # double asterix `**` + rst = self.build(kw.value) + if not isinstance(rst, dict): + raise TypeError('Invalid argument for call.' + 'Must be a mapping object.') + # give preference to the keys set directly from arg=value + for k, v in rst.items(): + if k not in kwargs: + kwargs[k] = v + else: # defined on the call as: arg=value + kwargs[kw.arg] = self.build(kw.value) + return callee(*args, **kwargs) + + def build_Call(self, o): + if sys.version_info >= (3, 5): + return self._build_call35(o) + + callee = self.build(o.func) + + if o.args is None: + args = () + else: + args = tuple([self.build(a) for a in o.args]) + + if o.starargs is None: + starargs = () + else: + starargs = tuple(self.build(o.starargs)) + + if o.kwargs is None: + kwargs = {} + else: + kwargs = self.build(o.kwargs) + if o.keywords is not None: # direct a=b keywords + for kw in o.keywords: + # preference because is a direct keyword against **kwargs + kwargs[kw.arg] = self.build(kw.value) + return callee(*(args + starargs), **kwargs) + + def build_List(self, o): + return list(map(self.build, o.elts)) + + def build_Str(self, o): + return o.s + + def build_Num(self, o): + return o.n + + def build_Dict(self, o): + return dict([(self.build(k), self.build(v)) + for k, v in zip(o.keys, o.values)]) + + def build_Tuple(self, o): + return tuple(self.build_List(o)) + + def build_Name(self, o): + name = o.id + if name == 'None': + return None + if name == 'True': + return True + if name == 'False': + return False + + # See if the Name is a package or module. If it is, import it. + try: + return modules(name) + except ImportError: + pass + + # See if the Name is in builtins. + try: + import builtins + return getattr(builtins, name) + except AttributeError: + pass + + raise TypeError('unrepr could not resolve the name %s' % repr(name)) + + def build_NameConstant(self, o): + return o.value + + def build_UnaryOp(self, o): + op, operand = map(self.build, [o.op, o.operand]) + return op(operand) + + def build_BinOp(self, o): + left, op, right = map(self.build, [o.left, o.op, o.right]) + return op(left, right) + + def build_Add(self, o): + return operator.add + + def build_Mult(self, o): + return operator.mul + + def build_USub(self, o): + return operator.neg + + def build_Attribute(self, o): + parent = self.build(o.value) + return getattr(parent, o.attr) + + def build_NoneType(self, o): + return None + + +def unrepr(s): + """Return a Python object compiled from a string.""" + if not s: + return s + if sys.version_info < (3, 0): + b = _Builder2() + else: + b = _Builder3() + obj = b.astnode(s) + return b.build(obj) + + +def modules(modulePath): + """Load a module and retrieve a reference to that module.""" + __import__(modulePath) + return sys.modules[modulePath] + + +def attributes(full_attribute_name): + """Load a module and retrieve an attribute of that module.""" + + # Parse out the path, module, and attribute + last_dot = full_attribute_name.rfind('.') + attr_name = full_attribute_name[last_dot + 1:] + mod_path = full_attribute_name[:last_dot] + + mod = modules(mod_path) + # Let an AttributeError propagate outward. + try: + attr = getattr(mod, attr_name) + except AttributeError: + raise AttributeError("'%s' object has no attribute '%s'" + % (mod_path, attr_name)) + + # Return a reference to the attribute. + return attr diff --git a/libraries/cherrypy/lib/sessions.py b/libraries/cherrypy/lib/sessions.py new file mode 100644 index 00000000..5b49ee13 --- /dev/null +++ b/libraries/cherrypy/lib/sessions.py @@ -0,0 +1,919 @@ +"""Session implementation for CherryPy. + +You need to edit your config file to use sessions. Here's an example:: + + [/] + tools.sessions.on = True + tools.sessions.storage_class = cherrypy.lib.sessions.FileSession + tools.sessions.storage_path = "/home/site/sessions" + tools.sessions.timeout = 60 + +This sets the session to be stored in files in the directory +/home/site/sessions, and the session timeout to 60 minutes. If you omit +``storage_class``, the sessions will be saved in RAM. +``tools.sessions.on`` is the only required line for working sessions, +the rest are optional. + +By default, the session ID is passed in a cookie, so the client's browser must +have cookies enabled for your site. + +To set data for the current session, use +``cherrypy.session['fieldname'] = 'fieldvalue'``; +to get data use ``cherrypy.session.get('fieldname')``. + +================ +Locking sessions +================ + +By default, the ``'locking'`` mode of sessions is ``'implicit'``, which means +the session is locked early and unlocked late. Be mindful of this default mode +for any requests that take a long time to process (streaming responses, +expensive calculations, database lookups, API calls, etc), as other concurrent +requests that also utilize sessions will hang until the session is unlocked. + +If you want to control when the session data is locked and unlocked, +set ``tools.sessions.locking = 'explicit'``. Then call +``cherrypy.session.acquire_lock()`` and ``cherrypy.session.release_lock()``. +Regardless of which mode you use, the session is guaranteed to be unlocked when +the request is complete. + +================= +Expiring Sessions +================= + +You can force a session to expire with :func:`cherrypy.lib.sessions.expire`. +Simply call that function at the point you want the session to expire, and it +will cause the session cookie to expire client-side. + +=========================== +Session Fixation Protection +=========================== + +If CherryPy receives, via a request cookie, a session id that it does not +recognize, it will reject that id and create a new one to return in the +response cookie. This `helps prevent session fixation attacks +<http://en.wikipedia.org/wiki/Session_fixation#Regenerate_SID_on_each_request>`_. +However, CherryPy "recognizes" a session id by looking up the saved session +data for that id. Therefore, if you never save any session data, +**you will get a new session id for every request**. + +A side effect of CherryPy overwriting unrecognised session ids is that if you +have multiple, separate CherryPy applications running on a single domain (e.g. +on different ports), each app will overwrite the other's session id because by +default they use the same cookie name (``"session_id"``) but do not recognise +each others sessions. It is therefore a good idea to use a different name for +each, for example:: + + [/] + ... + tools.sessions.name = "my_app_session_id" + +================ +Sharing Sessions +================ + +If you run multiple instances of CherryPy (for example via mod_python behind +Apache prefork), you most likely cannot use the RAM session backend, since each +instance of CherryPy will have its own memory space. Use a different backend +instead, and verify that all instances are pointing at the same file or db +location. Alternately, you might try a load balancer which makes sessions +"sticky". Google is your friend, there. + +================ +Expiration Dates +================ + +The response cookie will possess an expiration date to inform the client at +which point to stop sending the cookie back in requests. If the server time +and client time differ, expect sessions to be unreliable. **Make sure the +system time of your server is accurate**. + +CherryPy defaults to a 60-minute session timeout, which also applies to the +cookie which is sent to the client. Unfortunately, some versions of Safari +("4 public beta" on Windows XP at least) appear to have a bug in their parsing +of the GMT expiration date--they appear to interpret the date as one hour in +the past. Sixty minutes minus one hour is pretty close to zero, so you may +experience this bug as a new session id for every request, unless the requests +are less than one second apart. To fix, try increasing the session.timeout. + +On the other extreme, some users report Firefox sending cookies after their +expiration date, although this was on a system with an inaccurate system time. +Maybe FF doesn't trust system time. +""" +import sys +import datetime +import os +import time +import threading +import binascii + +import six +from six.moves import cPickle as pickle +import contextlib2 + +import zc.lockfile + +import cherrypy +from cherrypy.lib import httputil +from cherrypy.lib import locking +from cherrypy.lib import is_iterator + + +if six.PY2: + FileNotFoundError = OSError + + +missing = object() + + +class Session(object): + + """A CherryPy dict-like Session object (one per request).""" + + _id = None + + id_observers = None + "A list of callbacks to which to pass new id's." + + @property + def id(self): + """Return the current session id.""" + return self._id + + @id.setter + def id(self, value): + self._id = value + for o in self.id_observers: + o(value) + + timeout = 60 + 'Number of minutes after which to delete session data.' + + locked = False + """ + If True, this session instance has exclusive read/write access + to session data.""" + + loaded = False + """ + If True, data has been retrieved from storage. This should happen + automatically on the first attempt to access session data.""" + + clean_thread = None + 'Class-level Monitor which calls self.clean_up.' + + clean_freq = 5 + 'The poll rate for expired session cleanup in minutes.' + + originalid = None + 'The session id passed by the client. May be missing or unsafe.' + + missing = False + 'True if the session requested by the client did not exist.' + + regenerated = False + """ + True if the application called session.regenerate(). This is not set by + internal calls to regenerate the session id.""" + + debug = False + 'If True, log debug information.' + + # --------------------- Session management methods --------------------- # + + def __init__(self, id=None, **kwargs): + self.id_observers = [] + self._data = {} + + for k, v in kwargs.items(): + setattr(self, k, v) + + self.originalid = id + self.missing = False + if id is None: + if self.debug: + cherrypy.log('No id given; making a new one', 'TOOLS.SESSIONS') + self._regenerate() + else: + self.id = id + if self._exists(): + if self.debug: + cherrypy.log('Set id to %s.' % id, 'TOOLS.SESSIONS') + else: + if self.debug: + cherrypy.log('Expired or malicious session %r; ' + 'making a new one' % id, 'TOOLS.SESSIONS') + # Expired or malicious session. Make a new one. + # See https://github.com/cherrypy/cherrypy/issues/709. + self.id = None + self.missing = True + self._regenerate() + + def now(self): + """Generate the session specific concept of 'now'. + + Other session providers can override this to use alternative, + possibly timezone aware, versions of 'now'. + """ + return datetime.datetime.now() + + def regenerate(self): + """Replace the current session (with a new id).""" + self.regenerated = True + self._regenerate() + + def _regenerate(self): + if self.id is not None: + if self.debug: + cherrypy.log( + 'Deleting the existing session %r before ' + 'regeneration.' % self.id, + 'TOOLS.SESSIONS') + self.delete() + + old_session_was_locked = self.locked + if old_session_was_locked: + self.release_lock() + if self.debug: + cherrypy.log('Old lock released.', 'TOOLS.SESSIONS') + + self.id = None + while self.id is None: + self.id = self.generate_id() + # Assert that the generated id is not already stored. + if self._exists(): + self.id = None + if self.debug: + cherrypy.log('Set id to generated %s.' % self.id, + 'TOOLS.SESSIONS') + + if old_session_was_locked: + self.acquire_lock() + if self.debug: + cherrypy.log('Regenerated lock acquired.', 'TOOLS.SESSIONS') + + def clean_up(self): + """Clean up expired sessions.""" + pass + + def generate_id(self): + """Return a new session id.""" + return binascii.hexlify(os.urandom(20)).decode('ascii') + + def save(self): + """Save session data.""" + try: + # If session data has never been loaded then it's never been + # accessed: no need to save it + if self.loaded: + t = datetime.timedelta(seconds=self.timeout * 60) + expiration_time = self.now() + t + if self.debug: + cherrypy.log('Saving session %r with expiry %s' % + (self.id, expiration_time), + 'TOOLS.SESSIONS') + self._save(expiration_time) + else: + if self.debug: + cherrypy.log( + 'Skipping save of session %r (no session loaded).' % + self.id, 'TOOLS.SESSIONS') + finally: + if self.locked: + # Always release the lock if the user didn't release it + self.release_lock() + if self.debug: + cherrypy.log('Lock released after save.', 'TOOLS.SESSIONS') + + def load(self): + """Copy stored session data into this session instance.""" + data = self._load() + # data is either None or a tuple (session_data, expiration_time) + if data is None or data[1] < self.now(): + if self.debug: + cherrypy.log('Expired session %r, flushing data.' % self.id, + 'TOOLS.SESSIONS') + self._data = {} + else: + if self.debug: + cherrypy.log('Data loaded for session %r.' % self.id, + 'TOOLS.SESSIONS') + self._data = data[0] + self.loaded = True + + # Stick the clean_thread in the class, not the instance. + # The instances are created and destroyed per-request. + cls = self.__class__ + if self.clean_freq and not cls.clean_thread: + # clean_up is an instancemethod and not a classmethod, + # so that tool config can be accessed inside the method. + t = cherrypy.process.plugins.Monitor( + cherrypy.engine, self.clean_up, self.clean_freq * 60, + name='Session cleanup') + t.subscribe() + cls.clean_thread = t + t.start() + if self.debug: + cherrypy.log('Started cleanup thread.', 'TOOLS.SESSIONS') + + def delete(self): + """Delete stored session data.""" + self._delete() + if self.debug: + cherrypy.log('Deleted session %s.' % self.id, + 'TOOLS.SESSIONS') + + # -------------------- Application accessor methods -------------------- # + + def __getitem__(self, key): + if not self.loaded: + self.load() + return self._data[key] + + def __setitem__(self, key, value): + if not self.loaded: + self.load() + self._data[key] = value + + def __delitem__(self, key): + if not self.loaded: + self.load() + del self._data[key] + + def pop(self, key, default=missing): + """Remove the specified key and return the corresponding value. + If key is not found, default is returned if given, + otherwise KeyError is raised. + """ + if not self.loaded: + self.load() + if default is missing: + return self._data.pop(key) + else: + return self._data.pop(key, default) + + def __contains__(self, key): + if not self.loaded: + self.load() + return key in self._data + + def get(self, key, default=None): + """D.get(k[,d]) -> D[k] if k in D, else d. d defaults to None.""" + if not self.loaded: + self.load() + return self._data.get(key, default) + + def update(self, d): + """D.update(E) -> None. Update D from E: for k in E: D[k] = E[k].""" + if not self.loaded: + self.load() + self._data.update(d) + + def setdefault(self, key, default=None): + """D.setdefault(k[,d]) -> D.get(k,d), also set D[k]=d if k not in D.""" + if not self.loaded: + self.load() + return self._data.setdefault(key, default) + + def clear(self): + """D.clear() -> None. Remove all items from D.""" + if not self.loaded: + self.load() + self._data.clear() + + def keys(self): + """D.keys() -> list of D's keys.""" + if not self.loaded: + self.load() + return self._data.keys() + + def items(self): + """D.items() -> list of D's (key, value) pairs, as 2-tuples.""" + if not self.loaded: + self.load() + return self._data.items() + + def values(self): + """D.values() -> list of D's values.""" + if not self.loaded: + self.load() + return self._data.values() + + +class RamSession(Session): + + # Class-level objects. Don't rebind these! + cache = {} + locks = {} + + def clean_up(self): + """Clean up expired sessions.""" + + now = self.now() + for _id, (data, expiration_time) in list(six.iteritems(self.cache)): + if expiration_time <= now: + try: + del self.cache[_id] + except KeyError: + pass + try: + if self.locks[_id].acquire(blocking=False): + lock = self.locks.pop(_id) + lock.release() + except KeyError: + pass + + # added to remove obsolete lock objects + for _id in list(self.locks): + locked = ( + _id not in self.cache + and self.locks[_id].acquire(blocking=False) + ) + if locked: + lock = self.locks.pop(_id) + lock.release() + + def _exists(self): + return self.id in self.cache + + def _load(self): + return self.cache.get(self.id) + + def _save(self, expiration_time): + self.cache[self.id] = (self._data, expiration_time) + + def _delete(self): + self.cache.pop(self.id, None) + + def acquire_lock(self): + """Acquire an exclusive lock on the currently-loaded session data.""" + self.locked = True + self.locks.setdefault(self.id, threading.RLock()).acquire() + + def release_lock(self): + """Release the lock on the currently-loaded session data.""" + self.locks[self.id].release() + self.locked = False + + def __len__(self): + """Return the number of active sessions.""" + return len(self.cache) + + +class FileSession(Session): + + """Implementation of the File backend for sessions + + storage_path + The folder where session data will be saved. Each session + will be saved as pickle.dump(data, expiration_time) in its own file; + the filename will be self.SESSION_PREFIX + self.id. + + lock_timeout + A timedelta or numeric seconds indicating how long + to block acquiring a lock. If None (default), acquiring a lock + will block indefinitely. + """ + + SESSION_PREFIX = 'session-' + LOCK_SUFFIX = '.lock' + pickle_protocol = pickle.HIGHEST_PROTOCOL + + def __init__(self, id=None, **kwargs): + # The 'storage_path' arg is required for file-based sessions. + kwargs['storage_path'] = os.path.abspath(kwargs['storage_path']) + kwargs.setdefault('lock_timeout', None) + + Session.__init__(self, id=id, **kwargs) + + # validate self.lock_timeout + if isinstance(self.lock_timeout, (int, float)): + self.lock_timeout = datetime.timedelta(seconds=self.lock_timeout) + if not isinstance(self.lock_timeout, (datetime.timedelta, type(None))): + raise ValueError( + 'Lock timeout must be numeric seconds or a timedelta instance.' + ) + + @classmethod + def setup(cls, **kwargs): + """Set up the storage system for file-based sessions. + + This should only be called once per process; this will be done + automatically when using sessions.init (as the built-in Tool does). + """ + # The 'storage_path' arg is required for file-based sessions. + kwargs['storage_path'] = os.path.abspath(kwargs['storage_path']) + + for k, v in kwargs.items(): + setattr(cls, k, v) + + def _get_file_path(self): + f = os.path.join(self.storage_path, self.SESSION_PREFIX + self.id) + if not os.path.abspath(f).startswith(self.storage_path): + raise cherrypy.HTTPError(400, 'Invalid session id in cookie.') + return f + + def _exists(self): + path = self._get_file_path() + return os.path.exists(path) + + def _load(self, path=None): + assert self.locked, ('The session load without being locked. ' + "Check your tools' priority levels.") + if path is None: + path = self._get_file_path() + try: + f = open(path, 'rb') + try: + return pickle.load(f) + finally: + f.close() + except (IOError, EOFError): + e = sys.exc_info()[1] + if self.debug: + cherrypy.log('Error loading the session pickle: %s' % + e, 'TOOLS.SESSIONS') + return None + + def _save(self, expiration_time): + assert self.locked, ('The session was saved without being locked. ' + "Check your tools' priority levels.") + f = open(self._get_file_path(), 'wb') + try: + pickle.dump((self._data, expiration_time), f, self.pickle_protocol) + finally: + f.close() + + def _delete(self): + assert self.locked, ('The session deletion without being locked. ' + "Check your tools' priority levels.") + try: + os.unlink(self._get_file_path()) + except OSError: + pass + + def acquire_lock(self, path=None): + """Acquire an exclusive lock on the currently-loaded session data.""" + if path is None: + path = self._get_file_path() + path += self.LOCK_SUFFIX + checker = locking.LockChecker(self.id, self.lock_timeout) + while not checker.expired(): + try: + self.lock = zc.lockfile.LockFile(path) + except zc.lockfile.LockError: + time.sleep(0.1) + else: + break + self.locked = True + if self.debug: + cherrypy.log('Lock acquired.', 'TOOLS.SESSIONS') + + def release_lock(self, path=None): + """Release the lock on the currently-loaded session data.""" + self.lock.close() + with contextlib2.suppress(FileNotFoundError): + os.remove(self.lock._path) + self.locked = False + + def clean_up(self): + """Clean up expired sessions.""" + now = self.now() + # Iterate over all session files in self.storage_path + for fname in os.listdir(self.storage_path): + have_session = ( + fname.startswith(self.SESSION_PREFIX) + and not fname.endswith(self.LOCK_SUFFIX) + ) + if have_session: + # We have a session file: lock and load it and check + # if it's expired. If it fails, nevermind. + path = os.path.join(self.storage_path, fname) + self.acquire_lock(path) + if self.debug: + # This is a bit of a hack, since we're calling clean_up + # on the first instance rather than the entire class, + # so depending on whether you have "debug" set on the + # path of the first session called, this may not run. + cherrypy.log('Cleanup lock acquired.', 'TOOLS.SESSIONS') + + try: + contents = self._load(path) + # _load returns None on IOError + if contents is not None: + data, expiration_time = contents + if expiration_time < now: + # Session expired: deleting it + os.unlink(path) + finally: + self.release_lock(path) + + def __len__(self): + """Return the number of active sessions.""" + return len([fname for fname in os.listdir(self.storage_path) + if (fname.startswith(self.SESSION_PREFIX) and + not fname.endswith(self.LOCK_SUFFIX))]) + + +class MemcachedSession(Session): + + # The most popular memcached client for Python isn't thread-safe. + # Wrap all .get and .set operations in a single lock. + mc_lock = threading.RLock() + + # This is a separate set of locks per session id. + locks = {} + + servers = ['127.0.0.1:11211'] + + @classmethod + def setup(cls, **kwargs): + """Set up the storage system for memcached-based sessions. + + This should only be called once per process; this will be done + automatically when using sessions.init (as the built-in Tool does). + """ + for k, v in kwargs.items(): + setattr(cls, k, v) + + import memcache + cls.cache = memcache.Client(cls.servers) + + def _exists(self): + self.mc_lock.acquire() + try: + return bool(self.cache.get(self.id)) + finally: + self.mc_lock.release() + + def _load(self): + self.mc_lock.acquire() + try: + return self.cache.get(self.id) + finally: + self.mc_lock.release() + + def _save(self, expiration_time): + # Send the expiration time as "Unix time" (seconds since 1/1/1970) + td = int(time.mktime(expiration_time.timetuple())) + self.mc_lock.acquire() + try: + if not self.cache.set(self.id, (self._data, expiration_time), td): + raise AssertionError( + 'Session data for id %r not set.' % self.id) + finally: + self.mc_lock.release() + + def _delete(self): + self.cache.delete(self.id) + + def acquire_lock(self): + """Acquire an exclusive lock on the currently-loaded session data.""" + self.locked = True + self.locks.setdefault(self.id, threading.RLock()).acquire() + if self.debug: + cherrypy.log('Lock acquired.', 'TOOLS.SESSIONS') + + def release_lock(self): + """Release the lock on the currently-loaded session data.""" + self.locks[self.id].release() + self.locked = False + + def __len__(self): + """Return the number of active sessions.""" + raise NotImplementedError + + +# Hook functions (for CherryPy tools) + +def save(): + """Save any changed session data.""" + + if not hasattr(cherrypy.serving, 'session'): + return + request = cherrypy.serving.request + response = cherrypy.serving.response + + # Guard against running twice + if hasattr(request, '_sessionsaved'): + return + request._sessionsaved = True + + if response.stream: + # If the body is being streamed, we have to save the data + # *after* the response has been written out + request.hooks.attach('on_end_request', cherrypy.session.save) + else: + # If the body is not being streamed, we save the data now + # (so we can release the lock). + if is_iterator(response.body): + response.collapse_body() + cherrypy.session.save() + + +save.failsafe = True + + +def close(): + """Close the session object for this request.""" + sess = getattr(cherrypy.serving, 'session', None) + if getattr(sess, 'locked', False): + # If the session is still locked we release the lock + sess.release_lock() + if sess.debug: + cherrypy.log('Lock released on close.', 'TOOLS.SESSIONS') + + +close.failsafe = True +close.priority = 90 + + +def init(storage_type=None, path=None, path_header=None, name='session_id', + timeout=60, domain=None, secure=False, clean_freq=5, + persistent=True, httponly=False, debug=False, + # Py27 compat + # *, storage_class=RamSession, + **kwargs): + """Initialize session object (using cookies). + + storage_class + The Session subclass to use. Defaults to RamSession. + + storage_type + (deprecated) + One of 'ram', 'file', memcached'. This will be + used to look up the corresponding class in cherrypy.lib.sessions + globals. For example, 'file' will use the FileSession class. + + path + The 'path' value to stick in the response cookie metadata. + + path_header + If 'path' is None (the default), then the response + cookie 'path' will be pulled from request.headers[path_header]. + + name + The name of the cookie. + + timeout + The expiration timeout (in minutes) for the stored session data. + If 'persistent' is True (the default), this is also the timeout + for the cookie. + + domain + The cookie domain. + + secure + If False (the default) the cookie 'secure' value will not + be set. If True, the cookie 'secure' value will be set (to 1). + + clean_freq (minutes) + The poll rate for expired session cleanup. + + persistent + If True (the default), the 'timeout' argument will be used + to expire the cookie. If False, the cookie will not have an expiry, + and the cookie will be a "session cookie" which expires when the + browser is closed. + + httponly + If False (the default) the cookie 'httponly' value will not be set. + If True, the cookie 'httponly' value will be set (to 1). + + Any additional kwargs will be bound to the new Session instance, + and may be specific to the storage type. See the subclass of Session + you're using for more information. + """ + + # Py27 compat + storage_class = kwargs.pop('storage_class', RamSession) + + request = cherrypy.serving.request + + # Guard against running twice + if hasattr(request, '_session_init_flag'): + return + request._session_init_flag = True + + # Check if request came with a session ID + id = None + if name in request.cookie: + id = request.cookie[name].value + if debug: + cherrypy.log('ID obtained from request.cookie: %r' % id, + 'TOOLS.SESSIONS') + + first_time = not hasattr(cherrypy, 'session') + + if storage_type: + if first_time: + msg = 'storage_type is deprecated. Supply storage_class instead' + cherrypy.log(msg) + storage_class = storage_type.title() + 'Session' + storage_class = globals()[storage_class] + + # call setup first time only + if first_time: + if hasattr(storage_class, 'setup'): + storage_class.setup(**kwargs) + + # Create and attach a new Session instance to cherrypy.serving. + # It will possess a reference to (and lock, and lazily load) + # the requested session data. + kwargs['timeout'] = timeout + kwargs['clean_freq'] = clean_freq + cherrypy.serving.session = sess = storage_class(id, **kwargs) + sess.debug = debug + + def update_cookie(id): + """Update the cookie every time the session id changes.""" + cherrypy.serving.response.cookie[name] = id + sess.id_observers.append(update_cookie) + + # Create cherrypy.session which will proxy to cherrypy.serving.session + if not hasattr(cherrypy, 'session'): + cherrypy.session = cherrypy._ThreadLocalProxy('session') + + if persistent: + cookie_timeout = timeout + else: + # See http://support.microsoft.com/kb/223799/EN-US/ + # and http://support.mozilla.com/en-US/kb/Cookies + cookie_timeout = None + set_response_cookie(path=path, path_header=path_header, name=name, + timeout=cookie_timeout, domain=domain, secure=secure, + httponly=httponly) + + +def set_response_cookie(path=None, path_header=None, name='session_id', + timeout=60, domain=None, secure=False, httponly=False): + """Set a response cookie for the client. + + path + the 'path' value to stick in the response cookie metadata. + + path_header + if 'path' is None (the default), then the response + cookie 'path' will be pulled from request.headers[path_header]. + + name + the name of the cookie. + + timeout + the expiration timeout for the cookie. If 0 or other boolean + False, no 'expires' param will be set, and the cookie will be a + "session cookie" which expires when the browser is closed. + + domain + the cookie domain. + + secure + if False (the default) the cookie 'secure' value will not + be set. If True, the cookie 'secure' value will be set (to 1). + + httponly + If False (the default) the cookie 'httponly' value will not be set. + If True, the cookie 'httponly' value will be set (to 1). + + """ + # Set response cookie + cookie = cherrypy.serving.response.cookie + cookie[name] = cherrypy.serving.session.id + cookie[name]['path'] = ( + path or + cherrypy.serving.request.headers.get(path_header) or + '/' + ) + + if timeout: + cookie[name]['max-age'] = timeout * 60 + _add_MSIE_max_age_workaround(cookie[name], timeout) + if domain is not None: + cookie[name]['domain'] = domain + if secure: + cookie[name]['secure'] = 1 + if httponly: + if not cookie[name].isReservedKey('httponly'): + raise ValueError('The httponly cookie token is not supported.') + cookie[name]['httponly'] = 1 + + +def _add_MSIE_max_age_workaround(cookie, timeout): + """ + We'd like to use the "max-age" param as indicated in + http://www.faqs.org/rfcs/rfc2109.html but IE doesn't + save it to disk and the session is lost if people close + the browser. So we have to use the old "expires" ... sigh ... + """ + expires = time.time() + timeout * 60 + cookie['expires'] = httputil.HTTPDate(expires) + + +def expire(): + """Expire the current session cookie.""" + name = cherrypy.serving.request.config.get( + 'tools.sessions.name', 'session_id') + one_year = 60 * 60 * 24 * 365 + e = time.time() - one_year + cherrypy.serving.response.cookie[name]['expires'] = httputil.HTTPDate(e) + cherrypy.serving.response.cookie[name].pop('max-age', None) diff --git a/libraries/cherrypy/lib/static.py b/libraries/cherrypy/lib/static.py new file mode 100644 index 00000000..da9d9373 --- /dev/null +++ b/libraries/cherrypy/lib/static.py @@ -0,0 +1,390 @@ +"""Module with helpers for serving static files.""" + +import os +import platform +import re +import stat +import mimetypes + +from email.generator import _make_boundary as make_boundary +from io import UnsupportedOperation + +from six.moves import urllib + +import cherrypy +from cherrypy._cpcompat import ntob +from cherrypy.lib import cptools, httputil, file_generator_limited + + +def _setup_mimetypes(): + """Pre-initialize global mimetype map.""" + if not mimetypes.inited: + mimetypes.init() + mimetypes.types_map['.dwg'] = 'image/x-dwg' + mimetypes.types_map['.ico'] = 'image/x-icon' + mimetypes.types_map['.bz2'] = 'application/x-bzip2' + mimetypes.types_map['.gz'] = 'application/x-gzip' + + +_setup_mimetypes() + + +def serve_file(path, content_type=None, disposition=None, name=None, + debug=False): + """Set status, headers, and body in order to serve the given path. + + The Content-Type header will be set to the content_type arg, if provided. + If not provided, the Content-Type will be guessed by the file extension + of the 'path' argument. + + If disposition is not None, the Content-Disposition header will be set + to "<disposition>; filename=<name>". If name is None, it will be set + to the basename of path. If disposition is None, no Content-Disposition + header will be written. + """ + response = cherrypy.serving.response + + # If path is relative, users should fix it by making path absolute. + # That is, CherryPy should not guess where the application root is. + # It certainly should *not* use cwd (since CP may be invoked from a + # variety of paths). If using tools.staticdir, you can make your relative + # paths become absolute by supplying a value for "tools.staticdir.root". + if not os.path.isabs(path): + msg = "'%s' is not an absolute path." % path + if debug: + cherrypy.log(msg, 'TOOLS.STATICFILE') + raise ValueError(msg) + + try: + st = os.stat(path) + except (OSError, TypeError, ValueError): + # OSError when file fails to stat + # TypeError on Python 2 when there's a null byte + # ValueError on Python 3 when there's a null byte + if debug: + cherrypy.log('os.stat(%r) failed' % path, 'TOOLS.STATIC') + raise cherrypy.NotFound() + + # Check if path is a directory. + if stat.S_ISDIR(st.st_mode): + # Let the caller deal with it as they like. + if debug: + cherrypy.log('%r is a directory' % path, 'TOOLS.STATIC') + raise cherrypy.NotFound() + + # Set the Last-Modified response header, so that + # modified-since validation code can work. + response.headers['Last-Modified'] = httputil.HTTPDate(st.st_mtime) + cptools.validate_since() + + if content_type is None: + # Set content-type based on filename extension + ext = '' + i = path.rfind('.') + if i != -1: + ext = path[i:].lower() + content_type = mimetypes.types_map.get(ext, None) + if content_type is not None: + response.headers['Content-Type'] = content_type + if debug: + cherrypy.log('Content-Type: %r' % content_type, 'TOOLS.STATIC') + + cd = None + if disposition is not None: + if name is None: + name = os.path.basename(path) + cd = '%s; filename="%s"' % (disposition, name) + response.headers['Content-Disposition'] = cd + if debug: + cherrypy.log('Content-Disposition: %r' % cd, 'TOOLS.STATIC') + + # Set Content-Length and use an iterable (file object) + # this way CP won't load the whole file in memory + content_length = st.st_size + fileobj = open(path, 'rb') + return _serve_fileobj(fileobj, content_type, content_length, debug=debug) + + +def serve_fileobj(fileobj, content_type=None, disposition=None, name=None, + debug=False): + """Set status, headers, and body in order to serve the given file object. + + The Content-Type header will be set to the content_type arg, if provided. + + If disposition is not None, the Content-Disposition header will be set + to "<disposition>; filename=<name>". If name is None, 'filename' will + not be set. If disposition is None, no Content-Disposition header will + be written. + + CAUTION: If the request contains a 'Range' header, one or more seek()s will + be performed on the file object. This may cause undesired behavior if + the file object is not seekable. It could also produce undesired results + if the caller set the read position of the file object prior to calling + serve_fileobj(), expecting that the data would be served starting from that + position. + """ + response = cherrypy.serving.response + + try: + st = os.fstat(fileobj.fileno()) + except AttributeError: + if debug: + cherrypy.log('os has no fstat attribute', 'TOOLS.STATIC') + content_length = None + except UnsupportedOperation: + content_length = None + else: + # Set the Last-Modified response header, so that + # modified-since validation code can work. + response.headers['Last-Modified'] = httputil.HTTPDate(st.st_mtime) + cptools.validate_since() + content_length = st.st_size + + if content_type is not None: + response.headers['Content-Type'] = content_type + if debug: + cherrypy.log('Content-Type: %r' % content_type, 'TOOLS.STATIC') + + cd = None + if disposition is not None: + if name is None: + cd = disposition + else: + cd = '%s; filename="%s"' % (disposition, name) + response.headers['Content-Disposition'] = cd + if debug: + cherrypy.log('Content-Disposition: %r' % cd, 'TOOLS.STATIC') + + return _serve_fileobj(fileobj, content_type, content_length, debug=debug) + + +def _serve_fileobj(fileobj, content_type, content_length, debug=False): + """Internal. Set response.body to the given file object, perhaps ranged.""" + response = cherrypy.serving.response + + # HTTP/1.0 didn't have Range/Accept-Ranges headers, or the 206 code + request = cherrypy.serving.request + if request.protocol >= (1, 1): + response.headers['Accept-Ranges'] = 'bytes' + r = httputil.get_ranges(request.headers.get('Range'), content_length) + if r == []: + response.headers['Content-Range'] = 'bytes */%s' % content_length + message = ('Invalid Range (first-byte-pos greater than ' + 'Content-Length)') + if debug: + cherrypy.log(message, 'TOOLS.STATIC') + raise cherrypy.HTTPError(416, message) + + if r: + if len(r) == 1: + # Return a single-part response. + start, stop = r[0] + if stop > content_length: + stop = content_length + r_len = stop - start + if debug: + cherrypy.log( + 'Single part; start: %r, stop: %r' % (start, stop), + 'TOOLS.STATIC') + response.status = '206 Partial Content' + response.headers['Content-Range'] = ( + 'bytes %s-%s/%s' % (start, stop - 1, content_length)) + response.headers['Content-Length'] = r_len + fileobj.seek(start) + response.body = file_generator_limited(fileobj, r_len) + else: + # Return a multipart/byteranges response. + response.status = '206 Partial Content' + boundary = make_boundary() + ct = 'multipart/byteranges; boundary=%s' % boundary + response.headers['Content-Type'] = ct + if 'Content-Length' in response.headers: + # Delete Content-Length header so finalize() recalcs it. + del response.headers['Content-Length'] + + def file_ranges(): + # Apache compatibility: + yield b'\r\n' + + for start, stop in r: + if debug: + cherrypy.log( + 'Multipart; start: %r, stop: %r' % ( + start, stop), + 'TOOLS.STATIC') + yield ntob('--' + boundary, 'ascii') + yield ntob('\r\nContent-type: %s' % content_type, + 'ascii') + yield ntob( + '\r\nContent-range: bytes %s-%s/%s\r\n\r\n' % ( + start, stop - 1, content_length), + 'ascii') + fileobj.seek(start) + gen = file_generator_limited(fileobj, stop - start) + for chunk in gen: + yield chunk + yield b'\r\n' + # Final boundary + yield ntob('--' + boundary + '--', 'ascii') + + # Apache compatibility: + yield b'\r\n' + response.body = file_ranges() + return response.body + else: + if debug: + cherrypy.log('No byteranges requested', 'TOOLS.STATIC') + + # Set Content-Length and use an iterable (file object) + # this way CP won't load the whole file in memory + response.headers['Content-Length'] = content_length + response.body = fileobj + return response.body + + +def serve_download(path, name=None): + """Serve 'path' as an application/x-download attachment.""" + # This is such a common idiom I felt it deserved its own wrapper. + return serve_file(path, 'application/x-download', 'attachment', name) + + +def _attempt(filename, content_types, debug=False): + if debug: + cherrypy.log('Attempting %r (content_types %r)' % + (filename, content_types), 'TOOLS.STATICDIR') + try: + # you can set the content types for a + # complete directory per extension + content_type = None + if content_types: + r, ext = os.path.splitext(filename) + content_type = content_types.get(ext[1:], None) + serve_file(filename, content_type=content_type, debug=debug) + return True + except cherrypy.NotFound: + # If we didn't find the static file, continue handling the + # request. We might find a dynamic handler instead. + if debug: + cherrypy.log('NotFound', 'TOOLS.STATICFILE') + return False + + +def staticdir(section, dir, root='', match='', content_types=None, index='', + debug=False): + """Serve a static resource from the given (root +) dir. + + match + If given, request.path_info will be searched for the given + regular expression before attempting to serve static content. + + content_types + If given, it should be a Python dictionary of + {file-extension: content-type} pairs, where 'file-extension' is + a string (e.g. "gif") and 'content-type' is the value to write + out in the Content-Type response header (e.g. "image/gif"). + + index + If provided, it should be the (relative) name of a file to + serve for directory requests. For example, if the dir argument is + '/home/me', the Request-URI is 'myapp', and the index arg is + 'index.html', the file '/home/me/myapp/index.html' will be sought. + """ + request = cherrypy.serving.request + if request.method not in ('GET', 'HEAD'): + if debug: + cherrypy.log('request.method not GET or HEAD', 'TOOLS.STATICDIR') + return False + + if match and not re.search(match, request.path_info): + if debug: + cherrypy.log('request.path_info %r does not match pattern %r' % + (request.path_info, match), 'TOOLS.STATICDIR') + return False + + # Allow the use of '~' to refer to a user's home directory. + dir = os.path.expanduser(dir) + + # If dir is relative, make absolute using "root". + if not os.path.isabs(dir): + if not root: + msg = 'Static dir requires an absolute dir (or root).' + if debug: + cherrypy.log(msg, 'TOOLS.STATICDIR') + raise ValueError(msg) + dir = os.path.join(root, dir) + + # Determine where we are in the object tree relative to 'section' + # (where the static tool was defined). + if section == 'global': + section = '/' + section = section.rstrip(r'\/') + branch = request.path_info[len(section) + 1:] + branch = urllib.parse.unquote(branch.lstrip(r'\/')) + + # Requesting a file in sub-dir of the staticdir results + # in mixing of delimiter styles, e.g. C:\static\js/script.js. + # Windows accepts this form except not when the path is + # supplied in extended-path notation, e.g. \\?\C:\static\js/script.js. + # http://bit.ly/1vdioCX + if platform.system() == 'Windows': + branch = branch.replace('/', '\\') + + # If branch is "", filename will end in a slash + filename = os.path.join(dir, branch) + if debug: + cherrypy.log('Checking file %r to fulfill %r' % + (filename, request.path_info), 'TOOLS.STATICDIR') + + # There's a chance that the branch pulled from the URL might + # have ".." or similar uplevel attacks in it. Check that the final + # filename is a child of dir. + if not os.path.normpath(filename).startswith(os.path.normpath(dir)): + raise cherrypy.HTTPError(403) # Forbidden + + handled = _attempt(filename, content_types) + if not handled: + # Check for an index file if a folder was requested. + if index: + handled = _attempt(os.path.join(filename, index), content_types) + if handled: + request.is_index = filename[-1] in (r'\/') + return handled + + +def staticfile(filename, root=None, match='', content_types=None, debug=False): + """Serve a static resource from the given (root +) filename. + + match + If given, request.path_info will be searched for the given + regular expression before attempting to serve static content. + + content_types + If given, it should be a Python dictionary of + {file-extension: content-type} pairs, where 'file-extension' is + a string (e.g. "gif") and 'content-type' is the value to write + out in the Content-Type response header (e.g. "image/gif"). + + """ + request = cherrypy.serving.request + if request.method not in ('GET', 'HEAD'): + if debug: + cherrypy.log('request.method not GET or HEAD', 'TOOLS.STATICFILE') + return False + + if match and not re.search(match, request.path_info): + if debug: + cherrypy.log('request.path_info %r does not match pattern %r' % + (request.path_info, match), 'TOOLS.STATICFILE') + return False + + # If filename is relative, make absolute using "root". + if not os.path.isabs(filename): + if not root: + msg = "Static tool requires an absolute filename (got '%s')." % ( + filename,) + if debug: + cherrypy.log(msg, 'TOOLS.STATICFILE') + raise ValueError(msg) + filename = os.path.join(root, filename) + + return _attempt(filename, content_types, debug=debug) diff --git a/libraries/cherrypy/lib/xmlrpcutil.py b/libraries/cherrypy/lib/xmlrpcutil.py new file mode 100644 index 00000000..ddaac86a --- /dev/null +++ b/libraries/cherrypy/lib/xmlrpcutil.py @@ -0,0 +1,61 @@ +"""XML-RPC tool helpers.""" +import sys + +from six.moves.xmlrpc_client import ( + loads as xmlrpc_loads, dumps as xmlrpc_dumps, + Fault as XMLRPCFault +) + +import cherrypy +from cherrypy._cpcompat import ntob + + +def process_body(): + """Return (params, method) from request body.""" + try: + return xmlrpc_loads(cherrypy.request.body.read()) + except Exception: + return ('ERROR PARAMS', ), 'ERRORMETHOD' + + +def patched_path(path): + """Return 'path', doctored for RPC.""" + if not path.endswith('/'): + path += '/' + if path.startswith('/RPC2/'): + # strip the first /rpc2 + path = path[5:] + return path + + +def _set_response(body): + """Set up HTTP status, headers and body within CherryPy.""" + # The XML-RPC spec (http://www.xmlrpc.com/spec) says: + # "Unless there's a lower-level error, always return 200 OK." + # Since Python's xmlrpc_client interprets a non-200 response + # as a "Protocol Error", we'll just return 200 every time. + response = cherrypy.response + response.status = '200 OK' + response.body = ntob(body, 'utf-8') + response.headers['Content-Type'] = 'text/xml' + response.headers['Content-Length'] = len(body) + + +def respond(body, encoding='utf-8', allow_none=0): + """Construct HTTP response body.""" + if not isinstance(body, XMLRPCFault): + body = (body,) + + _set_response( + xmlrpc_dumps( + body, methodresponse=1, + encoding=encoding, + allow_none=allow_none + ) + ) + + +def on_error(*args, **kwargs): + """Construct HTTP response body for an error response.""" + body = str(sys.exc_info()[1]) + _set_response(xmlrpc_dumps(XMLRPCFault(1, body))) diff --git a/libraries/cherrypy/process/__init__.py b/libraries/cherrypy/process/__init__.py new file mode 100644 index 00000000..f242d226 --- /dev/null +++ b/libraries/cherrypy/process/__init__.py @@ -0,0 +1,17 @@ +"""Site container for an HTTP server. + +A Web Site Process Bus object is used to connect applications, servers, +and frameworks with site-wide services such as daemonization, process +reload, signal handling, drop privileges, PID file management, logging +for all of these, and many more. + +The 'plugins' module defines a few abstract and concrete services for +use with the bus. Some use tool-specific channels; see the documentation +for each class. +""" + +from .wspbus import bus +from . import plugins, servers + + +__all__ = ('bus', 'plugins', 'servers') diff --git a/libraries/cherrypy/process/plugins.py b/libraries/cherrypy/process/plugins.py new file mode 100644 index 00000000..8c246c81 --- /dev/null +++ b/libraries/cherrypy/process/plugins.py @@ -0,0 +1,752 @@ +"""Site services for use with a Web Site Process Bus.""" + +import os +import re +import signal as _signal +import sys +import time +import threading + +from six.moves import _thread + +from cherrypy._cpcompat import text_or_bytes +from cherrypy._cpcompat import ntob, Timer + +# _module__file__base is used by Autoreload to make +# absolute any filenames retrieved from sys.modules which are not +# already absolute paths. This is to work around Python's quirk +# of importing the startup script and using a relative filename +# for it in sys.modules. +# +# Autoreload examines sys.modules afresh every time it runs. If an application +# changes the current directory by executing os.chdir(), then the next time +# Autoreload runs, it will not be able to find any filenames which are +# not absolute paths, because the current directory is not the same as when the +# module was first imported. Autoreload will then wrongly conclude the file +# has "changed", and initiate the shutdown/re-exec sequence. +# See ticket #917. +# For this workaround to have a decent probability of success, this module +# needs to be imported as early as possible, before the app has much chance +# to change the working directory. +_module__file__base = os.getcwd() + + +class SimplePlugin(object): + + """Plugin base class which auto-subscribes methods for known channels.""" + + bus = None + """A :class:`Bus <cherrypy.process.wspbus.Bus>`, usually cherrypy.engine. + """ + + def __init__(self, bus): + self.bus = bus + + def subscribe(self): + """Register this object as a (multi-channel) listener on the bus.""" + for channel in self.bus.listeners: + # Subscribe self.start, self.exit, etc. if present. + method = getattr(self, channel, None) + if method is not None: + self.bus.subscribe(channel, method) + + def unsubscribe(self): + """Unregister this object as a listener on the bus.""" + for channel in self.bus.listeners: + # Unsubscribe self.start, self.exit, etc. if present. + method = getattr(self, channel, None) + if method is not None: + self.bus.unsubscribe(channel, method) + + +class SignalHandler(object): + + """Register bus channels (and listeners) for system signals. + + You can modify what signals your application listens for, and what it does + when it receives signals, by modifying :attr:`SignalHandler.handlers`, + a dict of {signal name: callback} pairs. The default set is:: + + handlers = {'SIGTERM': self.bus.exit, + 'SIGHUP': self.handle_SIGHUP, + 'SIGUSR1': self.bus.graceful, + } + + The :func:`SignalHandler.handle_SIGHUP`` method calls + :func:`bus.restart()<cherrypy.process.wspbus.Bus.restart>` + if the process is daemonized, but + :func:`bus.exit()<cherrypy.process.wspbus.Bus.exit>` + if the process is attached to a TTY. This is because Unix window + managers tend to send SIGHUP to terminal windows when the user closes them. + + Feel free to add signals which are not available on every platform. + The :class:`SignalHandler` will ignore errors raised from attempting + to register handlers for unknown signals. + """ + + handlers = {} + """A map from signal names (e.g. 'SIGTERM') to handlers (e.g. bus.exit).""" + + signals = {} + """A map from signal numbers to names.""" + + for k, v in vars(_signal).items(): + if k.startswith('SIG') and not k.startswith('SIG_'): + signals[v] = k + del k, v + + def __init__(self, bus): + self.bus = bus + # Set default handlers + self.handlers = {'SIGTERM': self.bus.exit, + 'SIGHUP': self.handle_SIGHUP, + 'SIGUSR1': self.bus.graceful, + } + + if sys.platform[:4] == 'java': + del self.handlers['SIGUSR1'] + self.handlers['SIGUSR2'] = self.bus.graceful + self.bus.log('SIGUSR1 cannot be set on the JVM platform. ' + 'Using SIGUSR2 instead.') + self.handlers['SIGINT'] = self._jython_SIGINT_handler + + self._previous_handlers = {} + # used to determine is the process is a daemon in `self._is_daemonized` + self._original_pid = os.getpid() + + def _jython_SIGINT_handler(self, signum=None, frame=None): + # See http://bugs.jython.org/issue1313 + self.bus.log('Keyboard Interrupt: shutting down bus') + self.bus.exit() + + def _is_daemonized(self): + """Return boolean indicating if the current process is + running as a daemon. + + The criteria to determine the `daemon` condition is to verify + if the current pid is not the same as the one that got used on + the initial construction of the plugin *and* the stdin is not + connected to a terminal. + + The sole validation of the tty is not enough when the plugin + is executing inside other process like in a CI tool + (Buildbot, Jenkins). + """ + return ( + self._original_pid != os.getpid() and + not os.isatty(sys.stdin.fileno()) + ) + + def subscribe(self): + """Subscribe self.handlers to signals.""" + for sig, func in self.handlers.items(): + try: + self.set_handler(sig, func) + except ValueError: + pass + + def unsubscribe(self): + """Unsubscribe self.handlers from signals.""" + for signum, handler in self._previous_handlers.items(): + signame = self.signals[signum] + + if handler is None: + self.bus.log('Restoring %s handler to SIG_DFL.' % signame) + handler = _signal.SIG_DFL + else: + self.bus.log('Restoring %s handler %r.' % (signame, handler)) + + try: + our_handler = _signal.signal(signum, handler) + if our_handler is None: + self.bus.log('Restored old %s handler %r, but our ' + 'handler was not registered.' % + (signame, handler), level=30) + except ValueError: + self.bus.log('Unable to restore %s handler %r.' % + (signame, handler), level=40, traceback=True) + + def set_handler(self, signal, listener=None): + """Subscribe a handler for the given signal (number or name). + + If the optional 'listener' argument is provided, it will be + subscribed as a listener for the given signal's channel. + + If the given signal name or number is not available on the current + platform, ValueError is raised. + """ + if isinstance(signal, text_or_bytes): + signum = getattr(_signal, signal, None) + if signum is None: + raise ValueError('No such signal: %r' % signal) + signame = signal + else: + try: + signame = self.signals[signal] + except KeyError: + raise ValueError('No such signal: %r' % signal) + signum = signal + + prev = _signal.signal(signum, self._handle_signal) + self._previous_handlers[signum] = prev + + if listener is not None: + self.bus.log('Listening for %s.' % signame) + self.bus.subscribe(signame, listener) + + def _handle_signal(self, signum=None, frame=None): + """Python signal handler (self.set_handler subscribes it for you).""" + signame = self.signals[signum] + self.bus.log('Caught signal %s.' % signame) + self.bus.publish(signame) + + def handle_SIGHUP(self): + """Restart if daemonized, else exit.""" + if self._is_daemonized(): + self.bus.log('SIGHUP caught while daemonized. Restarting.') + self.bus.restart() + else: + # not daemonized (may be foreground or background) + self.bus.log('SIGHUP caught but not daemonized. Exiting.') + self.bus.exit() + + +try: + import pwd + import grp +except ImportError: + pwd, grp = None, None + + +class DropPrivileges(SimplePlugin): + + """Drop privileges. uid/gid arguments not available on Windows. + + Special thanks to `Gavin Baker + <http://antonym.org/2005/12/dropping-privileges-in-python.html>`_ + """ + + def __init__(self, bus, umask=None, uid=None, gid=None): + SimplePlugin.__init__(self, bus) + self.finalized = False + self.uid = uid + self.gid = gid + self.umask = umask + + @property + def uid(self): + """The uid under which to run. Availability: Unix.""" + return self._uid + + @uid.setter + def uid(self, val): + if val is not None: + if pwd is None: + self.bus.log('pwd module not available; ignoring uid.', + level=30) + val = None + elif isinstance(val, text_or_bytes): + val = pwd.getpwnam(val)[2] + self._uid = val + + @property + def gid(self): + """The gid under which to run. Availability: Unix.""" + return self._gid + + @gid.setter + def gid(self, val): + if val is not None: + if grp is None: + self.bus.log('grp module not available; ignoring gid.', + level=30) + val = None + elif isinstance(val, text_or_bytes): + val = grp.getgrnam(val)[2] + self._gid = val + + @property + def umask(self): + """The default permission mode for newly created files and directories. + + Usually expressed in octal format, for example, ``0644``. + Availability: Unix, Windows. + """ + return self._umask + + @umask.setter + def umask(self, val): + if val is not None: + try: + os.umask + except AttributeError: + self.bus.log('umask function not available; ignoring umask.', + level=30) + val = None + self._umask = val + + def start(self): + # uid/gid + def current_ids(): + """Return the current (uid, gid) if available.""" + name, group = None, None + if pwd: + name = pwd.getpwuid(os.getuid())[0] + if grp: + group = grp.getgrgid(os.getgid())[0] + return name, group + + if self.finalized: + if not (self.uid is None and self.gid is None): + self.bus.log('Already running as uid: %r gid: %r' % + current_ids()) + else: + if self.uid is None and self.gid is None: + if pwd or grp: + self.bus.log('uid/gid not set', level=30) + else: + self.bus.log('Started as uid: %r gid: %r' % current_ids()) + if self.gid is not None: + os.setgid(self.gid) + os.setgroups([]) + if self.uid is not None: + os.setuid(self.uid) + self.bus.log('Running as uid: %r gid: %r' % current_ids()) + + # umask + if self.finalized: + if self.umask is not None: + self.bus.log('umask already set to: %03o' % self.umask) + else: + if self.umask is None: + self.bus.log('umask not set', level=30) + else: + old_umask = os.umask(self.umask) + self.bus.log('umask old: %03o, new: %03o' % + (old_umask, self.umask)) + + self.finalized = True + # This is slightly higher than the priority for server.start + # in order to facilitate the most common use: starting on a low + # port (which requires root) and then dropping to another user. + start.priority = 77 + + +class Daemonizer(SimplePlugin): + + """Daemonize the running script. + + Use this with a Web Site Process Bus via:: + + Daemonizer(bus).subscribe() + + When this component finishes, the process is completely decoupled from + the parent environment. Please note that when this component is used, + the return code from the parent process will still be 0 if a startup + error occurs in the forked children. Errors in the initial daemonizing + process still return proper exit codes. Therefore, if you use this + plugin to daemonize, don't use the return code as an accurate indicator + of whether the process fully started. In fact, that return code only + indicates if the process successfully finished the first fork. + """ + + def __init__(self, bus, stdin='/dev/null', stdout='/dev/null', + stderr='/dev/null'): + SimplePlugin.__init__(self, bus) + self.stdin = stdin + self.stdout = stdout + self.stderr = stderr + self.finalized = False + + def start(self): + if self.finalized: + self.bus.log('Already deamonized.') + + # forking has issues with threads: + # http://www.opengroup.org/onlinepubs/000095399/functions/fork.html + # "The general problem with making fork() work in a multi-threaded + # world is what to do with all of the threads..." + # So we check for active threads: + if threading.activeCount() != 1: + self.bus.log('There are %r active threads. ' + 'Daemonizing now may cause strange failures.' % + threading.enumerate(), level=30) + + self.daemonize(self.stdin, self.stdout, self.stderr, self.bus.log) + + self.finalized = True + start.priority = 65 + + @staticmethod + def daemonize( + stdin='/dev/null', stdout='/dev/null', stderr='/dev/null', + logger=lambda msg: None): + # See http://www.erlenstar.demon.co.uk/unix/faq_2.html#SEC16 + # (or http://www.faqs.org/faqs/unix-faq/programmer/faq/ section 1.7) + # and http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/66012 + + # Finish up with the current stdout/stderr + sys.stdout.flush() + sys.stderr.flush() + + error_tmpl = ( + '{sys.argv[0]}: fork #{n} failed: ({exc.errno}) {exc.strerror}\n' + ) + + for fork in range(2): + msg = ['Forking once.', 'Forking twice.'][fork] + try: + pid = os.fork() + if pid > 0: + # This is the parent; exit. + logger(msg) + os._exit(0) + except OSError as exc: + # Python raises OSError rather than returning negative numbers. + sys.exit(error_tmpl.format(sys=sys, exc=exc, n=fork + 1)) + if fork == 0: + os.setsid() + + os.umask(0) + + si = open(stdin, 'r') + so = open(stdout, 'a+') + se = open(stderr, 'a+') + + # os.dup2(fd, fd2) will close fd2 if necessary, + # so we don't explicitly close stdin/out/err. + # See http://docs.python.org/lib/os-fd-ops.html + os.dup2(si.fileno(), sys.stdin.fileno()) + os.dup2(so.fileno(), sys.stdout.fileno()) + os.dup2(se.fileno(), sys.stderr.fileno()) + + logger('Daemonized to PID: %s' % os.getpid()) + + +class PIDFile(SimplePlugin): + + """Maintain a PID file via a WSPBus.""" + + def __init__(self, bus, pidfile): + SimplePlugin.__init__(self, bus) + self.pidfile = pidfile + self.finalized = False + + def start(self): + pid = os.getpid() + if self.finalized: + self.bus.log('PID %r already written to %r.' % (pid, self.pidfile)) + else: + open(self.pidfile, 'wb').write(ntob('%s\n' % pid, 'utf8')) + self.bus.log('PID %r written to %r.' % (pid, self.pidfile)) + self.finalized = True + start.priority = 70 + + def exit(self): + try: + os.remove(self.pidfile) + self.bus.log('PID file removed: %r.' % self.pidfile) + except (KeyboardInterrupt, SystemExit): + raise + except Exception: + pass + + +class PerpetualTimer(Timer): + + """A responsive subclass of threading.Timer whose run() method repeats. + + Use this timer only when you really need a very interruptible timer; + this checks its 'finished' condition up to 20 times a second, which can + results in pretty high CPU usage + """ + + def __init__(self, *args, **kwargs): + "Override parent constructor to allow 'bus' to be provided." + self.bus = kwargs.pop('bus', None) + super(PerpetualTimer, self).__init__(*args, **kwargs) + + def run(self): + while True: + self.finished.wait(self.interval) + if self.finished.isSet(): + return + try: + self.function(*self.args, **self.kwargs) + except Exception: + if self.bus: + self.bus.log( + 'Error in perpetual timer thread function %r.' % + self.function, level=40, traceback=True) + # Quit on first error to avoid massive logs. + raise + + +class BackgroundTask(threading.Thread): + + """A subclass of threading.Thread whose run() method repeats. + + Use this class for most repeating tasks. It uses time.sleep() to wait + for each interval, which isn't very responsive; that is, even if you call + self.cancel(), you'll have to wait until the sleep() call finishes before + the thread stops. To compensate, it defaults to being daemonic, which means + it won't delay stopping the whole process. + """ + + def __init__(self, interval, function, args=[], kwargs={}, bus=None): + super(BackgroundTask, self).__init__() + self.interval = interval + self.function = function + self.args = args + self.kwargs = kwargs + self.running = False + self.bus = bus + + # default to daemonic + self.daemon = True + + def cancel(self): + self.running = False + + def run(self): + self.running = True + while self.running: + time.sleep(self.interval) + if not self.running: + return + try: + self.function(*self.args, **self.kwargs) + except Exception: + if self.bus: + self.bus.log('Error in background task thread function %r.' + % self.function, level=40, traceback=True) + # Quit on first error to avoid massive logs. + raise + + +class Monitor(SimplePlugin): + + """WSPBus listener to periodically run a callback in its own thread.""" + + callback = None + """The function to call at intervals.""" + + frequency = 60 + """The time in seconds between callback runs.""" + + thread = None + """A :class:`BackgroundTask<cherrypy.process.plugins.BackgroundTask>` + thread. + """ + + def __init__(self, bus, callback, frequency=60, name=None): + SimplePlugin.__init__(self, bus) + self.callback = callback + self.frequency = frequency + self.thread = None + self.name = name + + def start(self): + """Start our callback in its own background thread.""" + if self.frequency > 0: + threadname = self.name or self.__class__.__name__ + if self.thread is None: + self.thread = BackgroundTask(self.frequency, self.callback, + bus=self.bus) + self.thread.setName(threadname) + self.thread.start() + self.bus.log('Started monitor thread %r.' % threadname) + else: + self.bus.log('Monitor thread %r already started.' % threadname) + start.priority = 70 + + def stop(self): + """Stop our callback's background task thread.""" + if self.thread is None: + self.bus.log('No thread running for %s.' % + self.name or self.__class__.__name__) + else: + if self.thread is not threading.currentThread(): + name = self.thread.getName() + self.thread.cancel() + if not self.thread.daemon: + self.bus.log('Joining %r' % name) + self.thread.join() + self.bus.log('Stopped thread %r.' % name) + self.thread = None + + def graceful(self): + """Stop the callback's background task thread and restart it.""" + self.stop() + self.start() + + +class Autoreloader(Monitor): + + """Monitor which re-executes the process when files change. + + This :ref:`plugin<plugins>` restarts the process (via :func:`os.execv`) + if any of the files it monitors change (or is deleted). By default, the + autoreloader monitors all imported modules; you can add to the + set by adding to ``autoreload.files``:: + + cherrypy.engine.autoreload.files.add(myFile) + + If there are imported files you do *not* wish to monitor, you can + adjust the ``match`` attribute, a regular expression. For example, + to stop monitoring cherrypy itself:: + + cherrypy.engine.autoreload.match = r'^(?!cherrypy).+' + + Like all :class:`Monitor<cherrypy.process.plugins.Monitor>` plugins, + the autoreload plugin takes a ``frequency`` argument. The default is + 1 second; that is, the autoreloader will examine files once each second. + """ + + files = None + """The set of files to poll for modifications.""" + + frequency = 1 + """The interval in seconds at which to poll for modified files.""" + + match = '.*' + """A regular expression by which to match filenames.""" + + def __init__(self, bus, frequency=1, match='.*'): + self.mtimes = {} + self.files = set() + self.match = match + Monitor.__init__(self, bus, self.run, frequency) + + def start(self): + """Start our own background task thread for self.run.""" + if self.thread is None: + self.mtimes = {} + Monitor.start(self) + start.priority = 70 + + def sysfiles(self): + """Return a Set of sys.modules filenames to monitor.""" + search_mod_names = filter(re.compile(self.match).match, sys.modules) + mods = map(sys.modules.get, search_mod_names) + return set(filter(None, map(self._file_for_module, mods))) + + @classmethod + def _file_for_module(cls, module): + """Return the relevant file for the module.""" + return ( + cls._archive_for_zip_module(module) + or cls._file_for_file_module(module) + ) + + @staticmethod + def _archive_for_zip_module(module): + """Return the archive filename for the module if relevant.""" + try: + return module.__loader__.archive + except AttributeError: + pass + + @classmethod + def _file_for_file_module(cls, module): + """Return the file for the module.""" + try: + return module.__file__ and cls._make_absolute(module.__file__) + except AttributeError: + pass + + @staticmethod + def _make_absolute(filename): + """Ensure filename is absolute to avoid effect of os.chdir.""" + return filename if os.path.isabs(filename) else ( + os.path.normpath(os.path.join(_module__file__base, filename)) + ) + + def run(self): + """Reload the process if registered files have been modified.""" + for filename in self.sysfiles() | self.files: + if filename: + if filename.endswith('.pyc'): + filename = filename[:-1] + + oldtime = self.mtimes.get(filename, 0) + if oldtime is None: + # Module with no .py file. Skip it. + continue + + try: + mtime = os.stat(filename).st_mtime + except OSError: + # Either a module with no .py file, or it's been deleted. + mtime = None + + if filename not in self.mtimes: + # If a module has no .py file, this will be None. + self.mtimes[filename] = mtime + else: + if mtime is None or mtime > oldtime: + # The file has been deleted or modified. + self.bus.log('Restarting because %s changed.' % + filename) + self.thread.cancel() + self.bus.log('Stopped thread %r.' % + self.thread.getName()) + self.bus.restart() + return + + +class ThreadManager(SimplePlugin): + + """Manager for HTTP request threads. + + If you have control over thread creation and destruction, publish to + the 'acquire_thread' and 'release_thread' channels (for each thread). + This will register/unregister the current thread and publish to + 'start_thread' and 'stop_thread' listeners in the bus as needed. + + If threads are created and destroyed by code you do not control + (e.g., Apache), then, at the beginning of every HTTP request, + publish to 'acquire_thread' only. You should not publish to + 'release_thread' in this case, since you do not know whether + the thread will be re-used or not. The bus will call + 'stop_thread' listeners for you when it stops. + """ + + threads = None + """A map of {thread ident: index number} pairs.""" + + def __init__(self, bus): + self.threads = {} + SimplePlugin.__init__(self, bus) + self.bus.listeners.setdefault('acquire_thread', set()) + self.bus.listeners.setdefault('start_thread', set()) + self.bus.listeners.setdefault('release_thread', set()) + self.bus.listeners.setdefault('stop_thread', set()) + + def acquire_thread(self): + """Run 'start_thread' listeners for the current thread. + + If the current thread has already been seen, any 'start_thread' + listeners will not be run again. + """ + thread_ident = _thread.get_ident() + if thread_ident not in self.threads: + # We can't just use get_ident as the thread ID + # because some platforms reuse thread ID's. + i = len(self.threads) + 1 + self.threads[thread_ident] = i + self.bus.publish('start_thread', i) + + def release_thread(self): + """Release the current thread and run 'stop_thread' listeners.""" + thread_ident = _thread.get_ident() + i = self.threads.pop(thread_ident, None) + if i is not None: + self.bus.publish('stop_thread', i) + + def stop(self): + """Release all threads and run all 'stop_thread' listeners.""" + for thread_ident, i in self.threads.items(): + self.bus.publish('stop_thread', i) + self.threads.clear() + graceful = stop diff --git a/libraries/cherrypy/process/servers.py b/libraries/cherrypy/process/servers.py new file mode 100644 index 00000000..dcb34de6 --- /dev/null +++ b/libraries/cherrypy/process/servers.py @@ -0,0 +1,416 @@ +r""" +Starting in CherryPy 3.1, cherrypy.server is implemented as an +:ref:`Engine Plugin<plugins>`. It's an instance of +:class:`cherrypy._cpserver.Server`, which is a subclass of +:class:`cherrypy.process.servers.ServerAdapter`. The ``ServerAdapter`` class +is designed to control other servers, as well. + +Multiple servers/ports +====================== + +If you need to start more than one HTTP server (to serve on multiple ports, or +protocols, etc.), you can manually register each one and then start them all +with engine.start:: + + s1 = ServerAdapter( + cherrypy.engine, + MyWSGIServer(host='0.0.0.0', port=80) + ) + s2 = ServerAdapter( + cherrypy.engine, + another.HTTPServer(host='127.0.0.1', SSL=True) + ) + s1.subscribe() + s2.subscribe() + cherrypy.engine.start() + +.. index:: SCGI + +FastCGI/SCGI +============ + +There are also Flup\ **F**\ CGIServer and Flup\ **S**\ CGIServer classes in +:mod:`cherrypy.process.servers`. To start an fcgi server, for example, +wrap an instance of it in a ServerAdapter:: + + addr = ('0.0.0.0', 4000) + f = servers.FlupFCGIServer(application=cherrypy.tree, bindAddress=addr) + s = servers.ServerAdapter(cherrypy.engine, httpserver=f, bind_addr=addr) + s.subscribe() + +The :doc:`cherryd</deployguide/cherryd>` startup script will do the above for +you via its `-f` flag. +Note that you need to download and install `flup <http://trac.saddi.com/flup>`_ +yourself, whether you use ``cherryd`` or not. + +.. _fastcgi: +.. index:: FastCGI + +FastCGI +------- + +A very simple setup lets your cherry run with FastCGI. +You just need the flup library, +plus a running Apache server (with ``mod_fastcgi``) or lighttpd server. + +CherryPy code +^^^^^^^^^^^^^ + +hello.py:: + + #!/usr/bin/python + import cherrypy + + class HelloWorld: + '''Sample request handler class.''' + @cherrypy.expose + def index(self): + return "Hello world!" + + cherrypy.tree.mount(HelloWorld()) + # CherryPy autoreload must be disabled for the flup server to work + cherrypy.config.update({'engine.autoreload.on':False}) + +Then run :doc:`/deployguide/cherryd` with the '-f' arg:: + + cherryd -c <myconfig> -d -f -i hello.py + +Apache +^^^^^^ + +At the top level in httpd.conf:: + + FastCgiIpcDir /tmp + FastCgiServer /path/to/cherry.fcgi -idle-timeout 120 -processes 4 + +And inside the relevant VirtualHost section:: + + # FastCGI config + AddHandler fastcgi-script .fcgi + ScriptAliasMatch (.*$) /path/to/cherry.fcgi$1 + +Lighttpd +^^^^^^^^ + +For `Lighttpd <http://www.lighttpd.net/>`_ you can follow these +instructions. Within ``lighttpd.conf`` make sure ``mod_fastcgi`` is +active within ``server.modules``. Then, within your ``$HTTP["host"]`` +directive, configure your fastcgi script like the following:: + + $HTTP["url"] =~ "" { + fastcgi.server = ( + "/" => ( + "script.fcgi" => ( + "bin-path" => "/path/to/your/script.fcgi", + "socket" => "/tmp/script.sock", + "check-local" => "disable", + "disable-time" => 1, + "min-procs" => 1, + "max-procs" => 1, # adjust as needed + ), + ), + ) + } # end of $HTTP["url"] =~ "^/" + +Please see `Lighttpd FastCGI Docs +<http://redmine.lighttpd.net/wiki/lighttpd/Docs:ModFastCGI>`_ for +an explanation of the possible configuration options. +""" + +import os +import sys +import time +import warnings +import contextlib + +import portend + + +class Timeouts: + occupied = 5 + free = 1 + + +class ServerAdapter(object): + + """Adapter for an HTTP server. + + If you need to start more than one HTTP server (to serve on multiple + ports, or protocols, etc.), you can manually register each one and then + start them all with bus.start:: + + s1 = ServerAdapter(bus, MyWSGIServer(host='0.0.0.0', port=80)) + s2 = ServerAdapter(bus, another.HTTPServer(host='127.0.0.1', SSL=True)) + s1.subscribe() + s2.subscribe() + bus.start() + """ + + def __init__(self, bus, httpserver=None, bind_addr=None): + self.bus = bus + self.httpserver = httpserver + self.bind_addr = bind_addr + self.interrupt = None + self.running = False + + def subscribe(self): + self.bus.subscribe('start', self.start) + self.bus.subscribe('stop', self.stop) + + def unsubscribe(self): + self.bus.unsubscribe('start', self.start) + self.bus.unsubscribe('stop', self.stop) + + def start(self): + """Start the HTTP server.""" + if self.running: + self.bus.log('Already serving on %s' % self.description) + return + + self.interrupt = None + if not self.httpserver: + raise ValueError('No HTTP server has been created.') + + if not os.environ.get('LISTEN_PID', None): + # Start the httpserver in a new thread. + if isinstance(self.bind_addr, tuple): + portend.free(*self.bind_addr, timeout=Timeouts.free) + + import threading + t = threading.Thread(target=self._start_http_thread) + t.setName('HTTPServer ' + t.getName()) + t.start() + + self.wait() + self.running = True + self.bus.log('Serving on %s' % self.description) + start.priority = 75 + + @property + def description(self): + """ + A description about where this server is bound. + """ + if self.bind_addr is None: + on_what = 'unknown interface (dynamic?)' + elif isinstance(self.bind_addr, tuple): + on_what = self._get_base() + else: + on_what = 'socket file: %s' % self.bind_addr + return on_what + + def _get_base(self): + if not self.httpserver: + return '' + host, port = self.bound_addr + if getattr(self.httpserver, 'ssl_adapter', None): + scheme = 'https' + if port != 443: + host += ':%s' % port + else: + scheme = 'http' + if port != 80: + host += ':%s' % port + + return '%s://%s' % (scheme, host) + + def _start_http_thread(self): + """HTTP servers MUST be running in new threads, so that the + main thread persists to receive KeyboardInterrupt's. If an + exception is raised in the httpserver's thread then it's + trapped here, and the bus (and therefore our httpserver) + are shut down. + """ + try: + self.httpserver.start() + except KeyboardInterrupt: + self.bus.log('<Ctrl-C> hit: shutting down HTTP server') + self.interrupt = sys.exc_info()[1] + self.bus.exit() + except SystemExit: + self.bus.log('SystemExit raised: shutting down HTTP server') + self.interrupt = sys.exc_info()[1] + self.bus.exit() + raise + except Exception: + self.interrupt = sys.exc_info()[1] + self.bus.log('Error in HTTP server: shutting down', + traceback=True, level=40) + self.bus.exit() + raise + + def wait(self): + """Wait until the HTTP server is ready to receive requests.""" + while not getattr(self.httpserver, 'ready', False): + if self.interrupt: + raise self.interrupt + time.sleep(.1) + + # bypass check when LISTEN_PID is set + if os.environ.get('LISTEN_PID', None): + return + + # bypass check when running via socket-activation + # (for socket-activation the port will be managed by systemd) + if not isinstance(self.bind_addr, tuple): + return + + # wait for port to be occupied + with _safe_wait(*self.bound_addr): + portend.occupied(*self.bound_addr, timeout=Timeouts.occupied) + + @property + def bound_addr(self): + """ + The bind address, or if it's an ephemeral port and the + socket has been bound, return the actual port bound. + """ + host, port = self.bind_addr + if port == 0 and self.httpserver.socket: + # Bound to ephemeral port. Get the actual port allocated. + port = self.httpserver.socket.getsockname()[1] + return host, port + + def stop(self): + """Stop the HTTP server.""" + if self.running: + # stop() MUST block until the server is *truly* stopped. + self.httpserver.stop() + # Wait for the socket to be truly freed. + if isinstance(self.bind_addr, tuple): + portend.free(*self.bound_addr, timeout=Timeouts.free) + self.running = False + self.bus.log('HTTP Server %s shut down' % self.httpserver) + else: + self.bus.log('HTTP Server %s already shut down' % self.httpserver) + stop.priority = 25 + + def restart(self): + """Restart the HTTP server.""" + self.stop() + self.start() + + +class FlupCGIServer(object): + + """Adapter for a flup.server.cgi.WSGIServer.""" + + def __init__(self, *args, **kwargs): + self.args = args + self.kwargs = kwargs + self.ready = False + + def start(self): + """Start the CGI server.""" + # We have to instantiate the server class here because its __init__ + # starts a threadpool. If we do it too early, daemonize won't work. + from flup.server.cgi import WSGIServer + + self.cgiserver = WSGIServer(*self.args, **self.kwargs) + self.ready = True + self.cgiserver.run() + + def stop(self): + """Stop the HTTP server.""" + self.ready = False + + +class FlupFCGIServer(object): + + """Adapter for a flup.server.fcgi.WSGIServer.""" + + def __init__(self, *args, **kwargs): + if kwargs.get('bindAddress', None) is None: + import socket + if not hasattr(socket, 'fromfd'): + raise ValueError( + 'Dynamic FCGI server not available on this platform. ' + 'You must use a static or external one by providing a ' + 'legal bindAddress.') + self.args = args + self.kwargs = kwargs + self.ready = False + + def start(self): + """Start the FCGI server.""" + # We have to instantiate the server class here because its __init__ + # starts a threadpool. If we do it too early, daemonize won't work. + from flup.server.fcgi import WSGIServer + self.fcgiserver = WSGIServer(*self.args, **self.kwargs) + # TODO: report this bug upstream to flup. + # If we don't set _oldSIGs on Windows, we get: + # File "C:\Python24\Lib\site-packages\flup\server\threadedserver.py", + # line 108, in run + # self._restoreSignalHandlers() + # File "C:\Python24\Lib\site-packages\flup\server\threadedserver.py", + # line 156, in _restoreSignalHandlers + # for signum,handler in self._oldSIGs: + # AttributeError: 'WSGIServer' object has no attribute '_oldSIGs' + self.fcgiserver._installSignalHandlers = lambda: None + self.fcgiserver._oldSIGs = [] + self.ready = True + self.fcgiserver.run() + + def stop(self): + """Stop the HTTP server.""" + # Forcibly stop the fcgi server main event loop. + self.fcgiserver._keepGoing = False + # Force all worker threads to die off. + self.fcgiserver._threadPool.maxSpare = ( + self.fcgiserver._threadPool._idleCount) + self.ready = False + + +class FlupSCGIServer(object): + + """Adapter for a flup.server.scgi.WSGIServer.""" + + def __init__(self, *args, **kwargs): + self.args = args + self.kwargs = kwargs + self.ready = False + + def start(self): + """Start the SCGI server.""" + # We have to instantiate the server class here because its __init__ + # starts a threadpool. If we do it too early, daemonize won't work. + from flup.server.scgi import WSGIServer + self.scgiserver = WSGIServer(*self.args, **self.kwargs) + # TODO: report this bug upstream to flup. + # If we don't set _oldSIGs on Windows, we get: + # File "C:\Python24\Lib\site-packages\flup\server\threadedserver.py", + # line 108, in run + # self._restoreSignalHandlers() + # File "C:\Python24\Lib\site-packages\flup\server\threadedserver.py", + # line 156, in _restoreSignalHandlers + # for signum,handler in self._oldSIGs: + # AttributeError: 'WSGIServer' object has no attribute '_oldSIGs' + self.scgiserver._installSignalHandlers = lambda: None + self.scgiserver._oldSIGs = [] + self.ready = True + self.scgiserver.run() + + def stop(self): + """Stop the HTTP server.""" + self.ready = False + # Forcibly stop the scgi server main event loop. + self.scgiserver._keepGoing = False + # Force all worker threads to die off. + self.scgiserver._threadPool.maxSpare = 0 + + +@contextlib.contextmanager +def _safe_wait(host, port): + """ + On systems where a loopback interface is not available and the + server is bound to all interfaces, it's difficult to determine + whether the server is in fact occupying the port. In this case, + just issue a warning and move on. See issue #1100. + """ + try: + yield + except portend.Timeout: + if host == portend.client_host(host): + raise + msg = 'Unable to verify that the server is bound on %r' % port + warnings.warn(msg) diff --git a/libraries/cherrypy/process/win32.py b/libraries/cherrypy/process/win32.py new file mode 100644 index 00000000..096b0278 --- /dev/null +++ b/libraries/cherrypy/process/win32.py @@ -0,0 +1,183 @@ +"""Windows service. Requires pywin32.""" + +import os +import win32api +import win32con +import win32event +import win32service +import win32serviceutil + +from cherrypy.process import wspbus, plugins + + +class ConsoleCtrlHandler(plugins.SimplePlugin): + + """A WSPBus plugin for handling Win32 console events (like Ctrl-C).""" + + def __init__(self, bus): + self.is_set = False + plugins.SimplePlugin.__init__(self, bus) + + def start(self): + if self.is_set: + self.bus.log('Handler for console events already set.', level=40) + return + + result = win32api.SetConsoleCtrlHandler(self.handle, 1) + if result == 0: + self.bus.log('Could not SetConsoleCtrlHandler (error %r)' % + win32api.GetLastError(), level=40) + else: + self.bus.log('Set handler for console events.', level=40) + self.is_set = True + + def stop(self): + if not self.is_set: + self.bus.log('Handler for console events already off.', level=40) + return + + try: + result = win32api.SetConsoleCtrlHandler(self.handle, 0) + except ValueError: + # "ValueError: The object has not been registered" + result = 1 + + if result == 0: + self.bus.log('Could not remove SetConsoleCtrlHandler (error %r)' % + win32api.GetLastError(), level=40) + else: + self.bus.log('Removed handler for console events.', level=40) + self.is_set = False + + def handle(self, event): + """Handle console control events (like Ctrl-C).""" + if event in (win32con.CTRL_C_EVENT, win32con.CTRL_LOGOFF_EVENT, + win32con.CTRL_BREAK_EVENT, win32con.CTRL_SHUTDOWN_EVENT, + win32con.CTRL_CLOSE_EVENT): + self.bus.log('Console event %s: shutting down bus' % event) + + # Remove self immediately so repeated Ctrl-C doesn't re-call it. + try: + self.stop() + except ValueError: + pass + + self.bus.exit() + # 'First to return True stops the calls' + return 1 + return 0 + + +class Win32Bus(wspbus.Bus): + + """A Web Site Process Bus implementation for Win32. + + Instead of time.sleep, this bus blocks using native win32event objects. + """ + + def __init__(self): + self.events = {} + wspbus.Bus.__init__(self) + + def _get_state_event(self, state): + """Return a win32event for the given state (creating it if needed).""" + try: + return self.events[state] + except KeyError: + event = win32event.CreateEvent(None, 0, 0, + 'WSPBus %s Event (pid=%r)' % + (state.name, os.getpid())) + self.events[state] = event + return event + + @property + def state(self): + return self._state + + @state.setter + def state(self, value): + self._state = value + event = self._get_state_event(value) + win32event.PulseEvent(event) + + def wait(self, state, interval=0.1, channel=None): + """Wait for the given state(s), KeyboardInterrupt or SystemExit. + + Since this class uses native win32event objects, the interval + argument is ignored. + """ + if isinstance(state, (tuple, list)): + # Don't wait for an event that beat us to the punch ;) + if self.state not in state: + events = tuple([self._get_state_event(s) for s in state]) + win32event.WaitForMultipleObjects( + events, 0, win32event.INFINITE) + else: + # Don't wait for an event that beat us to the punch ;) + if self.state != state: + event = self._get_state_event(state) + win32event.WaitForSingleObject(event, win32event.INFINITE) + + +class _ControlCodes(dict): + + """Control codes used to "signal" a service via ControlService. + + User-defined control codes are in the range 128-255. We generally use + the standard Python value for the Linux signal and add 128. Example: + + >>> signal.SIGUSR1 + 10 + control_codes['graceful'] = 128 + 10 + """ + + def key_for(self, obj): + """For the given value, return its corresponding key.""" + for key, val in self.items(): + if val is obj: + return key + raise ValueError('The given object could not be found: %r' % obj) + + +control_codes = _ControlCodes({'graceful': 138}) + + +def signal_child(service, command): + if command == 'stop': + win32serviceutil.StopService(service) + elif command == 'restart': + win32serviceutil.RestartService(service) + else: + win32serviceutil.ControlService(service, control_codes[command]) + + +class PyWebService(win32serviceutil.ServiceFramework): + + """Python Web Service.""" + + _svc_name_ = 'Python Web Service' + _svc_display_name_ = 'Python Web Service' + _svc_deps_ = None # sequence of service names on which this depends + _exe_name_ = 'pywebsvc' + _exe_args_ = None # Default to no arguments + + # Only exists on Windows 2000 or later, ignored on windows NT + _svc_description_ = 'Python Web Service' + + def SvcDoRun(self): + from cherrypy import process + process.bus.start() + process.bus.block() + + def SvcStop(self): + from cherrypy import process + self.ReportServiceStatus(win32service.SERVICE_STOP_PENDING) + process.bus.exit() + + def SvcOther(self, control): + from cherrypy import process + process.bus.publish(control_codes.key_for(control)) + + +if __name__ == '__main__': + win32serviceutil.HandleCommandLine(PyWebService) diff --git a/libraries/cherrypy/process/wspbus.py b/libraries/cherrypy/process/wspbus.py new file mode 100644 index 00000000..39ac45bf --- /dev/null +++ b/libraries/cherrypy/process/wspbus.py @@ -0,0 +1,590 @@ +r"""An implementation of the Web Site Process Bus. + +This module is completely standalone, depending only on the stdlib. + +Web Site Process Bus +-------------------- + +A Bus object is used to contain and manage site-wide behavior: +daemonization, HTTP server start/stop, process reload, signal handling, +drop privileges, PID file management, logging for all of these, +and many more. + +In addition, a Bus object provides a place for each web framework +to register code that runs in response to site-wide events (like +process start and stop), or which controls or otherwise interacts with +the site-wide components mentioned above. For example, a framework which +uses file-based templates would add known template filenames to an +autoreload component. + +Ideally, a Bus object will be flexible enough to be useful in a variety +of invocation scenarios: + + 1. The deployer starts a site from the command line via a + framework-neutral deployment script; applications from multiple frameworks + are mixed in a single site. Command-line arguments and configuration + files are used to define site-wide components such as the HTTP server, + WSGI component graph, autoreload behavior, signal handling, etc. + 2. The deployer starts a site via some other process, such as Apache; + applications from multiple frameworks are mixed in a single site. + Autoreload and signal handling (from Python at least) are disabled. + 3. The deployer starts a site via a framework-specific mechanism; + for example, when running tests, exploring tutorials, or deploying + single applications from a single framework. The framework controls + which site-wide components are enabled as it sees fit. + +The Bus object in this package uses topic-based publish-subscribe +messaging to accomplish all this. A few topic channels are built in +('start', 'stop', 'exit', 'graceful', 'log', and 'main'). Frameworks and +site containers are free to define their own. If a message is sent to a +channel that has not been defined or has no listeners, there is no effect. + +In general, there should only ever be a single Bus object per process. +Frameworks and site containers share a single Bus object by publishing +messages and subscribing listeners. + +The Bus object works as a finite state machine which models the current +state of the process. Bus methods move it from one state to another; +those methods then publish to subscribed listeners on the channel for +the new state.:: + + O + | + V + STOPPING --> STOPPED --> EXITING -> X + A A | + | \___ | + | \ | + | V V + STARTED <-- STARTING + +""" + +import atexit + +try: + import ctypes +except (ImportError, MemoryError): + """Google AppEngine is shipped without ctypes + + :seealso: http://stackoverflow.com/a/6523777/70170 + """ + ctypes = None + +import operator +import os +import sys +import threading +import time +import traceback as _traceback +import warnings +import subprocess +import functools + +import six + + +# Here I save the value of os.getcwd(), which, if I am imported early enough, +# will be the directory from which the startup script was run. This is needed +# by _do_execv(), to change back to the original directory before execv()ing a +# new process. This is a defense against the application having changed the +# current working directory (which could make sys.executable "not found" if +# sys.executable is a relative-path, and/or cause other problems). +_startup_cwd = os.getcwd() + + +class ChannelFailures(Exception): + """Exception raised during errors on Bus.publish().""" + + delimiter = '\n' + + def __init__(self, *args, **kwargs): + """Initialize ChannelFailures errors wrapper.""" + super(ChannelFailures, self).__init__(*args, **kwargs) + self._exceptions = list() + + def handle_exception(self): + """Append the current exception to self.""" + self._exceptions.append(sys.exc_info()[1]) + + def get_instances(self): + """Return a list of seen exception instances.""" + return self._exceptions[:] + + def __str__(self): + """Render the list of errors, which happened in channel.""" + exception_strings = map(repr, self.get_instances()) + return self.delimiter.join(exception_strings) + + __repr__ = __str__ + + def __bool__(self): + """Determine whether any error happened in channel.""" + return bool(self._exceptions) + __nonzero__ = __bool__ + +# Use a flag to indicate the state of the bus. + + +class _StateEnum(object): + + class State(object): + name = None + + def __repr__(self): + return 'states.%s' % self.name + + def __setattr__(self, key, value): + if isinstance(value, self.State): + value.name = key + object.__setattr__(self, key, value) + + +states = _StateEnum() +states.STOPPED = states.State() +states.STARTING = states.State() +states.STARTED = states.State() +states.STOPPING = states.State() +states.EXITING = states.State() + + +try: + import fcntl +except ImportError: + max_files = 0 +else: + try: + max_files = os.sysconf('SC_OPEN_MAX') + except AttributeError: + max_files = 1024 + + +class Bus(object): + """Process state-machine and messenger for HTTP site deployment. + + All listeners for a given channel are guaranteed to be called even + if others at the same channel fail. Each failure is logged, but + execution proceeds on to the next listener. The only way to stop all + processing from inside a listener is to raise SystemExit and stop the + whole server. + """ + + states = states + state = states.STOPPED + execv = False + max_cloexec_files = max_files + + def __init__(self): + """Initialize pub/sub bus.""" + self.execv = False + self.state = states.STOPPED + channels = 'start', 'stop', 'exit', 'graceful', 'log', 'main' + self.listeners = dict( + (channel, set()) + for channel in channels + ) + self._priorities = {} + + def subscribe(self, channel, callback=None, priority=None): + """Add the given callback at the given channel (if not present). + + If callback is None, return a partial suitable for decorating + the callback. + """ + if callback is None: + return functools.partial( + self.subscribe, + channel, + priority=priority, + ) + + ch_listeners = self.listeners.setdefault(channel, set()) + ch_listeners.add(callback) + + if priority is None: + priority = getattr(callback, 'priority', 50) + self._priorities[(channel, callback)] = priority + + def unsubscribe(self, channel, callback): + """Discard the given callback (if present).""" + listeners = self.listeners.get(channel) + if listeners and callback in listeners: + listeners.discard(callback) + del self._priorities[(channel, callback)] + + def publish(self, channel, *args, **kwargs): + """Return output of all subscribers for the given channel.""" + if channel not in self.listeners: + return [] + + exc = ChannelFailures() + output = [] + + raw_items = ( + (self._priorities[(channel, listener)], listener) + for listener in self.listeners[channel] + ) + items = sorted(raw_items, key=operator.itemgetter(0)) + for priority, listener in items: + try: + output.append(listener(*args, **kwargs)) + except KeyboardInterrupt: + raise + except SystemExit: + e = sys.exc_info()[1] + # If we have previous errors ensure the exit code is non-zero + if exc and e.code == 0: + e.code = 1 + raise + except Exception: + exc.handle_exception() + if channel == 'log': + # Assume any further messages to 'log' will fail. + pass + else: + self.log('Error in %r listener %r' % (channel, listener), + level=40, traceback=True) + if exc: + raise exc + return output + + def _clean_exit(self): + """Assert that the Bus is not running in atexit handler callback.""" + if self.state != states.EXITING: + warnings.warn( + 'The main thread is exiting, but the Bus is in the %r state; ' + 'shutting it down automatically now. You must either call ' + 'bus.block() after start(), or call bus.exit() before the ' + 'main thread exits.' % self.state, RuntimeWarning) + self.exit() + + def start(self): + """Start all services.""" + atexit.register(self._clean_exit) + + self.state = states.STARTING + self.log('Bus STARTING') + try: + self.publish('start') + self.state = states.STARTED + self.log('Bus STARTED') + except (KeyboardInterrupt, SystemExit): + raise + except Exception: + self.log('Shutting down due to error in start listener:', + level=40, traceback=True) + e_info = sys.exc_info()[1] + try: + self.exit() + except Exception: + # Any stop/exit errors will be logged inside publish(). + pass + # Re-raise the original error + raise e_info + + def exit(self): + """Stop all services and prepare to exit the process.""" + exitstate = self.state + EX_SOFTWARE = 70 + try: + self.stop() + + self.state = states.EXITING + self.log('Bus EXITING') + self.publish('exit') + # This isn't strictly necessary, but it's better than seeing + # "Waiting for child threads to terminate..." and then nothing. + self.log('Bus EXITED') + except Exception: + # This method is often called asynchronously (whether thread, + # signal handler, console handler, or atexit handler), so we + # can't just let exceptions propagate out unhandled. + # Assume it's been logged and just die. + os._exit(EX_SOFTWARE) + + if exitstate == states.STARTING: + # exit() was called before start() finished, possibly due to + # Ctrl-C because a start listener got stuck. In this case, + # we could get stuck in a loop where Ctrl-C never exits the + # process, so we just call os.exit here. + os._exit(EX_SOFTWARE) + + def restart(self): + """Restart the process (may close connections). + + This method does not restart the process from the calling thread; + instead, it stops the bus and asks the main thread to call execv. + """ + self.execv = True + self.exit() + + def graceful(self): + """Advise all services to reload.""" + self.log('Bus graceful') + self.publish('graceful') + + def block(self, interval=0.1): + """Wait for the EXITING state, KeyboardInterrupt or SystemExit. + + This function is intended to be called only by the main thread. + After waiting for the EXITING state, it also waits for all threads + to terminate, and then calls os.execv if self.execv is True. This + design allows another thread to call bus.restart, yet have the main + thread perform the actual execv call (required on some platforms). + """ + try: + self.wait(states.EXITING, interval=interval, channel='main') + except (KeyboardInterrupt, IOError): + # The time.sleep call might raise + # "IOError: [Errno 4] Interrupted function call" on KBInt. + self.log('Keyboard Interrupt: shutting down bus') + self.exit() + except SystemExit: + self.log('SystemExit raised: shutting down bus') + self.exit() + raise + + # Waiting for ALL child threads to finish is necessary on OS X. + # See https://github.com/cherrypy/cherrypy/issues/581. + # It's also good to let them all shut down before allowing + # the main thread to call atexit handlers. + # See https://github.com/cherrypy/cherrypy/issues/751. + self.log('Waiting for child threads to terminate...') + for t in threading.enumerate(): + # Validate the we're not trying to join the MainThread + # that will cause a deadlock and the case exist when + # implemented as a windows service and in any other case + # that another thread executes cherrypy.engine.exit() + if ( + t != threading.currentThread() and + not isinstance(t, threading._MainThread) and + # Note that any dummy (external) threads are + # always daemonic. + not t.daemon + ): + self.log('Waiting for thread %s.' % t.getName()) + t.join() + + if self.execv: + self._do_execv() + + def wait(self, state, interval=0.1, channel=None): + """Poll for the given state(s) at intervals; publish to channel.""" + if isinstance(state, (tuple, list)): + states = state + else: + states = [state] + + while self.state not in states: + time.sleep(interval) + self.publish(channel) + + def _do_execv(self): + """Re-execute the current process. + + This must be called from the main thread, because certain platforms + (OS X) don't allow execv to be called in a child thread very well. + """ + try: + args = self._get_true_argv() + except NotImplementedError: + """It's probably win32 or GAE""" + args = [sys.executable] + self._get_interpreter_argv() + sys.argv + + self.log('Re-spawning %s' % ' '.join(args)) + + self._extend_pythonpath(os.environ) + + if sys.platform[:4] == 'java': + from _systemrestart import SystemRestart + raise SystemRestart + else: + if sys.platform == 'win32': + args = ['"%s"' % arg for arg in args] + + os.chdir(_startup_cwd) + if self.max_cloexec_files: + self._set_cloexec() + os.execv(sys.executable, args) + + @staticmethod + def _get_interpreter_argv(): + """Retrieve current Python interpreter's arguments. + + Returns empty tuple in case of frozen mode, uses built-in arguments + reproduction function otherwise. + + Frozen mode is possible for the app has been packaged into a binary + executable using py2exe. In this case the interpreter's arguments are + already built-in into that executable. + + :seealso: https://github.com/cherrypy/cherrypy/issues/1526 + Ref: https://pythonhosted.org/PyInstaller/runtime-information.html + """ + return ([] + if getattr(sys, 'frozen', False) + else subprocess._args_from_interpreter_flags()) + + @staticmethod + def _get_true_argv(): + """Retrieve all real arguments of the python interpreter. + + ...even those not listed in ``sys.argv`` + + :seealso: http://stackoverflow.com/a/28338254/595220 + :seealso: http://stackoverflow.com/a/6683222/595220 + :seealso: http://stackoverflow.com/a/28414807/595220 + """ + try: + char_p = ctypes.c_char_p if six.PY2 else ctypes.c_wchar_p + + argv = ctypes.POINTER(char_p)() + argc = ctypes.c_int() + + ctypes.pythonapi.Py_GetArgcArgv( + ctypes.byref(argc), + ctypes.byref(argv), + ) + + _argv = argv[:argc.value] + + # The code below is trying to correctly handle special cases. + # `-c`'s argument interpreted by Python itself becomes `-c` as + # well. Same applies to `-m`. This snippet is trying to survive + # at least the case with `-m` + # Ref: https://github.com/cherrypy/cherrypy/issues/1545 + # Ref: python/cpython@418baf9 + argv_len, is_command, is_module = len(_argv), False, False + + try: + m_ind = _argv.index('-m') + if m_ind < argv_len - 1 and _argv[m_ind + 1] in ('-c', '-m'): + """ + In some older Python versions `-m`'s argument may be + substituted with `-c`, not `-m` + """ + is_module = True + except (IndexError, ValueError): + m_ind = None + + try: + c_ind = _argv.index('-c') + if c_ind < argv_len - 1 and _argv[c_ind + 1] == '-c': + is_command = True + except (IndexError, ValueError): + c_ind = None + + if is_module: + """It's containing `-m -m` sequence of arguments""" + if is_command and c_ind < m_ind: + """There's `-c -c` before `-m`""" + raise RuntimeError( + "Cannot reconstruct command from '-c'. Ref: " + 'https://github.com/cherrypy/cherrypy/issues/1545') + # Survive module argument here + original_module = sys.argv[0] + if not os.access(original_module, os.R_OK): + """There's no such module exist""" + raise AttributeError( + "{} doesn't seem to be a module " + 'accessible by current user'.format(original_module)) + del _argv[m_ind:m_ind + 2] # remove `-m -m` + # ... and substitute it with the original module path: + _argv.insert(m_ind, original_module) + elif is_command: + """It's containing just `-c -c` sequence of arguments""" + raise RuntimeError( + "Cannot reconstruct command from '-c'. " + 'Ref: https://github.com/cherrypy/cherrypy/issues/1545') + except AttributeError: + """It looks Py_GetArgcArgv is completely absent in some environments + + It is known, that there's no Py_GetArgcArgv in MS Windows and + ``ctypes`` module is completely absent in Google AppEngine + + :seealso: https://github.com/cherrypy/cherrypy/issues/1506 + :seealso: https://github.com/cherrypy/cherrypy/issues/1512 + :ref: http://bit.ly/2gK6bXK + """ + raise NotImplementedError + else: + return _argv + + @staticmethod + def _extend_pythonpath(env): + """Prepend current working dir to PATH environment variable if needed. + + If sys.path[0] is an empty string, the interpreter was likely + invoked with -m and the effective path is about to change on + re-exec. Add the current directory to $PYTHONPATH to ensure + that the new process sees the same path. + + This issue cannot be addressed in the general case because + Python cannot reliably reconstruct the + original command line (http://bugs.python.org/issue14208). + + (This idea filched from tornado.autoreload) + """ + path_prefix = '.' + os.pathsep + existing_path = env.get('PYTHONPATH', '') + needs_patch = ( + sys.path[0] == '' and + not existing_path.startswith(path_prefix) + ) + + if needs_patch: + env['PYTHONPATH'] = path_prefix + existing_path + + def _set_cloexec(self): + """Set the CLOEXEC flag on all open files (except stdin/out/err). + + If self.max_cloexec_files is an integer (the default), then on + platforms which support it, it represents the max open files setting + for the operating system. This function will be called just before + the process is restarted via os.execv() to prevent open files + from persisting into the new process. + + Set self.max_cloexec_files to 0 to disable this behavior. + """ + for fd in range(3, self.max_cloexec_files): # skip stdin/out/err + try: + flags = fcntl.fcntl(fd, fcntl.F_GETFD) + except IOError: + continue + fcntl.fcntl(fd, fcntl.F_SETFD, flags | fcntl.FD_CLOEXEC) + + def stop(self): + """Stop all services.""" + self.state = states.STOPPING + self.log('Bus STOPPING') + self.publish('stop') + self.state = states.STOPPED + self.log('Bus STOPPED') + + def start_with_callback(self, func, args=None, kwargs=None): + """Start 'func' in a new thread T, then start self (and return T).""" + if args is None: + args = () + if kwargs is None: + kwargs = {} + args = (func,) + args + + def _callback(func, *a, **kw): + self.wait(states.STARTED) + func(*a, **kw) + t = threading.Thread(target=_callback, args=args, kwargs=kwargs) + t.setName('Bus Callback ' + t.getName()) + t.start() + + self.start() + + return t + + def log(self, msg='', level=20, traceback=False): + """Log the given message. Append the last traceback if requested.""" + if traceback: + msg += '\n' + ''.join(_traceback.format_exception(*sys.exc_info())) + self.publish('log', msg, level) + + +bus = Bus() diff --git a/libraries/cherrypy/scaffold/__init__.py b/libraries/cherrypy/scaffold/__init__.py new file mode 100644 index 00000000..bcddba2d --- /dev/null +++ b/libraries/cherrypy/scaffold/__init__.py @@ -0,0 +1,63 @@ +"""<MyProject>, a CherryPy application. + +Use this as a base for creating new CherryPy applications. When you want +to make a new app, copy and paste this folder to some other location +(maybe site-packages) and rename it to the name of your project, +then tweak as desired. + +Even before any tweaking, this should serve a few demonstration pages. +Change to this directory and run: + + cherryd -c site.conf + +""" + +import cherrypy +from cherrypy import tools, url + +import os +local_dir = os.path.join(os.getcwd(), os.path.dirname(__file__)) + + +@cherrypy.config(**{'tools.log_tracebacks.on': True}) +class Root: + """Declaration of the CherryPy app URI structure.""" + + @cherrypy.expose + def index(self): + """Render HTML-template at the root path of the web-app.""" + return """<html> +<body>Try some <a href='%s?a=7'>other</a> path, +or a <a href='%s?n=14'>default</a> path.<br /> +Or, just look at the pretty picture:<br /> +<img src='%s' /> +</body></html>""" % (url('other'), url('else'), + url('files/made_with_cherrypy_small.png')) + + @cherrypy.expose + def default(self, *args, **kwargs): + """Render catch-all args and kwargs.""" + return 'args: %s kwargs: %s' % (args, kwargs) + + @cherrypy.expose + def other(self, a=2, b='bananas', c=None): + """Render number of fruits based on third argument.""" + cherrypy.response.headers['Content-Type'] = 'text/plain' + if c is None: + return 'Have %d %s.' % (int(a), b) + else: + return 'Have %d %s, %s.' % (int(a), b, c) + + files = tools.staticdir.handler( + section='/files', + dir=os.path.join(local_dir, 'static'), + # Ignore .php files, etc. + match=r'\.(css|gif|html?|ico|jpe?g|js|png|swf|xml)$', + ) + + +root = Root() + +# Uncomment the following to use your own favicon instead of CP's default. +# favicon_path = os.path.join(local_dir, "favicon.ico") +# root.favicon_ico = tools.staticfile.handler(filename=favicon_path) diff --git a/libraries/cherrypy/scaffold/apache-fcgi.conf b/libraries/cherrypy/scaffold/apache-fcgi.conf new file mode 100644 index 00000000..6e4f144c --- /dev/null +++ b/libraries/cherrypy/scaffold/apache-fcgi.conf @@ -0,0 +1,22 @@ +# Apache2 server conf file for using CherryPy with mod_fcgid. + +# This doesn't have to be "C:/", but it has to be a directory somewhere, and +# MUST match the directory used in the FastCgiExternalServer directive, below. +DocumentRoot "C:/" + +ServerName 127.0.0.1 +Listen 80 +LoadModule fastcgi_module modules/mod_fastcgi.dll +LoadModule rewrite_module modules/mod_rewrite.so + +Options ExecCGI +SetHandler fastcgi-script +RewriteEngine On +# Send requests for any URI to our fastcgi handler. +RewriteRule ^(.*)$ /fastcgi.pyc [L] + +# The FastCgiExternalServer directive defines filename as an external FastCGI application. +# If filename does not begin with a slash (/) then it is assumed to be relative to the ServerRoot. +# The filename does not have to exist in the local filesystem. URIs that Apache resolves to this +# filename will be handled by this external FastCGI application. +FastCgiExternalServer "C:/fastcgi.pyc" -host 127.0.0.1:8088 diff --git a/libraries/cherrypy/scaffold/example.conf b/libraries/cherrypy/scaffold/example.conf new file mode 100644 index 00000000..63250fe3 --- /dev/null +++ b/libraries/cherrypy/scaffold/example.conf @@ -0,0 +1,3 @@ +[/] +log.error_file: "error.log" +log.access_file: "access.log" diff --git a/libraries/cherrypy/scaffold/site.conf b/libraries/cherrypy/scaffold/site.conf new file mode 100644 index 00000000..6ed38983 --- /dev/null +++ b/libraries/cherrypy/scaffold/site.conf @@ -0,0 +1,14 @@ +[global] +# Uncomment this when you're done developing +#environment: "production" + +server.socket_host: "0.0.0.0" +server.socket_port: 8088 + +# Uncomment the following lines to run on HTTPS at the same time +#server.2.socket_host: "0.0.0.0" +#server.2.socket_port: 8433 +#server.2.ssl_certificate: '../test/test.pem' +#server.2.ssl_private_key: '../test/test.pem' + +tree.myapp: cherrypy.Application(scaffold.root, "/", "example.conf") diff --git a/libraries/cherrypy/scaffold/static/made_with_cherrypy_small.png b/libraries/cherrypy/scaffold/static/made_with_cherrypy_small.png new file mode 100644 index 00000000..724f9d72 Binary files /dev/null and b/libraries/cherrypy/scaffold/static/made_with_cherrypy_small.png differ diff --git a/libraries/cherrypy/test/__init__.py b/libraries/cherrypy/test/__init__.py new file mode 100644 index 00000000..068382be --- /dev/null +++ b/libraries/cherrypy/test/__init__.py @@ -0,0 +1,24 @@ +""" +Regression test suite for CherryPy. +""" + +import os +import sys + + +def newexit(): + os._exit(1) + + +def setup(): + # We want to monkey patch sys.exit so that we can get some + # information about where exit is being called. + newexit._old = sys.exit + sys.exit = newexit + + +def teardown(): + try: + sys.exit = sys.exit._old + except AttributeError: + sys.exit = sys._exit diff --git a/libraries/cherrypy/test/_test_decorators.py b/libraries/cherrypy/test/_test_decorators.py new file mode 100644 index 00000000..74832e40 --- /dev/null +++ b/libraries/cherrypy/test/_test_decorators.py @@ -0,0 +1,39 @@ +"""Test module for the @-decorator syntax, which is version-specific""" + +import cherrypy +from cherrypy import expose, tools + + +class ExposeExamples(object): + + @expose + def no_call(self): + return 'Mr E. R. Bradshaw' + + @expose() + def call_empty(self): + return 'Mrs. B.J. Smegma' + + @expose('call_alias') + def nesbitt(self): + return 'Mr Nesbitt' + + @expose(['alias1', 'alias2']) + def andrews(self): + return 'Mr Ken Andrews' + + @expose(alias='alias3') + def watson(self): + return 'Mr. and Mrs. Watson' + + +class ToolExamples(object): + + @expose + # This is here to demonstrate that using the config decorator + # does not overwrite other config attributes added by the Tool + # decorator (in this case response_headers). + @cherrypy.config(**{'response.stream': True}) + @tools.response_headers(headers=[('Content-Type', 'application/data')]) + def blah(self): + yield b'blah' diff --git a/libraries/cherrypy/test/_test_states_demo.py b/libraries/cherrypy/test/_test_states_demo.py new file mode 100644 index 00000000..a49407ba --- /dev/null +++ b/libraries/cherrypy/test/_test_states_demo.py @@ -0,0 +1,69 @@ +import os +import sys +import time + +import cherrypy + +starttime = time.time() + + +class Root: + + @cherrypy.expose + def index(self): + return 'Hello World' + + @cherrypy.expose + def mtimes(self): + return repr(cherrypy.engine.publish('Autoreloader', 'mtimes')) + + @cherrypy.expose + def pid(self): + return str(os.getpid()) + + @cherrypy.expose + def start(self): + return repr(starttime) + + @cherrypy.expose + def exit(self): + # This handler might be called before the engine is STARTED if an + # HTTP worker thread handles it before the HTTP server returns + # control to engine.start. We avoid that race condition here + # by waiting for the Bus to be STARTED. + cherrypy.engine.wait(state=cherrypy.engine.states.STARTED) + cherrypy.engine.exit() + + +@cherrypy.engine.subscribe('start', priority=100) +def unsub_sig(): + cherrypy.log('unsubsig: %s' % cherrypy.config.get('unsubsig', False)) + if cherrypy.config.get('unsubsig', False): + cherrypy.log('Unsubscribing the default cherrypy signal handler') + cherrypy.engine.signal_handler.unsubscribe() + try: + from signal import signal, SIGTERM + except ImportError: + pass + else: + def old_term_handler(signum=None, frame=None): + cherrypy.log('I am an old SIGTERM handler.') + sys.exit(0) + cherrypy.log('Subscribing the new one.') + signal(SIGTERM, old_term_handler) + + +@cherrypy.engine.subscribe('start', priority=6) +def starterror(): + if cherrypy.config.get('starterror', False): + 1 / 0 + + +@cherrypy.engine.subscribe('start', priority=6) +def log_test_case_name(): + if cherrypy.config.get('test_case_name', False): + cherrypy.log('STARTED FROM: %s' % + cherrypy.config.get('test_case_name')) + + +cherrypy.tree.mount(Root(), '/', {'/': {}}) diff --git a/libraries/cherrypy/test/benchmark.py b/libraries/cherrypy/test/benchmark.py new file mode 100644 index 00000000..44dfeff1 --- /dev/null +++ b/libraries/cherrypy/test/benchmark.py @@ -0,0 +1,425 @@ +"""CherryPy Benchmark Tool + + Usage: + benchmark.py [options] + + --null: use a null Request object (to bench the HTTP server only) + --notests: start the server but do not run the tests; this allows + you to check the tested pages with a browser + --help: show this help message + --cpmodpy: run tests via apache on 54583 (with the builtin _cpmodpy) + --modpython: run tests via apache on 54583 (with modpython_gateway) + --ab=path: Use the ab script/executable at 'path' (see below) + --apache=path: Use the apache script/exe at 'path' (see below) + + To run the benchmarks, the Apache Benchmark tool "ab" must either be on + your system path, or specified via the --ab=path option. + + To run the modpython tests, the "apache" executable or script must be + on your system path, or provided via the --apache=path option. On some + platforms, "apache" may be called "apachectl" or "apache2ctl"--create + a symlink to them if needed. +""" + +import getopt +import os +import re +import sys +import time + +import cherrypy +from cherrypy import _cperror, _cpmodpy +from cherrypy.lib import httputil + + +curdir = os.path.join(os.getcwd(), os.path.dirname(__file__)) + +AB_PATH = '' +APACHE_PATH = 'apache' +SCRIPT_NAME = '/cpbench/users/rdelon/apps/blog' + +__all__ = ['ABSession', 'Root', 'print_report', + 'run_standard_benchmarks', 'safe_threads', + 'size_report', 'thread_report', + ] + +size_cache = {} + + +class Root: + + @cherrypy.expose + def index(self): + return """<html> +<head> + <title>CherryPy Benchmark</title> +</head> +<body> + <ul> + <li><a href="hello">Hello, world! (14 byte dynamic)</a></li> + <li><a href="static/index.html">Static file (14 bytes static)</a></li> + <li><form action="sizer">Response of length: + <input type='text' name='size' value='10' /></form> + </li> + </ul> +</body> +</html>""" + + @cherrypy.expose + def hello(self): + return 'Hello, world\r\n' + + @cherrypy.expose + def sizer(self, size): + resp = size_cache.get(size, None) + if resp is None: + size_cache[size] = resp = 'X' * int(size) + return resp + + +def init(): + + cherrypy.config.update({ + 'log.error.file': '', + 'environment': 'production', + 'server.socket_host': '127.0.0.1', + 'server.socket_port': 54583, + 'server.max_request_header_size': 0, + 'server.max_request_body_size': 0, + }) + + # Cheat mode on ;) + del cherrypy.config['tools.log_tracebacks.on'] + del cherrypy.config['tools.log_headers.on'] + del cherrypy.config['tools.trailing_slash.on'] + + appconf = { + '/static': { + 'tools.staticdir.on': True, + 'tools.staticdir.dir': 'static', + 'tools.staticdir.root': curdir, + }, + } + globals().update( + app=cherrypy.tree.mount(Root(), SCRIPT_NAME, appconf), + ) + + +class NullRequest: + + """A null HTTP request class, returning 200 and an empty body.""" + + def __init__(self, local, remote, scheme='http'): + pass + + def close(self): + pass + + def run(self, method, path, query_string, protocol, headers, rfile): + cherrypy.response.status = '200 OK' + cherrypy.response.header_list = [('Content-Type', 'text/html'), + ('Server', 'Null CherryPy'), + ('Date', httputil.HTTPDate()), + ('Content-Length', '0'), + ] + cherrypy.response.body = [''] + return cherrypy.response + + +class NullResponse: + pass + + +class ABSession: + + """A session of 'ab', the Apache HTTP server benchmarking tool. + +Example output from ab: + +This is ApacheBench, Version 2.0.40-dev <$Revision: 1.121.2.1 $> apache-2.0 +Copyright (c) 1996 Adam Twiss, Zeus Technology Ltd, http://www.zeustech.net/ +Copyright (c) 1998-2002 The Apache Software Foundation, http://www.apache.org/ + +Benchmarking 127.0.0.1 (be patient) +Completed 100 requests +Completed 200 requests +Completed 300 requests +Completed 400 requests +Completed 500 requests +Completed 600 requests +Completed 700 requests +Completed 800 requests +Completed 900 requests + + +Server Software: CherryPy/3.1beta +Server Hostname: 127.0.0.1 +Server Port: 54583 + +Document Path: /static/index.html +Document Length: 14 bytes + +Concurrency Level: 10 +Time taken for tests: 9.643867 seconds +Complete requests: 1000 +Failed requests: 0 +Write errors: 0 +Total transferred: 189000 bytes +HTML transferred: 14000 bytes +Requests per second: 103.69 [#/sec] (mean) +Time per request: 96.439 [ms] (mean) +Time per request: 9.644 [ms] (mean, across all concurrent requests) +Transfer rate: 19.08 [Kbytes/sec] received + +Connection Times (ms) + min mean[+/-sd] median max +Connect: 0 0 2.9 0 10 +Processing: 20 94 7.3 90 130 +Waiting: 0 43 28.1 40 100 +Total: 20 95 7.3 100 130 + +Percentage of the requests served within a certain time (ms) + 50% 100 + 66% 100 + 75% 100 + 80% 100 + 90% 100 + 95% 100 + 98% 100 + 99% 110 + 100% 130 (longest request) +Finished 1000 requests +""" + + parse_patterns = [ + ('complete_requests', 'Completed', + br'^Complete requests:\s*(\d+)'), + ('failed_requests', 'Failed', + br'^Failed requests:\s*(\d+)'), + ('requests_per_second', 'req/sec', + br'^Requests per second:\s*([0-9.]+)'), + ('time_per_request_concurrent', 'msec/req', + br'^Time per request:\s*([0-9.]+).*concurrent requests\)$'), + ('transfer_rate', 'KB/sec', + br'^Transfer rate:\s*([0-9.]+)') + ] + + def __init__(self, path=SCRIPT_NAME + '/hello', requests=1000, + concurrency=10): + self.path = path + self.requests = requests + self.concurrency = concurrency + + def args(self): + port = cherrypy.server.socket_port + assert self.concurrency > 0 + assert self.requests > 0 + # Don't use "localhost". + # Cf + # http://mail.python.org/pipermail/python-win32/2008-March/007050.html + return ('-k -n %s -c %s http://127.0.0.1:%s%s' % + (self.requests, self.concurrency, port, self.path)) + + def run(self): + # Parse output of ab, setting attributes on self + try: + self.output = _cpmodpy.read_process(AB_PATH or 'ab', self.args()) + except Exception: + print(_cperror.format_exc()) + raise + + for attr, name, pattern in self.parse_patterns: + val = re.search(pattern, self.output, re.MULTILINE) + if val: + val = val.group(1) + setattr(self, attr, val) + else: + setattr(self, attr, None) + + +safe_threads = (25, 50, 100, 200, 400) +if sys.platform in ('win32',): + # For some reason, ab crashes with > 50 threads on my Win2k laptop. + safe_threads = (10, 20, 30, 40, 50) + + +def thread_report(path=SCRIPT_NAME + '/hello', concurrency=safe_threads): + sess = ABSession(path) + attrs, names, patterns = list(zip(*sess.parse_patterns)) + avg = dict.fromkeys(attrs, 0.0) + + yield ('threads',) + names + for c in concurrency: + sess.concurrency = c + sess.run() + row = [c] + for attr in attrs: + val = getattr(sess, attr) + if val is None: + print(sess.output) + row = None + break + val = float(val) + avg[attr] += float(val) + row.append(val) + if row: + yield row + + # Add a row of averages. + yield ['Average'] + [str(avg[attr] / len(concurrency)) for attr in attrs] + + +def size_report(sizes=(10, 100, 1000, 10000, 100000, 100000000), + concurrency=50): + sess = ABSession(concurrency=concurrency) + attrs, names, patterns = list(zip(*sess.parse_patterns)) + yield ('bytes',) + names + for sz in sizes: + sess.path = '%s/sizer?size=%s' % (SCRIPT_NAME, sz) + sess.run() + yield [sz] + [getattr(sess, attr) for attr in attrs] + + +def print_report(rows): + for row in rows: + print('') + for val in row: + sys.stdout.write(str(val).rjust(10) + ' | ') + print('') + + +def run_standard_benchmarks(): + print('') + print('Client Thread Report (1000 requests, 14 byte response body, ' + '%s server threads):' % cherrypy.server.thread_pool) + print_report(thread_report()) + + print('') + print('Client Thread Report (1000 requests, 14 bytes via staticdir, ' + '%s server threads):' % cherrypy.server.thread_pool) + print_report(thread_report('%s/static/index.html' % SCRIPT_NAME)) + + print('') + print('Size Report (1000 requests, 50 client threads, ' + '%s server threads):' % cherrypy.server.thread_pool) + print_report(size_report()) + + +# modpython and other WSGI # + +def startup_modpython(req=None): + """Start the CherryPy app server in 'serverless' mode (for modpython/WSGI). + """ + if cherrypy.engine.state == cherrypy._cpengine.STOPPED: + if req: + if 'nullreq' in req.get_options(): + cherrypy.engine.request_class = NullRequest + cherrypy.engine.response_class = NullResponse + ab_opt = req.get_options().get('ab', '') + if ab_opt: + global AB_PATH + AB_PATH = ab_opt + cherrypy.engine.start() + if cherrypy.engine.state == cherrypy._cpengine.STARTING: + cherrypy.engine.wait() + return 0 # apache.OK + + +def run_modpython(use_wsgi=False): + print('Starting mod_python...') + pyopts = [] + + # Pass the null and ab=path options through Apache + if '--null' in opts: + pyopts.append(('nullreq', '')) + + if '--ab' in opts: + pyopts.append(('ab', opts['--ab'])) + + s = _cpmodpy.ModPythonServer + if use_wsgi: + pyopts.append(('wsgi.application', 'cherrypy::tree')) + pyopts.append( + ('wsgi.startup', 'cherrypy.test.benchmark::startup_modpython')) + handler = 'modpython_gateway::handler' + s = s(port=54583, opts=pyopts, + apache_path=APACHE_PATH, handler=handler) + else: + pyopts.append( + ('cherrypy.setup', 'cherrypy.test.benchmark::startup_modpython')) + s = s(port=54583, opts=pyopts, apache_path=APACHE_PATH) + + try: + s.start() + run() + finally: + s.stop() + + +if __name__ == '__main__': + init() + + longopts = ['cpmodpy', 'modpython', 'null', 'notests', + 'help', 'ab=', 'apache='] + try: + switches, args = getopt.getopt(sys.argv[1:], '', longopts) + opts = dict(switches) + except getopt.GetoptError: + print(__doc__) + sys.exit(2) + + if '--help' in opts: + print(__doc__) + sys.exit(0) + + if '--ab' in opts: + AB_PATH = opts['--ab'] + + if '--notests' in opts: + # Return without stopping the server, so that the pages + # can be tested from a standard web browser. + def run(): + port = cherrypy.server.socket_port + print('You may now open http://127.0.0.1:%s%s/' % + (port, SCRIPT_NAME)) + + if '--null' in opts: + print('Using null Request object') + else: + def run(): + end = time.time() - start + print('Started in %s seconds' % end) + if '--null' in opts: + print('\nUsing null Request object') + try: + try: + run_standard_benchmarks() + except Exception: + print(_cperror.format_exc()) + raise + finally: + cherrypy.engine.exit() + + print('Starting CherryPy app server...') + + class NullWriter(object): + + """Suppresses the printing of socket errors.""" + + def write(self, data): + pass + sys.stderr = NullWriter() + + start = time.time() + + if '--cpmodpy' in opts: + run_modpython() + elif '--modpython' in opts: + run_modpython(use_wsgi=True) + else: + if '--null' in opts: + cherrypy.server.request_class = NullRequest + cherrypy.server.response_class = NullResponse + + cherrypy.engine.start_with_callback(run) + cherrypy.engine.block() diff --git a/libraries/cherrypy/test/checkerdemo.py b/libraries/cherrypy/test/checkerdemo.py new file mode 100644 index 00000000..3438bd0c --- /dev/null +++ b/libraries/cherrypy/test/checkerdemo.py @@ -0,0 +1,49 @@ +"""Demonstration app for cherrypy.checker. + +This application is intentionally broken and badly designed. +To demonstrate the output of the CherryPy Checker, simply execute +this module. +""" + +import os +import cherrypy +thisdir = os.path.dirname(os.path.abspath(__file__)) + + +class Root: + pass + + +if __name__ == '__main__': + conf = {'/base': {'tools.staticdir.root': thisdir, + # Obsolete key. + 'throw_errors': True, + }, + # This entry should be OK. + '/base/static': {'tools.staticdir.on': True, + 'tools.staticdir.dir': 'static'}, + # Warn on missing folder. + '/base/js': {'tools.staticdir.on': True, + 'tools.staticdir.dir': 'js'}, + # Warn on dir with an abs path even though we provide root. + '/base/static2': {'tools.staticdir.on': True, + 'tools.staticdir.dir': '/static'}, + # Warn on dir with a relative path with no root. + '/static3': {'tools.staticdir.on': True, + 'tools.staticdir.dir': 'static'}, + # Warn on unknown namespace + '/unknown': {'toobles.gzip.on': True}, + # Warn special on cherrypy.<known ns>.* + '/cpknown': {'cherrypy.tools.encode.on': True}, + # Warn on mismatched types + '/conftype': {'request.show_tracebacks': 14}, + # Warn on unknown tool. + '/web': {'tools.unknown.on': True}, + # Warn on server.* in app config. + '/app1': {'server.socket_host': '0.0.0.0'}, + # Warn on 'localhost' + 'global': {'server.socket_host': 'localhost'}, + # Warn on '[name]' + '[/extra_brackets]': {}, + } + cherrypy.quickstart(Root(), config=conf) diff --git a/libraries/cherrypy/test/fastcgi.conf b/libraries/cherrypy/test/fastcgi.conf new file mode 100644 index 00000000..e5c5163c --- /dev/null +++ b/libraries/cherrypy/test/fastcgi.conf @@ -0,0 +1,18 @@ + +# Apache2 server conf file for testing CherryPy with mod_fastcgi. +# fumanchu: I had to hard-code paths due to crazy Debian layouts :( +ServerRoot /usr/lib/apache2 +User #1000 +ErrorLog /usr/lib/python2.5/site-packages/cproot/trunk/cherrypy/test/mod_fastcgi.error.log + +DocumentRoot "/usr/lib/python2.5/site-packages/cproot/trunk/cherrypy/test" +ServerName 127.0.0.1 +Listen 8080 +LoadModule fastcgi_module modules/mod_fastcgi.so +LoadModule rewrite_module modules/mod_rewrite.so + +Options +ExecCGI +SetHandler fastcgi-script +RewriteEngine On +RewriteRule ^(.*)$ /fastcgi.pyc [L] +FastCgiExternalServer "/usr/lib/python2.5/site-packages/cproot/trunk/cherrypy/test/fastcgi.pyc" -host 127.0.0.1:4000 diff --git a/libraries/cherrypy/test/fcgi.conf b/libraries/cherrypy/test/fcgi.conf new file mode 100644 index 00000000..3062eb35 --- /dev/null +++ b/libraries/cherrypy/test/fcgi.conf @@ -0,0 +1,14 @@ + +# Apache2 server conf file for testing CherryPy with mod_fcgid. + +DocumentRoot "/usr/lib/python2.6/site-packages/cproot/trunk/cherrypy/test" +ServerName 127.0.0.1 +Listen 8080 +LoadModule fastcgi_module modules/mod_fastcgi.dll +LoadModule rewrite_module modules/mod_rewrite.so + +Options ExecCGI +SetHandler fastcgi-script +RewriteEngine On +RewriteRule ^(.*)$ /fastcgi.pyc [L] +FastCgiExternalServer "/usr/lib/python2.6/site-packages/cproot/trunk/cherrypy/test/fastcgi.pyc" -host 127.0.0.1:4000 diff --git a/libraries/cherrypy/test/helper.py b/libraries/cherrypy/test/helper.py new file mode 100644 index 00000000..01c5a0c0 --- /dev/null +++ b/libraries/cherrypy/test/helper.py @@ -0,0 +1,542 @@ +"""A library of helper functions for the CherryPy test suite.""" + +import datetime +import io +import logging +import os +import re +import subprocess +import sys +import time +import unittest +import warnings + +import portend +import pytest +import six + +from cheroot.test import webtest + +import cherrypy +from cherrypy._cpcompat import text_or_bytes, HTTPSConnection, ntob +from cherrypy.lib import httputil +from cherrypy.lib import gctools + +log = logging.getLogger(__name__) +thisdir = os.path.abspath(os.path.dirname(__file__)) +serverpem = os.path.join(os.getcwd(), thisdir, 'test.pem') + + +class Supervisor(object): + + """Base class for modeling and controlling servers during testing.""" + + def __init__(self, **kwargs): + for k, v in kwargs.items(): + if k == 'port': + setattr(self, k, int(v)) + setattr(self, k, v) + + +def log_to_stderr(msg, level): + return sys.stderr.write(msg + os.linesep) + + +class LocalSupervisor(Supervisor): + + """Base class for modeling/controlling servers which run in the same + process. + + When the server side runs in a different process, start/stop can dump all + state between each test module easily. When the server side runs in the + same process as the client, however, we have to do a bit more work to + ensure config and mounted apps are reset between tests. + """ + + using_apache = False + using_wsgi = False + + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + + cherrypy.server.httpserver = self.httpserver_class + + # This is perhaps the wrong place for this call but this is the only + # place that i've found so far that I KNOW is early enough to set this. + cherrypy.config.update({'log.screen': False}) + engine = cherrypy.engine + if hasattr(engine, 'signal_handler'): + engine.signal_handler.subscribe() + if hasattr(engine, 'console_control_handler'): + engine.console_control_handler.subscribe() + + def start(self, modulename=None): + """Load and start the HTTP server.""" + if modulename: + # Unhook httpserver so cherrypy.server.start() creates a new + # one (with config from setup_server, if declared). + cherrypy.server.httpserver = None + + cherrypy.engine.start() + + self.sync_apps() + + def sync_apps(self): + """Tell the server about any apps which the setup functions mounted.""" + pass + + def stop(self): + td = getattr(self, 'teardown', None) + if td: + td() + + cherrypy.engine.exit() + + servers_copy = list(six.iteritems(getattr(cherrypy, 'servers', {}))) + for name, server in servers_copy: + server.unsubscribe() + del cherrypy.servers[name] + + +class NativeServerSupervisor(LocalSupervisor): + + """Server supervisor for the builtin HTTP server.""" + + httpserver_class = 'cherrypy._cpnative_server.CPHTTPServer' + using_apache = False + using_wsgi = False + + def __str__(self): + return 'Builtin HTTP Server on %s:%s' % (self.host, self.port) + + +class LocalWSGISupervisor(LocalSupervisor): + + """Server supervisor for the builtin WSGI server.""" + + httpserver_class = 'cherrypy._cpwsgi_server.CPWSGIServer' + using_apache = False + using_wsgi = True + + def __str__(self): + return 'Builtin WSGI Server on %s:%s' % (self.host, self.port) + + def sync_apps(self): + """Hook a new WSGI app into the origin server.""" + cherrypy.server.httpserver.wsgi_app = self.get_app() + + def get_app(self, app=None): + """Obtain a new (decorated) WSGI app to hook into the origin server.""" + if app is None: + app = cherrypy.tree + + if self.validate: + try: + from wsgiref import validate + except ImportError: + warnings.warn( + 'Error importing wsgiref. The validator will not run.') + else: + # wraps the app in the validator + app = validate.validator(app) + + return app + + +def get_cpmodpy_supervisor(**options): + from cherrypy.test import modpy + sup = modpy.ModPythonSupervisor(**options) + sup.template = modpy.conf_cpmodpy + return sup + + +def get_modpygw_supervisor(**options): + from cherrypy.test import modpy + sup = modpy.ModPythonSupervisor(**options) + sup.template = modpy.conf_modpython_gateway + sup.using_wsgi = True + return sup + + +def get_modwsgi_supervisor(**options): + from cherrypy.test import modwsgi + return modwsgi.ModWSGISupervisor(**options) + + +def get_modfcgid_supervisor(**options): + from cherrypy.test import modfcgid + return modfcgid.ModFCGISupervisor(**options) + + +def get_modfastcgi_supervisor(**options): + from cherrypy.test import modfastcgi + return modfastcgi.ModFCGISupervisor(**options) + + +def get_wsgi_u_supervisor(**options): + cherrypy.server.wsgi_version = ('u', 0) + return LocalWSGISupervisor(**options) + + +class CPWebCase(webtest.WebCase): + + script_name = '' + scheme = 'http' + + available_servers = {'wsgi': LocalWSGISupervisor, + 'wsgi_u': get_wsgi_u_supervisor, + 'native': NativeServerSupervisor, + 'cpmodpy': get_cpmodpy_supervisor, + 'modpygw': get_modpygw_supervisor, + 'modwsgi': get_modwsgi_supervisor, + 'modfcgid': get_modfcgid_supervisor, + 'modfastcgi': get_modfastcgi_supervisor, + } + default_server = 'wsgi' + + @classmethod + def _setup_server(cls, supervisor, conf): + v = sys.version.split()[0] + log.info('Python version used to run this test script: %s' % v) + log.info('CherryPy version: %s' % cherrypy.__version__) + if supervisor.scheme == 'https': + ssl = ' (ssl)' + else: + ssl = '' + log.info('HTTP server version: %s%s' % (supervisor.protocol, ssl)) + log.info('PID: %s' % os.getpid()) + + cherrypy.server.using_apache = supervisor.using_apache + cherrypy.server.using_wsgi = supervisor.using_wsgi + + if sys.platform[:4] == 'java': + cherrypy.config.update({'server.nodelay': False}) + + if isinstance(conf, text_or_bytes): + parser = cherrypy.lib.reprconf.Parser() + conf = parser.dict_from_file(conf).get('global', {}) + else: + conf = conf or {} + baseconf = conf.copy() + baseconf.update({'server.socket_host': supervisor.host, + 'server.socket_port': supervisor.port, + 'server.protocol_version': supervisor.protocol, + 'environment': 'test_suite', + }) + if supervisor.scheme == 'https': + # baseconf['server.ssl_module'] = 'builtin' + baseconf['server.ssl_certificate'] = serverpem + baseconf['server.ssl_private_key'] = serverpem + + # helper must be imported lazily so the coverage tool + # can run against module-level statements within cherrypy. + # Also, we have to do "from cherrypy.test import helper", + # exactly like each test module does, because a relative import + # would stick a second instance of webtest in sys.modules, + # and we wouldn't be able to globally override the port anymore. + if supervisor.scheme == 'https': + webtest.WebCase.HTTP_CONN = HTTPSConnection + return baseconf + + @classmethod + def setup_class(cls): + '' + # Creates a server + conf = { + 'scheme': 'http', + 'protocol': 'HTTP/1.1', + 'port': 54583, + 'host': '127.0.0.1', + 'validate': False, + 'server': 'wsgi', + } + supervisor_factory = cls.available_servers.get( + conf.get('server', 'wsgi')) + if supervisor_factory is None: + raise RuntimeError('Unknown server in config: %s' % conf['server']) + supervisor = supervisor_factory(**conf) + + # Copied from "run_test_suite" + cherrypy.config.reset() + baseconf = cls._setup_server(supervisor, conf) + cherrypy.config.update(baseconf) + setup_client() + + if hasattr(cls, 'setup_server'): + # Clear the cherrypy tree and clear the wsgi server so that + # it can be updated with the new root + cherrypy.tree = cherrypy._cptree.Tree() + cherrypy.server.httpserver = None + cls.setup_server() + # Add a resource for verifying there are no refleaks + # to *every* test class. + cherrypy.tree.mount(gctools.GCRoot(), '/gc') + cls.do_gc_test = True + supervisor.start(cls.__module__) + + cls.supervisor = supervisor + + @classmethod + def teardown_class(cls): + '' + if hasattr(cls, 'setup_server'): + cls.supervisor.stop() + + do_gc_test = False + + def test_gc(self): + if not self.do_gc_test: + return + + self.getPage('/gc/stats') + try: + self.assertBody('Statistics:') + except Exception: + 'Failures occur intermittently. See #1420' + + def prefix(self): + return self.script_name.rstrip('/') + + def base(self): + if ((self.scheme == 'http' and self.PORT == 80) or + (self.scheme == 'https' and self.PORT == 443)): + port = '' + else: + port = ':%s' % self.PORT + + return '%s://%s%s%s' % (self.scheme, self.HOST, port, + self.script_name.rstrip('/')) + + def exit(self): + sys.exit() + + def getPage(self, url, headers=None, method='GET', body=None, + protocol=None, raise_subcls=None): + """Open the url. Return status, headers, body. + + `raise_subcls` must be a tuple with the exceptions classes + or a single exception class that are not going to be considered + a socket.error regardless that they were are subclass of a + socket.error and therefore not considered for a connection retry. + """ + if self.script_name: + url = httputil.urljoin(self.script_name, url) + return webtest.WebCase.getPage(self, url, headers, method, body, + protocol, raise_subcls) + + def skip(self, msg='skipped '): + pytest.skip(msg) + + def assertErrorPage(self, status, message=None, pattern=''): + """Compare the response body with a built in error page. + + The function will optionally look for the regexp pattern, + within the exception embedded in the error page.""" + + # This will never contain a traceback + page = cherrypy._cperror.get_error_page(status, message=message) + + # First, test the response body without checking the traceback. + # Stick a match-all group (.*) in to grab the traceback. + def esc(text): + return re.escape(ntob(text)) + epage = re.escape(page) + epage = epage.replace( + esc('<pre id="traceback"></pre>'), + esc('<pre id="traceback">') + b'(.*)' + esc('</pre>')) + m = re.match(epage, self.body, re.DOTALL) + if not m: + self._handlewebError( + 'Error page does not match; expected:\n' + page) + return + + # Now test the pattern against the traceback + if pattern is None: + # Special-case None to mean that there should be *no* traceback. + if m and m.group(1): + self._handlewebError('Error page contains traceback') + else: + if (m is None) or ( + not re.search(ntob(re.escape(pattern), self.encoding), + m.group(1))): + msg = 'Error page does not contain %s in traceback' + self._handlewebError(msg % repr(pattern)) + + date_tolerance = 2 + + def assertEqualDates(self, dt1, dt2, seconds=None): + """Assert abs(dt1 - dt2) is within Y seconds.""" + if seconds is None: + seconds = self.date_tolerance + + if dt1 > dt2: + diff = dt1 - dt2 + else: + diff = dt2 - dt1 + if not diff < datetime.timedelta(seconds=seconds): + raise AssertionError('%r and %r are not within %r seconds.' % + (dt1, dt2, seconds)) + + +def _test_method_sorter(_, x, y): + """Monkeypatch the test sorter to always run test_gc last in each suite.""" + if x == 'test_gc': + return 1 + if y == 'test_gc': + return -1 + if x > y: + return 1 + if x < y: + return -1 + return 0 + + +unittest.TestLoader.sortTestMethodsUsing = _test_method_sorter + + +def setup_client(): + """Set up the WebCase classes to match the server's socket settings.""" + webtest.WebCase.PORT = cherrypy.server.socket_port + webtest.WebCase.HOST = cherrypy.server.socket_host + if cherrypy.server.ssl_certificate: + CPWebCase.scheme = 'https' + +# --------------------------- Spawning helpers --------------------------- # + + +class CPProcess(object): + + pid_file = os.path.join(thisdir, 'test.pid') + config_file = os.path.join(thisdir, 'test.conf') + config_template = """[global] +server.socket_host: '%(host)s' +server.socket_port: %(port)s +checker.on: False +log.screen: False +log.error_file: r'%(error_log)s' +log.access_file: r'%(access_log)s' +%(ssl)s +%(extra)s +""" + error_log = os.path.join(thisdir, 'test.error.log') + access_log = os.path.join(thisdir, 'test.access.log') + + def __init__(self, wait=False, daemonize=False, ssl=False, + socket_host=None, socket_port=None): + self.wait = wait + self.daemonize = daemonize + self.ssl = ssl + self.host = socket_host or cherrypy.server.socket_host + self.port = socket_port or cherrypy.server.socket_port + + def write_conf(self, extra=''): + if self.ssl: + serverpem = os.path.join(thisdir, 'test.pem') + ssl = """ +server.ssl_certificate: r'%s' +server.ssl_private_key: r'%s' +""" % (serverpem, serverpem) + else: + ssl = '' + + conf = self.config_template % { + 'host': self.host, + 'port': self.port, + 'error_log': self.error_log, + 'access_log': self.access_log, + 'ssl': ssl, + 'extra': extra, + } + with io.open(self.config_file, 'w', encoding='utf-8') as f: + f.write(six.text_type(conf)) + + def start(self, imports=None): + """Start cherryd in a subprocess.""" + portend.free(self.host, self.port, timeout=1) + + args = [ + '-m', + 'cherrypy', + '-c', self.config_file, + '-p', self.pid_file, + ] + r""" + Command for running cherryd server with autoreload enabled + + Using + + ``` + ['-c', + "__requires__ = 'CherryPy'; \ + import pkg_resources, re, sys; \ + sys.argv[0] = re.sub(r'(-script\.pyw?|\.exe)?$', '', sys.argv[0]); \ + sys.exit(\ + pkg_resources.load_entry_point(\ + 'CherryPy', 'console_scripts', 'cherryd')())"] + ``` + + doesn't work as it's impossible to reconstruct the `-c`'s contents. + Ref: https://github.com/cherrypy/cherrypy/issues/1545 + """ + + if not isinstance(imports, (list, tuple)): + imports = [imports] + for i in imports: + if i: + args.append('-i') + args.append(i) + + if self.daemonize: + args.append('-d') + + env = os.environ.copy() + # Make sure we import the cherrypy package in which this module is + # defined. + grandparentdir = os.path.abspath(os.path.join(thisdir, '..', '..')) + if env.get('PYTHONPATH', ''): + env['PYTHONPATH'] = os.pathsep.join( + (grandparentdir, env['PYTHONPATH'])) + else: + env['PYTHONPATH'] = grandparentdir + self._proc = subprocess.Popen([sys.executable] + args, env=env) + if self.wait: + self.exit_code = self._proc.wait() + else: + portend.occupied(self.host, self.port, timeout=5) + + # Give the engine a wee bit more time to finish STARTING + if self.daemonize: + time.sleep(2) + else: + time.sleep(1) + + def get_pid(self): + if self.daemonize: + return int(open(self.pid_file, 'rb').read()) + return self._proc.pid + + def join(self): + """Wait for the process to exit.""" + if self.daemonize: + return self._join_daemon() + self._proc.wait() + + def _join_daemon(self): + try: + try: + # Mac, UNIX + os.wait() + except AttributeError: + # Windows + try: + pid = self.get_pid() + except IOError: + # Assume the subprocess deleted the pidfile on shutdown. + pass + else: + os.waitpid(pid, 0) + except OSError: + x = sys.exc_info()[1] + if x.args != (10, 'No child processes'): + raise diff --git a/libraries/cherrypy/test/logtest.py b/libraries/cherrypy/test/logtest.py new file mode 100644 index 00000000..ed8f1540 --- /dev/null +++ b/libraries/cherrypy/test/logtest.py @@ -0,0 +1,228 @@ +"""logtest, a unittest.TestCase helper for testing log output.""" + +import sys +import time +from uuid import UUID + +import six + +from cherrypy._cpcompat import text_or_bytes, ntob + + +try: + # On Windows, msvcrt.getch reads a single char without output. + import msvcrt + + def getchar(): + return msvcrt.getch() +except ImportError: + # Unix getchr + import tty + import termios + + def getchar(): + fd = sys.stdin.fileno() + old_settings = termios.tcgetattr(fd) + try: + tty.setraw(sys.stdin.fileno()) + ch = sys.stdin.read(1) + finally: + termios.tcsetattr(fd, termios.TCSADRAIN, old_settings) + return ch + + +class LogCase(object): + + """unittest.TestCase mixin for testing log messages. + + logfile: a filename for the desired log. Yes, I know modes are evil, + but it makes the test functions so much cleaner to set this once. + + lastmarker: the last marker in the log. This can be used to search for + messages since the last marker. + + markerPrefix: a string with which to prefix log markers. This should be + unique enough from normal log output to use for marker identification. + """ + + logfile = None + lastmarker = None + markerPrefix = b'test suite marker: ' + + def _handleLogError(self, msg, data, marker, pattern): + print('') + print(' ERROR: %s' % msg) + + if not self.interactive: + raise self.failureException(msg) + + p = (' Show: ' + '[L]og [M]arker [P]attern; ' + '[I]gnore, [R]aise, or sys.e[X]it >> ') + sys.stdout.write(p + ' ') + # ARGH + sys.stdout.flush() + while True: + i = getchar().upper() + if i not in 'MPLIRX': + continue + print(i.upper()) # Also prints new line + if i == 'L': + for x, line in enumerate(data): + if (x + 1) % self.console_height == 0: + # The \r and comma should make the next line overwrite + sys.stdout.write('<-- More -->\r ') + m = getchar().lower() + # Erase our "More" prompt + sys.stdout.write(' \r ') + if m == 'q': + break + print(line.rstrip()) + elif i == 'M': + print(repr(marker or self.lastmarker)) + elif i == 'P': + print(repr(pattern)) + elif i == 'I': + # return without raising the normal exception + return + elif i == 'R': + raise self.failureException(msg) + elif i == 'X': + self.exit() + sys.stdout.write(p + ' ') + + def exit(self): + sys.exit() + + def emptyLog(self): + """Overwrite self.logfile with 0 bytes.""" + open(self.logfile, 'wb').write('') + + def markLog(self, key=None): + """Insert a marker line into the log and set self.lastmarker.""" + if key is None: + key = str(time.time()) + self.lastmarker = key + + open(self.logfile, 'ab+').write( + ntob('%s%s\n' % (self.markerPrefix, key), 'utf-8')) + + def _read_marked_region(self, marker=None): + """Return lines from self.logfile in the marked region. + + If marker is None, self.lastmarker is used. If the log hasn't + been marked (using self.markLog), the entire log will be returned. + """ +# Give the logger time to finish writing? +# time.sleep(0.5) + + logfile = self.logfile + marker = marker or self.lastmarker + if marker is None: + return open(logfile, 'rb').readlines() + + if isinstance(marker, six.text_type): + marker = marker.encode('utf-8') + data = [] + in_region = False + for line in open(logfile, 'rb'): + if in_region: + if line.startswith(self.markerPrefix) and marker not in line: + break + else: + data.append(line) + elif marker in line: + in_region = True + return data + + def assertInLog(self, line, marker=None): + """Fail if the given (partial) line is not in the log. + + The log will be searched from the given marker to the next marker. + If marker is None, self.lastmarker is used. If the log hasn't + been marked (using self.markLog), the entire log will be searched. + """ + data = self._read_marked_region(marker) + for logline in data: + if line in logline: + return + msg = '%r not found in log' % line + self._handleLogError(msg, data, marker, line) + + def assertNotInLog(self, line, marker=None): + """Fail if the given (partial) line is in the log. + + The log will be searched from the given marker to the next marker. + If marker is None, self.lastmarker is used. If the log hasn't + been marked (using self.markLog), the entire log will be searched. + """ + data = self._read_marked_region(marker) + for logline in data: + if line in logline: + msg = '%r found in log' % line + self._handleLogError(msg, data, marker, line) + + def assertValidUUIDv4(self, marker=None): + """Fail if the given UUIDv4 is not valid. + + The log will be searched from the given marker to the next marker. + If marker is None, self.lastmarker is used. If the log hasn't + been marked (using self.markLog), the entire log will be searched. + """ + data = self._read_marked_region(marker) + data = [ + chunk.decode('utf-8').rstrip('\n').rstrip('\r') + for chunk in data + ] + for log_chunk in data: + try: + uuid_log = data[-1] + uuid_obj = UUID(uuid_log, version=4) + except (TypeError, ValueError): + pass # it might be in other chunk + else: + if str(uuid_obj) == uuid_log: + return + msg = '%r is not a valid UUIDv4' % uuid_log + self._handleLogError(msg, data, marker, log_chunk) + + msg = 'UUIDv4 not found in log' + self._handleLogError(msg, data, marker, log_chunk) + + def assertLog(self, sliceargs, lines, marker=None): + """Fail if log.readlines()[sliceargs] is not contained in 'lines'. + + The log will be searched from the given marker to the next marker. + If marker is None, self.lastmarker is used. If the log hasn't + been marked (using self.markLog), the entire log will be searched. + """ + data = self._read_marked_region(marker) + if isinstance(sliceargs, int): + # Single arg. Use __getitem__ and allow lines to be str or list. + if isinstance(lines, (tuple, list)): + lines = lines[0] + if isinstance(lines, six.text_type): + lines = lines.encode('utf-8') + if lines not in data[sliceargs]: + msg = '%r not found on log line %r' % (lines, sliceargs) + self._handleLogError( + msg, + [data[sliceargs], '--EXTRA CONTEXT--'] + data[ + sliceargs + 1:sliceargs + 6], + marker, + lines) + else: + # Multiple args. Use __getslice__ and require lines to be list. + if isinstance(lines, tuple): + lines = list(lines) + elif isinstance(lines, text_or_bytes): + raise TypeError("The 'lines' arg must be a list when " + "'sliceargs' is a tuple.") + + start, stop = sliceargs + for line, logline in zip(lines, data[start:stop]): + if isinstance(line, six.text_type): + line = line.encode('utf-8') + if line not in logline: + msg = '%r not found in log' % line + self._handleLogError(msg, data[start:stop], marker, line) diff --git a/libraries/cherrypy/test/modfastcgi.py b/libraries/cherrypy/test/modfastcgi.py new file mode 100644 index 00000000..79ec3d18 --- /dev/null +++ b/libraries/cherrypy/test/modfastcgi.py @@ -0,0 +1,136 @@ +"""Wrapper for mod_fastcgi, for use as a CherryPy HTTP server when testing. + +To autostart fastcgi, the "apache" executable or script must be +on your system path, or you must override the global APACHE_PATH. +On some platforms, "apache" may be called "apachectl", "apache2ctl", +or "httpd"--create a symlink to them if needed. + +You'll also need the WSGIServer from flup.servers. +See http://projects.amor.org/misc/wiki/ModPythonGateway + + +KNOWN BUGS +========== + +1. Apache processes Range headers automatically; CherryPy's truncated + output is then truncated again by Apache. See test_core.testRanges. + This was worked around in http://www.cherrypy.org/changeset/1319. +2. Apache does not allow custom HTTP methods like CONNECT as per the spec. + See test_core.testHTTPMethods. +3. Max request header and body settings do not work with Apache. +4. Apache replaces status "reason phrases" automatically. For example, + CherryPy may set "304 Not modified" but Apache will write out + "304 Not Modified" (capital "M"). +5. Apache does not allow custom error codes as per the spec. +6. Apache (or perhaps modpython, or modpython_gateway) unquotes %xx in the + Request-URI too early. +7. mod_python will not read request bodies which use the "chunked" + transfer-coding (it passes REQUEST_CHUNKED_ERROR to ap_setup_client_block + instead of REQUEST_CHUNKED_DECHUNK, see Apache2's http_protocol.c and + mod_python's requestobject.c). +8. Apache will output a "Content-Length: 0" response header even if there's + no response entity body. This isn't really a bug; it just differs from + the CherryPy default. +""" + +import os +import re + +import cherrypy +from cherrypy.process import servers +from cherrypy.test import helper + +curdir = os.path.join(os.getcwd(), os.path.dirname(__file__)) + + +def read_process(cmd, args=''): + pipein, pipeout = os.popen4('%s %s' % (cmd, args)) + try: + firstline = pipeout.readline() + if (re.search(r'(not recognized|No such file|not found)', firstline, + re.IGNORECASE)): + raise IOError('%s must be on your system path.' % cmd) + output = firstline + pipeout.read() + finally: + pipeout.close() + return output + + +APACHE_PATH = 'apache2ctl' +CONF_PATH = 'fastcgi.conf' + +conf_fastcgi = """ +# Apache2 server conf file for testing CherryPy with mod_fastcgi. +# fumanchu: I had to hard-code paths due to crazy Debian layouts :( +ServerRoot /usr/lib/apache2 +User #1000 +ErrorLog %(root)s/mod_fastcgi.error.log + +DocumentRoot "%(root)s" +ServerName 127.0.0.1 +Listen %(port)s +LoadModule fastcgi_module modules/mod_fastcgi.so +LoadModule rewrite_module modules/mod_rewrite.so + +Options +ExecCGI +SetHandler fastcgi-script +RewriteEngine On +RewriteRule ^(.*)$ /fastcgi.pyc [L] +FastCgiExternalServer "%(server)s" -host 127.0.0.1:4000 +""" + + +def erase_script_name(environ, start_response): + environ['SCRIPT_NAME'] = '' + return cherrypy.tree(environ, start_response) + + +class ModFCGISupervisor(helper.LocalWSGISupervisor): + + httpserver_class = 'cherrypy.process.servers.FlupFCGIServer' + using_apache = True + using_wsgi = True + template = conf_fastcgi + + def __str__(self): + return 'FCGI Server on %s:%s' % (self.host, self.port) + + def start(self, modulename): + cherrypy.server.httpserver = servers.FlupFCGIServer( + application=erase_script_name, bindAddress=('127.0.0.1', 4000)) + cherrypy.server.httpserver.bind_addr = ('127.0.0.1', 4000) + cherrypy.server.socket_port = 4000 + # For FCGI, we both start apache... + self.start_apache() + # ...and our local server + cherrypy.engine.start() + self.sync_apps() + + def start_apache(self): + fcgiconf = CONF_PATH + if not os.path.isabs(fcgiconf): + fcgiconf = os.path.join(curdir, fcgiconf) + + # Write the Apache conf file. + f = open(fcgiconf, 'wb') + try: + server = repr(os.path.join(curdir, 'fastcgi.pyc'))[1:-1] + output = self.template % {'port': self.port, 'root': curdir, + 'server': server} + output = output.replace('\r\n', '\n') + f.write(output) + finally: + f.close() + + result = read_process(APACHE_PATH, '-k start -f %s' % fcgiconf) + if result: + print(result) + + def stop(self): + """Gracefully shutdown a server that is serving forever.""" + read_process(APACHE_PATH, '-k stop') + helper.LocalWSGISupervisor.stop(self) + + def sync_apps(self): + cherrypy.server.httpserver.fcgiserver.application = self.get_app( + erase_script_name) diff --git a/libraries/cherrypy/test/modfcgid.py b/libraries/cherrypy/test/modfcgid.py new file mode 100644 index 00000000..d101bd67 --- /dev/null +++ b/libraries/cherrypy/test/modfcgid.py @@ -0,0 +1,124 @@ +"""Wrapper for mod_fcgid, for use as a CherryPy HTTP server when testing. + +To autostart fcgid, the "apache" executable or script must be +on your system path, or you must override the global APACHE_PATH. +On some platforms, "apache" may be called "apachectl", "apache2ctl", +or "httpd"--create a symlink to them if needed. + +You'll also need the WSGIServer from flup.servers. +See http://projects.amor.org/misc/wiki/ModPythonGateway + + +KNOWN BUGS +========== + +1. Apache processes Range headers automatically; CherryPy's truncated + output is then truncated again by Apache. See test_core.testRanges. + This was worked around in http://www.cherrypy.org/changeset/1319. +2. Apache does not allow custom HTTP methods like CONNECT as per the spec. + See test_core.testHTTPMethods. +3. Max request header and body settings do not work with Apache. +4. Apache replaces status "reason phrases" automatically. For example, + CherryPy may set "304 Not modified" but Apache will write out + "304 Not Modified" (capital "M"). +5. Apache does not allow custom error codes as per the spec. +6. Apache (or perhaps modpython, or modpython_gateway) unquotes %xx in the + Request-URI too early. +7. mod_python will not read request bodies which use the "chunked" + transfer-coding (it passes REQUEST_CHUNKED_ERROR to ap_setup_client_block + instead of REQUEST_CHUNKED_DECHUNK, see Apache2's http_protocol.c and + mod_python's requestobject.c). +8. Apache will output a "Content-Length: 0" response header even if there's + no response entity body. This isn't really a bug; it just differs from + the CherryPy default. +""" + +import os +import re + +import cherrypy +from cherrypy._cpcompat import ntob +from cherrypy.process import servers +from cherrypy.test import helper + +curdir = os.path.join(os.getcwd(), os.path.dirname(__file__)) + + +def read_process(cmd, args=''): + pipein, pipeout = os.popen4('%s %s' % (cmd, args)) + try: + firstline = pipeout.readline() + if (re.search(r'(not recognized|No such file|not found)', firstline, + re.IGNORECASE)): + raise IOError('%s must be on your system path.' % cmd) + output = firstline + pipeout.read() + finally: + pipeout.close() + return output + + +APACHE_PATH = 'httpd' +CONF_PATH = 'fcgi.conf' + +conf_fcgid = """ +# Apache2 server conf file for testing CherryPy with mod_fcgid. + +DocumentRoot "%(root)s" +ServerName 127.0.0.1 +Listen %(port)s +LoadModule fastcgi_module modules/mod_fastcgi.dll +LoadModule rewrite_module modules/mod_rewrite.so + +Options ExecCGI +SetHandler fastcgi-script +RewriteEngine On +RewriteRule ^(.*)$ /fastcgi.pyc [L] +FastCgiExternalServer "%(server)s" -host 127.0.0.1:4000 +""" + + +class ModFCGISupervisor(helper.LocalSupervisor): + + using_apache = True + using_wsgi = True + template = conf_fcgid + + def __str__(self): + return 'FCGI Server on %s:%s' % (self.host, self.port) + + def start(self, modulename): + cherrypy.server.httpserver = servers.FlupFCGIServer( + application=cherrypy.tree, bindAddress=('127.0.0.1', 4000)) + cherrypy.server.httpserver.bind_addr = ('127.0.0.1', 4000) + # For FCGI, we both start apache... + self.start_apache() + # ...and our local server + helper.LocalServer.start(self, modulename) + + def start_apache(self): + fcgiconf = CONF_PATH + if not os.path.isabs(fcgiconf): + fcgiconf = os.path.join(curdir, fcgiconf) + + # Write the Apache conf file. + f = open(fcgiconf, 'wb') + try: + server = repr(os.path.join(curdir, 'fastcgi.pyc'))[1:-1] + output = self.template % {'port': self.port, 'root': curdir, + 'server': server} + output = ntob(output.replace('\r\n', '\n')) + f.write(output) + finally: + f.close() + + result = read_process(APACHE_PATH, '-k start -f %s' % fcgiconf) + if result: + print(result) + + def stop(self): + """Gracefully shutdown a server that is serving forever.""" + read_process(APACHE_PATH, '-k stop') + helper.LocalServer.stop(self) + + def sync_apps(self): + cherrypy.server.httpserver.fcgiserver.application = self.get_app() diff --git a/libraries/cherrypy/test/modpy.py b/libraries/cherrypy/test/modpy.py new file mode 100644 index 00000000..7c288d2c --- /dev/null +++ b/libraries/cherrypy/test/modpy.py @@ -0,0 +1,164 @@ +"""Wrapper for mod_python, for use as a CherryPy HTTP server when testing. + +To autostart modpython, the "apache" executable or script must be +on your system path, or you must override the global APACHE_PATH. +On some platforms, "apache" may be called "apachectl" or "apache2ctl"-- +create a symlink to them if needed. + +If you wish to test the WSGI interface instead of our _cpmodpy interface, +you also need the 'modpython_gateway' module at: +http://projects.amor.org/misc/wiki/ModPythonGateway + + +KNOWN BUGS +========== + +1. Apache processes Range headers automatically; CherryPy's truncated + output is then truncated again by Apache. See test_core.testRanges. + This was worked around in http://www.cherrypy.org/changeset/1319. +2. Apache does not allow custom HTTP methods like CONNECT as per the spec. + See test_core.testHTTPMethods. +3. Max request header and body settings do not work with Apache. +4. Apache replaces status "reason phrases" automatically. For example, + CherryPy may set "304 Not modified" but Apache will write out + "304 Not Modified" (capital "M"). +5. Apache does not allow custom error codes as per the spec. +6. Apache (or perhaps modpython, or modpython_gateway) unquotes %xx in the + Request-URI too early. +7. mod_python will not read request bodies which use the "chunked" + transfer-coding (it passes REQUEST_CHUNKED_ERROR to ap_setup_client_block + instead of REQUEST_CHUNKED_DECHUNK, see Apache2's http_protocol.c and + mod_python's requestobject.c). +8. Apache will output a "Content-Length: 0" response header even if there's + no response entity body. This isn't really a bug; it just differs from + the CherryPy default. +""" + +import os +import re + +import cherrypy +from cherrypy.test import helper + +curdir = os.path.join(os.getcwd(), os.path.dirname(__file__)) + + +def read_process(cmd, args=''): + pipein, pipeout = os.popen4('%s %s' % (cmd, args)) + try: + firstline = pipeout.readline() + if (re.search(r'(not recognized|No such file|not found)', firstline, + re.IGNORECASE)): + raise IOError('%s must be on your system path.' % cmd) + output = firstline + pipeout.read() + finally: + pipeout.close() + return output + + +APACHE_PATH = 'httpd' +CONF_PATH = 'test_mp.conf' + +conf_modpython_gateway = """ +# Apache2 server conf file for testing CherryPy with modpython_gateway. + +ServerName 127.0.0.1 +DocumentRoot "/" +Listen %(port)s +LoadModule python_module modules/mod_python.so + +SetHandler python-program +PythonFixupHandler cherrypy.test.modpy::wsgisetup +PythonOption testmod %(modulename)s +PythonHandler modpython_gateway::handler +PythonOption wsgi.application cherrypy::tree +PythonOption socket_host %(host)s +PythonDebug On +""" + +conf_cpmodpy = """ +# Apache2 server conf file for testing CherryPy with _cpmodpy. + +ServerName 127.0.0.1 +DocumentRoot "/" +Listen %(port)s +LoadModule python_module modules/mod_python.so + +SetHandler python-program +PythonFixupHandler cherrypy.test.modpy::cpmodpysetup +PythonHandler cherrypy._cpmodpy::handler +PythonOption cherrypy.setup cherrypy.test.%(modulename)s::setup_server +PythonOption socket_host %(host)s +PythonDebug On +""" + + +class ModPythonSupervisor(helper.Supervisor): + + using_apache = True + using_wsgi = False + template = None + + def __str__(self): + return 'ModPython Server on %s:%s' % (self.host, self.port) + + def start(self, modulename): + mpconf = CONF_PATH + if not os.path.isabs(mpconf): + mpconf = os.path.join(curdir, mpconf) + + f = open(mpconf, 'wb') + try: + f.write(self.template % + {'port': self.port, 'modulename': modulename, + 'host': self.host}) + finally: + f.close() + + result = read_process(APACHE_PATH, '-k start -f %s' % mpconf) + if result: + print(result) + + def stop(self): + """Gracefully shutdown a server that is serving forever.""" + read_process(APACHE_PATH, '-k stop') + + +loaded = False + + +def wsgisetup(req): + global loaded + if not loaded: + loaded = True + options = req.get_options() + + cherrypy.config.update({ + 'log.error_file': os.path.join(curdir, 'test.log'), + 'environment': 'test_suite', + 'server.socket_host': options['socket_host'], + }) + + modname = options['testmod'] + mod = __import__(modname, globals(), locals(), ['']) + mod.setup_server() + + cherrypy.server.unsubscribe() + cherrypy.engine.start() + from mod_python import apache + return apache.OK + + +def cpmodpysetup(req): + global loaded + if not loaded: + loaded = True + options = req.get_options() + + cherrypy.config.update({ + 'log.error_file': os.path.join(curdir, 'test.log'), + 'environment': 'test_suite', + 'server.socket_host': options['socket_host'], + }) + from mod_python import apache + return apache.OK diff --git a/libraries/cherrypy/test/modwsgi.py b/libraries/cherrypy/test/modwsgi.py new file mode 100644 index 00000000..f558e223 --- /dev/null +++ b/libraries/cherrypy/test/modwsgi.py @@ -0,0 +1,154 @@ +"""Wrapper for mod_wsgi, for use as a CherryPy HTTP server. + +To autostart modwsgi, the "apache" executable or script must be +on your system path, or you must override the global APACHE_PATH. +On some platforms, "apache" may be called "apachectl" or "apache2ctl"-- +create a symlink to them if needed. + + +KNOWN BUGS +========== + +##1. Apache processes Range headers automatically; CherryPy's truncated +## output is then truncated again by Apache. See test_core.testRanges. +## This was worked around in http://www.cherrypy.org/changeset/1319. +2. Apache does not allow custom HTTP methods like CONNECT as per the spec. + See test_core.testHTTPMethods. +3. Max request header and body settings do not work with Apache. +##4. Apache replaces status "reason phrases" automatically. For example, +## CherryPy may set "304 Not modified" but Apache will write out +## "304 Not Modified" (capital "M"). +##5. Apache does not allow custom error codes as per the spec. +##6. Apache (or perhaps modpython, or modpython_gateway) unquotes %xx in the +## Request-URI too early. +7. mod_wsgi will not read request bodies which use the "chunked" + transfer-coding (it passes REQUEST_CHUNKED_ERROR to ap_setup_client_block + instead of REQUEST_CHUNKED_DECHUNK, see Apache2's http_protocol.c and + mod_python's requestobject.c). +8. When responding with 204 No Content, mod_wsgi adds a Content-Length + header for you. +9. When an error is raised, mod_wsgi has no facility for printing a + traceback as the response content (it's sent to the Apache log instead). +10. Startup and shutdown of Apache when running mod_wsgi seems slow. +""" + +import os +import re +import sys +import time + +import portend + +from cheroot.test import webtest + +import cherrypy +from cherrypy.test import helper + +curdir = os.path.abspath(os.path.dirname(__file__)) + + +def read_process(cmd, args=''): + pipein, pipeout = os.popen4('%s %s' % (cmd, args)) + try: + firstline = pipeout.readline() + if (re.search(r'(not recognized|No such file|not found)', firstline, + re.IGNORECASE)): + raise IOError('%s must be on your system path.' % cmd) + output = firstline + pipeout.read() + finally: + pipeout.close() + return output + + +if sys.platform == 'win32': + APACHE_PATH = 'httpd' +else: + APACHE_PATH = 'apache' + +CONF_PATH = 'test_mw.conf' + +conf_modwsgi = r""" +# Apache2 server conf file for testing CherryPy with modpython_gateway. + +ServerName 127.0.0.1 +DocumentRoot "/" +Listen %(port)s + +AllowEncodedSlashes On +LoadModule rewrite_module modules/mod_rewrite.so +RewriteEngine on +RewriteMap escaping int:escape + +LoadModule log_config_module modules/mod_log_config.so +LogFormat "%%h %%l %%u %%t \"%%r\" %%>s %%b \"%%{Referer}i\" \"%%{User-agent}i\"" combined +CustomLog "%(curdir)s/apache.access.log" combined +ErrorLog "%(curdir)s/apache.error.log" +LogLevel debug + +LoadModule wsgi_module modules/mod_wsgi.so +LoadModule env_module modules/mod_env.so + +WSGIScriptAlias / "%(curdir)s/modwsgi.py" +SetEnv testmod %(testmod)s +""" # noqa E501 + + +class ModWSGISupervisor(helper.Supervisor): + + """Server Controller for ModWSGI and CherryPy.""" + + using_apache = True + using_wsgi = True + template = conf_modwsgi + + def __str__(self): + return 'ModWSGI Server on %s:%s' % (self.host, self.port) + + def start(self, modulename): + mpconf = CONF_PATH + if not os.path.isabs(mpconf): + mpconf = os.path.join(curdir, mpconf) + + f = open(mpconf, 'wb') + try: + output = (self.template % + {'port': self.port, 'testmod': modulename, + 'curdir': curdir}) + f.write(output) + finally: + f.close() + + result = read_process(APACHE_PATH, '-k start -f %s' % mpconf) + if result: + print(result) + + # Make a request so mod_wsgi starts up our app. + # If we don't, concurrent initial requests will 404. + portend.occupied('127.0.0.1', self.port, timeout=5) + webtest.openURL('/ihopetheresnodefault', port=self.port) + time.sleep(1) + + def stop(self): + """Gracefully shutdown a server that is serving forever.""" + read_process(APACHE_PATH, '-k stop') + + +loaded = False + + +def application(environ, start_response): + global loaded + if not loaded: + loaded = True + modname = 'cherrypy.test.' + environ['testmod'] + mod = __import__(modname, globals(), locals(), ['']) + mod.setup_server() + + cherrypy.config.update({ + 'log.error_file': os.path.join(curdir, 'test.error.log'), + 'log.access_file': os.path.join(curdir, 'test.access.log'), + 'environment': 'test_suite', + 'engine.SIGHUP': None, + 'engine.SIGTERM': None, + }) + return cherrypy.tree(environ, start_response) diff --git a/libraries/cherrypy/test/sessiondemo.py b/libraries/cherrypy/test/sessiondemo.py new file mode 100644 index 00000000..8226c1b9 --- /dev/null +++ b/libraries/cherrypy/test/sessiondemo.py @@ -0,0 +1,161 @@ +#!/usr/bin/python +"""A session demonstration app.""" + +import calendar +from datetime import datetime +import sys + +import six + +import cherrypy +from cherrypy.lib import sessions + + +page = """ +<html> +<head> +<style type='text/css'> +table { border-collapse: collapse; border: 1px solid #663333; } +th { text-align: right; background-color: #663333; color: white; padding: 0.5em; } +td { white-space: pre-wrap; font-family: monospace; padding: 0.5em; + border: 1px solid #663333; } +.warn { font-family: serif; color: #990000; } +</style> +<script type="text/javascript"> +<!-- +function twodigit(d) { return d < 10 ? "0" + d : d; } +function formattime(t) { + var month = t.getUTCMonth() + 1; + var day = t.getUTCDate(); + var year = t.getUTCFullYear(); + var hours = t.getUTCHours(); + var minutes = t.getUTCMinutes(); + return (year + "/" + twodigit(month) + "/" + twodigit(day) + " " + + hours + ":" + twodigit(minutes) + " UTC"); +} + +function interval(s) { + // Return the given interval (in seconds) as an English phrase + var seconds = s %% 60; + s = Math.floor(s / 60); + var minutes = s %% 60; + s = Math.floor(s / 60); + var hours = s %% 24; + var v = twodigit(hours) + ":" + twodigit(minutes) + ":" + twodigit(seconds); + var days = Math.floor(s / 24); + if (days != 0) v = days + ' days, ' + v; + return v; +} + +var fudge_seconds = 5; + +function init() { + // Set the content of the 'btime' cell. + var currentTime = new Date(); + var bunixtime = Math.floor(currentTime.getTime() / 1000); + + var v = formattime(currentTime); + v += " (Unix time: " + bunixtime + ")"; + + var diff = Math.abs(%(serverunixtime)s - bunixtime); + if (diff > fudge_seconds) v += "<p class='warn'>Browser and Server times disagree.</p>"; + + document.getElementById('btime').innerHTML = v; + + // Warn if response cookie expires is not close to one hour in the future. + // Yes, we want this to happen when wit hit the 'Expire' link, too. + var expires = Date.parse("%(expires)s") / 1000; + var onehour = (60 * 60); + if (Math.abs(expires - (bunixtime + onehour)) > fudge_seconds) { + diff = Math.floor(expires - bunixtime); + if (expires > (bunixtime + onehour)) { + var msg = "Response cookie 'expires' date is " + interval(diff) + " in the future."; + } else { + var msg = "Response cookie 'expires' date is " + interval(0 - diff) + " in the past."; + } + document.getElementById('respcookiewarn').innerHTML = msg; + } +} +//--> +</script> +</head> + +<body onload='init()'> +<h2>Session Demo</h2> +<p>Reload this page. The session ID should not change from one reload to the next</p> +<p><a href='../'>Index</a> | <a href='expire'>Expire</a> | <a href='regen'>Regenerate</a></p> +<table> + <tr><th>Session ID:</th><td>%(sessionid)s<p class='warn'>%(changemsg)s</p></td></tr> + <tr><th>Request Cookie</th><td>%(reqcookie)s</td></tr> + <tr><th>Response Cookie</th><td>%(respcookie)s<p id='respcookiewarn' class='warn'></p></td></tr> + <tr><th>Session Data</th><td>%(sessiondata)s</td></tr> + <tr><th>Server Time</th><td id='stime'>%(servertime)s (Unix time: %(serverunixtime)s)</td></tr> + <tr><th>Browser Time</th><td id='btime'> </td></tr> + <tr><th>Cherrypy Version:</th><td>%(cpversion)s</td></tr> + <tr><th>Python Version:</th><td>%(pyversion)s</td></tr> +</table> +</body></html> +""" # noqa E501 + + +class Root(object): + + def page(self): + changemsg = [] + if cherrypy.session.id != cherrypy.session.originalid: + if cherrypy.session.originalid is None: + changemsg.append( + 'Created new session because no session id was given.') + if cherrypy.session.missing: + changemsg.append( + 'Created new session due to missing ' + '(expired or malicious) session.') + if cherrypy.session.regenerated: + changemsg.append('Application generated a new session.') + + try: + expires = cherrypy.response.cookie['session_id']['expires'] + except KeyError: + expires = '' + + return page % { + 'sessionid': cherrypy.session.id, + 'changemsg': '<br>'.join(changemsg), + 'respcookie': cherrypy.response.cookie.output(), + 'reqcookie': cherrypy.request.cookie.output(), + 'sessiondata': list(six.iteritems(cherrypy.session)), + 'servertime': ( + datetime.utcnow().strftime('%Y/%m/%d %H:%M') + ' UTC' + ), + 'serverunixtime': calendar.timegm(datetime.utcnow().timetuple()), + 'cpversion': cherrypy.__version__, + 'pyversion': sys.version, + 'expires': expires, + } + + @cherrypy.expose + def index(self): + # Must modify data or the session will not be saved. + cherrypy.session['color'] = 'green' + return self.page() + + @cherrypy.expose + def expire(self): + sessions.expire() + return self.page() + + @cherrypy.expose + def regen(self): + cherrypy.session.regenerate() + # Must modify data or the session will not be saved. + cherrypy.session['color'] = 'yellow' + return self.page() + + +if __name__ == '__main__': + cherrypy.config.update({ + # 'environment': 'production', + 'log.screen': True, + 'tools.sessions.on': True, + }) + cherrypy.quickstart(Root()) diff --git a/libraries/cherrypy/test/static/404.html b/libraries/cherrypy/test/static/404.html new file mode 100644 index 00000000..01b17b09 --- /dev/null +++ b/libraries/cherrypy/test/static/404.html @@ -0,0 +1,5 @@ +<html> + <body> + <h1>I couldn't find that thing you were looking for!</h1> + </body> +</html> diff --git a/libraries/cherrypy/test/static/dirback.jpg b/libraries/cherrypy/test/static/dirback.jpg new file mode 100644 index 00000000..80403dc2 Binary files /dev/null and b/libraries/cherrypy/test/static/dirback.jpg differ diff --git a/libraries/cherrypy/test/static/index.html b/libraries/cherrypy/test/static/index.html new file mode 100644 index 00000000..a5c19667 --- /dev/null +++ b/libraries/cherrypy/test/static/index.html @@ -0,0 +1 @@ +Hello, world diff --git a/libraries/cherrypy/test/style.css b/libraries/cherrypy/test/style.css new file mode 100644 index 00000000..b266e93d --- /dev/null +++ b/libraries/cherrypy/test/style.css @@ -0,0 +1 @@ +Dummy stylesheet diff --git a/libraries/cherrypy/test/test.pem b/libraries/cherrypy/test/test.pem new file mode 100644 index 00000000..47a47042 --- /dev/null +++ b/libraries/cherrypy/test/test.pem @@ -0,0 +1,38 @@ +-----BEGIN RSA PRIVATE KEY----- +MIICXAIBAAKBgQDBKo554mzIMY+AByUNpaUOP9bJnQ7ZLQe9XgHwoLJR4VzpyZZZ +R9L4WtImEew05FY3Izerfm3MN3+MC0tJ6yQU9sOiU3vBW6RrLIMlfKsnRwBRZ0Kn +da+O6xldVSosu8Ev3z9VZ94iC/ZgKzrH7Mjj/U8/MQO7RBS/LAqee8bFNQIDAQAB +AoGAWOCF0ZrWxn3XMucWq2LNwPKqlvVGwbIwX3cDmX22zmnM4Fy6arXbYh4XlyCj +9+ofqRrxIFz5k/7tFriTmZ0xag5+Jdx+Kwg0/twiP7XCNKipFogwe1Hznw8OFAoT +enKBdj2+/n2o0Bvo/tDB59m9L/538d46JGQUmJlzMyqYikECQQDyoq+8CtMNvE18 +8VgHcR/KtApxWAjj4HpaHYL637ATjThetUZkW92mgDgowyplthusxdNqhHWyv7E8 +tWNdYErZAkEAy85ShTR0M5aWmrE7o0r0SpWInAkNBH9aXQRRARFYsdBtNfRu6I0i +0lvU9wiu3eF57FMEC86yViZ5UBnQfTu7vQJAVesj/Zt7pwaCDfdMa740OsxMUlyR +MVhhGx4OLpYdPJ8qUecxGQKq13XZ7R1HGyNEY4bd2X80Smq08UFuATfC6QJAH8UB +yBHtKz2GLIcELOg6PIYizW/7v3+6rlVF60yw7sb2vzpjL40QqIn4IKoR2DSVtOkb +8FtAIX3N21aq0VrGYQJBAIPiaEc2AZ8Bq2GC4F3wOz/BxJ/izvnkiotR12QK4fh5 +yjZMhTjWCas5zwHR5PDjlD88AWGDMsZ1PicD4348xJQ= +-----END RSA PRIVATE KEY----- +-----BEGIN CERTIFICATE----- +MIIDxTCCAy6gAwIBAgIJAI18BD7eQxlGMA0GCSqGSIb3DQEBBAUAMIGeMQswCQYD +VQQGEwJVUzETMBEGA1UECBMKQ2FsaWZvcm5pYTESMBAGA1UEBxMJU2FuIERpZWdv +MRkwFwYDVQQKExBDaGVycnlQeSBQcm9qZWN0MREwDwYDVQQLEwhkZXYtdGVzdDEW +MBQGA1UEAxMNQ2hlcnJ5UHkgVGVhbTEgMB4GCSqGSIb3DQEJARYRcmVtaUBjaGVy +cnlweS5vcmcwHhcNMDYwOTA5MTkyMDIwWhcNMzQwMTI0MTkyMDIwWjCBnjELMAkG +A1UEBhMCVVMxEzARBgNVBAgTCkNhbGlmb3JuaWExEjAQBgNVBAcTCVNhbiBEaWVn +bzEZMBcGA1UEChMQQ2hlcnJ5UHkgUHJvamVjdDERMA8GA1UECxMIZGV2LXRlc3Qx +FjAUBgNVBAMTDUNoZXJyeVB5IFRlYW0xIDAeBgkqhkiG9w0BCQEWEXJlbWlAY2hl +cnJ5cHkub3JnMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDBKo554mzIMY+A +ByUNpaUOP9bJnQ7ZLQe9XgHwoLJR4VzpyZZZR9L4WtImEew05FY3Izerfm3MN3+M +C0tJ6yQU9sOiU3vBW6RrLIMlfKsnRwBRZ0Knda+O6xldVSosu8Ev3z9VZ94iC/Zg +KzrH7Mjj/U8/MQO7RBS/LAqee8bFNQIDAQABo4IBBzCCAQMwHQYDVR0OBBYEFDIQ +2feb71tVZCWpU0qJ/Tw+wdtoMIHTBgNVHSMEgcswgciAFDIQ2feb71tVZCWpU0qJ +/Tw+wdtooYGkpIGhMIGeMQswCQYDVQQGEwJVUzETMBEGA1UECBMKQ2FsaWZvcm5p +YTESMBAGA1UEBxMJU2FuIERpZWdvMRkwFwYDVQQKExBDaGVycnlQeSBQcm9qZWN0 +MREwDwYDVQQLEwhkZXYtdGVzdDEWMBQGA1UEAxMNQ2hlcnJ5UHkgVGVhbTEgMB4G +CSqGSIb3DQEJARYRcmVtaUBjaGVycnlweS5vcmeCCQCNfAQ+3kMZRjAMBgNVHRME +BTADAQH/MA0GCSqGSIb3DQEBBAUAA4GBAL7AAQz7IePV48ZTAFHKr88ntPALsL5S +8vHCZPNMevNkLTj3DYUw2BcnENxMjm1kou2F2BkvheBPNZKIhc6z4hAml3ed1xa2 +D7w6e6OTcstdK/+KrPDDHeOP1dhMWNs2JE1bNlfF1LiXzYKSXpe88eCKjCXsCT/T +NluCaWQys3MS +-----END CERTIFICATE----- diff --git a/libraries/cherrypy/test/test_auth_basic.py b/libraries/cherrypy/test/test_auth_basic.py new file mode 100644 index 00000000..d7e69a9b --- /dev/null +++ b/libraries/cherrypy/test/test_auth_basic.py @@ -0,0 +1,135 @@ +# This file is part of CherryPy <http://www.cherrypy.org/> +# -*- coding: utf-8 -*- +# vim:ts=4:sw=4:expandtab:fileencoding=utf-8 + +from hashlib import md5 + +import cherrypy +from cherrypy._cpcompat import ntob +from cherrypy.lib import auth_basic +from cherrypy.test import helper + + +class BasicAuthTest(helper.CPWebCase): + + @staticmethod + def setup_server(): + class Root: + + @cherrypy.expose + def index(self): + return 'This is public.' + + class BasicProtected: + + @cherrypy.expose + def index(self): + return "Hello %s, you've been authorized." % ( + cherrypy.request.login) + + class BasicProtected2: + + @cherrypy.expose + def index(self): + return "Hello %s, you've been authorized." % ( + cherrypy.request.login) + + class BasicProtected2_u: + + @cherrypy.expose + def index(self): + return "Hello %s, you've been authorized." % ( + cherrypy.request.login) + + userpassdict = {'xuser': 'xpassword'} + userhashdict = {'xuser': md5(b'xpassword').hexdigest()} + userhashdict_u = {'xюзер': md5(ntob('їжа', 'utf-8')).hexdigest()} + + def checkpasshash(realm, user, password): + p = userhashdict.get(user) + return p and p == md5(ntob(password)).hexdigest() or False + + def checkpasshash_u(realm, user, password): + p = userhashdict_u.get(user) + return p and p == md5(ntob(password, 'utf-8')).hexdigest() or False + + basic_checkpassword_dict = auth_basic.checkpassword_dict(userpassdict) + conf = { + '/basic': { + 'tools.auth_basic.on': True, + 'tools.auth_basic.realm': 'wonderland', + 'tools.auth_basic.checkpassword': basic_checkpassword_dict + }, + '/basic2': { + 'tools.auth_basic.on': True, + 'tools.auth_basic.realm': 'wonderland', + 'tools.auth_basic.checkpassword': checkpasshash, + 'tools.auth_basic.accept_charset': 'ISO-8859-1', + }, + '/basic2_u': { + 'tools.auth_basic.on': True, + 'tools.auth_basic.realm': 'wonderland', + 'tools.auth_basic.checkpassword': checkpasshash_u, + 'tools.auth_basic.accept_charset': 'UTF-8', + }, + } + + root = Root() + root.basic = BasicProtected() + root.basic2 = BasicProtected2() + root.basic2_u = BasicProtected2_u() + cherrypy.tree.mount(root, config=conf) + + def testPublic(self): + self.getPage('/') + self.assertStatus('200 OK') + self.assertHeader('Content-Type', 'text/html;charset=utf-8') + self.assertBody('This is public.') + + def testBasic(self): + self.getPage('/basic/') + self.assertStatus(401) + self.assertHeader( + 'WWW-Authenticate', + 'Basic realm="wonderland", charset="UTF-8"' + ) + + self.getPage('/basic/', + [('Authorization', 'Basic eHVzZXI6eHBhc3N3b3JX')]) + self.assertStatus(401) + + self.getPage('/basic/', + [('Authorization', 'Basic eHVzZXI6eHBhc3N3b3Jk')]) + self.assertStatus('200 OK') + self.assertBody("Hello xuser, you've been authorized.") + + def testBasic2(self): + self.getPage('/basic2/') + self.assertStatus(401) + self.assertHeader('WWW-Authenticate', 'Basic realm="wonderland"') + + self.getPage('/basic2/', + [('Authorization', 'Basic eHVzZXI6eHBhc3N3b3JX')]) + self.assertStatus(401) + + self.getPage('/basic2/', + [('Authorization', 'Basic eHVzZXI6eHBhc3N3b3Jk')]) + self.assertStatus('200 OK') + self.assertBody("Hello xuser, you've been authorized.") + + def testBasic2_u(self): + self.getPage('/basic2_u/') + self.assertStatus(401) + self.assertHeader( + 'WWW-Authenticate', + 'Basic realm="wonderland", charset="UTF-8"' + ) + + self.getPage('/basic2_u/', + [('Authorization', 'Basic eNGO0LfQtdGAOtGX0LbRgw==')]) + self.assertStatus(401) + + self.getPage('/basic2_u/', + [('Authorization', 'Basic eNGO0LfQtdGAOtGX0LbQsA==')]) + self.assertStatus('200 OK') + self.assertBody("Hello xюзер, you've been authorized.") diff --git a/libraries/cherrypy/test/test_auth_digest.py b/libraries/cherrypy/test/test_auth_digest.py new file mode 100644 index 00000000..512e39a5 --- /dev/null +++ b/libraries/cherrypy/test/test_auth_digest.py @@ -0,0 +1,134 @@ +# This file is part of CherryPy <http://www.cherrypy.org/> +# -*- coding: utf-8 -*- +# vim:ts=4:sw=4:expandtab:fileencoding=utf-8 + +import six + + +import cherrypy +from cherrypy.lib import auth_digest +from cherrypy._cpcompat import ntob + +from cherrypy.test import helper + + +def _fetch_users(): + return {'test': 'test', '☃йюзер': 'їпароль'} + + +get_ha1 = cherrypy.lib.auth_digest.get_ha1_dict_plain(_fetch_users()) + + +class DigestAuthTest(helper.CPWebCase): + + @staticmethod + def setup_server(): + class Root: + + @cherrypy.expose + def index(self): + return 'This is public.' + + class DigestProtected: + + @cherrypy.expose + def index(self, *args, **kwargs): + return "Hello %s, you've been authorized." % ( + cherrypy.request.login) + + conf = {'/digest': {'tools.auth_digest.on': True, + 'tools.auth_digest.realm': 'localhost', + 'tools.auth_digest.get_ha1': get_ha1, + 'tools.auth_digest.key': 'a565c27146791cfb', + 'tools.auth_digest.debug': True, + 'tools.auth_digest.accept_charset': 'UTF-8'}} + + root = Root() + root.digest = DigestProtected() + cherrypy.tree.mount(root, config=conf) + + def testPublic(self): + self.getPage('/') + assert self.status == '200 OK' + self.assertHeader('Content-Type', 'text/html;charset=utf-8') + assert self.body == b'This is public.' + + def _test_parametric_digest(self, username, realm): + test_uri = '/digest/?@/=%2F%40&%f0%9f%99%88=path' + + self.getPage(test_uri) + assert self.status_code == 401 + + msg = 'Digest authentification scheme was not found' + www_auth_digest = tuple(filter( + lambda kv: kv[0].lower() == 'www-authenticate' + and kv[1].startswith('Digest '), + self.headers, + )) + assert len(www_auth_digest) == 1, msg + + items = www_auth_digest[0][-1][7:].split(', ') + tokens = {} + for item in items: + key, value = item.split('=') + tokens[key.lower()] = value + + assert tokens['realm'] == '"localhost"' + assert tokens['algorithm'] == '"MD5"' + assert tokens['qop'] == '"auth"' + assert tokens['charset'] == '"UTF-8"' + + nonce = tokens['nonce'].strip('"') + + # Test user agent response with a wrong value for 'realm' + base_auth = ('Digest username="%s", ' + 'realm="%s", ' + 'nonce="%s", ' + 'uri="%s", ' + 'algorithm=MD5, ' + 'response="%s", ' + 'qop=auth, ' + 'nc=%s, ' + 'cnonce="1522e61005789929"') + + encoded_user = username + if six.PY3: + encoded_user = encoded_user.encode('utf-8') + encoded_user = encoded_user.decode('latin1') + auth_header = base_auth % ( + encoded_user, realm, nonce, test_uri, + '11111111111111111111111111111111', '00000001', + ) + auth = auth_digest.HttpDigestAuthorization(auth_header, 'GET') + # calculate the response digest + ha1 = get_ha1(auth.realm, auth.username) + response = auth.request_digest(ha1) + auth_header = base_auth % ( + encoded_user, realm, nonce, test_uri, + response, '00000001', + ) + self.getPage(test_uri, [('Authorization', auth_header)]) + + def test_wrong_realm(self): + # send response with correct response digest, but wrong realm + self._test_parametric_digest(username='test', realm='wrong realm') + assert self.status_code == 401 + + def test_ascii_user(self): + self._test_parametric_digest(username='test', realm='localhost') + assert self.status == '200 OK' + assert self.body == b"Hello test, you've been authorized." + + def test_unicode_user(self): + self._test_parametric_digest(username='☃йюзер', realm='localhost') + assert self.status == '200 OK' + assert self.body == ntob( + "Hello ☃йюзер, you've been authorized.", 'utf-8', + ) + + def test_wrong_scheme(self): + basic_auth = { + 'Authorization': 'Basic foo:bar', + } + self.getPage('/digest/', headers=list(basic_auth.items())) + assert self.status_code == 401 diff --git a/libraries/cherrypy/test/test_bus.py b/libraries/cherrypy/test/test_bus.py new file mode 100644 index 00000000..6026b47e --- /dev/null +++ b/libraries/cherrypy/test/test_bus.py @@ -0,0 +1,274 @@ +import threading +import time +import unittest + +from cherrypy.process import wspbus + + +msg = 'Listener %d on channel %s: %s.' + + +class PublishSubscribeTests(unittest.TestCase): + + def get_listener(self, channel, index): + def listener(arg=None): + self.responses.append(msg % (index, channel, arg)) + return listener + + def test_builtin_channels(self): + b = wspbus.Bus() + + self.responses, expected = [], [] + + for channel in b.listeners: + for index, priority in enumerate([100, 50, 0, 51]): + b.subscribe(channel, + self.get_listener(channel, index), priority) + + for channel in b.listeners: + b.publish(channel) + expected.extend([msg % (i, channel, None) for i in (2, 1, 3, 0)]) + b.publish(channel, arg=79347) + expected.extend([msg % (i, channel, 79347) for i in (2, 1, 3, 0)]) + + self.assertEqual(self.responses, expected) + + def test_custom_channels(self): + b = wspbus.Bus() + + self.responses, expected = [], [] + + custom_listeners = ('hugh', 'louis', 'dewey') + for channel in custom_listeners: + for index, priority in enumerate([None, 10, 60, 40]): + b.subscribe(channel, + self.get_listener(channel, index), priority) + + for channel in custom_listeners: + b.publish(channel, 'ah so') + expected.extend([msg % (i, channel, 'ah so') + for i in (1, 3, 0, 2)]) + b.publish(channel) + expected.extend([msg % (i, channel, None) for i in (1, 3, 0, 2)]) + + self.assertEqual(self.responses, expected) + + def test_listener_errors(self): + b = wspbus.Bus() + + self.responses, expected = [], [] + channels = [c for c in b.listeners if c != 'log'] + + for channel in channels: + b.subscribe(channel, self.get_listener(channel, 1)) + # This will break since the lambda takes no args. + b.subscribe(channel, lambda: None, priority=20) + + for channel in channels: + self.assertRaises(wspbus.ChannelFailures, b.publish, channel, 123) + expected.append(msg % (1, channel, 123)) + + self.assertEqual(self.responses, expected) + + +class BusMethodTests(unittest.TestCase): + + def log(self, bus): + self._log_entries = [] + + def logit(msg, level): + self._log_entries.append(msg) + bus.subscribe('log', logit) + + def assertLog(self, entries): + self.assertEqual(self._log_entries, entries) + + def get_listener(self, channel, index): + def listener(arg=None): + self.responses.append(msg % (index, channel, arg)) + return listener + + def test_start(self): + b = wspbus.Bus() + self.log(b) + + self.responses = [] + num = 3 + for index in range(num): + b.subscribe('start', self.get_listener('start', index)) + + b.start() + try: + # The start method MUST call all 'start' listeners. + self.assertEqual( + set(self.responses), + set([msg % (i, 'start', None) for i in range(num)])) + # The start method MUST move the state to STARTED + # (or EXITING, if errors occur) + self.assertEqual(b.state, b.states.STARTED) + # The start method MUST log its states. + self.assertLog(['Bus STARTING', 'Bus STARTED']) + finally: + # Exit so the atexit handler doesn't complain. + b.exit() + + def test_stop(self): + b = wspbus.Bus() + self.log(b) + + self.responses = [] + num = 3 + for index in range(num): + b.subscribe('stop', self.get_listener('stop', index)) + + b.stop() + + # The stop method MUST call all 'stop' listeners. + self.assertEqual(set(self.responses), + set([msg % (i, 'stop', None) for i in range(num)])) + # The stop method MUST move the state to STOPPED + self.assertEqual(b.state, b.states.STOPPED) + # The stop method MUST log its states. + self.assertLog(['Bus STOPPING', 'Bus STOPPED']) + + def test_graceful(self): + b = wspbus.Bus() + self.log(b) + + self.responses = [] + num = 3 + for index in range(num): + b.subscribe('graceful', self.get_listener('graceful', index)) + + b.graceful() + + # The graceful method MUST call all 'graceful' listeners. + self.assertEqual( + set(self.responses), + set([msg % (i, 'graceful', None) for i in range(num)])) + # The graceful method MUST log its states. + self.assertLog(['Bus graceful']) + + def test_exit(self): + b = wspbus.Bus() + self.log(b) + + self.responses = [] + num = 3 + for index in range(num): + b.subscribe('stop', self.get_listener('stop', index)) + b.subscribe('exit', self.get_listener('exit', index)) + + b.exit() + + # The exit method MUST call all 'stop' listeners, + # and then all 'exit' listeners. + self.assertEqual(set(self.responses), + set([msg % (i, 'stop', None) for i in range(num)] + + [msg % (i, 'exit', None) for i in range(num)])) + # The exit method MUST move the state to EXITING + self.assertEqual(b.state, b.states.EXITING) + # The exit method MUST log its states. + self.assertLog( + ['Bus STOPPING', 'Bus STOPPED', 'Bus EXITING', 'Bus EXITED']) + + def test_wait(self): + b = wspbus.Bus() + + def f(method): + time.sleep(0.2) + getattr(b, method)() + + for method, states in [('start', [b.states.STARTED]), + ('stop', [b.states.STOPPED]), + ('start', + [b.states.STARTING, b.states.STARTED]), + ('exit', [b.states.EXITING]), + ]: + threading.Thread(target=f, args=(method,)).start() + b.wait(states) + + # The wait method MUST wait for the given state(s). + if b.state not in states: + self.fail('State %r not in %r' % (b.state, states)) + + def test_block(self): + b = wspbus.Bus() + self.log(b) + + def f(): + time.sleep(0.2) + b.exit() + + def g(): + time.sleep(0.4) + threading.Thread(target=f).start() + threading.Thread(target=g).start() + threads = [t for t in threading.enumerate() if not t.daemon] + self.assertEqual(len(threads), 3) + + b.block() + + # The block method MUST wait for the EXITING state. + self.assertEqual(b.state, b.states.EXITING) + # The block method MUST wait for ALL non-main, non-daemon threads to + # finish. + threads = [t for t in threading.enumerate() if not t.daemon] + self.assertEqual(len(threads), 1) + # The last message will mention an indeterminable thread name; ignore + # it + self.assertEqual(self._log_entries[:-1], + ['Bus STOPPING', 'Bus STOPPED', + 'Bus EXITING', 'Bus EXITED', + 'Waiting for child threads to terminate...']) + + def test_start_with_callback(self): + b = wspbus.Bus() + self.log(b) + try: + events = [] + + def f(*args, **kwargs): + events.append(('f', args, kwargs)) + + def g(): + events.append('g') + b.subscribe('start', g) + b.start_with_callback(f, (1, 3, 5), {'foo': 'bar'}) + # Give wait() time to run f() + time.sleep(0.2) + + # The callback method MUST wait for the STARTED state. + self.assertEqual(b.state, b.states.STARTED) + # The callback method MUST run after all start methods. + self.assertEqual(events, ['g', ('f', (1, 3, 5), {'foo': 'bar'})]) + finally: + b.exit() + + def test_log(self): + b = wspbus.Bus() + self.log(b) + self.assertLog([]) + + # Try a normal message. + expected = [] + for msg in ["O mah darlin'"] * 3 + ['Clementiiiiiiiine']: + b.log(msg) + expected.append(msg) + self.assertLog(expected) + + # Try an error message + try: + foo + except NameError: + b.log('You are lost and gone forever', traceback=True) + lastmsg = self._log_entries[-1] + if 'Traceback' not in lastmsg or 'NameError' not in lastmsg: + self.fail('Last log message %r did not contain ' + 'the expected traceback.' % lastmsg) + else: + self.fail('NameError was not raised as expected.') + + +if __name__ == '__main__': + unittest.main() diff --git a/libraries/cherrypy/test/test_caching.py b/libraries/cherrypy/test/test_caching.py new file mode 100644 index 00000000..1a6ed4f2 --- /dev/null +++ b/libraries/cherrypy/test/test_caching.py @@ -0,0 +1,392 @@ +import datetime +from itertools import count +import os +import threading +import time + +from six.moves import range +from six.moves import urllib + +import pytest + +import cherrypy +from cherrypy.lib import httputil + +from cherrypy.test import helper + + +curdir = os.path.join(os.getcwd(), os.path.dirname(__file__)) + +gif_bytes = ( + b'GIF89a\x01\x00\x01\x00\x82\x00\x01\x99"\x1e\x00\x00\x00\x00\x00' + b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00' + b'\x00,\x00\x00\x00\x00\x01\x00\x01\x00\x02\x03\x02\x08\t\x00;' +) + + +class CacheTest(helper.CPWebCase): + + @staticmethod + def setup_server(): + + @cherrypy.config(**{'tools.caching.on': True}) + class Root: + + def __init__(self): + self.counter = 0 + self.control_counter = 0 + self.longlock = threading.Lock() + + @cherrypy.expose + def index(self): + self.counter += 1 + msg = 'visit #%s' % self.counter + return msg + + @cherrypy.expose + def control(self): + self.control_counter += 1 + return 'visit #%s' % self.control_counter + + @cherrypy.expose + def a_gif(self): + cherrypy.response.headers[ + 'Last-Modified'] = httputil.HTTPDate() + return gif_bytes + + @cherrypy.expose + def long_process(self, seconds='1'): + try: + self.longlock.acquire() + time.sleep(float(seconds)) + finally: + self.longlock.release() + return 'success!' + + @cherrypy.expose + def clear_cache(self, path): + cherrypy._cache.store[cherrypy.request.base + path].clear() + + @cherrypy.config(**{ + 'tools.caching.on': True, + 'tools.response_headers.on': True, + 'tools.response_headers.headers': [ + ('Vary', 'Our-Varying-Header') + ], + }) + class VaryHeaderCachingServer(object): + + def __init__(self): + self.counter = count(1) + + @cherrypy.expose + def index(self): + return 'visit #%s' % next(self.counter) + + @cherrypy.config(**{ + 'tools.expires.on': True, + 'tools.expires.secs': 60, + 'tools.staticdir.on': True, + 'tools.staticdir.dir': 'static', + 'tools.staticdir.root': curdir, + }) + class UnCached(object): + + @cherrypy.expose + @cherrypy.config(**{'tools.expires.secs': 0}) + def force(self): + cherrypy.response.headers['Etag'] = 'bibbitybobbityboo' + self._cp_config['tools.expires.force'] = True + self._cp_config['tools.expires.secs'] = 0 + return 'being forceful' + + @cherrypy.expose + def dynamic(self): + cherrypy.response.headers['Etag'] = 'bibbitybobbityboo' + cherrypy.response.headers['Cache-Control'] = 'private' + return 'D-d-d-dynamic!' + + @cherrypy.expose + def cacheable(self): + cherrypy.response.headers['Etag'] = 'bibbitybobbityboo' + return "Hi, I'm cacheable." + + @cherrypy.expose + @cherrypy.config(**{'tools.expires.secs': 86400}) + def specific(self): + cherrypy.response.headers[ + 'Etag'] = 'need_this_to_make_me_cacheable' + return 'I am being specific' + + class Foo(object): + pass + + @cherrypy.expose + @cherrypy.config(**{'tools.expires.secs': Foo()}) + def wrongtype(self): + cherrypy.response.headers[ + 'Etag'] = 'need_this_to_make_me_cacheable' + return 'Woops' + + @cherrypy.config(**{ + 'tools.gzip.mime_types': ['text/*', 'image/*'], + 'tools.caching.on': True, + 'tools.staticdir.on': True, + 'tools.staticdir.dir': 'static', + 'tools.staticdir.root': curdir + }) + class GzipStaticCache(object): + pass + + cherrypy.tree.mount(Root()) + cherrypy.tree.mount(UnCached(), '/expires') + cherrypy.tree.mount(VaryHeaderCachingServer(), '/varying_headers') + cherrypy.tree.mount(GzipStaticCache(), '/gzip_static_cache') + cherrypy.config.update({'tools.gzip.on': True}) + + def testCaching(self): + elapsed = 0.0 + for trial in range(10): + self.getPage('/') + # The response should be the same every time, + # except for the Age response header. + self.assertBody('visit #1') + if trial != 0: + age = int(self.assertHeader('Age')) + self.assert_(age >= elapsed) + elapsed = age + + # POST, PUT, DELETE should not be cached. + self.getPage('/', method='POST') + self.assertBody('visit #2') + # Because gzip is turned on, the Vary header should always Vary for + # content-encoding + self.assertHeader('Vary', 'Accept-Encoding') + # The previous request should have invalidated the cache, + # so this request will recalc the response. + self.getPage('/', method='GET') + self.assertBody('visit #3') + # ...but this request should get the cached copy. + self.getPage('/', method='GET') + self.assertBody('visit #3') + self.getPage('/', method='DELETE') + self.assertBody('visit #4') + + # The previous request should have invalidated the cache, + # so this request will recalc the response. + self.getPage('/', method='GET', headers=[('Accept-Encoding', 'gzip')]) + self.assertHeader('Content-Encoding', 'gzip') + self.assertHeader('Vary') + self.assertEqual( + cherrypy.lib.encoding.decompress(self.body), b'visit #5') + + # Now check that a second request gets the gzip header and gzipped body + # This also tests a bug in 3.0 to 3.0.2 whereby the cached, gzipped + # response body was being gzipped a second time. + self.getPage('/', method='GET', headers=[('Accept-Encoding', 'gzip')]) + self.assertHeader('Content-Encoding', 'gzip') + self.assertEqual( + cherrypy.lib.encoding.decompress(self.body), b'visit #5') + + # Now check that a third request that doesn't accept gzip + # skips the cache (because the 'Vary' header denies it). + self.getPage('/', method='GET') + self.assertNoHeader('Content-Encoding') + self.assertBody('visit #6') + + def testVaryHeader(self): + self.getPage('/varying_headers/') + self.assertStatus('200 OK') + self.assertHeaderItemValue('Vary', 'Our-Varying-Header') + self.assertBody('visit #1') + + # Now check that different 'Vary'-fields don't evict each other. + # This test creates 2 requests with different 'Our-Varying-Header' + # and then tests if the first one still exists. + self.getPage('/varying_headers/', + headers=[('Our-Varying-Header', 'request 2')]) + self.assertStatus('200 OK') + self.assertBody('visit #2') + + self.getPage('/varying_headers/', + headers=[('Our-Varying-Header', 'request 2')]) + self.assertStatus('200 OK') + self.assertBody('visit #2') + + self.getPage('/varying_headers/') + self.assertStatus('200 OK') + self.assertBody('visit #1') + + def testExpiresTool(self): + # test setting an expires header + self.getPage('/expires/specific') + self.assertStatus('200 OK') + self.assertHeader('Expires') + + # test exceptions for bad time values + self.getPage('/expires/wrongtype') + self.assertStatus(500) + self.assertInBody('TypeError') + + # static content should not have "cache prevention" headers + self.getPage('/expires/index.html') + self.assertStatus('200 OK') + self.assertNoHeader('Pragma') + self.assertNoHeader('Cache-Control') + self.assertHeader('Expires') + + # dynamic content that sets indicators should not have + # "cache prevention" headers + self.getPage('/expires/cacheable') + self.assertStatus('200 OK') + self.assertNoHeader('Pragma') + self.assertNoHeader('Cache-Control') + self.assertHeader('Expires') + + self.getPage('/expires/dynamic') + self.assertBody('D-d-d-dynamic!') + # the Cache-Control header should be untouched + self.assertHeader('Cache-Control', 'private') + self.assertHeader('Expires') + + # configure the tool to ignore indicators and replace existing headers + self.getPage('/expires/force') + self.assertStatus('200 OK') + # This also gives us a chance to test 0 expiry with no other headers + self.assertHeader('Pragma', 'no-cache') + if cherrypy.server.protocol_version == 'HTTP/1.1': + self.assertHeader('Cache-Control', 'no-cache, must-revalidate') + self.assertHeader('Expires', 'Sun, 28 Jan 2007 00:00:00 GMT') + + # static content should now have "cache prevention" headers + self.getPage('/expires/index.html') + self.assertStatus('200 OK') + self.assertHeader('Pragma', 'no-cache') + if cherrypy.server.protocol_version == 'HTTP/1.1': + self.assertHeader('Cache-Control', 'no-cache, must-revalidate') + self.assertHeader('Expires', 'Sun, 28 Jan 2007 00:00:00 GMT') + + # the cacheable handler should now have "cache prevention" headers + self.getPage('/expires/cacheable') + self.assertStatus('200 OK') + self.assertHeader('Pragma', 'no-cache') + if cherrypy.server.protocol_version == 'HTTP/1.1': + self.assertHeader('Cache-Control', 'no-cache, must-revalidate') + self.assertHeader('Expires', 'Sun, 28 Jan 2007 00:00:00 GMT') + + self.getPage('/expires/dynamic') + self.assertBody('D-d-d-dynamic!') + # dynamic sets Cache-Control to private but it should be + # overwritten here ... + self.assertHeader('Pragma', 'no-cache') + if cherrypy.server.protocol_version == 'HTTP/1.1': + self.assertHeader('Cache-Control', 'no-cache, must-revalidate') + self.assertHeader('Expires', 'Sun, 28 Jan 2007 00:00:00 GMT') + + def _assert_resp_len_and_enc_for_gzip(self, uri): + """ + Test that after querying gzipped content it's remains valid in + cache and available non-gzipped as well. + """ + ACCEPT_GZIP_HEADERS = [('Accept-Encoding', 'gzip')] + content_len = None + + for _ in range(3): + self.getPage(uri, method='GET', headers=ACCEPT_GZIP_HEADERS) + + if content_len is not None: + # all requests should get the same length + self.assertHeader('Content-Length', content_len) + self.assertHeader('Content-Encoding', 'gzip') + + content_len = dict(self.headers)['Content-Length'] + + # check that we can still get non-gzipped version + self.getPage(uri, method='GET') + self.assertNoHeader('Content-Encoding') + # non-gzipped version should have a different content length + self.assertNoHeaderItemValue('Content-Length', content_len) + + def testGzipStaticCache(self): + """Test that cache and gzip tools play well together when both enabled. + + Ref GitHub issue #1190. + """ + GZIP_STATIC_CACHE_TMPL = '/gzip_static_cache/{}' + resource_files = ('index.html', 'dirback.jpg') + + for f in resource_files: + uri = GZIP_STATIC_CACHE_TMPL.format(f) + self._assert_resp_len_and_enc_for_gzip(uri) + + def testLastModified(self): + self.getPage('/a.gif') + self.assertStatus(200) + self.assertBody(gif_bytes) + lm1 = self.assertHeader('Last-Modified') + + # this request should get the cached copy. + self.getPage('/a.gif') + self.assertStatus(200) + self.assertBody(gif_bytes) + self.assertHeader('Age') + lm2 = self.assertHeader('Last-Modified') + self.assertEqual(lm1, lm2) + + # this request should match the cached copy, but raise 304. + self.getPage('/a.gif', [('If-Modified-Since', lm1)]) + self.assertStatus(304) + self.assertNoHeader('Last-Modified') + if not getattr(cherrypy.server, 'using_apache', False): + self.assertHeader('Age') + + @pytest.mark.xfail(reason='#1536') + def test_antistampede(self): + SECONDS = 4 + slow_url = '/long_process?seconds={SECONDS}'.format(**locals()) + # We MUST make an initial synchronous request in order to create the + # AntiStampedeCache object, and populate its selecting_headers, + # before the actual stampede. + self.getPage(slow_url) + self.assertBody('success!') + path = urllib.parse.quote(slow_url, safe='') + self.getPage('/clear_cache?path=' + path) + self.assertStatus(200) + + start = datetime.datetime.now() + + def run(): + self.getPage(slow_url) + # The response should be the same every time + self.assertBody('success!') + ts = [threading.Thread(target=run) for i in range(100)] + for t in ts: + t.start() + for t in ts: + t.join() + finish = datetime.datetime.now() + # Allow for overhead, two seconds for slow hosts + allowance = SECONDS + 2 + self.assertEqualDates(start, finish, seconds=allowance) + + def test_cache_control(self): + self.getPage('/control') + self.assertBody('visit #1') + self.getPage('/control') + self.assertBody('visit #1') + + self.getPage('/control', headers=[('Cache-Control', 'no-cache')]) + self.assertBody('visit #2') + self.getPage('/control') + self.assertBody('visit #2') + + self.getPage('/control', headers=[('Pragma', 'no-cache')]) + self.assertBody('visit #3') + self.getPage('/control') + self.assertBody('visit #3') + + time.sleep(1) + self.getPage('/control', headers=[('Cache-Control', 'max-age=0')]) + self.assertBody('visit #4') + self.getPage('/control') + self.assertBody('visit #4') diff --git a/libraries/cherrypy/test/test_compat.py b/libraries/cherrypy/test/test_compat.py new file mode 100644 index 00000000..44a9fa31 --- /dev/null +++ b/libraries/cherrypy/test/test_compat.py @@ -0,0 +1,34 @@ +"""Test Python 2/3 compatibility module.""" +from __future__ import unicode_literals + +import unittest + +import pytest +import six + +from cherrypy import _cpcompat as compat + + +class StringTester(unittest.TestCase): + """Tests for string conversion.""" + + @pytest.mark.skipif(six.PY3, reason='Only useful on Python 2') + def test_ntob_non_native(self): + """ntob should raise an Exception on unicode. + + (Python 2 only) + + See #1132 for discussion. + """ + self.assertRaises(TypeError, compat.ntob, 'fight') + + +class EscapeTester(unittest.TestCase): + """Class to test escape_html function from _cpcompat.""" + + def test_escape_quote(self): + """test_escape_quote - Verify the output for &<>"' chars.""" + self.assertEqual( + """xx&<>"aa'""", + compat.escape_html("""xx&<>"aa'"""), + ) diff --git a/libraries/cherrypy/test/test_config.py b/libraries/cherrypy/test/test_config.py new file mode 100644 index 00000000..be17df90 --- /dev/null +++ b/libraries/cherrypy/test/test_config.py @@ -0,0 +1,303 @@ +"""Tests for the CherryPy configuration system.""" + +import io +import os +import sys +import unittest + +import six + +import cherrypy + +from cherrypy.test import helper + + +localDir = os.path.join(os.getcwd(), os.path.dirname(__file__)) + + +def StringIOFromNative(x): + return io.StringIO(six.text_type(x)) + + +def setup_server(): + + @cherrypy.config(foo='this', bar='that') + class Root: + + def __init__(self): + cherrypy.config.namespaces['db'] = self.db_namespace + + def db_namespace(self, k, v): + if k == 'scheme': + self.db = v + + @cherrypy.expose(alias=('global_', 'xyz')) + def index(self, key): + return cherrypy.request.config.get(key, 'None') + + @cherrypy.expose + def repr(self, key): + return repr(cherrypy.request.config.get(key, None)) + + @cherrypy.expose + def dbscheme(self): + return self.db + + @cherrypy.expose + @cherrypy.config(**{'request.body.attempt_charsets': ['utf-16']}) + def plain(self, x): + return x + + favicon_ico = cherrypy.tools.staticfile.handler( + filename=os.path.join(localDir, '../favicon.ico')) + + @cherrypy.config(foo='this2', baz='that2') + class Foo: + + @cherrypy.expose + def index(self, key): + return cherrypy.request.config.get(key, 'None') + nex = index + + @cherrypy.expose + @cherrypy.config(**{'response.headers.X-silly': 'sillyval'}) + def silly(self): + return 'Hello world' + + # Test the expose and config decorators + @cherrypy.config(foo='this3', **{'bax': 'this4'}) + @cherrypy.expose + def bar(self, key): + return repr(cherrypy.request.config.get(key, None)) + + class Another: + + @cherrypy.expose + def index(self, key): + return str(cherrypy.request.config.get(key, 'None')) + + def raw_namespace(key, value): + if key == 'input.map': + handler = cherrypy.request.handler + + def wrapper(): + params = cherrypy.request.params + for name, coercer in list(value.items()): + try: + params[name] = coercer(params[name]) + except KeyError: + pass + return handler() + cherrypy.request.handler = wrapper + elif key == 'output': + handler = cherrypy.request.handler + + def wrapper(): + # 'value' is a type (like int or str). + return value(handler()) + cherrypy.request.handler = wrapper + + @cherrypy.config(**{'raw.output': repr}) + class Raw: + + @cherrypy.expose + @cherrypy.config(**{'raw.input.map': {'num': int}}) + def incr(self, num): + return num + 1 + + if not six.PY3: + thing3 = "thing3: unicode('test', errors='ignore')" + else: + thing3 = '' + + ioconf = StringIOFromNative(""" +[/] +neg: -1234 +filename: os.path.join(sys.prefix, "hello.py") +thing1: cherrypy.lib.httputil.response_codes[404] +thing2: __import__('cherrypy.tutorial', globals(), locals(), ['']).thing2 +%s +complex: 3+2j +mul: 6*3 +ones: "11" +twos: "22" +stradd: %%(ones)s + %%(twos)s + "33" + +[/favicon.ico] +tools.staticfile.filename = %r +""" % (thing3, os.path.join(localDir, 'static/dirback.jpg'))) + + root = Root() + root.foo = Foo() + root.raw = Raw() + app = cherrypy.tree.mount(root, config=ioconf) + app.request_class.namespaces['raw'] = raw_namespace + + cherrypy.tree.mount(Another(), '/another') + cherrypy.config.update({'luxuryyacht': 'throatwobblermangrove', + 'db.scheme': r'sqlite///memory', + }) + + +# Client-side code # + + +class ConfigTests(helper.CPWebCase): + setup_server = staticmethod(setup_server) + + def testConfig(self): + tests = [ + ('/', 'nex', 'None'), + ('/', 'foo', 'this'), + ('/', 'bar', 'that'), + ('/xyz', 'foo', 'this'), + ('/foo/', 'foo', 'this2'), + ('/foo/', 'bar', 'that'), + ('/foo/', 'bax', 'None'), + ('/foo/bar', 'baz', "'that2'"), + ('/foo/nex', 'baz', 'that2'), + # If 'foo' == 'this', then the mount point '/another' leaks into + # '/'. + ('/another/', 'foo', 'None'), + ] + for path, key, expected in tests: + self.getPage(path + '?key=' + key) + self.assertBody(expected) + + expectedconf = { + # From CP defaults + 'tools.log_headers.on': False, + 'tools.log_tracebacks.on': True, + 'request.show_tracebacks': True, + 'log.screen': False, + 'environment': 'test_suite', + 'engine.autoreload.on': False, + # From global config + 'luxuryyacht': 'throatwobblermangrove', + # From Root._cp_config + 'bar': 'that', + # From Foo._cp_config + 'baz': 'that2', + # From Foo.bar._cp_config + 'foo': 'this3', + 'bax': 'this4', + } + for key, expected in expectedconf.items(): + self.getPage('/foo/bar?key=' + key) + self.assertBody(repr(expected)) + + def testUnrepr(self): + self.getPage('/repr?key=neg') + self.assertBody('-1234') + + self.getPage('/repr?key=filename') + self.assertBody(repr(os.path.join(sys.prefix, 'hello.py'))) + + self.getPage('/repr?key=thing1') + self.assertBody(repr(cherrypy.lib.httputil.response_codes[404])) + + if not getattr(cherrypy.server, 'using_apache', False): + # The object ID's won't match up when using Apache, since the + # server and client are running in different processes. + self.getPage('/repr?key=thing2') + from cherrypy.tutorial import thing2 + self.assertBody(repr(thing2)) + + if not six.PY3: + self.getPage('/repr?key=thing3') + self.assertBody(repr(six.text_type('test'))) + + self.getPage('/repr?key=complex') + self.assertBody('(3+2j)') + + self.getPage('/repr?key=mul') + self.assertBody('18') + + self.getPage('/repr?key=stradd') + self.assertBody(repr('112233')) + + def testRespNamespaces(self): + self.getPage('/foo/silly') + self.assertHeader('X-silly', 'sillyval') + self.assertBody('Hello world') + + def testCustomNamespaces(self): + self.getPage('/raw/incr?num=12') + self.assertBody('13') + + self.getPage('/dbscheme') + self.assertBody(r'sqlite///memory') + + def testHandlerToolConfigOverride(self): + # Assert that config overrides tool constructor args. Above, we set + # the favicon in the page handler to be '../favicon.ico', + # but then overrode it in config to be './static/dirback.jpg'. + self.getPage('/favicon.ico') + self.assertBody(open(os.path.join(localDir, 'static/dirback.jpg'), + 'rb').read()) + + def test_request_body_namespace(self): + self.getPage('/plain', method='POST', headers=[ + ('Content-Type', 'application/x-www-form-urlencoded'), + ('Content-Length', '13')], + body=b'\xff\xfex\x00=\xff\xfea\x00b\x00c\x00') + self.assertBody('abc') + + +class VariableSubstitutionTests(unittest.TestCase): + setup_server = staticmethod(setup_server) + + def test_config(self): + from textwrap import dedent + + # variable substitution with [DEFAULT] + conf = dedent(""" + [DEFAULT] + dir = "/some/dir" + my.dir = %(dir)s + "/sub" + + [my] + my.dir = %(dir)s + "/my/dir" + my.dir2 = %(my.dir)s + '/dir2' + + """) + + fp = StringIOFromNative(conf) + + cherrypy.config.update(fp) + self.assertEqual(cherrypy.config['my']['my.dir'], '/some/dir/my/dir') + self.assertEqual(cherrypy.config['my'] + ['my.dir2'], '/some/dir/my/dir/dir2') + + +class CallablesInConfigTest(unittest.TestCase): + setup_server = staticmethod(setup_server) + + def test_call_with_literal_dict(self): + from textwrap import dedent + conf = dedent(""" + [my] + value = dict(**{'foo': 'bar'}) + """) + fp = StringIOFromNative(conf) + cherrypy.config.update(fp) + self.assertEqual(cherrypy.config['my']['value'], {'foo': 'bar'}) + + def test_call_with_kwargs(self): + from textwrap import dedent + conf = dedent(""" + [my] + value = dict(foo="buzz", **cherrypy._test_dict) + """) + test_dict = { + 'foo': 'bar', + 'bar': 'foo', + 'fizz': 'buzz' + } + cherrypy._test_dict = test_dict + fp = StringIOFromNative(conf) + cherrypy.config.update(fp) + test_dict['foo'] = 'buzz' + self.assertEqual(cherrypy.config['my']['value']['foo'], 'buzz') + self.assertEqual(cherrypy.config['my']['value'], test_dict) + del cherrypy._test_dict diff --git a/libraries/cherrypy/test/test_config_server.py b/libraries/cherrypy/test/test_config_server.py new file mode 100644 index 00000000..7b183530 --- /dev/null +++ b/libraries/cherrypy/test/test_config_server.py @@ -0,0 +1,126 @@ +"""Tests for the CherryPy configuration system.""" + +import os + +import cherrypy +from cherrypy.test import helper + + +localDir = os.path.join(os.getcwd(), os.path.dirname(__file__)) + + +# Client-side code # + + +class ServerConfigTests(helper.CPWebCase): + + @staticmethod + def setup_server(): + + class Root: + + @cherrypy.expose + def index(self): + return cherrypy.request.wsgi_environ['SERVER_PORT'] + + @cherrypy.expose + def upload(self, file): + return 'Size: %s' % len(file.file.read()) + + @cherrypy.expose + @cherrypy.config(**{'request.body.maxbytes': 100}) + def tinyupload(self): + return cherrypy.request.body.read() + + cherrypy.tree.mount(Root()) + + cherrypy.config.update({ + 'server.socket_host': '0.0.0.0', + 'server.socket_port': 9876, + 'server.max_request_body_size': 200, + 'server.max_request_header_size': 500, + 'server.socket_timeout': 0.5, + + # Test explicit server.instance + 'server.2.instance': 'cherrypy._cpwsgi_server.CPWSGIServer', + 'server.2.socket_port': 9877, + + # Test non-numeric <servername> + # Also test default server.instance = builtin server + 'server.yetanother.socket_port': 9878, + }) + + PORT = 9876 + + def testBasicConfig(self): + self.getPage('/') + self.assertBody(str(self.PORT)) + + def testAdditionalServers(self): + if self.scheme == 'https': + return self.skip('not available under ssl') + self.PORT = 9877 + self.getPage('/') + self.assertBody(str(self.PORT)) + self.PORT = 9878 + self.getPage('/') + self.assertBody(str(self.PORT)) + + def testMaxRequestSizePerHandler(self): + if getattr(cherrypy.server, 'using_apache', False): + return self.skip('skipped due to known Apache differences... ') + + self.getPage('/tinyupload', method='POST', + headers=[('Content-Type', 'text/plain'), + ('Content-Length', '100')], + body='x' * 100) + self.assertStatus(200) + self.assertBody('x' * 100) + + self.getPage('/tinyupload', method='POST', + headers=[('Content-Type', 'text/plain'), + ('Content-Length', '101')], + body='x' * 101) + self.assertStatus(413) + + def testMaxRequestSize(self): + if getattr(cherrypy.server, 'using_apache', False): + return self.skip('skipped due to known Apache differences... ') + + for size in (500, 5000, 50000): + self.getPage('/', headers=[('From', 'x' * 500)]) + self.assertStatus(413) + + # Test for https://github.com/cherrypy/cherrypy/issues/421 + # (Incorrect border condition in readline of SizeCheckWrapper). + # This hangs in rev 891 and earlier. + lines256 = 'x' * 248 + self.getPage('/', + headers=[('Host', '%s:%s' % (self.HOST, self.PORT)), + ('From', lines256)]) + + # Test upload + cd = ( + 'Content-Disposition: form-data; ' + 'name="file"; ' + 'filename="hello.txt"' + ) + body = '\r\n'.join([ + '--x', + cd, + 'Content-Type: text/plain', + '', + '%s', + '--x--']) + partlen = 200 - len(body) + b = body % ('x' * partlen) + h = [('Content-type', 'multipart/form-data; boundary=x'), + ('Content-Length', '%s' % len(b))] + self.getPage('/upload', h, 'POST', b) + self.assertBody('Size: %d' % partlen) + + b = body % ('x' * 200) + h = [('Content-type', 'multipart/form-data; boundary=x'), + ('Content-Length', '%s' % len(b))] + self.getPage('/upload', h, 'POST', b) + self.assertStatus(413) diff --git a/libraries/cherrypy/test/test_conn.py b/libraries/cherrypy/test/test_conn.py new file mode 100644 index 00000000..7d60c6fb --- /dev/null +++ b/libraries/cherrypy/test/test_conn.py @@ -0,0 +1,873 @@ +"""Tests for TCP connection handling, including proper and timely close.""" + +import errno +import socket +import sys +import time + +import six +from six.moves import urllib +from six.moves.http_client import BadStatusLine, HTTPConnection, NotConnected + +import pytest + +from cheroot.test import webtest + +import cherrypy +from cherrypy._cpcompat import HTTPSConnection, ntob, tonative +from cherrypy.test import helper + + +timeout = 1 +pov = 'pPeErRsSiIsStTeEnNcCeE oOfF vViIsSiIoOnN' + + +def setup_server(): + + def raise500(): + raise cherrypy.HTTPError(500) + + class Root: + + @cherrypy.expose + def index(self): + return pov + page1 = index + page2 = index + page3 = index + + @cherrypy.expose + def hello(self): + return 'Hello, world!' + + @cherrypy.expose + def timeout(self, t): + return str(cherrypy.server.httpserver.timeout) + + @cherrypy.expose + @cherrypy.config(**{'response.stream': True}) + def stream(self, set_cl=False): + if set_cl: + cherrypy.response.headers['Content-Length'] = 10 + + def content(): + for x in range(10): + yield str(x) + + return content() + + @cherrypy.expose + def error(self, code=500): + raise cherrypy.HTTPError(code) + + @cherrypy.expose + def upload(self): + if not cherrypy.request.method == 'POST': + raise AssertionError("'POST' != request.method %r" % + cherrypy.request.method) + return "thanks for '%s'" % cherrypy.request.body.read() + + @cherrypy.expose + def custom(self, response_code): + cherrypy.response.status = response_code + return 'Code = %s' % response_code + + @cherrypy.expose + @cherrypy.config(**{'hooks.on_start_resource': raise500}) + def err_before_read(self): + return 'ok' + + @cherrypy.expose + def one_megabyte_of_a(self): + return ['a' * 1024] * 1024 + + @cherrypy.expose + # Turn off the encoding tool so it doens't collapse + # our response body and reclaculate the Content-Length. + @cherrypy.config(**{'tools.encode.on': False}) + def custom_cl(self, body, cl): + cherrypy.response.headers['Content-Length'] = cl + if not isinstance(body, list): + body = [body] + newbody = [] + for chunk in body: + if isinstance(chunk, six.text_type): + chunk = chunk.encode('ISO-8859-1') + newbody.append(chunk) + return newbody + + cherrypy.tree.mount(Root()) + cherrypy.config.update({ + 'server.max_request_body_size': 1001, + 'server.socket_timeout': timeout, + }) + + +class ConnectionCloseTests(helper.CPWebCase): + setup_server = staticmethod(setup_server) + + def test_HTTP11(self): + if cherrypy.server.protocol_version != 'HTTP/1.1': + return self.skip() + + self.PROTOCOL = 'HTTP/1.1' + + self.persistent = True + + # Make the first request and assert there's no "Connection: close". + self.getPage('/') + self.assertStatus('200 OK') + self.assertBody(pov) + self.assertNoHeader('Connection') + + # Make another request on the same connection. + self.getPage('/page1') + self.assertStatus('200 OK') + self.assertBody(pov) + self.assertNoHeader('Connection') + + # Test client-side close. + self.getPage('/page2', headers=[('Connection', 'close')]) + self.assertStatus('200 OK') + self.assertBody(pov) + self.assertHeader('Connection', 'close') + + # Make another request on the same connection, which should error. + self.assertRaises(NotConnected, self.getPage, '/') + + def test_Streaming_no_len(self): + try: + self._streaming(set_cl=False) + finally: + try: + self.HTTP_CONN.close() + except (TypeError, AttributeError): + pass + + def test_Streaming_with_len(self): + try: + self._streaming(set_cl=True) + finally: + try: + self.HTTP_CONN.close() + except (TypeError, AttributeError): + pass + + def _streaming(self, set_cl): + if cherrypy.server.protocol_version == 'HTTP/1.1': + self.PROTOCOL = 'HTTP/1.1' + + self.persistent = True + + # Make the first request and assert there's no "Connection: close". + self.getPage('/') + self.assertStatus('200 OK') + self.assertBody(pov) + self.assertNoHeader('Connection') + + # Make another, streamed request on the same connection. + if set_cl: + # When a Content-Length is provided, the content should stream + # without closing the connection. + self.getPage('/stream?set_cl=Yes') + self.assertHeader('Content-Length') + self.assertNoHeader('Connection', 'close') + self.assertNoHeader('Transfer-Encoding') + + self.assertStatus('200 OK') + self.assertBody('0123456789') + else: + # When no Content-Length response header is provided, + # streamed output will either close the connection, or use + # chunked encoding, to determine transfer-length. + self.getPage('/stream') + self.assertNoHeader('Content-Length') + self.assertStatus('200 OK') + self.assertBody('0123456789') + + chunked_response = False + for k, v in self.headers: + if k.lower() == 'transfer-encoding': + if str(v) == 'chunked': + chunked_response = True + + if chunked_response: + self.assertNoHeader('Connection', 'close') + else: + self.assertHeader('Connection', 'close') + + # Make another request on the same connection, which should + # error. + self.assertRaises(NotConnected, self.getPage, '/') + + # Try HEAD. See + # https://github.com/cherrypy/cherrypy/issues/864. + self.getPage('/stream', method='HEAD') + self.assertStatus('200 OK') + self.assertBody('') + self.assertNoHeader('Transfer-Encoding') + else: + self.PROTOCOL = 'HTTP/1.0' + + self.persistent = True + + # Make the first request and assert Keep-Alive. + self.getPage('/', headers=[('Connection', 'Keep-Alive')]) + self.assertStatus('200 OK') + self.assertBody(pov) + self.assertHeader('Connection', 'Keep-Alive') + + # Make another, streamed request on the same connection. + if set_cl: + # When a Content-Length is provided, the content should + # stream without closing the connection. + self.getPage('/stream?set_cl=Yes', + headers=[('Connection', 'Keep-Alive')]) + self.assertHeader('Content-Length') + self.assertHeader('Connection', 'Keep-Alive') + self.assertNoHeader('Transfer-Encoding') + self.assertStatus('200 OK') + self.assertBody('0123456789') + else: + # When a Content-Length is not provided, + # the server should close the connection. + self.getPage('/stream', headers=[('Connection', 'Keep-Alive')]) + self.assertStatus('200 OK') + self.assertBody('0123456789') + + self.assertNoHeader('Content-Length') + self.assertNoHeader('Connection', 'Keep-Alive') + self.assertNoHeader('Transfer-Encoding') + + # Make another request on the same connection, which should + # error. + self.assertRaises(NotConnected, self.getPage, '/') + + def test_HTTP10_KeepAlive(self): + self.PROTOCOL = 'HTTP/1.0' + if self.scheme == 'https': + self.HTTP_CONN = HTTPSConnection + else: + self.HTTP_CONN = HTTPConnection + + # Test a normal HTTP/1.0 request. + self.getPage('/page2') + self.assertStatus('200 OK') + self.assertBody(pov) + # Apache, for example, may emit a Connection header even for HTTP/1.0 + # self.assertNoHeader("Connection") + + # Test a keep-alive HTTP/1.0 request. + self.persistent = True + + self.getPage('/page3', headers=[('Connection', 'Keep-Alive')]) + self.assertStatus('200 OK') + self.assertBody(pov) + self.assertHeader('Connection', 'Keep-Alive') + + # Remove the keep-alive header again. + self.getPage('/page3') + self.assertStatus('200 OK') + self.assertBody(pov) + # Apache, for example, may emit a Connection header even for HTTP/1.0 + # self.assertNoHeader("Connection") + + +class PipelineTests(helper.CPWebCase): + setup_server = staticmethod(setup_server) + + def test_HTTP11_Timeout(self): + # If we timeout without sending any data, + # the server will close the conn with a 408. + if cherrypy.server.protocol_version != 'HTTP/1.1': + return self.skip() + + self.PROTOCOL = 'HTTP/1.1' + + # Connect but send nothing. + self.persistent = True + conn = self.HTTP_CONN + conn.auto_open = False + conn.connect() + + # Wait for our socket timeout + time.sleep(timeout * 2) + + # The request should have returned 408 already. + response = conn.response_class(conn.sock, method='GET') + response.begin() + self.assertEqual(response.status, 408) + conn.close() + + # Connect but send half the headers only. + self.persistent = True + conn = self.HTTP_CONN + conn.auto_open = False + conn.connect() + conn.send(b'GET /hello HTTP/1.1') + conn.send(('Host: %s' % self.HOST).encode('ascii')) + + # Wait for our socket timeout + time.sleep(timeout * 2) + + # The conn should have already sent 408. + response = conn.response_class(conn.sock, method='GET') + response.begin() + self.assertEqual(response.status, 408) + conn.close() + + def test_HTTP11_Timeout_after_request(self): + # If we timeout after at least one request has succeeded, + # the server will close the conn without 408. + if cherrypy.server.protocol_version != 'HTTP/1.1': + return self.skip() + + self.PROTOCOL = 'HTTP/1.1' + + # Make an initial request + self.persistent = True + conn = self.HTTP_CONN + conn.putrequest('GET', '/timeout?t=%s' % timeout, skip_host=True) + conn.putheader('Host', self.HOST) + conn.endheaders() + response = conn.response_class(conn.sock, method='GET') + response.begin() + self.assertEqual(response.status, 200) + self.body = response.read() + self.assertBody(str(timeout)) + + # Make a second request on the same socket + conn._output(b'GET /hello HTTP/1.1') + conn._output(ntob('Host: %s' % self.HOST, 'ascii')) + conn._send_output() + response = conn.response_class(conn.sock, method='GET') + response.begin() + self.assertEqual(response.status, 200) + self.body = response.read() + self.assertBody('Hello, world!') + + # Wait for our socket timeout + time.sleep(timeout * 2) + + # Make another request on the same socket, which should error + conn._output(b'GET /hello HTTP/1.1') + conn._output(ntob('Host: %s' % self.HOST, 'ascii')) + conn._send_output() + response = conn.response_class(conn.sock, method='GET') + try: + response.begin() + except Exception: + if not isinstance(sys.exc_info()[1], + (socket.error, BadStatusLine)): + self.fail("Writing to timed out socket didn't fail" + ' as it should have: %s' % sys.exc_info()[1]) + else: + if response.status != 408: + self.fail("Writing to timed out socket didn't fail" + ' as it should have: %s' % + response.read()) + + conn.close() + + # Make another request on a new socket, which should work + self.persistent = True + conn = self.HTTP_CONN + conn.putrequest('GET', '/', skip_host=True) + conn.putheader('Host', self.HOST) + conn.endheaders() + response = conn.response_class(conn.sock, method='GET') + response.begin() + self.assertEqual(response.status, 200) + self.body = response.read() + self.assertBody(pov) + + # Make another request on the same socket, + # but timeout on the headers + conn.send(b'GET /hello HTTP/1.1') + # Wait for our socket timeout + time.sleep(timeout * 2) + response = conn.response_class(conn.sock, method='GET') + try: + response.begin() + except Exception: + if not isinstance(sys.exc_info()[1], + (socket.error, BadStatusLine)): + self.fail("Writing to timed out socket didn't fail" + ' as it should have: %s' % sys.exc_info()[1]) + else: + self.fail("Writing to timed out socket didn't fail" + ' as it should have: %s' % + response.read()) + + conn.close() + + # Retry the request on a new connection, which should work + self.persistent = True + conn = self.HTTP_CONN + conn.putrequest('GET', '/', skip_host=True) + conn.putheader('Host', self.HOST) + conn.endheaders() + response = conn.response_class(conn.sock, method='GET') + response.begin() + self.assertEqual(response.status, 200) + self.body = response.read() + self.assertBody(pov) + conn.close() + + def test_HTTP11_pipelining(self): + if cherrypy.server.protocol_version != 'HTTP/1.1': + return self.skip() + + self.PROTOCOL = 'HTTP/1.1' + + # Test pipelining. httplib doesn't support this directly. + self.persistent = True + conn = self.HTTP_CONN + + # Put request 1 + conn.putrequest('GET', '/hello', skip_host=True) + conn.putheader('Host', self.HOST) + conn.endheaders() + + for trial in range(5): + # Put next request + conn._output(b'GET /hello HTTP/1.1') + conn._output(ntob('Host: %s' % self.HOST, 'ascii')) + conn._send_output() + + # Retrieve previous response + response = conn.response_class(conn.sock, method='GET') + # there is a bug in python3 regarding the buffering of + # ``conn.sock``. Until that bug get's fixed we will + # monkey patch the ``response`` instance. + # https://bugs.python.org/issue23377 + if six.PY3: + response.fp = conn.sock.makefile('rb', 0) + response.begin() + body = response.read(13) + self.assertEqual(response.status, 200) + self.assertEqual(body, b'Hello, world!') + + # Retrieve final response + response = conn.response_class(conn.sock, method='GET') + response.begin() + body = response.read() + self.assertEqual(response.status, 200) + self.assertEqual(body, b'Hello, world!') + + conn.close() + + def test_100_Continue(self): + if cherrypy.server.protocol_version != 'HTTP/1.1': + return self.skip() + + self.PROTOCOL = 'HTTP/1.1' + + self.persistent = True + conn = self.HTTP_CONN + + # Try a page without an Expect request header first. + # Note that httplib's response.begin automatically ignores + # 100 Continue responses, so we must manually check for it. + try: + conn.putrequest('POST', '/upload', skip_host=True) + conn.putheader('Host', self.HOST) + conn.putheader('Content-Type', 'text/plain') + conn.putheader('Content-Length', '4') + conn.endheaders() + conn.send(ntob("d'oh")) + response = conn.response_class(conn.sock, method='POST') + version, status, reason = response._read_status() + self.assertNotEqual(status, 100) + finally: + conn.close() + + # Now try a page with an Expect header... + try: + conn.connect() + conn.putrequest('POST', '/upload', skip_host=True) + conn.putheader('Host', self.HOST) + conn.putheader('Content-Type', 'text/plain') + conn.putheader('Content-Length', '17') + conn.putheader('Expect', '100-continue') + conn.endheaders() + response = conn.response_class(conn.sock, method='POST') + + # ...assert and then skip the 100 response + version, status, reason = response._read_status() + self.assertEqual(status, 100) + while True: + line = response.fp.readline().strip() + if line: + self.fail( + '100 Continue should not output any headers. Got %r' % + line) + else: + break + + # ...send the body + body = b'I am a small file' + conn.send(body) + + # ...get the final response + response.begin() + self.status, self.headers, self.body = webtest.shb(response) + self.assertStatus(200) + self.assertBody("thanks for '%s'" % body) + finally: + conn.close() + + +class ConnectionTests(helper.CPWebCase): + setup_server = staticmethod(setup_server) + + def test_readall_or_close(self): + if cherrypy.server.protocol_version != 'HTTP/1.1': + return self.skip() + + self.PROTOCOL = 'HTTP/1.1' + + if self.scheme == 'https': + self.HTTP_CONN = HTTPSConnection + else: + self.HTTP_CONN = HTTPConnection + + # Test a max of 0 (the default) and then reset to what it was above. + old_max = cherrypy.server.max_request_body_size + for new_max in (0, old_max): + cherrypy.server.max_request_body_size = new_max + + self.persistent = True + conn = self.HTTP_CONN + + # Get a POST page with an error + conn.putrequest('POST', '/err_before_read', skip_host=True) + conn.putheader('Host', self.HOST) + conn.putheader('Content-Type', 'text/plain') + conn.putheader('Content-Length', '1000') + conn.putheader('Expect', '100-continue') + conn.endheaders() + response = conn.response_class(conn.sock, method='POST') + + # ...assert and then skip the 100 response + version, status, reason = response._read_status() + self.assertEqual(status, 100) + while True: + skip = response.fp.readline().strip() + if not skip: + break + + # ...send the body + conn.send(ntob('x' * 1000)) + + # ...get the final response + response.begin() + self.status, self.headers, self.body = webtest.shb(response) + self.assertStatus(500) + + # Now try a working page with an Expect header... + conn._output(b'POST /upload HTTP/1.1') + conn._output(ntob('Host: %s' % self.HOST, 'ascii')) + conn._output(b'Content-Type: text/plain') + conn._output(b'Content-Length: 17') + conn._output(b'Expect: 100-continue') + conn._send_output() + response = conn.response_class(conn.sock, method='POST') + + # ...assert and then skip the 100 response + version, status, reason = response._read_status() + self.assertEqual(status, 100) + while True: + skip = response.fp.readline().strip() + if not skip: + break + + # ...send the body + body = b'I am a small file' + conn.send(body) + + # ...get the final response + response.begin() + self.status, self.headers, self.body = webtest.shb(response) + self.assertStatus(200) + self.assertBody("thanks for '%s'" % body) + conn.close() + + def test_No_Message_Body(self): + if cherrypy.server.protocol_version != 'HTTP/1.1': + return self.skip() + + self.PROTOCOL = 'HTTP/1.1' + + # Set our HTTP_CONN to an instance so it persists between requests. + self.persistent = True + + # Make the first request and assert there's no "Connection: close". + self.getPage('/') + self.assertStatus('200 OK') + self.assertBody(pov) + self.assertNoHeader('Connection') + + # Make a 204 request on the same connection. + self.getPage('/custom/204') + self.assertStatus(204) + self.assertNoHeader('Content-Length') + self.assertBody('') + self.assertNoHeader('Connection') + + # Make a 304 request on the same connection. + self.getPage('/custom/304') + self.assertStatus(304) + self.assertNoHeader('Content-Length') + self.assertBody('') + self.assertNoHeader('Connection') + + def test_Chunked_Encoding(self): + if cherrypy.server.protocol_version != 'HTTP/1.1': + return self.skip() + + if (hasattr(self, 'harness') and + 'modpython' in self.harness.__class__.__name__.lower()): + # mod_python forbids chunked encoding + return self.skip() + + self.PROTOCOL = 'HTTP/1.1' + + # Set our HTTP_CONN to an instance so it persists between requests. + self.persistent = True + conn = self.HTTP_CONN + + # Try a normal chunked request (with extensions) + body = ntob('8;key=value\r\nxx\r\nxxxx\r\n5\r\nyyyyy\r\n0\r\n' + 'Content-Type: application/json\r\n' + '\r\n') + conn.putrequest('POST', '/upload', skip_host=True) + conn.putheader('Host', self.HOST) + conn.putheader('Transfer-Encoding', 'chunked') + conn.putheader('Trailer', 'Content-Type') + # Note that this is somewhat malformed: + # we shouldn't be sending Content-Length. + # RFC 2616 says the server should ignore it. + conn.putheader('Content-Length', '3') + conn.endheaders() + conn.send(body) + response = conn.getresponse() + self.status, self.headers, self.body = webtest.shb(response) + self.assertStatus('200 OK') + self.assertBody("thanks for '%s'" % b'xx\r\nxxxxyyyyy') + + # Try a chunked request that exceeds server.max_request_body_size. + # Note that the delimiters and trailer are included. + body = ntob('3e3\r\n' + ('x' * 995) + '\r\n0\r\n\r\n') + conn.putrequest('POST', '/upload', skip_host=True) + conn.putheader('Host', self.HOST) + conn.putheader('Transfer-Encoding', 'chunked') + conn.putheader('Content-Type', 'text/plain') + # Chunked requests don't need a content-length + # # conn.putheader("Content-Length", len(body)) + conn.endheaders() + conn.send(body) + response = conn.getresponse() + self.status, self.headers, self.body = webtest.shb(response) + self.assertStatus(413) + conn.close() + + def test_Content_Length_in(self): + # Try a non-chunked request where Content-Length exceeds + # server.max_request_body_size. Assert error before body send. + self.persistent = True + conn = self.HTTP_CONN + conn.putrequest('POST', '/upload', skip_host=True) + conn.putheader('Host', self.HOST) + conn.putheader('Content-Type', 'text/plain') + conn.putheader('Content-Length', '9999') + conn.endheaders() + response = conn.getresponse() + self.status, self.headers, self.body = webtest.shb(response) + self.assertStatus(413) + self.assertBody('The entity sent with the request exceeds ' + 'the maximum allowed bytes.') + conn.close() + + def test_Content_Length_out_preheaders(self): + # Try a non-chunked response where Content-Length is less than + # the actual bytes in the response body. + self.persistent = True + conn = self.HTTP_CONN + conn.putrequest('GET', '/custom_cl?body=I+have+too+many+bytes&cl=5', + skip_host=True) + conn.putheader('Host', self.HOST) + conn.endheaders() + response = conn.getresponse() + self.status, self.headers, self.body = webtest.shb(response) + self.assertStatus(500) + self.assertBody( + 'The requested resource returned more bytes than the ' + 'declared Content-Length.') + conn.close() + + def test_Content_Length_out_postheaders(self): + # Try a non-chunked response where Content-Length is less than + # the actual bytes in the response body. + self.persistent = True + conn = self.HTTP_CONN + conn.putrequest( + 'GET', '/custom_cl?body=I+too&body=+have+too+many&cl=5', + skip_host=True) + conn.putheader('Host', self.HOST) + conn.endheaders() + response = conn.getresponse() + self.status, self.headers, self.body = webtest.shb(response) + self.assertStatus(200) + self.assertBody('I too') + conn.close() + + def test_598(self): + tmpl = '{scheme}://{host}:{port}/one_megabyte_of_a/' + url = tmpl.format( + scheme=self.scheme, + host=self.HOST, + port=self.PORT, + ) + remote_data_conn = urllib.request.urlopen(url) + buf = remote_data_conn.read(512) + time.sleep(timeout * 0.6) + remaining = (1024 * 1024) - 512 + while remaining: + data = remote_data_conn.read(remaining) + if not data: + break + else: + buf += data + remaining -= len(data) + + self.assertEqual(len(buf), 1024 * 1024) + self.assertEqual(buf, ntob('a' * 1024 * 1024)) + self.assertEqual(remaining, 0) + remote_data_conn.close() + + +def setup_upload_server(): + + class Root: + @cherrypy.expose + def upload(self): + if not cherrypy.request.method == 'POST': + raise AssertionError("'POST' != request.method %r" % + cherrypy.request.method) + return "thanks for '%s'" % tonative(cherrypy.request.body.read()) + + cherrypy.tree.mount(Root()) + cherrypy.config.update({ + 'server.max_request_body_size': 1001, + 'server.socket_timeout': 10, + 'server.accepted_queue_size': 5, + 'server.accepted_queue_timeout': 0.1, + }) + + +reset_names = 'ECONNRESET', 'WSAECONNRESET' +socket_reset_errors = [ + getattr(errno, name) + for name in reset_names + if hasattr(errno, name) +] +'reset error numbers available on this platform' + +socket_reset_errors += [ + # Python 3.5 raises an http.client.RemoteDisconnected + # with this message + 'Remote end closed connection without response', +] + + +class LimitedRequestQueueTests(helper.CPWebCase): + setup_server = staticmethod(setup_upload_server) + + @pytest.mark.xfail(reason='#1535') + def test_queue_full(self): + conns = [] + overflow_conn = None + + try: + # Make 15 initial requests and leave them open, which should use + # all of wsgiserver's WorkerThreads and fill its Queue. + for i in range(15): + conn = self.HTTP_CONN(self.HOST, self.PORT) + conn.putrequest('POST', '/upload', skip_host=True) + conn.putheader('Host', self.HOST) + conn.putheader('Content-Type', 'text/plain') + conn.putheader('Content-Length', '4') + conn.endheaders() + conns.append(conn) + + # Now try a 16th conn, which should be closed by the + # server immediately. + overflow_conn = self.HTTP_CONN(self.HOST, self.PORT) + # Manually connect since httplib won't let us set a timeout + for res in socket.getaddrinfo(self.HOST, self.PORT, 0, + socket.SOCK_STREAM): + af, socktype, proto, canonname, sa = res + overflow_conn.sock = socket.socket(af, socktype, proto) + overflow_conn.sock.settimeout(5) + overflow_conn.sock.connect(sa) + break + + overflow_conn.putrequest('GET', '/', skip_host=True) + overflow_conn.putheader('Host', self.HOST) + overflow_conn.endheaders() + response = overflow_conn.response_class( + overflow_conn.sock, + method='GET', + ) + try: + response.begin() + except socket.error as exc: + if exc.args[0] in socket_reset_errors: + pass # Expected. + else: + tmpl = ( + 'Overflow conn did not get RST. ' + 'Got {exc.args!r} instead' + ) + raise AssertionError(tmpl.format(**locals())) + except BadStatusLine: + # This is a special case in OS X. Linux and Windows will + # RST correctly. + assert sys.platform == 'darwin' + else: + raise AssertionError('Overflow conn did not get RST ') + finally: + for conn in conns: + conn.send(b'done') + response = conn.response_class(conn.sock, method='POST') + response.begin() + self.body = response.read() + self.assertBody("thanks for 'done'") + self.assertEqual(response.status, 200) + conn.close() + if overflow_conn: + overflow_conn.close() + + +class BadRequestTests(helper.CPWebCase): + setup_server = staticmethod(setup_server) + + def test_No_CRLF(self): + self.persistent = True + + conn = self.HTTP_CONN + conn.send(b'GET /hello HTTP/1.1\n\n') + response = conn.response_class(conn.sock, method='GET') + response.begin() + self.body = response.read() + self.assertBody('HTTP requires CRLF terminators') + conn.close() + + conn.connect() + conn.send(b'GET /hello HTTP/1.1\r\n\n') + response = conn.response_class(conn.sock, method='GET') + response.begin() + self.body = response.read() + self.assertBody('HTTP requires CRLF terminators') + conn.close() diff --git a/libraries/cherrypy/test/test_core.py b/libraries/cherrypy/test/test_core.py new file mode 100644 index 00000000..9834c1f3 --- /dev/null +++ b/libraries/cherrypy/test/test_core.py @@ -0,0 +1,823 @@ +# coding: utf-8 + +"""Basic tests for the CherryPy core: request handling.""" + +import os +import sys +import types + +import six + +import cherrypy +from cherrypy._cpcompat import ntou +from cherrypy import _cptools, tools +from cherrypy.lib import httputil, static + +from cherrypy.test._test_decorators import ExposeExamples +from cherrypy.test import helper + + +localDir = os.path.dirname(__file__) +favicon_path = os.path.join(os.getcwd(), localDir, '../favicon.ico') + +# Client-side code # + + +class CoreRequestHandlingTest(helper.CPWebCase): + + @staticmethod + def setup_server(): + class Root: + + @cherrypy.expose + def index(self): + return 'hello' + + favicon_ico = tools.staticfile.handler(filename=favicon_path) + + @cherrypy.expose + def defct(self, newct): + newct = 'text/%s' % newct + cherrypy.config.update({'tools.response_headers.on': True, + 'tools.response_headers.headers': + [('Content-Type', newct)]}) + + @cherrypy.expose + def baseurl(self, path_info, relative=None): + return cherrypy.url(path_info, relative=bool(relative)) + + root = Root() + root.expose_dec = ExposeExamples() + + class TestType(type): + + """Metaclass which automatically exposes all functions in each + subclass, and adds an instance of the subclass as an attribute + of root. + """ + def __init__(cls, name, bases, dct): + type.__init__(cls, name, bases, dct) + for value in six.itervalues(dct): + if isinstance(value, types.FunctionType): + value.exposed = True + setattr(root, name.lower(), cls()) + Test = TestType('Test', (object, ), {}) + + @cherrypy.config(**{'tools.trailing_slash.on': False}) + class URL(Test): + + def index(self, path_info, relative=None): + if relative != 'server': + relative = bool(relative) + return cherrypy.url(path_info, relative=relative) + + def leaf(self, path_info, relative=None): + if relative != 'server': + relative = bool(relative) + return cherrypy.url(path_info, relative=relative) + + def qs(self, qs): + return cherrypy.url(qs=qs) + + def log_status(): + Status.statuses.append(cherrypy.response.status) + cherrypy.tools.log_status = cherrypy.Tool( + 'on_end_resource', log_status) + + class Status(Test): + + def index(self): + return 'normal' + + def blank(self): + cherrypy.response.status = '' + + # According to RFC 2616, new status codes are OK as long as they + # are between 100 and 599. + + # Here is an illegal code... + def illegal(self): + cherrypy.response.status = 781 + return 'oops' + + # ...and here is an unknown but legal code. + def unknown(self): + cherrypy.response.status = '431 My custom error' + return 'funky' + + # Non-numeric code + def bad(self): + cherrypy.response.status = 'error' + return 'bad news' + + statuses = [] + + @cherrypy.config(**{'tools.log_status.on': True}) + def on_end_resource_stage(self): + return repr(self.statuses) + + class Redirect(Test): + + @cherrypy.config(**{ + 'tools.err_redirect.on': True, + 'tools.err_redirect.url': '/errpage', + 'tools.err_redirect.internal': False, + }) + class Error: + @cherrypy.expose + def index(self): + raise NameError('redirect_test') + + error = Error() + + def index(self): + return 'child' + + def custom(self, url, code): + raise cherrypy.HTTPRedirect(url, code) + + @cherrypy.config(**{'tools.trailing_slash.extra': True}) + def by_code(self, code): + raise cherrypy.HTTPRedirect('somewhere%20else', code) + + def nomodify(self): + raise cherrypy.HTTPRedirect('', 304) + + def proxy(self): + raise cherrypy.HTTPRedirect('proxy', 305) + + def stringify(self): + return str(cherrypy.HTTPRedirect('/')) + + def fragment(self, frag): + raise cherrypy.HTTPRedirect('/some/url#%s' % frag) + + def url_with_quote(self): + raise cherrypy.HTTPRedirect("/some\"url/that'we/want") + + def url_with_xss(self): + raise cherrypy.HTTPRedirect( + "/some<script>alert(1);</script>url/that'we/want") + + def url_with_unicode(self): + raise cherrypy.HTTPRedirect(ntou('тест', 'utf-8')) + + def login_redir(): + if not getattr(cherrypy.request, 'login', None): + raise cherrypy.InternalRedirect('/internalredirect/login') + tools.login_redir = _cptools.Tool('before_handler', login_redir) + + def redir_custom(): + raise cherrypy.InternalRedirect('/internalredirect/custom_err') + + class InternalRedirect(Test): + + def index(self): + raise cherrypy.InternalRedirect('/') + + @cherrypy.expose + @cherrypy.config(**{'hooks.before_error_response': redir_custom}) + def choke(self): + return 3 / 0 + + def relative(self, a, b): + raise cherrypy.InternalRedirect('cousin?t=6') + + def cousin(self, t): + assert cherrypy.request.prev.closed + return cherrypy.request.prev.query_string + + def petshop(self, user_id): + if user_id == 'parrot': + # Trade it for a slug when redirecting + raise cherrypy.InternalRedirect( + '/image/getImagesByUser?user_id=slug') + elif user_id == 'terrier': + # Trade it for a fish when redirecting + raise cherrypy.InternalRedirect( + '/image/getImagesByUser?user_id=fish') + else: + # This should pass the user_id through to getImagesByUser + raise cherrypy.InternalRedirect( + '/image/getImagesByUser?user_id=%s' % str(user_id)) + + # We support Python 2.3, but the @-deco syntax would look like + # this: + # @tools.login_redir() + def secure(self): + return 'Welcome!' + secure = tools.login_redir()(secure) + # Since calling the tool returns the same function you pass in, + # you could skip binding the return value, and just write: + # tools.login_redir()(secure) + + def login(self): + return 'Please log in' + + def custom_err(self): + return 'Something went horribly wrong.' + + @cherrypy.config(**{'hooks.before_request_body': redir_custom}) + def early_ir(self, arg): + return 'whatever' + + class Image(Test): + + def getImagesByUser(self, user_id): + return '0 images for %s' % user_id + + class Flatten(Test): + + def as_string(self): + return 'content' + + def as_list(self): + return ['con', 'tent'] + + def as_yield(self): + yield b'content' + + @cherrypy.config(**{'tools.flatten.on': True}) + def as_dblyield(self): + yield self.as_yield() + + def as_refyield(self): + for chunk in self.as_yield(): + yield chunk + + class Ranges(Test): + + def get_ranges(self, bytes): + return repr(httputil.get_ranges('bytes=%s' % bytes, 8)) + + def slice_file(self): + path = os.path.join(os.getcwd(), os.path.dirname(__file__)) + return static.serve_file( + os.path.join(path, 'static/index.html')) + + class Cookies(Test): + + def single(self, name): + cookie = cherrypy.request.cookie[name] + # Python2's SimpleCookie.__setitem__ won't take unicode keys. + cherrypy.response.cookie[str(name)] = cookie.value + + def multiple(self, names): + list(map(self.single, names)) + + def append_headers(header_list, debug=False): + if debug: + cherrypy.log( + 'Extending response headers with %s' % repr(header_list), + 'TOOLS.APPEND_HEADERS') + cherrypy.serving.response.header_list.extend(header_list) + cherrypy.tools.append_headers = cherrypy.Tool( + 'on_end_resource', append_headers) + + class MultiHeader(Test): + + def header_list(self): + pass + header_list = cherrypy.tools.append_headers(header_list=[ + (b'WWW-Authenticate', b'Negotiate'), + (b'WWW-Authenticate', b'Basic realm="foo"'), + ])(header_list) + + def commas(self): + cherrypy.response.headers[ + 'WWW-Authenticate'] = 'Negotiate,Basic realm="foo"' + + cherrypy.tree.mount(root) + + def testStatus(self): + self.getPage('/status/') + self.assertBody('normal') + self.assertStatus(200) + + self.getPage('/status/blank') + self.assertBody('') + self.assertStatus(200) + + self.getPage('/status/illegal') + self.assertStatus(500) + msg = 'Illegal response status from server (781 is out of range).' + self.assertErrorPage(500, msg) + + if not getattr(cherrypy.server, 'using_apache', False): + self.getPage('/status/unknown') + self.assertBody('funky') + self.assertStatus(431) + + self.getPage('/status/bad') + self.assertStatus(500) + msg = "Illegal response status from server ('error' is non-numeric)." + self.assertErrorPage(500, msg) + + def test_on_end_resource_status(self): + self.getPage('/status/on_end_resource_stage') + self.assertBody('[]') + self.getPage('/status/on_end_resource_stage') + self.assertBody(repr(['200 OK'])) + + def testSlashes(self): + # Test that requests for index methods without a trailing slash + # get redirected to the same URI path with a trailing slash. + # Make sure GET params are preserved. + self.getPage('/redirect?id=3') + self.assertStatus(301) + self.assertMatchesBody( + '<a href=([\'"])%s/redirect/[?]id=3\\1>' + '%s/redirect/[?]id=3</a>' % (self.base(), self.base()) + ) + + if self.prefix(): + # Corner case: the "trailing slash" redirect could be tricky if + # we're using a virtual root and the URI is "/vroot" (no slash). + self.getPage('') + self.assertStatus(301) + self.assertMatchesBody("<a href=(['\"])%s/\\1>%s/</a>" % + (self.base(), self.base())) + + # Test that requests for NON-index methods WITH a trailing slash + # get redirected to the same URI path WITHOUT a trailing slash. + # Make sure GET params are preserved. + self.getPage('/redirect/by_code/?code=307') + self.assertStatus(301) + self.assertMatchesBody( + "<a href=(['\"])%s/redirect/by_code[?]code=307\\1>" + '%s/redirect/by_code[?]code=307</a>' + % (self.base(), self.base()) + ) + + # If the trailing_slash tool is off, CP should just continue + # as if the slashes were correct. But it needs some help + # inside cherrypy.url to form correct output. + self.getPage('/url?path_info=page1') + self.assertBody('%s/url/page1' % self.base()) + self.getPage('/url/leaf/?path_info=page1') + self.assertBody('%s/url/page1' % self.base()) + + def testRedirect(self): + self.getPage('/redirect/') + self.assertBody('child') + self.assertStatus(200) + + self.getPage('/redirect/by_code?code=300') + self.assertMatchesBody( + r"<a href=(['\"])(.*)somewhere%20else\1>\2somewhere%20else</a>") + self.assertStatus(300) + + self.getPage('/redirect/by_code?code=301') + self.assertMatchesBody( + r"<a href=(['\"])(.*)somewhere%20else\1>\2somewhere%20else</a>") + self.assertStatus(301) + + self.getPage('/redirect/by_code?code=302') + self.assertMatchesBody( + r"<a href=(['\"])(.*)somewhere%20else\1>\2somewhere%20else</a>") + self.assertStatus(302) + + self.getPage('/redirect/by_code?code=303') + self.assertMatchesBody( + r"<a href=(['\"])(.*)somewhere%20else\1>\2somewhere%20else</a>") + self.assertStatus(303) + + self.getPage('/redirect/by_code?code=307') + self.assertMatchesBody( + r"<a href=(['\"])(.*)somewhere%20else\1>\2somewhere%20else</a>") + self.assertStatus(307) + + self.getPage('/redirect/nomodify') + self.assertBody('') + self.assertStatus(304) + + self.getPage('/redirect/proxy') + self.assertBody('') + self.assertStatus(305) + + # HTTPRedirect on error + self.getPage('/redirect/error/') + self.assertStatus(('302 Found', '303 See Other')) + self.assertInBody('/errpage') + + # Make sure str(HTTPRedirect()) works. + self.getPage('/redirect/stringify', protocol='HTTP/1.0') + self.assertStatus(200) + self.assertBody("(['%s/'], 302)" % self.base()) + if cherrypy.server.protocol_version == 'HTTP/1.1': + self.getPage('/redirect/stringify', protocol='HTTP/1.1') + self.assertStatus(200) + self.assertBody("(['%s/'], 303)" % self.base()) + + # check that #fragments are handled properly + # http://skrb.org/ietf/http_errata.html#location-fragments + frag = 'foo' + self.getPage('/redirect/fragment/%s' % frag) + self.assertMatchesBody( + r"<a href=(['\"])(.*)\/some\/url\#%s\1>\2\/some\/url\#%s</a>" % ( + frag, frag)) + loc = self.assertHeader('Location') + assert loc.endswith('#%s' % frag) + self.assertStatus(('302 Found', '303 See Other')) + + # check injection protection + # See https://github.com/cherrypy/cherrypy/issues/1003 + self.getPage( + '/redirect/custom?' + 'code=303&url=/foobar/%0d%0aSet-Cookie:%20somecookie=someval') + self.assertStatus(303) + loc = self.assertHeader('Location') + assert 'Set-Cookie' in loc + self.assertNoHeader('Set-Cookie') + + def assertValidXHTML(): + from xml.etree import ElementTree + try: + ElementTree.fromstring( + '<html><body>%s</body></html>' % self.body, + ) + except ElementTree.ParseError: + self._handlewebError( + 'automatically generated redirect did not ' + 'generate well-formed html', + ) + + # check redirects to URLs generated valid HTML - we check this + # by seeing if it appears as valid XHTML. + self.getPage('/redirect/by_code?code=303') + self.assertStatus(303) + assertValidXHTML() + + # do the same with a url containing quote characters. + self.getPage('/redirect/url_with_quote') + self.assertStatus(303) + assertValidXHTML() + + def test_redirect_with_xss(self): + """A redirect to a URL with HTML injected should result + in page contents escaped.""" + self.getPage('/redirect/url_with_xss') + self.assertStatus(303) + assert b'<script>' not in self.body + assert b'<script>' in self.body + + def test_redirect_with_unicode(self): + """ + A redirect to a URL with Unicode should return a Location + header containing that Unicode URL. + """ + # test disabled due to #1440 + return + self.getPage('/redirect/url_with_unicode') + self.assertStatus(303) + loc = self.assertHeader('Location') + assert ntou('тест', encoding='utf-8') in loc + + def test_InternalRedirect(self): + # InternalRedirect + self.getPage('/internalredirect/') + self.assertBody('hello') + self.assertStatus(200) + + # Test passthrough + self.getPage( + '/internalredirect/petshop?user_id=Sir-not-appearing-in-this-film') + self.assertBody('0 images for Sir-not-appearing-in-this-film') + self.assertStatus(200) + + # Test args + self.getPage('/internalredirect/petshop?user_id=parrot') + self.assertBody('0 images for slug') + self.assertStatus(200) + + # Test POST + self.getPage('/internalredirect/petshop', method='POST', + body='user_id=terrier') + self.assertBody('0 images for fish') + self.assertStatus(200) + + # Test ir before body read + self.getPage('/internalredirect/early_ir', method='POST', + body='arg=aha!') + self.assertBody('Something went horribly wrong.') + self.assertStatus(200) + + self.getPage('/internalredirect/secure') + self.assertBody('Please log in') + self.assertStatus(200) + + # Relative path in InternalRedirect. + # Also tests request.prev. + self.getPage('/internalredirect/relative?a=3&b=5') + self.assertBody('a=3&b=5') + self.assertStatus(200) + + # InternalRedirect on error + self.getPage('/internalredirect/choke') + self.assertStatus(200) + self.assertBody('Something went horribly wrong.') + + def testFlatten(self): + for url in ['/flatten/as_string', '/flatten/as_list', + '/flatten/as_yield', '/flatten/as_dblyield', + '/flatten/as_refyield']: + self.getPage(url) + self.assertBody('content') + + def testRanges(self): + self.getPage('/ranges/get_ranges?bytes=3-6') + self.assertBody('[(3, 7)]') + + # Test multiple ranges and a suffix-byte-range-spec, for good measure. + self.getPage('/ranges/get_ranges?bytes=2-4,-1') + self.assertBody('[(2, 5), (7, 8)]') + + # Test a suffix-byte-range longer than the content + # length. Note that in this test, the content length + # is 8 bytes. + self.getPage('/ranges/get_ranges?bytes=-100') + self.assertBody('[(0, 8)]') + + # Get a partial file. + if cherrypy.server.protocol_version == 'HTTP/1.1': + self.getPage('/ranges/slice_file', [('Range', 'bytes=2-5')]) + self.assertStatus(206) + self.assertHeader('Content-Type', 'text/html;charset=utf-8') + self.assertHeader('Content-Range', 'bytes 2-5/14') + self.assertBody('llo,') + + # What happens with overlapping ranges (and out of order, too)? + self.getPage('/ranges/slice_file', [('Range', 'bytes=4-6,2-5')]) + self.assertStatus(206) + ct = self.assertHeader('Content-Type') + expected_type = 'multipart/byteranges; boundary=' + self.assert_(ct.startswith(expected_type)) + boundary = ct[len(expected_type):] + expected_body = ('\r\n--%s\r\n' + 'Content-type: text/html\r\n' + 'Content-range: bytes 4-6/14\r\n' + '\r\n' + 'o, \r\n' + '--%s\r\n' + 'Content-type: text/html\r\n' + 'Content-range: bytes 2-5/14\r\n' + '\r\n' + 'llo,\r\n' + '--%s--\r\n' % (boundary, boundary, boundary)) + self.assertBody(expected_body) + self.assertHeader('Content-Length') + + # Test "416 Requested Range Not Satisfiable" + self.getPage('/ranges/slice_file', [('Range', 'bytes=2300-2900')]) + self.assertStatus(416) + # "When this status code is returned for a byte-range request, + # the response SHOULD include a Content-Range entity-header + # field specifying the current length of the selected resource" + self.assertHeader('Content-Range', 'bytes */14') + elif cherrypy.server.protocol_version == 'HTTP/1.0': + # Test Range behavior with HTTP/1.0 request + self.getPage('/ranges/slice_file', [('Range', 'bytes=2-5')]) + self.assertStatus(200) + self.assertBody('Hello, world\r\n') + + def testFavicon(self): + # favicon.ico is served by staticfile. + icofilename = os.path.join(localDir, '../favicon.ico') + icofile = open(icofilename, 'rb') + data = icofile.read() + icofile.close() + + self.getPage('/favicon.ico') + self.assertBody(data) + + def skip_if_bad_cookies(self): + """ + cookies module fails to reject invalid cookies + https://github.com/cherrypy/cherrypy/issues/1405 + """ + cookies = sys.modules.get('http.cookies') + _is_legal_key = getattr(cookies, '_is_legal_key', lambda x: False) + if not _is_legal_key(','): + return + issue = 'http://bugs.python.org/issue26302' + tmpl = 'Broken cookies module ({issue})' + self.skip(tmpl.format(**locals())) + + def testCookies(self): + self.skip_if_bad_cookies() + + self.getPage('/cookies/single?name=First', + [('Cookie', 'First=Dinsdale;')]) + self.assertHeader('Set-Cookie', 'First=Dinsdale') + + self.getPage('/cookies/multiple?names=First&names=Last', + [('Cookie', 'First=Dinsdale; Last=Piranha;'), + ]) + self.assertHeader('Set-Cookie', 'First=Dinsdale') + self.assertHeader('Set-Cookie', 'Last=Piranha') + + self.getPage('/cookies/single?name=Something-With%2CComma', + [('Cookie', 'Something-With,Comma=some-value')]) + self.assertStatus(400) + + def testDefaultContentType(self): + self.getPage('/') + self.assertHeader('Content-Type', 'text/html;charset=utf-8') + self.getPage('/defct/plain') + self.getPage('/') + self.assertHeader('Content-Type', 'text/plain;charset=utf-8') + self.getPage('/defct/html') + + def test_multiple_headers(self): + self.getPage('/multiheader/header_list') + self.assertEqual( + [(k, v) for k, v in self.headers if k == 'WWW-Authenticate'], + [('WWW-Authenticate', 'Negotiate'), + ('WWW-Authenticate', 'Basic realm="foo"'), + ]) + self.getPage('/multiheader/commas') + self.assertHeader('WWW-Authenticate', 'Negotiate,Basic realm="foo"') + + def test_cherrypy_url(self): + # Input relative to current + self.getPage('/url/leaf?path_info=page1') + self.assertBody('%s/url/page1' % self.base()) + self.getPage('/url/?path_info=page1') + self.assertBody('%s/url/page1' % self.base()) + # Other host header + host = 'www.mydomain.example' + self.getPage('/url/leaf?path_info=page1', + headers=[('Host', host)]) + self.assertBody('%s://%s/url/page1' % (self.scheme, host)) + + # Input is 'absolute'; that is, relative to script_name + self.getPage('/url/leaf?path_info=/page1') + self.assertBody('%s/page1' % self.base()) + self.getPage('/url/?path_info=/page1') + self.assertBody('%s/page1' % self.base()) + + # Single dots + self.getPage('/url/leaf?path_info=./page1') + self.assertBody('%s/url/page1' % self.base()) + self.getPage('/url/leaf?path_info=other/./page1') + self.assertBody('%s/url/other/page1' % self.base()) + self.getPage('/url/?path_info=/other/./page1') + self.assertBody('%s/other/page1' % self.base()) + self.getPage('/url/?path_info=/other/././././page1') + self.assertBody('%s/other/page1' % self.base()) + + # Double dots + self.getPage('/url/leaf?path_info=../page1') + self.assertBody('%s/page1' % self.base()) + self.getPage('/url/leaf?path_info=other/../page1') + self.assertBody('%s/url/page1' % self.base()) + self.getPage('/url/leaf?path_info=/other/../page1') + self.assertBody('%s/page1' % self.base()) + self.getPage('/url/leaf?path_info=/other/../../../page1') + self.assertBody('%s/page1' % self.base()) + self.getPage('/url/leaf?path_info=/other/../../../../../page1') + self.assertBody('%s/page1' % self.base()) + + # qs param is not normalized as a path + self.getPage('/url/qs?qs=/other') + self.assertBody('%s/url/qs?/other' % self.base()) + self.getPage('/url/qs?qs=/other/../page1') + self.assertBody('%s/url/qs?/other/../page1' % self.base()) + self.getPage('/url/qs?qs=../page1') + self.assertBody('%s/url/qs?../page1' % self.base()) + self.getPage('/url/qs?qs=../../page1') + self.assertBody('%s/url/qs?../../page1' % self.base()) + + # Output relative to current path or script_name + self.getPage('/url/?path_info=page1&relative=True') + self.assertBody('page1') + self.getPage('/url/leaf?path_info=/page1&relative=True') + self.assertBody('../page1') + self.getPage('/url/leaf?path_info=page1&relative=True') + self.assertBody('page1') + self.getPage('/url/leaf?path_info=leaf/page1&relative=True') + self.assertBody('leaf/page1') + self.getPage('/url/leaf?path_info=../page1&relative=True') + self.assertBody('../page1') + self.getPage('/url/?path_info=other/../page1&relative=True') + self.assertBody('page1') + + # Output relative to / + self.getPage('/baseurl?path_info=ab&relative=True') + self.assertBody('ab') + # Output relative to / + self.getPage('/baseurl?path_info=/ab&relative=True') + self.assertBody('ab') + + # absolute-path references ("server-relative") + # Input relative to current + self.getPage('/url/leaf?path_info=page1&relative=server') + self.assertBody('/url/page1') + self.getPage('/url/?path_info=page1&relative=server') + self.assertBody('/url/page1') + # Input is 'absolute'; that is, relative to script_name + self.getPage('/url/leaf?path_info=/page1&relative=server') + self.assertBody('/page1') + self.getPage('/url/?path_info=/page1&relative=server') + self.assertBody('/page1') + + def test_expose_decorator(self): + # Test @expose + self.getPage('/expose_dec/no_call') + self.assertStatus(200) + self.assertBody('Mr E. R. Bradshaw') + + # Test @expose() + self.getPage('/expose_dec/call_empty') + self.assertStatus(200) + self.assertBody('Mrs. B.J. Smegma') + + # Test @expose("alias") + self.getPage('/expose_dec/call_alias') + self.assertStatus(200) + self.assertBody('Mr Nesbitt') + # Does the original name work? + self.getPage('/expose_dec/nesbitt') + self.assertStatus(200) + self.assertBody('Mr Nesbitt') + + # Test @expose(["alias1", "alias2"]) + self.getPage('/expose_dec/alias1') + self.assertStatus(200) + self.assertBody('Mr Ken Andrews') + self.getPage('/expose_dec/alias2') + self.assertStatus(200) + self.assertBody('Mr Ken Andrews') + # Does the original name work? + self.getPage('/expose_dec/andrews') + self.assertStatus(200) + self.assertBody('Mr Ken Andrews') + + # Test @expose(alias="alias") + self.getPage('/expose_dec/alias3') + self.assertStatus(200) + self.assertBody('Mr. and Mrs. Watson') + + +class ErrorTests(helper.CPWebCase): + + @staticmethod + def setup_server(): + def break_header(): + # Add a header after finalize that is invalid + cherrypy.serving.response.header_list.append((2, 3)) + cherrypy.tools.break_header = cherrypy.Tool( + 'on_end_resource', break_header) + + class Root: + + @cherrypy.expose + def index(self): + return 'hello' + + @cherrypy.config(**{'tools.break_header.on': True}) + def start_response_error(self): + return 'salud!' + + @cherrypy.expose + def stat(self, path): + with cherrypy.HTTPError.handle(OSError, 404): + os.stat(path) + + root = Root() + + cherrypy.tree.mount(root) + + def test_start_response_error(self): + self.getPage('/start_response_error') + self.assertStatus(500) + self.assertInBody( + 'TypeError: response.header_list key 2 is not a byte string.') + + def test_contextmanager(self): + self.getPage('/stat/missing') + self.assertStatus(404) + body_text = self.body.decode('utf-8') + assert ( + 'No such file or directory' in body_text or + 'cannot find the file specified' in body_text + ) + + +class TestBinding: + def test_bind_ephemeral_port(self): + """ + A server configured to bind to port 0 will bind to an ephemeral + port and indicate that port number on startup. + """ + cherrypy.config.reset() + bind_ephemeral_conf = { + 'server.socket_port': 0, + } + cherrypy.config.update(bind_ephemeral_conf) + cherrypy.engine.start() + assert cherrypy.server.bound_addr != cherrypy.server.bind_addr + _host, port = cherrypy.server.bound_addr + assert port > 0 + cherrypy.engine.stop() + assert cherrypy.server.bind_addr == cherrypy.server.bound_addr diff --git a/libraries/cherrypy/test/test_dynamicobjectmapping.py b/libraries/cherrypy/test/test_dynamicobjectmapping.py new file mode 100644 index 00000000..725a3ce0 --- /dev/null +++ b/libraries/cherrypy/test/test_dynamicobjectmapping.py @@ -0,0 +1,424 @@ +import six + +import cherrypy +from cherrypy.test import helper + +script_names = ['', '/foo', '/users/fred/blog', '/corp/blog'] + + +def setup_server(): + class SubSubRoot: + + @cherrypy.expose + def index(self): + return 'SubSubRoot index' + + @cherrypy.expose + def default(self, *args): + return 'SubSubRoot default' + + @cherrypy.expose + def handler(self): + return 'SubSubRoot handler' + + @cherrypy.expose + def dispatch(self): + return 'SubSubRoot dispatch' + + subsubnodes = { + '1': SubSubRoot(), + '2': SubSubRoot(), + } + + class SubRoot: + + @cherrypy.expose + def index(self): + return 'SubRoot index' + + @cherrypy.expose + def default(self, *args): + return 'SubRoot %s' % (args,) + + @cherrypy.expose + def handler(self): + return 'SubRoot handler' + + def _cp_dispatch(self, vpath): + return subsubnodes.get(vpath[0], None) + + subnodes = { + '1': SubRoot(), + '2': SubRoot(), + } + + class Root: + + @cherrypy.expose + def index(self): + return 'index' + + @cherrypy.expose + def default(self, *args): + return 'default %s' % (args,) + + @cherrypy.expose + def handler(self): + return 'handler' + + def _cp_dispatch(self, vpath): + return subnodes.get(vpath[0]) + + # ------------------------------------------------------------------------- + # DynamicNodeAndMethodDispatcher example. + # This example exposes a fairly naive HTTP api + class User(object): + + def __init__(self, id, name): + self.id = id + self.name = name + + def __unicode__(self): + return six.text_type(self.name) + + def __str__(self): + return str(self.name) + + user_lookup = { + 1: User(1, 'foo'), + 2: User(2, 'bar'), + } + + def make_user(name, id=None): + if not id: + id = max(*list(user_lookup.keys())) + 1 + user_lookup[id] = User(id, name) + return id + + @cherrypy.expose + class UserContainerNode(object): + + def POST(self, name): + """ + Allow the creation of a new Object + """ + return 'POST %d' % make_user(name) + + def GET(self): + return six.text_type(sorted(user_lookup.keys())) + + def dynamic_dispatch(self, vpath): + try: + id = int(vpath[0]) + except (ValueError, IndexError): + return None + return UserInstanceNode(id) + + @cherrypy.expose + class UserInstanceNode(object): + + def __init__(self, id): + self.id = id + self.user = user_lookup.get(id, None) + + # For all but PUT methods there MUST be a valid user identified + # by self.id + if not self.user and cherrypy.request.method != 'PUT': + raise cherrypy.HTTPError(404) + + def GET(self, *args, **kwargs): + """ + Return the appropriate representation of the instance. + """ + return six.text_type(self.user) + + def POST(self, name): + """ + Update the fields of the user instance. + """ + self.user.name = name + return 'POST %d' % self.user.id + + def PUT(self, name): + """ + Create a new user with the specified id, or edit it if it already + exists + """ + if self.user: + # Edit the current user + self.user.name = name + return 'PUT %d' % self.user.id + else: + # Make a new user with said attributes. + return 'PUT %d' % make_user(name, self.id) + + def DELETE(self): + """ + Delete the user specified at the id. + """ + id = self.user.id + del user_lookup[self.user.id] + del self.user + return 'DELETE %d' % id + + class ABHandler: + + class CustomDispatch: + + @cherrypy.expose + def index(self, a, b): + return 'custom' + + def _cp_dispatch(self, vpath): + """Make sure that if we don't pop anything from vpath, + processing still works. + """ + return self.CustomDispatch() + + @cherrypy.expose + def index(self, a, b=None): + body = ['a:' + str(a)] + if b is not None: + body.append(',b:' + str(b)) + return ''.join(body) + + @cherrypy.expose + def delete(self, a, b): + return 'deleting ' + str(a) + ' and ' + str(b) + + class IndexOnly: + + def _cp_dispatch(self, vpath): + """Make sure that popping ALL of vpath still shows the index + handler. + """ + while vpath: + vpath.pop() + return self + + @cherrypy.expose + def index(self): + return 'IndexOnly index' + + class DecoratedPopArgs: + + """Test _cp_dispatch with @cherrypy.popargs.""" + + @cherrypy.expose + def index(self): + return 'no params' + + @cherrypy.expose + def hi(self): + return "hi was not interpreted as 'a' param" + DecoratedPopArgs = cherrypy.popargs( + 'a', 'b', handler=ABHandler())(DecoratedPopArgs) + + class NonDecoratedPopArgs: + + """Test _cp_dispatch = cherrypy.popargs()""" + + _cp_dispatch = cherrypy.popargs('a') + + @cherrypy.expose + def index(self, a): + return 'index: ' + str(a) + + class ParameterizedHandler: + + """Special handler created for each request""" + + def __init__(self, a): + self.a = a + + @cherrypy.expose + def index(self): + if 'a' in cherrypy.request.params: + raise Exception( + 'Parameterized handler argument ended up in ' + 'request.params') + return self.a + + class ParameterizedPopArgs: + + """Test cherrypy.popargs() with a function call handler""" + ParameterizedPopArgs = cherrypy.popargs( + 'a', handler=ParameterizedHandler)(ParameterizedPopArgs) + + Root.decorated = DecoratedPopArgs() + Root.undecorated = NonDecoratedPopArgs() + Root.index_only = IndexOnly() + Root.parameter_test = ParameterizedPopArgs() + + Root.users = UserContainerNode() + + md = cherrypy.dispatch.MethodDispatcher('dynamic_dispatch') + for url in script_names: + conf = { + '/': { + 'user': (url or '/').split('/')[-2], + }, + '/users': { + 'request.dispatch': md + }, + } + cherrypy.tree.mount(Root(), url, conf) + + +class DynamicObjectMappingTest(helper.CPWebCase): + setup_server = staticmethod(setup_server) + + def testObjectMapping(self): + for url in script_names: + self.script_name = url + + self.getPage('/') + self.assertBody('index') + + self.getPage('/handler') + self.assertBody('handler') + + # Dynamic dispatch will succeed here for the subnodes + # so the subroot gets called + self.getPage('/1/') + self.assertBody('SubRoot index') + + self.getPage('/2/') + self.assertBody('SubRoot index') + + self.getPage('/1/handler') + self.assertBody('SubRoot handler') + + self.getPage('/2/handler') + self.assertBody('SubRoot handler') + + # Dynamic dispatch will fail here for the subnodes + # so the default gets called + self.getPage('/asdf/') + self.assertBody("default ('asdf',)") + + self.getPage('/asdf/asdf') + self.assertBody("default ('asdf', 'asdf')") + + self.getPage('/asdf/handler') + self.assertBody("default ('asdf', 'handler')") + + # Dynamic dispatch will succeed here for the subsubnodes + # so the subsubroot gets called + self.getPage('/1/1/') + self.assertBody('SubSubRoot index') + + self.getPage('/2/2/') + self.assertBody('SubSubRoot index') + + self.getPage('/1/1/handler') + self.assertBody('SubSubRoot handler') + + self.getPage('/2/2/handler') + self.assertBody('SubSubRoot handler') + + self.getPage('/2/2/dispatch') + self.assertBody('SubSubRoot dispatch') + + # The exposed dispatch will not be called as a dispatch + # method. + self.getPage('/2/2/foo/foo') + self.assertBody('SubSubRoot default') + + # Dynamic dispatch will fail here for the subsubnodes + # so the SubRoot gets called + self.getPage('/1/asdf/') + self.assertBody("SubRoot ('asdf',)") + + self.getPage('/1/asdf/asdf') + self.assertBody("SubRoot ('asdf', 'asdf')") + + self.getPage('/1/asdf/handler') + self.assertBody("SubRoot ('asdf', 'handler')") + + def testMethodDispatch(self): + # GET acts like a container + self.getPage('/users') + self.assertBody('[1, 2]') + self.assertHeader('Allow', 'GET, HEAD, POST') + + # POST to the container URI allows creation + self.getPage('/users', method='POST', body='name=baz') + self.assertBody('POST 3') + self.assertHeader('Allow', 'GET, HEAD, POST') + + # POST to a specific instanct URI results in a 404 + # as the resource does not exit. + self.getPage('/users/5', method='POST', body='name=baz') + self.assertStatus(404) + + # PUT to a specific instanct URI results in creation + self.getPage('/users/5', method='PUT', body='name=boris') + self.assertBody('PUT 5') + self.assertHeader('Allow', 'DELETE, GET, HEAD, POST, PUT') + + # GET acts like a container + self.getPage('/users') + self.assertBody('[1, 2, 3, 5]') + self.assertHeader('Allow', 'GET, HEAD, POST') + + test_cases = ( + (1, 'foo', 'fooupdated', 'DELETE, GET, HEAD, POST, PUT'), + (2, 'bar', 'barupdated', 'DELETE, GET, HEAD, POST, PUT'), + (3, 'baz', 'bazupdated', 'DELETE, GET, HEAD, POST, PUT'), + (5, 'boris', 'borisupdated', 'DELETE, GET, HEAD, POST, PUT'), + ) + for id, name, updatedname, headers in test_cases: + self.getPage('/users/%d' % id) + self.assertBody(name) + self.assertHeader('Allow', headers) + + # Make sure POSTs update already existings resources + self.getPage('/users/%d' % + id, method='POST', body='name=%s' % updatedname) + self.assertBody('POST %d' % id) + self.assertHeader('Allow', headers) + + # Make sure PUTs Update already existing resources. + self.getPage('/users/%d' % + id, method='PUT', body='name=%s' % updatedname) + self.assertBody('PUT %d' % id) + self.assertHeader('Allow', headers) + + # Make sure DELETES Remove already existing resources. + self.getPage('/users/%d' % id, method='DELETE') + self.assertBody('DELETE %d' % id) + self.assertHeader('Allow', headers) + + # GET acts like a container + self.getPage('/users') + self.assertBody('[]') + self.assertHeader('Allow', 'GET, HEAD, POST') + + def testVpathDispatch(self): + self.getPage('/decorated/') + self.assertBody('no params') + + self.getPage('/decorated/hi') + self.assertBody("hi was not interpreted as 'a' param") + + self.getPage('/decorated/yo/') + self.assertBody('a:yo') + + self.getPage('/decorated/yo/there/') + self.assertBody('a:yo,b:there') + + self.getPage('/decorated/yo/there/delete') + self.assertBody('deleting yo and there') + + self.getPage('/decorated/yo/there/handled_by_dispatch/') + self.assertBody('custom') + + self.getPage('/undecorated/blah/') + self.assertBody('index: blah') + + self.getPage('/index_only/a/b/c/d/e/f/g/') + self.assertBody('IndexOnly index') + + self.getPage('/parameter_test/argument2/') + self.assertBody('argument2') diff --git a/libraries/cherrypy/test/test_encoding.py b/libraries/cherrypy/test/test_encoding.py new file mode 100644 index 00000000..ab24ab93 --- /dev/null +++ b/libraries/cherrypy/test/test_encoding.py @@ -0,0 +1,426 @@ +# coding: utf-8 + +import gzip +import io +from unittest import mock + +from six.moves.http_client import IncompleteRead +from six.moves.urllib.parse import quote as url_quote + +import cherrypy +from cherrypy._cpcompat import ntob, ntou + +from cherrypy.test import helper + + +europoundUnicode = ntou('£', encoding='utf-8') +sing = ntou('毛泽东: Sing, Little Birdie?', encoding='utf-8') + +sing8 = sing.encode('utf-8') +sing16 = sing.encode('utf-16') + + +class EncodingTests(helper.CPWebCase): + + @staticmethod + def setup_server(): + class Root: + + @cherrypy.expose + def index(self, param): + assert param == europoundUnicode, '%r != %r' % ( + param, europoundUnicode) + yield europoundUnicode + + @cherrypy.expose + def mao_zedong(self): + return sing + + @cherrypy.expose + @cherrypy.config(**{'tools.encode.encoding': 'utf-8'}) + def utf8(self): + return sing8 + + @cherrypy.expose + def cookies_and_headers(self): + # if the headers have non-ascii characters and a cookie has + # any part which is unicode (even ascii), the response + # should not fail. + cherrypy.response.cookie['candy'] = 'bar' + cherrypy.response.cookie['candy']['domain'] = 'cherrypy.org' + cherrypy.response.headers[ + 'Some-Header'] = 'My d\xc3\xb6g has fleas' + return 'Any content' + + @cherrypy.expose + def reqparams(self, *args, **kwargs): + return b', '.join( + [': '.join((k, v)).encode('utf8') + for k, v in sorted(cherrypy.request.params.items())] + ) + + @cherrypy.expose + @cherrypy.config(**{ + 'tools.encode.text_only': False, + 'tools.encode.add_charset': True, + }) + def nontext(self, *args, **kwargs): + cherrypy.response.headers[ + 'Content-Type'] = 'application/binary' + return '\x00\x01\x02\x03' + + class GZIP: + + @cherrypy.expose + def index(self): + yield 'Hello, world' + + @cherrypy.expose + # Turn encoding off so the gzip tool is the one doing the collapse. + @cherrypy.config(**{'tools.encode.on': False}) + def noshow(self): + # Test for ticket #147, where yield showed no exceptions + # (content-encoding was still gzip even though traceback + # wasn't zipped). + raise IndexError() + yield 'Here be dragons' + + @cherrypy.expose + @cherrypy.config(**{'response.stream': True}) + def noshow_stream(self): + # Test for ticket #147, where yield showed no exceptions + # (content-encoding was still gzip even though traceback + # wasn't zipped). + raise IndexError() + yield 'Here be dragons' + + class Decode: + + @cherrypy.expose + @cherrypy.config(**{ + 'tools.decode.on': True, + 'tools.decode.default_encoding': ['utf-16'], + }) + def extra_charset(self, *args, **kwargs): + return ', '.join([': '.join((k, v)) + for k, v in cherrypy.request.params.items()]) + + @cherrypy.expose + @cherrypy.config(**{ + 'tools.decode.on': True, + 'tools.decode.encoding': 'utf-16', + }) + def force_charset(self, *args, **kwargs): + return ', '.join([': '.join((k, v)) + for k, v in cherrypy.request.params.items()]) + + root = Root() + root.gzip = GZIP() + root.decode = Decode() + cherrypy.tree.mount(root, config={'/gzip': {'tools.gzip.on': True}}) + + def test_query_string_decoding(self): + URI_TMPL = '/reqparams?q={q}' + + europoundUtf8_2_bytes = europoundUnicode.encode('utf-8') + europoundUtf8_2nd_byte = europoundUtf8_2_bytes[1:2] + + # Encoded utf8 query strings MUST be parsed correctly. + # Here, q is the POUND SIGN U+00A3 encoded in utf8 and then %HEX + self.getPage(URI_TMPL.format(q=url_quote(europoundUtf8_2_bytes))) + # The return value will be encoded as utf8. + self.assertBody(b'q: ' + europoundUtf8_2_bytes) + + # Query strings that are incorrectly encoded MUST raise 404. + # Here, q is the second byte of POUND SIGN U+A3 encoded in utf8 + # and then %HEX + # TODO: check whether this shouldn't raise 400 Bad Request instead + self.getPage(URI_TMPL.format(q=url_quote(europoundUtf8_2nd_byte))) + self.assertStatus(404) + self.assertErrorPage( + 404, + 'The given query string could not be processed. Query ' + "strings for this resource must be encoded with 'utf8'.") + + def test_urlencoded_decoding(self): + # Test the decoding of an application/x-www-form-urlencoded entity. + europoundUtf8 = europoundUnicode.encode('utf-8') + body = b'param=' + europoundUtf8 + self.getPage('/', + method='POST', + headers=[ + ('Content-Type', 'application/x-www-form-urlencoded'), + ('Content-Length', str(len(body))), + ], + body=body), + self.assertBody(europoundUtf8) + + # Encoded utf8 entities MUST be parsed and decoded correctly. + # Here, q is the POUND SIGN U+00A3 encoded in utf8 + body = b'q=\xc2\xa3' + self.getPage('/reqparams', method='POST', + headers=[( + 'Content-Type', 'application/x-www-form-urlencoded'), + ('Content-Length', str(len(body))), + ], + body=body), + self.assertBody(b'q: \xc2\xa3') + + # ...and in utf16, which is not in the default attempt_charsets list: + body = b'\xff\xfeq\x00=\xff\xfe\xa3\x00' + self.getPage('/reqparams', + method='POST', + headers=[ + ('Content-Type', + 'application/x-www-form-urlencoded;charset=utf-16'), + ('Content-Length', str(len(body))), + ], + body=body), + self.assertBody(b'q: \xc2\xa3') + + # Entities that are incorrectly encoded MUST raise 400. + # Here, q is the POUND SIGN U+00A3 encoded in utf16, but + # the Content-Type incorrectly labels it utf-8. + body = b'\xff\xfeq\x00=\xff\xfe\xa3\x00' + self.getPage('/reqparams', + method='POST', + headers=[ + ('Content-Type', + 'application/x-www-form-urlencoded;charset=utf-8'), + ('Content-Length', str(len(body))), + ], + body=body), + self.assertStatus(400) + self.assertErrorPage( + 400, + 'The request entity could not be decoded. The following charsets ' + "were attempted: ['utf-8']") + + def test_decode_tool(self): + # An extra charset should be tried first, and succeed if it matches. + # Here, we add utf-16 as a charset and pass a utf-16 body. + body = b'\xff\xfeq\x00=\xff\xfe\xa3\x00' + self.getPage('/decode/extra_charset', method='POST', + headers=[( + 'Content-Type', 'application/x-www-form-urlencoded'), + ('Content-Length', str(len(body))), + ], + body=body), + self.assertBody(b'q: \xc2\xa3') + + # An extra charset should be tried first, and continue to other default + # charsets if it doesn't match. + # Here, we add utf-16 as a charset but still pass a utf-8 body. + body = b'q=\xc2\xa3' + self.getPage('/decode/extra_charset', method='POST', + headers=[( + 'Content-Type', 'application/x-www-form-urlencoded'), + ('Content-Length', str(len(body))), + ], + body=body), + self.assertBody(b'q: \xc2\xa3') + + # An extra charset should error if force is True and it doesn't match. + # Here, we force utf-16 as a charset but still pass a utf-8 body. + body = b'q=\xc2\xa3' + self.getPage('/decode/force_charset', method='POST', + headers=[( + 'Content-Type', 'application/x-www-form-urlencoded'), + ('Content-Length', str(len(body))), + ], + body=body), + self.assertErrorPage( + 400, + 'The request entity could not be decoded. The following charsets ' + "were attempted: ['utf-16']") + + def test_multipart_decoding(self): + # Test the decoding of a multipart entity when the charset (utf16) is + # explicitly given. + body = ntob('\r\n'.join([ + '--X', + 'Content-Type: text/plain;charset=utf-16', + 'Content-Disposition: form-data; name="text"', + '', + '\xff\xfea\x00b\x00\x1c c\x00', + '--X', + 'Content-Type: text/plain;charset=utf-16', + 'Content-Disposition: form-data; name="submit"', + '', + '\xff\xfeC\x00r\x00e\x00a\x00t\x00e\x00', + '--X--' + ])) + self.getPage('/reqparams', method='POST', + headers=[( + 'Content-Type', 'multipart/form-data;boundary=X'), + ('Content-Length', str(len(body))), + ], + body=body), + self.assertBody(b'submit: Create, text: ab\xe2\x80\x9cc') + + @mock.patch('cherrypy._cpreqbody.Part.maxrambytes', 1) + def test_multipart_decoding_bigger_maxrambytes(self): + """ + Decoding of a multipart entity should also pass when + the entity is bigger than maxrambytes. See ticket #1352. + """ + self.test_multipart_decoding() + + def test_multipart_decoding_no_charset(self): + # Test the decoding of a multipart entity when the charset (utf8) is + # NOT explicitly given, but is in the list of charsets to attempt. + body = ntob('\r\n'.join([ + '--X', + 'Content-Disposition: form-data; name="text"', + '', + '\xe2\x80\x9c', + '--X', + 'Content-Disposition: form-data; name="submit"', + '', + 'Create', + '--X--' + ])) + self.getPage('/reqparams', method='POST', + headers=[( + 'Content-Type', 'multipart/form-data;boundary=X'), + ('Content-Length', str(len(body))), + ], + body=body), + self.assertBody(b'submit: Create, text: \xe2\x80\x9c') + + def test_multipart_decoding_no_successful_charset(self): + # Test the decoding of a multipart entity when the charset (utf16) is + # NOT explicitly given, and is NOT in the list of charsets to attempt. + body = ntob('\r\n'.join([ + '--X', + 'Content-Disposition: form-data; name="text"', + '', + '\xff\xfea\x00b\x00\x1c c\x00', + '--X', + 'Content-Disposition: form-data; name="submit"', + '', + '\xff\xfeC\x00r\x00e\x00a\x00t\x00e\x00', + '--X--' + ])) + self.getPage('/reqparams', method='POST', + headers=[( + 'Content-Type', 'multipart/form-data;boundary=X'), + ('Content-Length', str(len(body))), + ], + body=body), + self.assertStatus(400) + self.assertErrorPage( + 400, + 'The request entity could not be decoded. The following charsets ' + "were attempted: ['us-ascii', 'utf-8']") + + def test_nontext(self): + self.getPage('/nontext') + self.assertHeader('Content-Type', 'application/binary;charset=utf-8') + self.assertBody('\x00\x01\x02\x03') + + def testEncoding(self): + # Default encoding should be utf-8 + self.getPage('/mao_zedong') + self.assertBody(sing8) + + # Ask for utf-16. + self.getPage('/mao_zedong', [('Accept-Charset', 'utf-16')]) + self.assertHeader('Content-Type', 'text/html;charset=utf-16') + self.assertBody(sing16) + + # Ask for multiple encodings. ISO-8859-1 should fail, and utf-16 + # should be produced. + self.getPage('/mao_zedong', [('Accept-Charset', + 'iso-8859-1;q=1, utf-16;q=0.5')]) + self.assertBody(sing16) + + # The "*" value should default to our default_encoding, utf-8 + self.getPage('/mao_zedong', [('Accept-Charset', '*;q=1, utf-7;q=.2')]) + self.assertBody(sing8) + + # Only allow iso-8859-1, which should fail and raise 406. + self.getPage('/mao_zedong', [('Accept-Charset', 'iso-8859-1, *;q=0')]) + self.assertStatus('406 Not Acceptable') + self.assertInBody('Your client sent this Accept-Charset header: ' + 'iso-8859-1, *;q=0. We tried these charsets: ' + 'iso-8859-1.') + + # Ask for x-mac-ce, which should be unknown. See ticket #569. + self.getPage('/mao_zedong', [('Accept-Charset', + 'us-ascii, ISO-8859-1, x-mac-ce')]) + self.assertStatus('406 Not Acceptable') + self.assertInBody('Your client sent this Accept-Charset header: ' + 'us-ascii, ISO-8859-1, x-mac-ce. We tried these ' + 'charsets: ISO-8859-1, us-ascii, x-mac-ce.') + + # Test the 'encoding' arg to encode. + self.getPage('/utf8') + self.assertBody(sing8) + self.getPage('/utf8', [('Accept-Charset', 'us-ascii, ISO-8859-1')]) + self.assertStatus('406 Not Acceptable') + + # Test malformed quality value, which should raise 400. + self.getPage('/mao_zedong', [('Accept-Charset', + 'ISO-8859-1,utf-8;q=0.7,*;q=0.7)')]) + self.assertStatus('400 Bad Request') + + def testGzip(self): + zbuf = io.BytesIO() + zfile = gzip.GzipFile(mode='wb', fileobj=zbuf, compresslevel=9) + zfile.write(b'Hello, world') + zfile.close() + + self.getPage('/gzip/', headers=[('Accept-Encoding', 'gzip')]) + self.assertInBody(zbuf.getvalue()[:3]) + self.assertHeader('Vary', 'Accept-Encoding') + self.assertHeader('Content-Encoding', 'gzip') + + # Test when gzip is denied. + self.getPage('/gzip/', headers=[('Accept-Encoding', 'identity')]) + self.assertHeader('Vary', 'Accept-Encoding') + self.assertNoHeader('Content-Encoding') + self.assertBody('Hello, world') + + self.getPage('/gzip/', headers=[('Accept-Encoding', 'gzip;q=0')]) + self.assertHeader('Vary', 'Accept-Encoding') + self.assertNoHeader('Content-Encoding') + self.assertBody('Hello, world') + + # Test that trailing comma doesn't cause IndexError + # Ref: https://github.com/cherrypy/cherrypy/issues/988 + self.getPage('/gzip/', headers=[('Accept-Encoding', 'gzip,deflate,')]) + self.assertStatus(200) + self.assertNotInBody('IndexError') + + self.getPage('/gzip/', headers=[('Accept-Encoding', '*;q=0')]) + self.assertStatus(406) + self.assertNoHeader('Content-Encoding') + self.assertErrorPage(406, 'identity, gzip') + + # Test for ticket #147 + self.getPage('/gzip/noshow', headers=[('Accept-Encoding', 'gzip')]) + self.assertNoHeader('Content-Encoding') + self.assertStatus(500) + self.assertErrorPage(500, pattern='IndexError\n') + + # In this case, there's nothing we can do to deliver a + # readable page, since 1) the gzip header is already set, + # and 2) we may have already written some of the body. + # The fix is to never stream yields when using gzip. + if (cherrypy.server.protocol_version == 'HTTP/1.0' or + getattr(cherrypy.server, 'using_apache', False)): + self.getPage('/gzip/noshow_stream', + headers=[('Accept-Encoding', 'gzip')]) + self.assertHeader('Content-Encoding', 'gzip') + self.assertInBody('\x1f\x8b\x08\x00') + else: + # The wsgiserver will simply stop sending data, and the HTTP client + # will error due to an incomplete chunk-encoded stream. + self.assertRaises((ValueError, IncompleteRead), self.getPage, + '/gzip/noshow_stream', + headers=[('Accept-Encoding', 'gzip')]) + + def test_UnicodeHeaders(self): + self.getPage('/cookies_and_headers') + self.assertBody('Any content') diff --git a/libraries/cherrypy/test/test_etags.py b/libraries/cherrypy/test/test_etags.py new file mode 100644 index 00000000..293eb866 --- /dev/null +++ b/libraries/cherrypy/test/test_etags.py @@ -0,0 +1,84 @@ +import cherrypy +from cherrypy._cpcompat import ntou +from cherrypy.test import helper + + +class ETagTest(helper.CPWebCase): + + @staticmethod + def setup_server(): + class Root: + + @cherrypy.expose + def resource(self): + return 'Oh wah ta goo Siam.' + + @cherrypy.expose + def fail(self, code): + code = int(code) + if 300 <= code <= 399: + raise cherrypy.HTTPRedirect([], code) + else: + raise cherrypy.HTTPError(code) + + @cherrypy.expose + # In Python 3, tools.encode is on by default + @cherrypy.config(**{'tools.encode.on': True}) + def unicoded(self): + return ntou('I am a \u1ee4nicode string.', 'escape') + + conf = {'/': {'tools.etags.on': True, + 'tools.etags.autotags': True, + }} + cherrypy.tree.mount(Root(), config=conf) + + def test_etags(self): + self.getPage('/resource') + self.assertStatus('200 OK') + self.assertHeader('Content-Type', 'text/html;charset=utf-8') + self.assertBody('Oh wah ta goo Siam.') + etag = self.assertHeader('ETag') + + # Test If-Match (both valid and invalid) + self.getPage('/resource', headers=[('If-Match', etag)]) + self.assertStatus('200 OK') + self.getPage('/resource', headers=[('If-Match', '*')]) + self.assertStatus('200 OK') + self.getPage('/resource', headers=[('If-Match', '*')], method='POST') + self.assertStatus('200 OK') + self.getPage('/resource', headers=[('If-Match', 'a bogus tag')]) + self.assertStatus('412 Precondition Failed') + + # Test If-None-Match (both valid and invalid) + self.getPage('/resource', headers=[('If-None-Match', etag)]) + self.assertStatus(304) + self.getPage('/resource', method='POST', + headers=[('If-None-Match', etag)]) + self.assertStatus('412 Precondition Failed') + self.getPage('/resource', headers=[('If-None-Match', '*')]) + self.assertStatus(304) + self.getPage('/resource', headers=[('If-None-Match', 'a bogus tag')]) + self.assertStatus('200 OK') + + def test_errors(self): + self.getPage('/resource') + self.assertStatus(200) + etag = self.assertHeader('ETag') + + # Test raising errors in page handler + self.getPage('/fail/412', headers=[('If-Match', etag)]) + self.assertStatus(412) + self.getPage('/fail/304', headers=[('If-Match', etag)]) + self.assertStatus(304) + self.getPage('/fail/412', headers=[('If-None-Match', '*')]) + self.assertStatus(412) + self.getPage('/fail/304', headers=[('If-None-Match', '*')]) + self.assertStatus(304) + + def test_unicode_body(self): + self.getPage('/unicoded') + self.assertStatus(200) + etag1 = self.assertHeader('ETag') + self.getPage('/unicoded', headers=[('If-Match', etag1)]) + self.assertStatus(200) + self.assertHeader('ETag', etag1) diff --git a/libraries/cherrypy/test/test_http.py b/libraries/cherrypy/test/test_http.py new file mode 100644 index 00000000..0899d4d0 --- /dev/null +++ b/libraries/cherrypy/test/test_http.py @@ -0,0 +1,307 @@ +# coding: utf-8 +"""Tests for managing HTTP issues (malformed requests, etc).""" + +import errno +import mimetypes +import socket +import sys +from unittest import mock + +import six +from six.moves.http_client import HTTPConnection +from six.moves import urllib + +import cherrypy +from cherrypy._cpcompat import HTTPSConnection, quote + +from cherrypy.test import helper + + +def is_ascii(text): + """ + Return True if the text encodes as ascii. + """ + try: + text.encode('ascii') + return True + except Exception: + pass + return False + + +def encode_filename(filename): + """ + Given a filename to be used in a multipart/form-data, + encode the name. Return the key and encoded filename. + """ + if is_ascii(filename): + return 'filename', '"{filename}"'.format(**locals()) + encoded = quote(filename, encoding='utf-8') + return 'filename*', "'".join(( + 'UTF-8', + '', # lang + encoded, + )) + + +def encode_multipart_formdata(files): + """Return (content_type, body) ready for httplib.HTTP instance. + + files: a sequence of (name, filename, value) tuples for multipart uploads. + filename can be a string or a tuple ('filename string', 'encoding') + """ + BOUNDARY = '________ThIs_Is_tHe_bouNdaRY_$' + L = [] + for key, filename, value in files: + L.append('--' + BOUNDARY) + + fn_key, encoded = encode_filename(filename) + tmpl = \ + 'Content-Disposition: form-data; name="{key}"; {fn_key}={encoded}' + L.append(tmpl.format(**locals())) + ct = mimetypes.guess_type(filename)[0] or 'application/octet-stream' + L.append('Content-Type: %s' % ct) + L.append('') + L.append(value) + L.append('--' + BOUNDARY + '--') + L.append('') + body = '\r\n'.join(L) + content_type = 'multipart/form-data; boundary=%s' % BOUNDARY + return content_type, body + + +class HTTPTests(helper.CPWebCase): + + def make_connection(self): + if self.scheme == 'https': + return HTTPSConnection('%s:%s' % (self.interface(), self.PORT)) + else: + return HTTPConnection('%s:%s' % (self.interface(), self.PORT)) + + @staticmethod + def setup_server(): + class Root: + + @cherrypy.expose + def index(self, *args, **kwargs): + return 'Hello world!' + + @cherrypy.expose + @cherrypy.config(**{'request.process_request_body': False}) + def no_body(self, *args, **kwargs): + return 'Hello world!' + + @cherrypy.expose + def post_multipart(self, file): + """Return a summary ("a * 65536\nb * 65536") of the uploaded + file. + """ + contents = file.file.read() + summary = [] + curchar = None + count = 0 + for c in contents: + if c == curchar: + count += 1 + else: + if count: + if six.PY3: + curchar = chr(curchar) + summary.append('%s * %d' % (curchar, count)) + count = 1 + curchar = c + if count: + if six.PY3: + curchar = chr(curchar) + summary.append('%s * %d' % (curchar, count)) + return ', '.join(summary) + + @cherrypy.expose + def post_filename(self, myfile): + '''Return the name of the file which was uploaded.''' + return myfile.filename + + cherrypy.tree.mount(Root()) + cherrypy.config.update({'server.max_request_body_size': 30000000}) + + def test_no_content_length(self): + # "The presence of a message-body in a request is signaled by the + # inclusion of a Content-Length or Transfer-Encoding header field in + # the request's message-headers." + # + # Send a message with neither header and no body. Even though + # the request is of method POST, this should be OK because we set + # request.process_request_body to False for our handler. + c = self.make_connection() + c.request('POST', '/no_body') + response = c.getresponse() + self.body = response.fp.read() + self.status = str(response.status) + self.assertStatus(200) + self.assertBody(b'Hello world!') + + # Now send a message that has no Content-Length, but does send a body. + # Verify that CP times out the socket and responds + # with 411 Length Required. + if self.scheme == 'https': + c = HTTPSConnection('%s:%s' % (self.interface(), self.PORT)) + else: + c = HTTPConnection('%s:%s' % (self.interface(), self.PORT)) + + # `_get_content_length` is needed for Python 3.6+ + with mock.patch.object( + c, + '_get_content_length', + lambda body, method: None, + create=True): + # `_set_content_length` is needed for Python 2.7-3.5 + with mock.patch.object(c, '_set_content_length', create=True): + c.request('POST', '/') + + response = c.getresponse() + self.body = response.fp.read() + self.status = str(response.status) + self.assertStatus(411) + + def test_post_multipart(self): + alphabet = 'abcdefghijklmnopqrstuvwxyz' + # generate file contents for a large post + contents = ''.join([c * 65536 for c in alphabet]) + + # encode as multipart form data + files = [('file', 'file.txt', contents)] + content_type, body = encode_multipart_formdata(files) + body = body.encode('Latin-1') + + # post file + c = self.make_connection() + c.putrequest('POST', '/post_multipart') + c.putheader('Content-Type', content_type) + c.putheader('Content-Length', str(len(body))) + c.endheaders() + c.send(body) + + response = c.getresponse() + self.body = response.fp.read() + self.status = str(response.status) + self.assertStatus(200) + parts = ['%s * 65536' % ch for ch in alphabet] + self.assertBody(', '.join(parts)) + + def test_post_filename_with_special_characters(self): + '''Testing that we can handle filenames with special characters. This + was reported as a bug in: + https://github.com/cherrypy/cherrypy/issues/1146/ + https://github.com/cherrypy/cherrypy/issues/1397/ + https://github.com/cherrypy/cherrypy/issues/1694/ + ''' + # We'll upload a bunch of files with differing names. + fnames = [ + 'boop.csv', 'foo, bar.csv', 'bar, xxxx.csv', 'file"name.csv', + 'file;name.csv', 'file; name.csv', u'test_łóąä.txt', + ] + for fname in fnames: + files = [('myfile', fname, 'yunyeenyunyue')] + content_type, body = encode_multipart_formdata(files) + body = body.encode('Latin-1') + + # post file + c = self.make_connection() + c.putrequest('POST', '/post_filename') + c.putheader('Content-Type', content_type) + c.putheader('Content-Length', str(len(body))) + c.endheaders() + c.send(body) + + response = c.getresponse() + self.body = response.fp.read() + self.status = str(response.status) + self.assertStatus(200) + self.assertBody(fname) + + def test_malformed_request_line(self): + if getattr(cherrypy.server, 'using_apache', False): + return self.skip('skipped due to known Apache differences...') + + # Test missing version in Request-Line + c = self.make_connection() + c._output(b'geT /') + c._send_output() + if hasattr(c, 'strict'): + response = c.response_class(c.sock, strict=c.strict, method='GET') + else: + # Python 3.2 removed the 'strict' feature, saying: + # "http.client now always assumes HTTP/1.x compliant servers." + response = c.response_class(c.sock, method='GET') + response.begin() + self.assertEqual(response.status, 400) + self.assertEqual(response.fp.read(22), b'Malformed Request-Line') + c.close() + + def test_request_line_split_issue_1220(self): + params = { + 'intervenant-entreprise-evenement_classaction': + 'evenement-mailremerciements', + '_path': 'intervenant-entreprise-evenement', + 'intervenant-entreprise-evenement_action-id': 19404, + 'intervenant-entreprise-evenement_id': 19404, + 'intervenant-entreprise_id': 28092, + } + Request_URI = '/index?' + urllib.parse.urlencode(params) + self.assertEqual(len('GET %s HTTP/1.1\r\n' % Request_URI), 256) + self.getPage(Request_URI) + self.assertBody('Hello world!') + + def test_malformed_header(self): + c = self.make_connection() + c.putrequest('GET', '/') + c.putheader('Content-Type', 'text/plain') + # See https://github.com/cherrypy/cherrypy/issues/941 + c._output(b're, 1.2.3.4#015#012') + c.endheaders() + + response = c.getresponse() + self.status = str(response.status) + self.assertStatus(400) + self.body = response.fp.read(20) + self.assertBody('Illegal header line.') + + def test_http_over_https(self): + if self.scheme != 'https': + return self.skip('skipped (not running HTTPS)... ') + + # Try connecting without SSL. + conn = HTTPConnection('%s:%s' % (self.interface(), self.PORT)) + conn.putrequest('GET', '/', skip_host=True) + conn.putheader('Host', self.HOST) + conn.endheaders() + response = conn.response_class(conn.sock, method='GET') + try: + response.begin() + self.assertEqual(response.status, 400) + self.body = response.read() + self.assertBody('The client sent a plain HTTP request, but this ' + 'server only speaks HTTPS on this port.') + except socket.error: + e = sys.exc_info()[1] + # "Connection reset by peer" is also acceptable. + if e.errno != errno.ECONNRESET: + raise + + def test_garbage_in(self): + # Connect without SSL regardless of server.scheme + c = HTTPConnection('%s:%s' % (self.interface(), self.PORT)) + c._output(b'gjkgjklsgjklsgjkljklsg') + c._send_output() + response = c.response_class(c.sock, method='GET') + try: + response.begin() + self.assertEqual(response.status, 400) + self.assertEqual(response.fp.read(22), + b'Malformed Request-Line') + c.close() + except socket.error: + e = sys.exc_info()[1] + # "Connection reset by peer" is also acceptable. + if e.errno != errno.ECONNRESET: + raise diff --git a/libraries/cherrypy/test/test_httputil.py b/libraries/cherrypy/test/test_httputil.py new file mode 100644 index 00000000..656b8a3d --- /dev/null +++ b/libraries/cherrypy/test/test_httputil.py @@ -0,0 +1,80 @@ +"""Test helpers from ``cherrypy.lib.httputil`` module.""" +import pytest +from six.moves import http_client + +from cherrypy.lib import httputil + + +@pytest.mark.parametrize( + 'script_name,path_info,expected_url', + [ + ('/sn/', '/pi/', '/sn/pi/'), + ('/sn/', '/pi', '/sn/pi'), + ('/sn/', '/', '/sn/'), + ('/sn/', '', '/sn/'), + ('/sn', '/pi/', '/sn/pi/'), + ('/sn', '/pi', '/sn/pi'), + ('/sn', '/', '/sn/'), + ('/sn', '', '/sn'), + ('/', '/pi/', '/pi/'), + ('/', '/pi', '/pi'), + ('/', '/', '/'), + ('/', '', '/'), + ('', '/pi/', '/pi/'), + ('', '/pi', '/pi'), + ('', '/', '/'), + ('', '', '/'), + ] +) +def test_urljoin(script_name, path_info, expected_url): + """Test all slash+atom combinations for SCRIPT_NAME and PATH_INFO.""" + actual_url = httputil.urljoin(script_name, path_info) + assert actual_url == expected_url + + +EXPECTED_200 = (200, 'OK', 'Request fulfilled, document follows') +EXPECTED_500 = ( + 500, + 'Internal Server Error', + 'The server encountered an unexpected condition which ' + 'prevented it from fulfilling the request.', +) +EXPECTED_404 = (404, 'Not Found', 'Nothing matches the given URI') +EXPECTED_444 = (444, 'Non-existent reason', '') + + +@pytest.mark.parametrize( + 'status,expected_status', + [ + (None, EXPECTED_200), + (200, EXPECTED_200), + ('500', EXPECTED_500), + (http_client.NOT_FOUND, EXPECTED_404), + ('444 Non-existent reason', EXPECTED_444), + ] +) +def test_valid_status(status, expected_status): + """Check valid int, string and http_client-constants + statuses processing.""" + assert httputil.valid_status(status) == expected_status + + +@pytest.mark.parametrize( + 'status_code,error_msg', + [ + ('hey', "Illegal response status from server ('hey' is non-numeric)."), + ( + {'hey': 'hi'}, + 'Illegal response status from server ' + "({'hey': 'hi'} is non-numeric).", + ), + (1, 'Illegal response status from server (1 is out of range).'), + (600, 'Illegal response status from server (600 is out of range).'), + ] +) +def test_invalid_status(status_code, error_msg): + """Check that invalid status cause certain errors.""" + with pytest.raises(ValueError) as excinfo: + httputil.valid_status(status_code) + + assert error_msg in str(excinfo) diff --git a/libraries/cherrypy/test/test_iterator.py b/libraries/cherrypy/test/test_iterator.py new file mode 100644 index 00000000..92f08e7c --- /dev/null +++ b/libraries/cherrypy/test/test_iterator.py @@ -0,0 +1,196 @@ +import six + +import cherrypy +from cherrypy.test import helper + + +class IteratorBase(object): + + created = 0 + datachunk = 'butternut squash' * 256 + + @classmethod + def incr(cls): + cls.created += 1 + + @classmethod + def decr(cls): + cls.created -= 1 + + +class OurGenerator(IteratorBase): + + def __iter__(self): + self.incr() + try: + for i in range(1024): + yield self.datachunk + finally: + self.decr() + + +class OurIterator(IteratorBase): + + started = False + closed_off = False + count = 0 + + def increment(self): + self.incr() + + def decrement(self): + if not self.closed_off: + self.closed_off = True + self.decr() + + def __iter__(self): + return self + + def __next__(self): + if not self.started: + self.started = True + self.increment() + self.count += 1 + if self.count > 1024: + raise StopIteration + return self.datachunk + + next = __next__ + + def __del__(self): + self.decrement() + + +class OurClosableIterator(OurIterator): + + def close(self): + self.decrement() + + +class OurNotClosableIterator(OurIterator): + + # We can't close something which requires an additional argument. + def close(self, somearg): + self.decrement() + + +class OurUnclosableIterator(OurIterator): + close = 'close' # not callable! + + +class IteratorTest(helper.CPWebCase): + + @staticmethod + def setup_server(): + + class Root(object): + + @cherrypy.expose + def count(self, clsname): + cherrypy.response.headers['Content-Type'] = 'text/plain' + return six.text_type(globals()[clsname].created) + + @cherrypy.expose + def getall(self, clsname): + cherrypy.response.headers['Content-Type'] = 'text/plain' + return globals()[clsname]() + + @cherrypy.expose + @cherrypy.config(**{'response.stream': True}) + def stream(self, clsname): + return self.getall(clsname) + + cherrypy.tree.mount(Root()) + + def test_iterator(self): + try: + self._test_iterator() + except Exception: + 'Test fails intermittently. See #1419' + + def _test_iterator(self): + if cherrypy.server.protocol_version != 'HTTP/1.1': + return self.skip() + + self.PROTOCOL = 'HTTP/1.1' + + # Check the counts of all the classes, they should be zero. + closables = ['OurClosableIterator', 'OurGenerator'] + unclosables = ['OurUnclosableIterator', 'OurNotClosableIterator'] + all_classes = closables + unclosables + + import random + random.shuffle(all_classes) + + for clsname in all_classes: + self.getPage('/count/' + clsname) + self.assertStatus(200) + self.assertBody('0') + + # We should also be able to read the entire content body + # successfully, though we don't need to, we just want to + # check the header. + for clsname in all_classes: + itr_conn = self.get_conn() + itr_conn.putrequest('GET', '/getall/' + clsname) + itr_conn.endheaders() + response = itr_conn.getresponse() + self.assertEqual(response.status, 200) + headers = response.getheaders() + for header_name, header_value in headers: + if header_name.lower() == 'content-length': + expected = six.text_type(1024 * 16 * 256) + assert header_value == expected, header_value + break + else: + raise AssertionError('No Content-Length header found') + + # As the response should be fully consumed by CherryPy + # before sending back, the count should still be at zero + # by the time the response has been sent. + self.getPage('/count/' + clsname) + self.assertStatus(200) + self.assertBody('0') + + # Now we do the same check with streaming - some classes will + # be automatically closed, while others cannot. + stream_counts = {} + for clsname in all_classes: + itr_conn = self.get_conn() + itr_conn.putrequest('GET', '/stream/' + clsname) + itr_conn.endheaders() + response = itr_conn.getresponse() + self.assertEqual(response.status, 200) + response.fp.read(65536) + + # Let's check the count - this should always be one. + self.getPage('/count/' + clsname) + self.assertBody('1') + + # Now if we close the connection, the count should go back + # to zero. + itr_conn.close() + self.getPage('/count/' + clsname) + + # If this is a response which should be easily closed, then + # we will test to see if the value has gone back down to + # zero. + if clsname in closables: + + # Sometimes we try to get the answer too quickly - we + # will wait for 100 ms before asking again if we didn't + # get the answer we wanted. + if self.body != '0': + import time + time.sleep(0.1) + self.getPage('/count/' + clsname) + + stream_counts[clsname] = int(self.body) + + # Check that we closed off the classes which should provide + # easy mechanisms for doing so. + for clsname in closables: + assert stream_counts[clsname] == 0, ( + 'did not close off stream response correctly, expected ' + 'count of zero for %s: %s' % (clsname, stream_counts) + ) diff --git a/libraries/cherrypy/test/test_json.py b/libraries/cherrypy/test/test_json.py new file mode 100644 index 00000000..1585f6e6 --- /dev/null +++ b/libraries/cherrypy/test/test_json.py @@ -0,0 +1,102 @@ +import cherrypy +from cherrypy.test import helper + +from cherrypy._cpcompat import json + + +json_out = cherrypy.config(**{'tools.json_out.on': True}) +json_in = cherrypy.config(**{'tools.json_in.on': True}) + + +class JsonTest(helper.CPWebCase): + + @staticmethod + def setup_server(): + class Root(object): + + @cherrypy.expose + def plain(self): + return 'hello' + + @cherrypy.expose + @json_out + def json_string(self): + return 'hello' + + @cherrypy.expose + @json_out + def json_list(self): + return ['a', 'b', 42] + + @cherrypy.expose + @json_out + def json_dict(self): + return {'answer': 42} + + @cherrypy.expose + @json_in + def json_post(self): + if cherrypy.request.json == [13, 'c']: + return 'ok' + else: + return 'nok' + + @cherrypy.expose + @json_out + @cherrypy.config(**{'tools.caching.on': True}) + def json_cached(self): + return 'hello there' + + root = Root() + cherrypy.tree.mount(root) + + def test_json_output(self): + if json is None: + self.skip('json not found ') + return + + self.getPage('/plain') + self.assertBody('hello') + + self.getPage('/json_string') + self.assertBody('"hello"') + + self.getPage('/json_list') + self.assertBody('["a", "b", 42]') + + self.getPage('/json_dict') + self.assertBody('{"answer": 42}') + + def test_json_input(self): + if json is None: + self.skip('json not found ') + return + + body = '[13, "c"]' + headers = [('Content-Type', 'application/json'), + ('Content-Length', str(len(body)))] + self.getPage('/json_post', method='POST', headers=headers, body=body) + self.assertBody('ok') + + body = '[13, "c"]' + headers = [('Content-Type', 'text/plain'), + ('Content-Length', str(len(body)))] + self.getPage('/json_post', method='POST', headers=headers, body=body) + self.assertStatus(415, 'Expected an application/json content type') + + body = '[13, -]' + headers = [('Content-Type', 'application/json'), + ('Content-Length', str(len(body)))] + self.getPage('/json_post', method='POST', headers=headers, body=body) + self.assertStatus(400, 'Invalid JSON document') + + def test_cached(self): + if json is None: + self.skip('json not found ') + return + + self.getPage('/json_cached') + self.assertStatus(200, '"hello"') + + self.getPage('/json_cached') # 2'nd time to hit cache + self.assertStatus(200, '"hello"') diff --git a/libraries/cherrypy/test/test_logging.py b/libraries/cherrypy/test/test_logging.py new file mode 100644 index 00000000..c4948c20 --- /dev/null +++ b/libraries/cherrypy/test/test_logging.py @@ -0,0 +1,209 @@ +"""Basic tests for the CherryPy core: request handling.""" + +import os +from unittest import mock + +import six + +import cherrypy +from cherrypy._cpcompat import ntou +from cherrypy.test import helper, logtest + +localDir = os.path.dirname(__file__) +access_log = os.path.join(localDir, 'access.log') +error_log = os.path.join(localDir, 'error.log') + +# Some unicode strings. +tartaros = ntou('\u03a4\u1f71\u03c1\u03c4\u03b1\u03c1\u03bf\u03c2', 'escape') +erebos = ntou('\u0388\u03c1\u03b5\u03b2\u03bf\u03c2.com', 'escape') + + +def setup_server(): + class Root: + + @cherrypy.expose + def index(self): + return 'hello' + + @cherrypy.expose + def uni_code(self): + cherrypy.request.login = tartaros + cherrypy.request.remote.name = erebos + + @cherrypy.expose + def slashes(self): + cherrypy.request.request_line = r'GET /slashed\path HTTP/1.1' + + @cherrypy.expose + def whitespace(self): + # User-Agent = "User-Agent" ":" 1*( product | comment ) + # comment = "(" *( ctext | quoted-pair | comment ) ")" + # ctext = <any TEXT excluding "(" and ")"> + # TEXT = <any OCTET except CTLs, but including LWS> + # LWS = [CRLF] 1*( SP | HT ) + cherrypy.request.headers['User-Agent'] = 'Browzuh (1.0\r\n\t\t.3)' + + @cherrypy.expose + def as_string(self): + return 'content' + + @cherrypy.expose + def as_yield(self): + yield 'content' + + @cherrypy.expose + @cherrypy.config(**{'tools.log_tracebacks.on': True}) + def error(self): + raise ValueError() + + root = Root() + + cherrypy.config.update({ + 'log.error_file': error_log, + 'log.access_file': access_log, + }) + cherrypy.tree.mount(root) + + +class AccessLogTests(helper.CPWebCase, logtest.LogCase): + setup_server = staticmethod(setup_server) + + logfile = access_log + + def testNormalReturn(self): + self.markLog() + self.getPage('/as_string', + headers=[('Referer', 'http://www.cherrypy.org/'), + ('User-Agent', 'Mozilla/5.0')]) + self.assertBody('content') + self.assertStatus(200) + + intro = '%s - - [' % self.interface() + + self.assertLog(-1, intro) + + if [k for k, v in self.headers if k.lower() == 'content-length']: + self.assertLog(-1, '] "GET %s/as_string HTTP/1.1" 200 7 ' + '"http://www.cherrypy.org/" "Mozilla/5.0"' + % self.prefix()) + else: + self.assertLog(-1, '] "GET %s/as_string HTTP/1.1" 200 - ' + '"http://www.cherrypy.org/" "Mozilla/5.0"' + % self.prefix()) + + def testNormalYield(self): + self.markLog() + self.getPage('/as_yield') + self.assertBody('content') + self.assertStatus(200) + + intro = '%s - - [' % self.interface() + + self.assertLog(-1, intro) + if [k for k, v in self.headers if k.lower() == 'content-length']: + self.assertLog(-1, '] "GET %s/as_yield HTTP/1.1" 200 7 "" ""' % + self.prefix()) + else: + self.assertLog(-1, '] "GET %s/as_yield HTTP/1.1" 200 - "" ""' + % self.prefix()) + + @mock.patch( + 'cherrypy._cplogging.LogManager.access_log_format', + '{h} {l} {u} {t} "{r}" {s} {b} "{f}" "{a}" {o}' + if six.PY3 else + '%(h)s %(l)s %(u)s %(t)s "%(r)s" %(s)s %(b)s "%(f)s" "%(a)s" %(o)s' + ) + def testCustomLogFormat(self): + """Test a customized access_log_format string, which is a + feature of _cplogging.LogManager.access().""" + self.markLog() + self.getPage('/as_string', headers=[('Referer', 'REFERER'), + ('User-Agent', 'USERAGENT'), + ('Host', 'HOST')]) + self.assertLog(-1, '%s - - [' % self.interface()) + self.assertLog(-1, '] "GET /as_string HTTP/1.1" ' + '200 7 "REFERER" "USERAGENT" HOST') + + @mock.patch( + 'cherrypy._cplogging.LogManager.access_log_format', + '{h} {l} {u} {z} "{r}" {s} {b} "{f}" "{a}" {o}' + if six.PY3 else + '%(h)s %(l)s %(u)s %(z)s "%(r)s" %(s)s %(b)s "%(f)s" "%(a)s" %(o)s' + ) + def testTimezLogFormat(self): + """Test a customized access_log_format string, which is a + feature of _cplogging.LogManager.access().""" + self.markLog() + + expected_time = str(cherrypy._cplogging.LazyRfc3339UtcTime()) + with mock.patch( + 'cherrypy._cplogging.LazyRfc3339UtcTime', + lambda: expected_time): + self.getPage('/as_string', headers=[('Referer', 'REFERER'), + ('User-Agent', 'USERAGENT'), + ('Host', 'HOST')]) + + self.assertLog(-1, '%s - - ' % self.interface()) + self.assertLog(-1, expected_time) + self.assertLog(-1, ' "GET /as_string HTTP/1.1" ' + '200 7 "REFERER" "USERAGENT" HOST') + + @mock.patch( + 'cherrypy._cplogging.LogManager.access_log_format', + '{i}' if six.PY3 else '%(i)s' + ) + def testUUIDv4ParameterLogFormat(self): + """Test rendering of UUID4 within access log.""" + self.markLog() + self.getPage('/as_string') + self.assertValidUUIDv4() + + def testEscapedOutput(self): + # Test unicode in access log pieces. + self.markLog() + self.getPage('/uni_code') + self.assertStatus(200) + if six.PY3: + # The repr of a bytestring in six.PY3 includes a b'' prefix + self.assertLog(-1, repr(tartaros.encode('utf8'))[2:-1]) + else: + self.assertLog(-1, repr(tartaros.encode('utf8'))[1:-1]) + # Test the erebos value. Included inline for your enlightenment. + # Note the 'r' prefix--those backslashes are literals. + self.assertLog(-1, r'\xce\x88\xcf\x81\xce\xb5\xce\xb2\xce\xbf\xcf\x82') + + # Test backslashes in output. + self.markLog() + self.getPage('/slashes') + self.assertStatus(200) + if six.PY3: + self.assertLog(-1, b'"GET /slashed\\path HTTP/1.1"') + else: + self.assertLog(-1, r'"GET /slashed\\path HTTP/1.1"') + + # Test whitespace in output. + self.markLog() + self.getPage('/whitespace') + self.assertStatus(200) + # Again, note the 'r' prefix. + self.assertLog(-1, r'"Browzuh (1.0\r\n\t\t.3)"') + + +class ErrorLogTests(helper.CPWebCase, logtest.LogCase): + setup_server = staticmethod(setup_server) + + logfile = error_log + + def testTracebacks(self): + # Test that tracebacks get written to the error log. + self.markLog() + ignore = helper.webtest.ignored_exceptions + ignore.append(ValueError) + try: + self.getPage('/error') + self.assertInBody('raise ValueError()') + self.assertLog(0, 'HTTP') + self.assertLog(1, 'Traceback (most recent call last):') + self.assertLog(-2, 'raise ValueError()') + finally: + ignore.pop() diff --git a/libraries/cherrypy/test/test_mime.py b/libraries/cherrypy/test/test_mime.py new file mode 100644 index 00000000..ef35d10e --- /dev/null +++ b/libraries/cherrypy/test/test_mime.py @@ -0,0 +1,134 @@ +"""Tests for various MIME issues, including the safe_multipart Tool.""" + +import cherrypy +from cherrypy._cpcompat import ntou +from cherrypy.test import helper + + +def setup_server(): + + class Root: + + @cherrypy.expose + def multipart(self, parts): + return repr(parts) + + @cherrypy.expose + def multipart_form_data(self, **kwargs): + return repr(list(sorted(kwargs.items()))) + + @cherrypy.expose + def flashupload(self, Filedata, Upload, Filename): + return ('Upload: %s, Filename: %s, Filedata: %r' % + (Upload, Filename, Filedata.file.read())) + + cherrypy.config.update({'server.max_request_body_size': 0}) + cherrypy.tree.mount(Root()) + + +# Client-side code # + + +class MultipartTest(helper.CPWebCase): + setup_server = staticmethod(setup_server) + + def test_multipart(self): + text_part = ntou('This is the text version') + html_part = ntou( + """<!DOCTYPE html PUBLIC "-//W3C//DTD HTML 4.01 Transitional//EN"> +<html> +<head> + <meta content="text/html;charset=ISO-8859-1" http-equiv="Content-Type"> +</head> +<body bgcolor="#ffffff" text="#000000"> + +This is the <strong>HTML</strong> version +</body> +</html> +""") + body = '\r\n'.join([ + '--123456789', + "Content-Type: text/plain; charset='ISO-8859-1'", + 'Content-Transfer-Encoding: 7bit', + '', + text_part, + '--123456789', + "Content-Type: text/html; charset='ISO-8859-1'", + '', + html_part, + '--123456789--']) + headers = [ + ('Content-Type', 'multipart/mixed; boundary=123456789'), + ('Content-Length', str(len(body))), + ] + self.getPage('/multipart', headers, 'POST', body) + self.assertBody(repr([text_part, html_part])) + + def test_multipart_form_data(self): + body = '\r\n'.join([ + '--X', + 'Content-Disposition: form-data; name="foo"', + '', + 'bar', + '--X', + # Test a param with more than one value. + # See + # https://github.com/cherrypy/cherrypy/issues/1028 + 'Content-Disposition: form-data; name="baz"', + '', + '111', + '--X', + 'Content-Disposition: form-data; name="baz"', + '', + '333', + '--X--' + ]) + self.getPage('/multipart_form_data', method='POST', + headers=[( + 'Content-Type', 'multipart/form-data;boundary=X'), + ('Content-Length', str(len(body))), + ], + body=body), + self.assertBody( + repr([('baz', [ntou('111'), ntou('333')]), ('foo', ntou('bar'))])) + + +class SafeMultipartHandlingTest(helper.CPWebCase): + setup_server = staticmethod(setup_server) + + def test_Flash_Upload(self): + headers = [ + ('Accept', 'text/*'), + ('Content-Type', 'multipart/form-data; ' + 'boundary=----------KM7Ij5cH2KM7Ef1gL6ae0ae0cH2gL6'), + ('User-Agent', 'Shockwave Flash'), + ('Host', 'www.example.com:54583'), + ('Content-Length', '499'), + ('Connection', 'Keep-Alive'), + ('Cache-Control', 'no-cache'), + ] + filedata = (b'<?xml version="1.0" encoding="UTF-8"?>\r\n' + b'<projectDescription>\r\n' + b'</projectDescription>\r\n') + body = ( + b'------------KM7Ij5cH2KM7Ef1gL6ae0ae0cH2gL6\r\n' + b'Content-Disposition: form-data; name="Filename"\r\n' + b'\r\n' + b'.project\r\n' + b'------------KM7Ij5cH2KM7Ef1gL6ae0ae0cH2gL6\r\n' + b'Content-Disposition: form-data; ' + b'name="Filedata"; filename=".project"\r\n' + b'Content-Type: application/octet-stream\r\n' + b'\r\n' + + filedata + + b'\r\n' + b'------------KM7Ij5cH2KM7Ef1gL6ae0ae0cH2gL6\r\n' + b'Content-Disposition: form-data; name="Upload"\r\n' + b'\r\n' + b'Submit Query\r\n' + # Flash apps omit the trailing \r\n on the last line: + b'------------KM7Ij5cH2KM7Ef1gL6ae0ae0cH2gL6--' + ) + self.getPage('/flashupload', headers, 'POST', body) + self.assertBody('Upload: Submit Query, Filename: .project, ' + 'Filedata: %r' % filedata) diff --git a/libraries/cherrypy/test/test_misc_tools.py b/libraries/cherrypy/test/test_misc_tools.py new file mode 100644 index 00000000..fb85b8f8 --- /dev/null +++ b/libraries/cherrypy/test/test_misc_tools.py @@ -0,0 +1,210 @@ +import os + +import cherrypy +from cherrypy import tools +from cherrypy.test import helper + + +localDir = os.path.dirname(__file__) +logfile = os.path.join(localDir, 'test_misc_tools.log') + + +def setup_server(): + class Root: + + @cherrypy.expose + def index(self): + yield 'Hello, world' + h = [('Content-Language', 'en-GB'), ('Content-Type', 'text/plain')] + tools.response_headers(headers=h)(index) + + @cherrypy.expose + @cherrypy.config(**{ + 'tools.response_headers.on': True, + 'tools.response_headers.headers': [ + ('Content-Language', 'fr'), + ('Content-Type', 'text/plain'), + ], + 'tools.log_hooks.on': True, + }) + def other(self): + return 'salut' + + @cherrypy.config(**{'tools.accept.on': True}) + class Accept: + + @cherrypy.expose + def index(self): + return '<a href="feed">Atom feed</a>' + + @cherrypy.expose + @tools.accept(media='application/atom+xml') + def feed(self): + return """<?xml version="1.0" encoding="utf-8"?> +<feed xmlns="http://www.w3.org/2005/Atom"> + <title>Unknown Blog</title> +</feed>""" + + @cherrypy.expose + def select(self): + # We could also write this: mtype = cherrypy.lib.accept.accept(...) + mtype = tools.accept.callable(['text/html', 'text/plain']) + if mtype == 'text/html': + return '<h2>Page Title</h2>' + else: + return 'PAGE TITLE' + + class Referer: + + @cherrypy.expose + def accept(self): + return 'Accepted!' + reject = accept + + class AutoVary: + + @cherrypy.expose + def index(self): + # Read a header directly with 'get' + cherrypy.request.headers.get('Accept-Encoding') + # Read a header directly with '__getitem__' + cherrypy.request.headers['Host'] + # Read a header directly with '__contains__' + 'If-Modified-Since' in cherrypy.request.headers + # Read a header directly + 'Range' in cherrypy.request.headers + # Call a lib function + tools.accept.callable(['text/html', 'text/plain']) + return 'Hello, world!' + + conf = {'/referer': {'tools.referer.on': True, + 'tools.referer.pattern': r'http://[^/]*example\.com', + }, + '/referer/reject': {'tools.referer.accept': False, + 'tools.referer.accept_missing': True, + }, + '/autovary': {'tools.autovary.on': True}, + } + + root = Root() + root.referer = Referer() + root.accept = Accept() + root.autovary = AutoVary() + cherrypy.tree.mount(root, config=conf) + cherrypy.config.update({'log.error_file': logfile}) + + +class ResponseHeadersTest(helper.CPWebCase): + setup_server = staticmethod(setup_server) + + def testResponseHeadersDecorator(self): + self.getPage('/') + self.assertHeader('Content-Language', 'en-GB') + self.assertHeader('Content-Type', 'text/plain;charset=utf-8') + + def testResponseHeaders(self): + self.getPage('/other') + self.assertHeader('Content-Language', 'fr') + self.assertHeader('Content-Type', 'text/plain;charset=utf-8') + + +class RefererTest(helper.CPWebCase): + setup_server = staticmethod(setup_server) + + def testReferer(self): + self.getPage('/referer/accept') + self.assertErrorPage(403, 'Forbidden Referer header.') + + self.getPage('/referer/accept', + headers=[('Referer', 'http://www.example.com/')]) + self.assertStatus(200) + self.assertBody('Accepted!') + + # Reject + self.getPage('/referer/reject') + self.assertStatus(200) + self.assertBody('Accepted!') + + self.getPage('/referer/reject', + headers=[('Referer', 'http://www.example.com/')]) + self.assertErrorPage(403, 'Forbidden Referer header.') + + +class AcceptTest(helper.CPWebCase): + setup_server = staticmethod(setup_server) + + def test_Accept_Tool(self): + # Test with no header provided + self.getPage('/accept/feed') + self.assertStatus(200) + self.assertInBody('<title>Unknown Blog</title>') + + # Specify exact media type + self.getPage('/accept/feed', + headers=[('Accept', 'application/atom+xml')]) + self.assertStatus(200) + self.assertInBody('<title>Unknown Blog</title>') + + # Specify matching media range + self.getPage('/accept/feed', headers=[('Accept', 'application/*')]) + self.assertStatus(200) + self.assertInBody('<title>Unknown Blog</title>') + + # Specify all media ranges + self.getPage('/accept/feed', headers=[('Accept', '*/*')]) + self.assertStatus(200) + self.assertInBody('<title>Unknown Blog</title>') + + # Specify unacceptable media types + self.getPage('/accept/feed', headers=[('Accept', 'text/html')]) + self.assertErrorPage(406, + 'Your client sent this Accept header: text/html. ' + 'But this resource only emits these media types: ' + 'application/atom+xml.') + + # Test resource where tool is 'on' but media is None (not set). + self.getPage('/accept/') + self.assertStatus(200) + self.assertBody('<a href="feed">Atom feed</a>') + + def test_accept_selection(self): + # Try both our expected media types + self.getPage('/accept/select', [('Accept', 'text/html')]) + self.assertStatus(200) + self.assertBody('<h2>Page Title</h2>') + self.getPage('/accept/select', [('Accept', 'text/plain')]) + self.assertStatus(200) + self.assertBody('PAGE TITLE') + self.getPage('/accept/select', + [('Accept', 'text/plain, text/*;q=0.5')]) + self.assertStatus(200) + self.assertBody('PAGE TITLE') + + # text/* and */* should prefer text/html since it comes first + # in our 'media' argument to tools.accept + self.getPage('/accept/select', [('Accept', 'text/*')]) + self.assertStatus(200) + self.assertBody('<h2>Page Title</h2>') + self.getPage('/accept/select', [('Accept', '*/*')]) + self.assertStatus(200) + self.assertBody('<h2>Page Title</h2>') + + # Try unacceptable media types + self.getPage('/accept/select', [('Accept', 'application/xml')]) + self.assertErrorPage( + 406, + 'Your client sent this Accept header: application/xml. ' + 'But this resource only emits these media types: ' + 'text/html, text/plain.') + + +class AutoVaryTest(helper.CPWebCase): + setup_server = staticmethod(setup_server) + + def testAutoVary(self): + self.getPage('/autovary/') + self.assertHeader( + 'Vary', + 'Accept, Accept-Charset, Accept-Encoding, ' + 'Host, If-Modified-Since, Range' + ) diff --git a/libraries/cherrypy/test/test_native.py b/libraries/cherrypy/test/test_native.py new file mode 100644 index 00000000..caebc3f4 --- /dev/null +++ b/libraries/cherrypy/test/test_native.py @@ -0,0 +1,35 @@ +"""Test the native server.""" + +import pytest +from requests_toolbelt import sessions + +import cherrypy._cpnative_server + + +pytestmark = pytest.mark.skipif( + 'sys.platform == "win32"', + reason='tests fail on Windows', +) + + +@pytest.fixture +def cp_native_server(request): + """A native server.""" + class Root(object): + @cherrypy.expose + def index(self): + return 'Hello World!' + + cls = cherrypy._cpnative_server.CPHTTPServer + cherrypy.server.httpserver = cls(cherrypy.server) + + cherrypy.tree.mount(Root(), '/') + cherrypy.engine.start() + request.addfinalizer(cherrypy.engine.stop) + url = 'http://localhost:{cherrypy.server.socket_port}'.format(**globals()) + return sessions.BaseUrlSession(url) + + +def test_basic_request(cp_native_server): + """A request to a native server should succeed.""" + cp_native_server.get('/') diff --git a/libraries/cherrypy/test/test_objectmapping.py b/libraries/cherrypy/test/test_objectmapping.py new file mode 100644 index 00000000..98402b8b --- /dev/null +++ b/libraries/cherrypy/test/test_objectmapping.py @@ -0,0 +1,430 @@ +import sys +import cherrypy +from cherrypy._cpcompat import ntou +from cherrypy._cptree import Application +from cherrypy.test import helper + +script_names = ['', '/foo', '/users/fred/blog', '/corp/blog'] + + +class ObjectMappingTest(helper.CPWebCase): + + @staticmethod + def setup_server(): + class Root: + + @cherrypy.expose + def index(self, name='world'): + return name + + @cherrypy.expose + def foobar(self): + return 'bar' + + @cherrypy.expose + def default(self, *params, **kwargs): + return 'default:' + repr(params) + + @cherrypy.expose + def other(self): + return 'other' + + @cherrypy.expose + def extra(self, *p): + return repr(p) + + @cherrypy.expose + def redirect(self): + raise cherrypy.HTTPRedirect('dir1/', 302) + + def notExposed(self): + return 'not exposed' + + @cherrypy.expose + def confvalue(self): + return cherrypy.request.config.get('user') + + @cherrypy.expose + def redirect_via_url(self, path): + raise cherrypy.HTTPRedirect(cherrypy.url(path)) + + @cherrypy.expose + def translate_html(self): + return 'OK' + + @cherrypy.expose + def mapped_func(self, ID=None): + return 'ID is %s' % ID + setattr(Root, 'Von B\xfclow', mapped_func) + + class Exposing: + + @cherrypy.expose + def base(self): + return 'expose works!' + cherrypy.expose(base, '1') + cherrypy.expose(base, '2') + + class ExposingNewStyle(object): + + @cherrypy.expose + def base(self): + return 'expose works!' + cherrypy.expose(base, '1') + cherrypy.expose(base, '2') + + class Dir1: + + @cherrypy.expose + def index(self): + return 'index for dir1' + + @cherrypy.expose + @cherrypy.config(**{'tools.trailing_slash.extra': True}) + def myMethod(self): + return 'myMethod from dir1, path_info is:' + repr( + cherrypy.request.path_info) + + @cherrypy.expose + def default(self, *params): + return 'default for dir1, param is:' + repr(params) + + class Dir2: + + @cherrypy.expose + def index(self): + return 'index for dir2, path is:' + cherrypy.request.path_info + + @cherrypy.expose + def script_name(self): + return cherrypy.tree.script_name() + + @cherrypy.expose + def cherrypy_url(self): + return cherrypy.url('/extra') + + @cherrypy.expose + def posparam(self, *vpath): + return '/'.join(vpath) + + class Dir3: + + def default(self): + return 'default for dir3, not exposed' + + class Dir4: + + def index(self): + return 'index for dir4, not exposed' + + class DefNoIndex: + + @cherrypy.expose + def default(self, *args): + raise cherrypy.HTTPRedirect('contact') + + # MethodDispatcher code + @cherrypy.expose + class ByMethod: + + def __init__(self, *things): + self.things = list(things) + + def GET(self): + return repr(self.things) + + def POST(self, thing): + self.things.append(thing) + + class Collection: + default = ByMethod('a', 'bit') + + Root.exposing = Exposing() + Root.exposingnew = ExposingNewStyle() + Root.dir1 = Dir1() + Root.dir1.dir2 = Dir2() + Root.dir1.dir2.dir3 = Dir3() + Root.dir1.dir2.dir3.dir4 = Dir4() + Root.defnoindex = DefNoIndex() + Root.bymethod = ByMethod('another') + Root.collection = Collection() + + d = cherrypy.dispatch.MethodDispatcher() + for url in script_names: + conf = {'/': {'user': (url or '/').split('/')[-2]}, + '/bymethod': {'request.dispatch': d}, + '/collection': {'request.dispatch': d}, + } + cherrypy.tree.mount(Root(), url, conf) + + class Isolated: + + @cherrypy.expose + def index(self): + return 'made it!' + + cherrypy.tree.mount(Isolated(), '/isolated') + + @cherrypy.expose + class AnotherApp: + + def GET(self): + return 'milk' + + cherrypy.tree.mount(AnotherApp(), '/app', + {'/': {'request.dispatch': d}}) + + def testObjectMapping(self): + for url in script_names: + self.script_name = url + + self.getPage('/') + self.assertBody('world') + + self.getPage('/dir1/myMethod') + self.assertBody( + "myMethod from dir1, path_info is:'/dir1/myMethod'") + + self.getPage('/this/method/does/not/exist') + self.assertBody( + "default:('this', 'method', 'does', 'not', 'exist')") + + self.getPage('/extra/too/much') + self.assertBody("('too', 'much')") + + self.getPage('/other') + self.assertBody('other') + + self.getPage('/notExposed') + self.assertBody("default:('notExposed',)") + + self.getPage('/dir1/dir2/') + self.assertBody('index for dir2, path is:/dir1/dir2/') + + # Test omitted trailing slash (should be redirected by default). + self.getPage('/dir1/dir2') + self.assertStatus(301) + self.assertHeader('Location', '%s/dir1/dir2/' % self.base()) + + # Test extra trailing slash (should be redirected if configured). + self.getPage('/dir1/myMethod/') + self.assertStatus(301) + self.assertHeader('Location', '%s/dir1/myMethod' % self.base()) + + # Test that default method must be exposed in order to match. + self.getPage('/dir1/dir2/dir3/dir4/index') + self.assertBody( + "default for dir1, param is:('dir2', 'dir3', 'dir4', 'index')") + + # Test *vpath when default() is defined but not index() + # This also tests HTTPRedirect with default. + self.getPage('/defnoindex') + self.assertStatus((302, 303)) + self.assertHeader('Location', '%s/contact' % self.base()) + self.getPage('/defnoindex/') + self.assertStatus((302, 303)) + self.assertHeader('Location', '%s/defnoindex/contact' % + self.base()) + self.getPage('/defnoindex/page') + self.assertStatus((302, 303)) + self.assertHeader('Location', '%s/defnoindex/contact' % + self.base()) + + self.getPage('/redirect') + self.assertStatus('302 Found') + self.assertHeader('Location', '%s/dir1/' % self.base()) + + if not getattr(cherrypy.server, 'using_apache', False): + # Test that we can use URL's which aren't all valid Python + # identifiers + # This should also test the %XX-unquoting of URL's. + self.getPage('/Von%20B%fclow?ID=14') + self.assertBody('ID is 14') + + # Test that %2F in the path doesn't get unquoted too early; + # that is, it should not be used to separate path components. + # See ticket #393. + self.getPage('/page%2Fname') + self.assertBody("default:('page/name',)") + + self.getPage('/dir1/dir2/script_name') + self.assertBody(url) + self.getPage('/dir1/dir2/cherrypy_url') + self.assertBody('%s/extra' % self.base()) + + # Test that configs don't overwrite each other from different apps + self.getPage('/confvalue') + self.assertBody((url or '/').split('/')[-2]) + + self.script_name = '' + + # Test absoluteURI's in the Request-Line + self.getPage('http://%s:%s/' % (self.interface(), self.PORT)) + self.assertBody('world') + + self.getPage('http://%s:%s/abs/?service=http://192.168.0.1/x/y/z' % + (self.interface(), self.PORT)) + self.assertBody("default:('abs',)") + + self.getPage('/rel/?service=http://192.168.120.121:8000/x/y/z') + self.assertBody("default:('rel',)") + + # Test that the "isolated" app doesn't leak url's into the root app. + # If it did leak, Root.default() would answer with + # "default:('isolated', 'doesnt', 'exist')". + self.getPage('/isolated/') + self.assertStatus('200 OK') + self.assertBody('made it!') + self.getPage('/isolated/doesnt/exist') + self.assertStatus('404 Not Found') + + # Make sure /foobar maps to Root.foobar and not to the app + # mounted at /foo. See + # https://github.com/cherrypy/cherrypy/issues/573 + self.getPage('/foobar') + self.assertBody('bar') + + def test_translate(self): + self.getPage('/translate_html') + self.assertStatus('200 OK') + self.assertBody('OK') + + self.getPage('/translate.html') + self.assertStatus('200 OK') + self.assertBody('OK') + + self.getPage('/translate-html') + self.assertStatus('200 OK') + self.assertBody('OK') + + def test_redir_using_url(self): + for url in script_names: + self.script_name = url + + # Test the absolute path to the parent (leading slash) + self.getPage('/redirect_via_url?path=./') + self.assertStatus(('302 Found', '303 See Other')) + self.assertHeader('Location', '%s/' % self.base()) + + # Test the relative path to the parent (no leading slash) + self.getPage('/redirect_via_url?path=./') + self.assertStatus(('302 Found', '303 See Other')) + self.assertHeader('Location', '%s/' % self.base()) + + # Test the absolute path to the parent (leading slash) + self.getPage('/redirect_via_url/?path=./') + self.assertStatus(('302 Found', '303 See Other')) + self.assertHeader('Location', '%s/' % self.base()) + + # Test the relative path to the parent (no leading slash) + self.getPage('/redirect_via_url/?path=./') + self.assertStatus(('302 Found', '303 See Other')) + self.assertHeader('Location', '%s/' % self.base()) + + def testPositionalParams(self): + self.getPage('/dir1/dir2/posparam/18/24/hut/hike') + self.assertBody('18/24/hut/hike') + + # intermediate index methods should not receive posparams; + # only the "final" index method should do so. + self.getPage('/dir1/dir2/5/3/sir') + self.assertBody("default for dir1, param is:('dir2', '5', '3', 'sir')") + + # test that extra positional args raises an 404 Not Found + # See https://github.com/cherrypy/cherrypy/issues/733. + self.getPage('/dir1/dir2/script_name/extra/stuff') + self.assertStatus(404) + + def testExpose(self): + # Test the cherrypy.expose function/decorator + self.getPage('/exposing/base') + self.assertBody('expose works!') + + self.getPage('/exposing/1') + self.assertBody('expose works!') + + self.getPage('/exposing/2') + self.assertBody('expose works!') + + self.getPage('/exposingnew/base') + self.assertBody('expose works!') + + self.getPage('/exposingnew/1') + self.assertBody('expose works!') + + self.getPage('/exposingnew/2') + self.assertBody('expose works!') + + def testMethodDispatch(self): + self.getPage('/bymethod') + self.assertBody("['another']") + self.assertHeader('Allow', 'GET, HEAD, POST') + + self.getPage('/bymethod', method='HEAD') + self.assertBody('') + self.assertHeader('Allow', 'GET, HEAD, POST') + + self.getPage('/bymethod', method='POST', body='thing=one') + self.assertBody('') + self.assertHeader('Allow', 'GET, HEAD, POST') + + self.getPage('/bymethod') + self.assertBody(repr(['another', ntou('one')])) + self.assertHeader('Allow', 'GET, HEAD, POST') + + self.getPage('/bymethod', method='PUT') + self.assertErrorPage(405) + self.assertHeader('Allow', 'GET, HEAD, POST') + + # Test default with posparams + self.getPage('/collection/silly', method='POST') + self.getPage('/collection', method='GET') + self.assertBody("['a', 'bit', 'silly']") + + # Test custom dispatcher set on app root (see #737). + self.getPage('/app') + self.assertBody('milk') + + def testTreeMounting(self): + class Root(object): + + @cherrypy.expose + def hello(self): + return 'Hello world!' + + # When mounting an application instance, + # we can't specify a different script name in the call to mount. + a = Application(Root(), '/somewhere') + self.assertRaises(ValueError, cherrypy.tree.mount, a, '/somewhereelse') + + # When mounting an application instance... + a = Application(Root(), '/somewhere') + # ...we MUST allow in identical script name in the call to mount... + cherrypy.tree.mount(a, '/somewhere') + self.getPage('/somewhere/hello') + self.assertStatus(200) + # ...and MUST allow a missing script_name. + del cherrypy.tree.apps['/somewhere'] + cherrypy.tree.mount(a) + self.getPage('/somewhere/hello') + self.assertStatus(200) + + # In addition, we MUST be able to create an Application using + # script_name == None for access to the wsgi_environ. + a = Application(Root(), script_name=None) + # However, this does not apply to tree.mount + self.assertRaises(TypeError, cherrypy.tree.mount, a, None) + + def testKeywords(self): + if sys.version_info < (3,): + return self.skip('skipped (Python 3 only)') + exec("""class Root(object): + @cherrypy.expose + def hello(self, *, name='world'): + return 'Hello %s!' % name +cherrypy.tree.mount(Application(Root(), '/keywords'))""") + + self.getPage('/keywords/hello') + self.assertStatus(200) + self.getPage('/keywords/hello/extra') + self.assertStatus(404) diff --git a/libraries/cherrypy/test/test_params.py b/libraries/cherrypy/test/test_params.py new file mode 100644 index 00000000..73b4cb4c --- /dev/null +++ b/libraries/cherrypy/test/test_params.py @@ -0,0 +1,61 @@ +import sys +import textwrap + +import cherrypy +from cherrypy.test import helper + + +class ParamsTest(helper.CPWebCase): + @staticmethod + def setup_server(): + class Root: + @cherrypy.expose + @cherrypy.tools.json_out() + @cherrypy.tools.params() + def resource(self, limit=None, sort=None): + return type(limit).__name__ + # for testing on Py 2 + resource.__annotations__ = {'limit': int} + conf = {'/': {'tools.params.on': True}} + cherrypy.tree.mount(Root(), config=conf) + + def test_pass(self): + self.getPage('/resource') + self.assertStatus(200) + self.assertBody('"NoneType"') + + self.getPage('/resource?limit=0') + self.assertStatus(200) + self.assertBody('"int"') + + def test_error(self): + self.getPage('/resource?limit=') + self.assertStatus(400) + self.assertInBody('invalid literal for int') + + cherrypy.config['tools.params.error'] = 422 + self.getPage('/resource?limit=') + self.assertStatus(422) + self.assertInBody('invalid literal for int') + + cherrypy.config['tools.params.exception'] = TypeError + self.getPage('/resource?limit=') + self.assertStatus(500) + + def test_syntax(self): + if sys.version_info < (3,): + return self.skip('skipped (Python 3 only)') + code = textwrap.dedent(""" + class Root: + @cherrypy.expose + @cherrypy.tools.params() + def resource(self, limit: int): + return type(limit).__name__ + conf = {'/': {'tools.params.on': True}} + cherrypy.tree.mount(Root(), config=conf) + """) + exec(code) + + self.getPage('/resource?limit=0') + self.assertStatus(200) + self.assertBody('int') diff --git a/libraries/cherrypy/test/test_plugins.py b/libraries/cherrypy/test/test_plugins.py new file mode 100644 index 00000000..4d3aa6b1 --- /dev/null +++ b/libraries/cherrypy/test/test_plugins.py @@ -0,0 +1,14 @@ +from cherrypy.process import plugins + + +__metaclass__ = type + + +class TestAutoreloader: + def test_file_for_file_module_when_None(self): + """No error when module.__file__ is None. + """ + class test_module: + __file__ = None + + assert plugins.Autoreloader._file_for_file_module(test_module) is None diff --git a/libraries/cherrypy/test/test_proxy.py b/libraries/cherrypy/test/test_proxy.py new file mode 100644 index 00000000..4d34440a --- /dev/null +++ b/libraries/cherrypy/test/test_proxy.py @@ -0,0 +1,154 @@ +import cherrypy +from cherrypy.test import helper + +script_names = ['', '/path/to/myapp'] + + +class ProxyTest(helper.CPWebCase): + + @staticmethod + def setup_server(): + + # Set up site + cherrypy.config.update({ + 'tools.proxy.on': True, + 'tools.proxy.base': 'www.mydomain.test', + }) + + # Set up application + + class Root: + + def __init__(self, sn): + # Calculate a URL outside of any requests. + self.thisnewpage = cherrypy.url( + '/this/new/page', script_name=sn) + + @cherrypy.expose + def pageurl(self): + return self.thisnewpage + + @cherrypy.expose + def index(self): + raise cherrypy.HTTPRedirect('dummy') + + @cherrypy.expose + def remoteip(self): + return cherrypy.request.remote.ip + + @cherrypy.expose + @cherrypy.config(**{ + 'tools.proxy.local': 'X-Host', + 'tools.trailing_slash.extra': True, + }) + def xhost(self): + raise cherrypy.HTTPRedirect('blah') + + @cherrypy.expose + def base(self): + return cherrypy.request.base + + @cherrypy.expose + @cherrypy.config(**{'tools.proxy.scheme': 'X-Forwarded-Ssl'}) + def ssl(self): + return cherrypy.request.base + + @cherrypy.expose + def newurl(self): + return ("Browse to <a href='%s'>this page</a>." + % cherrypy.url('/this/new/page')) + + @cherrypy.expose + @cherrypy.config(**{ + 'tools.proxy.base': None, + }) + def base_no_base(self): + return cherrypy.request.base + + for sn in script_names: + cherrypy.tree.mount(Root(sn), sn) + + def testProxy(self): + self.getPage('/') + self.assertHeader('Location', + '%s://www.mydomain.test%s/dummy' % + (self.scheme, self.prefix())) + + # Test X-Forwarded-Host (Apache 1.3.33+ and Apache 2) + self.getPage( + '/', headers=[('X-Forwarded-Host', 'http://www.example.test')]) + self.assertHeader('Location', 'http://www.example.test/dummy') + self.getPage('/', headers=[('X-Forwarded-Host', 'www.example.test')]) + self.assertHeader('Location', '%s://www.example.test/dummy' % + self.scheme) + # Test multiple X-Forwarded-Host headers + self.getPage('/', headers=[ + ('X-Forwarded-Host', 'http://www.example.test, www.cherrypy.test'), + ]) + self.assertHeader('Location', 'http://www.example.test/dummy') + + # Test X-Forwarded-For (Apache2) + self.getPage('/remoteip', + headers=[('X-Forwarded-For', '192.168.0.20')]) + self.assertBody('192.168.0.20') + # Fix bug #1268 + self.getPage('/remoteip', + headers=[ + ('X-Forwarded-For', '67.15.36.43, 192.168.0.20') + ]) + self.assertBody('67.15.36.43') + + # Test X-Host (lighttpd; see https://trac.lighttpd.net/trac/ticket/418) + self.getPage('/xhost', headers=[('X-Host', 'www.example.test')]) + self.assertHeader('Location', '%s://www.example.test/blah' % + self.scheme) + + # Test X-Forwarded-Proto (lighttpd) + self.getPage('/base', headers=[('X-Forwarded-Proto', 'https')]) + self.assertBody('https://www.mydomain.test') + + # Test X-Forwarded-Ssl (webfaction?) + self.getPage('/ssl', headers=[('X-Forwarded-Ssl', 'on')]) + self.assertBody('https://www.mydomain.test') + + # Test cherrypy.url() + for sn in script_names: + # Test the value inside requests + self.getPage(sn + '/newurl') + self.assertBody( + "Browse to <a href='%s://www.mydomain.test" % self.scheme + + sn + "/this/new/page'>this page</a>.") + self.getPage(sn + '/newurl', headers=[('X-Forwarded-Host', + 'http://www.example.test')]) + self.assertBody("Browse to <a href='http://www.example.test" + + sn + "/this/new/page'>this page</a>.") + + # Test the value outside requests + port = '' + if self.scheme == 'http' and self.PORT != 80: + port = ':%s' % self.PORT + elif self.scheme == 'https' and self.PORT != 443: + port = ':%s' % self.PORT + host = self.HOST + if host in ('0.0.0.0', '::'): + import socket + host = socket.gethostname() + expected = ('%s://%s%s%s/this/new/page' + % (self.scheme, host, port, sn)) + self.getPage(sn + '/pageurl') + self.assertBody(expected) + + # Test trailing slash (see + # https://github.com/cherrypy/cherrypy/issues/562). + self.getPage('/xhost/', headers=[('X-Host', 'www.example.test')]) + self.assertHeader('Location', '%s://www.example.test/xhost' + % self.scheme) + + def test_no_base_port_in_host(self): + """ + If no base is indicated, and the host header is used to resolve + the base, it should rely on the host header for the port also. + """ + headers = {'Host': 'localhost:8080'}.items() + self.getPage('/base_no_base', headers=headers) + self.assertBody('http://localhost:8080') diff --git a/libraries/cherrypy/test/test_refleaks.py b/libraries/cherrypy/test/test_refleaks.py new file mode 100644 index 00000000..c2fe9e66 --- /dev/null +++ b/libraries/cherrypy/test/test_refleaks.py @@ -0,0 +1,66 @@ +"""Tests for refleaks.""" + +import itertools +import platform +import threading + +from six.moves.http_client import HTTPConnection + +import cherrypy +from cherrypy._cpcompat import HTTPSConnection +from cherrypy.test import helper + + +data = object() + + +class ReferenceTests(helper.CPWebCase): + + @staticmethod + def setup_server(): + + class Root: + + @cherrypy.expose + def index(self, *args, **kwargs): + cherrypy.request.thing = data + return 'Hello world!' + + cherrypy.tree.mount(Root()) + + def test_threadlocal_garbage(self): + if platform.system() == 'Darwin': + self.skip('queue issues; see #1474') + success = itertools.count() + + def getpage(): + host = '%s:%s' % (self.interface(), self.PORT) + if self.scheme == 'https': + c = HTTPSConnection(host) + else: + c = HTTPConnection(host) + try: + c.putrequest('GET', '/') + c.endheaders() + response = c.getresponse() + body = response.read() + self.assertEqual(response.status, 200) + self.assertEqual(body, b'Hello world!') + finally: + c.close() + next(success) + + ITERATIONS = 25 + + ts = [ + threading.Thread(target=getpage) + for _ in range(ITERATIONS) + ] + + for t in ts: + t.start() + + for t in ts: + t.join() + + self.assertEqual(next(success), ITERATIONS) diff --git a/libraries/cherrypy/test/test_request_obj.py b/libraries/cherrypy/test/test_request_obj.py new file mode 100644 index 00000000..6b93e13d --- /dev/null +++ b/libraries/cherrypy/test/test_request_obj.py @@ -0,0 +1,932 @@ +"""Basic tests for the cherrypy.Request object.""" + +from functools import wraps +import os +import sys +import types +import uuid + +import six +from six.moves.http_client import IncompleteRead + +import cherrypy +from cherrypy._cpcompat import ntou +from cherrypy.lib import httputil +from cherrypy.test import helper + +localDir = os.path.dirname(__file__) + +defined_http_methods = ('OPTIONS', 'GET', 'HEAD', 'POST', 'PUT', 'DELETE', + 'TRACE', 'PROPFIND', 'PATCH') + + +# Client-side code # + + +class RequestObjectTests(helper.CPWebCase): + + @staticmethod + def setup_server(): + class Root: + + @cherrypy.expose + def index(self): + return 'hello' + + @cherrypy.expose + def scheme(self): + return cherrypy.request.scheme + + @cherrypy.expose + def created_example_com_3128(self): + """Handle CONNECT method.""" + cherrypy.response.status = 204 + + @cherrypy.expose + def body_example_com_3128(self): + """Handle CONNECT method.""" + return ( + cherrypy.request.method + + 'ed to ' + + cherrypy.request.path_info + ) + + @cherrypy.expose + def request_uuid4(self): + return [ + str(cherrypy.request.unique_id), + ' ', + str(cherrypy.request.unique_id), + ] + + root = Root() + + class TestType(type): + """Metaclass which automatically exposes all functions in each + subclass, and adds an instance of the subclass as an attribute + of root. + """ + def __init__(cls, name, bases, dct): + type.__init__(cls, name, bases, dct) + for value in dct.values(): + if isinstance(value, types.FunctionType): + value.exposed = True + setattr(root, name.lower(), cls()) + Test = TestType('Test', (object,), {}) + + class PathInfo(Test): + + def default(self, *args): + return cherrypy.request.path_info + + class Params(Test): + + def index(self, thing): + return repr(thing) + + def ismap(self, x, y): + return 'Coordinates: %s, %s' % (x, y) + + @cherrypy.config(**{'request.query_string_encoding': 'latin1'}) + def default(self, *args, **kwargs): + return 'args: %s kwargs: %s' % (args, sorted(kwargs.items())) + + @cherrypy.expose + class ParamErrorsCallable(object): + + def __call__(self): + return 'data' + + def handler_dec(f): + @wraps(f) + def wrapper(handler, *args, **kwargs): + return f(handler, *args, **kwargs) + return wrapper + + class ParamErrors(Test): + + @cherrypy.expose + def one_positional(self, param1): + return 'data' + + @cherrypy.expose + def one_positional_args(self, param1, *args): + return 'data' + + @cherrypy.expose + def one_positional_args_kwargs(self, param1, *args, **kwargs): + return 'data' + + @cherrypy.expose + def one_positional_kwargs(self, param1, **kwargs): + return 'data' + + @cherrypy.expose + def no_positional(self): + return 'data' + + @cherrypy.expose + def no_positional_args(self, *args): + return 'data' + + @cherrypy.expose + def no_positional_args_kwargs(self, *args, **kwargs): + return 'data' + + @cherrypy.expose + def no_positional_kwargs(self, **kwargs): + return 'data' + + callable_object = ParamErrorsCallable() + + @cherrypy.expose + def raise_type_error(self, **kwargs): + raise TypeError('Client Error') + + @cherrypy.expose + def raise_type_error_with_default_param(self, x, y=None): + return '%d' % 'a' # throw an exception + + @cherrypy.expose + @handler_dec + def raise_type_error_decorated(self, *args, **kwargs): + raise TypeError('Client Error') + + def callable_error_page(status, **kwargs): + return "Error %s - Well, I'm very sorry but you haven't paid!" % ( + status) + + @cherrypy.config(**{'tools.log_tracebacks.on': True}) + class Error(Test): + + def reason_phrase(self): + raise cherrypy.HTTPError("410 Gone fishin'") + + @cherrypy.config(**{ + 'error_page.404': os.path.join(localDir, 'static/index.html'), + 'error_page.401': callable_error_page, + }) + def custom(self, err='404'): + raise cherrypy.HTTPError( + int(err), 'No, <b>really</b>, not found!') + + @cherrypy.config(**{ + 'error_page.default': callable_error_page, + }) + def custom_default(self): + return 1 + 'a' # raise an unexpected error + + @cherrypy.config(**{'error_page.404': 'nonexistent.html'}) + def noexist(self): + raise cherrypy.HTTPError(404, 'No, <b>really</b>, not found!') + + def page_method(self): + raise ValueError() + + def page_yield(self): + yield 'howdy' + raise ValueError() + + @cherrypy.config(**{'response.stream': True}) + def page_streamed(self): + yield 'word up' + raise ValueError() + yield 'very oops' + + @cherrypy.config(**{'request.show_tracebacks': False}) + def cause_err_in_finalize(self): + # Since status must start with an int, this should error. + cherrypy.response.status = 'ZOO OK' + + @cherrypy.config(**{'request.throw_errors': True}) + def rethrow(self): + """Test that an error raised here will be thrown out to + the server. + """ + raise ValueError() + + class Expect(Test): + + def expectation_failed(self): + expect = cherrypy.request.headers.elements('Expect') + if expect and expect[0].value != '100-continue': + raise cherrypy.HTTPError(400) + raise cherrypy.HTTPError(417, 'Expectation Failed') + + class Headers(Test): + + def default(self, headername): + """Spit back out the value for the requested header.""" + return cherrypy.request.headers[headername] + + def doubledheaders(self): + # From https://github.com/cherrypy/cherrypy/issues/165: + # "header field names should not be case sensitive sayes the + # rfc. if i set a headerfield in complete lowercase i end up + # with two header fields, one in lowercase, the other in + # mixed-case." + + # Set the most common headers + hMap = cherrypy.response.headers + hMap['content-type'] = 'text/html' + hMap['content-length'] = 18 + hMap['server'] = 'CherryPy headertest' + hMap['location'] = ('%s://%s:%s/headers/' + % (cherrypy.request.local.ip, + cherrypy.request.local.port, + cherrypy.request.scheme)) + + # Set a rare header for fun + hMap['Expires'] = 'Thu, 01 Dec 2194 16:00:00 GMT' + + return 'double header test' + + def ifmatch(self): + val = cherrypy.request.headers['If-Match'] + assert isinstance(val, six.text_type) + cherrypy.response.headers['ETag'] = val + return val + + class HeaderElements(Test): + + def get_elements(self, headername): + e = cherrypy.request.headers.elements(headername) + return '\n'.join([six.text_type(x) for x in e]) + + class Method(Test): + + def index(self): + m = cherrypy.request.method + if m in defined_http_methods or m == 'CONNECT': + return m + + if m == 'LINK': + raise cherrypy.HTTPError(405) + else: + raise cherrypy.HTTPError(501) + + def parameterized(self, data): + return data + + def request_body(self): + # This should be a file object (temp file), + # which CP will just pipe back out if we tell it to. + return cherrypy.request.body + + def reachable(self): + return 'success' + + class Divorce(Test): + + """HTTP Method handlers shouldn't collide with normal method names. + For example, a GET-handler shouldn't collide with a method named + 'get'. + + If you build HTTP method dispatching into CherryPy, rewrite this + class to use your new dispatch mechanism and make sure that: + "GET /divorce HTTP/1.1" maps to divorce.index() and + "GET /divorce/get?ID=13 HTTP/1.1" maps to divorce.get() + """ + + documents = {} + + @cherrypy.expose + def index(self): + yield '<h1>Choose your document</h1>\n' + yield '<ul>\n' + for id, contents in self.documents.items(): + yield ( + " <li><a href='/divorce/get?ID=%s'>%s</a>:" + ' %s</li>\n' % (id, id, contents)) + yield '</ul>' + + @cherrypy.expose + def get(self, ID): + return ('Divorce document %s: %s' % + (ID, self.documents.get(ID, 'empty'))) + + class ThreadLocal(Test): + + def index(self): + existing = repr(getattr(cherrypy.request, 'asdf', None)) + cherrypy.request.asdf = 'rassfrassin' + return existing + + appconf = { + '/method': { + 'request.methods_with_bodies': ('POST', 'PUT', 'PROPFIND', + 'PATCH') + }, + } + cherrypy.tree.mount(root, config=appconf) + + def test_scheme(self): + self.getPage('/scheme') + self.assertBody(self.scheme) + + def test_per_request_uuid4(self): + self.getPage('/request_uuid4') + first_uuid4, _, second_uuid4 = self.body.decode().partition(' ') + assert ( + uuid.UUID(first_uuid4, version=4) + == uuid.UUID(second_uuid4, version=4) + ) + + self.getPage('/request_uuid4') + third_uuid4, _, _ = self.body.decode().partition(' ') + assert ( + uuid.UUID(first_uuid4, version=4) + != uuid.UUID(third_uuid4, version=4) + ) + + def testRelativeURIPathInfo(self): + self.getPage('/pathinfo/foo/bar') + self.assertBody('/pathinfo/foo/bar') + + def testAbsoluteURIPathInfo(self): + # http://cherrypy.org/ticket/1061 + self.getPage('http://localhost/pathinfo/foo/bar') + self.assertBody('/pathinfo/foo/bar') + + def testParams(self): + self.getPage('/params/?thing=a') + self.assertBody(repr(ntou('a'))) + + self.getPage('/params/?thing=a&thing=b&thing=c') + self.assertBody(repr([ntou('a'), ntou('b'), ntou('c')])) + + # Test friendly error message when given params are not accepted. + cherrypy.config.update({'request.show_mismatched_params': True}) + self.getPage('/params/?notathing=meeting') + self.assertInBody('Missing parameters: thing') + self.getPage('/params/?thing=meeting¬athing=meeting') + self.assertInBody('Unexpected query string parameters: notathing') + + # Test ability to turn off friendly error messages + cherrypy.config.update({'request.show_mismatched_params': False}) + self.getPage('/params/?notathing=meeting') + self.assertInBody('Not Found') + self.getPage('/params/?thing=meeting¬athing=meeting') + self.assertInBody('Not Found') + + # Test "% HEX HEX"-encoded URL, param keys, and values + self.getPage('/params/%d4%20%e3/cheese?Gruy%E8re=Bulgn%e9ville') + self.assertBody('args: %s kwargs: %s' % + (('\xd4 \xe3', 'cheese'), + [('Gruy\xe8re', ntou('Bulgn\xe9ville'))])) + + # Make sure that encoded = and & get parsed correctly + self.getPage( + '/params/code?url=http%3A//cherrypy.org/index%3Fa%3D1%26b%3D2') + self.assertBody('args: %s kwargs: %s' % + (('code',), + [('url', ntou('http://cherrypy.org/index?a=1&b=2'))])) + + # Test coordinates sent by <img ismap> + self.getPage('/params/ismap?223,114') + self.assertBody('Coordinates: 223, 114') + + # Test "name[key]" dict-like params + self.getPage('/params/dictlike?a[1]=1&a[2]=2&b=foo&b[bar]=baz') + self.assertBody('args: %s kwargs: %s' % + (('dictlike',), + [('a[1]', ntou('1')), ('a[2]', ntou('2')), + ('b', ntou('foo')), ('b[bar]', ntou('baz'))])) + + def testParamErrors(self): + + # test that all of the handlers work when given + # the correct parameters in order to ensure that the + # errors below aren't coming from some other source. + for uri in ( + '/paramerrors/one_positional?param1=foo', + '/paramerrors/one_positional_args?param1=foo', + '/paramerrors/one_positional_args/foo', + '/paramerrors/one_positional_args/foo/bar/baz', + '/paramerrors/one_positional_args_kwargs?' + 'param1=foo¶m2=bar', + '/paramerrors/one_positional_args_kwargs/foo?' + 'param2=bar¶m3=baz', + '/paramerrors/one_positional_args_kwargs/foo/bar/baz?' + 'param2=bar¶m3=baz', + '/paramerrors/one_positional_kwargs?' + 'param1=foo¶m2=bar¶m3=baz', + '/paramerrors/one_positional_kwargs/foo?' + 'param4=foo¶m2=bar¶m3=baz', + '/paramerrors/no_positional', + '/paramerrors/no_positional_args/foo', + '/paramerrors/no_positional_args/foo/bar/baz', + '/paramerrors/no_positional_args_kwargs?param1=foo¶m2=bar', + '/paramerrors/no_positional_args_kwargs/foo?param2=bar', + '/paramerrors/no_positional_args_kwargs/foo/bar/baz?' + 'param2=bar¶m3=baz', + '/paramerrors/no_positional_kwargs?param1=foo¶m2=bar', + '/paramerrors/callable_object', + ): + self.getPage(uri) + self.assertStatus(200) + + error_msgs = [ + 'Missing parameters', + 'Nothing matches the given URI', + 'Multiple values for parameters', + 'Unexpected query string parameters', + 'Unexpected body parameters', + 'Invalid path in Request-URI', + 'Illegal #fragment in Request-URI', + ] + + # uri should be tested for valid absolute path, the status must be 400. + for uri, error_idx in ( + ('invalid/path/without/leading/slash', 5), + ('/valid/path#invalid=fragment', 6), + ): + self.getPage(uri) + self.assertStatus(400) + self.assertInBody(error_msgs[error_idx]) + + # query string parameters are part of the URI, so if they are wrong + # for a particular handler, the status MUST be a 404. + for uri, msg in ( + ('/paramerrors/one_positional', error_msgs[0]), + ('/paramerrors/one_positional?foo=foo', error_msgs[0]), + ('/paramerrors/one_positional/foo/bar/baz', error_msgs[1]), + ('/paramerrors/one_positional/foo?param1=foo', error_msgs[2]), + ('/paramerrors/one_positional/foo?param1=foo¶m2=foo', + error_msgs[2]), + ('/paramerrors/one_positional_args/foo?param1=foo¶m2=foo', + error_msgs[2]), + ('/paramerrors/one_positional_args/foo/bar/baz?param2=foo', + error_msgs[3]), + ('/paramerrors/one_positional_args_kwargs/foo/bar/baz?' + 'param1=bar¶m3=baz', + error_msgs[2]), + ('/paramerrors/one_positional_kwargs/foo?' + 'param1=foo¶m2=bar¶m3=baz', + error_msgs[2]), + ('/paramerrors/no_positional/boo', error_msgs[1]), + ('/paramerrors/no_positional?param1=foo', error_msgs[3]), + ('/paramerrors/no_positional_args/boo?param1=foo', error_msgs[3]), + ('/paramerrors/no_positional_kwargs/boo?param1=foo', + error_msgs[1]), + ('/paramerrors/callable_object?param1=foo', error_msgs[3]), + ('/paramerrors/callable_object/boo', error_msgs[1]), + ): + for show_mismatched_params in (True, False): + cherrypy.config.update( + {'request.show_mismatched_params': show_mismatched_params}) + self.getPage(uri) + self.assertStatus(404) + if show_mismatched_params: + self.assertInBody(msg) + else: + self.assertInBody('Not Found') + + # if body parameters are wrong, a 400 must be returned. + for uri, body, msg in ( + ('/paramerrors/one_positional/foo', + 'param1=foo', error_msgs[2]), + ('/paramerrors/one_positional/foo', + 'param1=foo¶m2=foo', error_msgs[2]), + ('/paramerrors/one_positional_args/foo', + 'param1=foo¶m2=foo', error_msgs[2]), + ('/paramerrors/one_positional_args/foo/bar/baz', + 'param2=foo', error_msgs[4]), + ('/paramerrors/one_positional_args_kwargs/foo/bar/baz', + 'param1=bar¶m3=baz', error_msgs[2]), + ('/paramerrors/one_positional_kwargs/foo', + 'param1=foo¶m2=bar¶m3=baz', error_msgs[2]), + ('/paramerrors/no_positional', 'param1=foo', error_msgs[4]), + ('/paramerrors/no_positional_args/boo', + 'param1=foo', error_msgs[4]), + ('/paramerrors/callable_object', 'param1=foo', error_msgs[4]), + ): + for show_mismatched_params in (True, False): + cherrypy.config.update( + {'request.show_mismatched_params': show_mismatched_params}) + self.getPage(uri, method='POST', body=body) + self.assertStatus(400) + if show_mismatched_params: + self.assertInBody(msg) + else: + self.assertInBody('400 Bad') + + # even if body parameters are wrong, if we get the uri wrong, then + # it's a 404 + for uri, body, msg in ( + ('/paramerrors/one_positional?param2=foo', + 'param1=foo', error_msgs[3]), + ('/paramerrors/one_positional/foo/bar', + 'param2=foo', error_msgs[1]), + ('/paramerrors/one_positional_args/foo/bar?param2=foo', + 'param3=foo', error_msgs[3]), + ('/paramerrors/one_positional_kwargs/foo/bar', + 'param2=bar¶m3=baz', error_msgs[1]), + ('/paramerrors/no_positional?param1=foo', + 'param2=foo', error_msgs[3]), + ('/paramerrors/no_positional_args/boo?param2=foo', + 'param1=foo', error_msgs[3]), + ('/paramerrors/callable_object?param2=bar', + 'param1=foo', error_msgs[3]), + ): + for show_mismatched_params in (True, False): + cherrypy.config.update( + {'request.show_mismatched_params': show_mismatched_params}) + self.getPage(uri, method='POST', body=body) + self.assertStatus(404) + if show_mismatched_params: + self.assertInBody(msg) + else: + self.assertInBody('Not Found') + + # In the case that a handler raises a TypeError we should + # let that type error through. + for uri in ( + '/paramerrors/raise_type_error', + '/paramerrors/raise_type_error_with_default_param?x=0', + '/paramerrors/raise_type_error_with_default_param?x=0&y=0', + '/paramerrors/raise_type_error_decorated', + ): + self.getPage(uri, method='GET') + self.assertStatus(500) + self.assertTrue('Client Error', self.body) + + def testErrorHandling(self): + self.getPage('/error/missing') + self.assertStatus(404) + self.assertErrorPage(404, "The path '/error/missing' was not found.") + + ignore = helper.webtest.ignored_exceptions + ignore.append(ValueError) + try: + valerr = '\n raise ValueError()\nValueError' + self.getPage('/error/page_method') + self.assertErrorPage(500, pattern=valerr) + + self.getPage('/error/page_yield') + self.assertErrorPage(500, pattern=valerr) + + if (cherrypy.server.protocol_version == 'HTTP/1.0' or + getattr(cherrypy.server, 'using_apache', False)): + self.getPage('/error/page_streamed') + # Because this error is raised after the response body has + # started, the status should not change to an error status. + self.assertStatus(200) + self.assertBody('word up') + else: + # Under HTTP/1.1, the chunked transfer-coding is used. + # The HTTP client will choke when the output is incomplete. + self.assertRaises((ValueError, IncompleteRead), self.getPage, + '/error/page_streamed') + + # No traceback should be present + self.getPage('/error/cause_err_in_finalize') + msg = "Illegal response status from server ('ZOO' is non-numeric)." + self.assertErrorPage(500, msg, None) + finally: + ignore.pop() + + # Test HTTPError with a reason-phrase in the status arg. + self.getPage('/error/reason_phrase') + self.assertStatus("410 Gone fishin'") + + # Test custom error page for a specific error. + self.getPage('/error/custom') + self.assertStatus(404) + self.assertBody('Hello, world\r\n' + (' ' * 499)) + + # Test custom error page for a specific error. + self.getPage('/error/custom?err=401') + self.assertStatus(401) + self.assertBody( + 'Error 401 Unauthorized - ' + "Well, I'm very sorry but you haven't paid!") + + # Test default custom error page. + self.getPage('/error/custom_default') + self.assertStatus(500) + self.assertBody( + 'Error 500 Internal Server Error - ' + "Well, I'm very sorry but you haven't paid!".ljust(513)) + + # Test error in custom error page (ticket #305). + # Note that the message is escaped for HTML (ticket #310). + self.getPage('/error/noexist') + self.assertStatus(404) + if sys.version_info >= (3, 3): + exc_name = 'FileNotFoundError' + else: + exc_name = 'IOError' + msg = ('No, <b>really</b>, not found!<br />' + 'In addition, the custom error page failed:\n<br />' + '%s: [Errno 2] ' + "No such file or directory: 'nonexistent.html'") % (exc_name,) + self.assertInBody(msg) + + if getattr(cherrypy.server, 'using_apache', False): + pass + else: + # Test throw_errors (ticket #186). + self.getPage('/error/rethrow') + self.assertInBody('raise ValueError()') + + def testExpect(self): + e = ('Expect', '100-continue') + self.getPage('/headerelements/get_elements?headername=Expect', [e]) + self.assertBody('100-continue') + + self.getPage('/expect/expectation_failed', [e]) + self.assertStatus(417) + + def testHeaderElements(self): + # Accept-* header elements should be sorted, with most preferred first. + h = [('Accept', 'audio/*; q=0.2, audio/basic')] + self.getPage('/headerelements/get_elements?headername=Accept', h) + self.assertStatus(200) + self.assertBody('audio/basic\n' + 'audio/*;q=0.2') + + h = [ + ('Accept', + 'text/plain; q=0.5, text/html, text/x-dvi; q=0.8, text/x-c') + ] + self.getPage('/headerelements/get_elements?headername=Accept', h) + self.assertStatus(200) + self.assertBody('text/x-c\n' + 'text/html\n' + 'text/x-dvi;q=0.8\n' + 'text/plain;q=0.5') + + # Test that more specific media ranges get priority. + h = [('Accept', 'text/*, text/html, text/html;level=1, */*')] + self.getPage('/headerelements/get_elements?headername=Accept', h) + self.assertStatus(200) + self.assertBody('text/html;level=1\n' + 'text/html\n' + 'text/*\n' + '*/*') + + # Test Accept-Charset + h = [('Accept-Charset', 'iso-8859-5, unicode-1-1;q=0.8')] + self.getPage( + '/headerelements/get_elements?headername=Accept-Charset', h) + self.assertStatus('200 OK') + self.assertBody('iso-8859-5\n' + 'unicode-1-1;q=0.8') + + # Test Accept-Encoding + h = [('Accept-Encoding', 'gzip;q=1.0, identity; q=0.5, *;q=0')] + self.getPage( + '/headerelements/get_elements?headername=Accept-Encoding', h) + self.assertStatus('200 OK') + self.assertBody('gzip;q=1.0\n' + 'identity;q=0.5\n' + '*;q=0') + + # Test Accept-Language + h = [('Accept-Language', 'da, en-gb;q=0.8, en;q=0.7')] + self.getPage( + '/headerelements/get_elements?headername=Accept-Language', h) + self.assertStatus('200 OK') + self.assertBody('da\n' + 'en-gb;q=0.8\n' + 'en;q=0.7') + + # Test malformed header parsing. See + # https://github.com/cherrypy/cherrypy/issues/763. + self.getPage('/headerelements/get_elements?headername=Content-Type', + # Note the illegal trailing ";" + headers=[('Content-Type', 'text/html; charset=utf-8;')]) + self.assertStatus(200) + self.assertBody('text/html;charset=utf-8') + + def test_repeated_headers(self): + # Test that two request headers are collapsed into one. + # See https://github.com/cherrypy/cherrypy/issues/542. + self.getPage('/headers/Accept-Charset', + headers=[('Accept-Charset', 'iso-8859-5'), + ('Accept-Charset', 'unicode-1-1;q=0.8')]) + self.assertBody('iso-8859-5, unicode-1-1;q=0.8') + + # Tests that each header only appears once, regardless of case. + self.getPage('/headers/doubledheaders') + self.assertBody('double header test') + hnames = [name.title() for name, val in self.headers] + for key in ['Content-Length', 'Content-Type', 'Date', + 'Expires', 'Location', 'Server']: + self.assertEqual(hnames.count(key), 1, self.headers) + + def test_encoded_headers(self): + # First, make sure the innards work like expected. + self.assertEqual( + httputil.decode_TEXT(ntou('=?utf-8?q?f=C3=BCr?=')), ntou('f\xfcr')) + + if cherrypy.server.protocol_version == 'HTTP/1.1': + # Test RFC-2047-encoded request and response header values + u = ntou('\u212bngstr\xf6m', 'escape') + c = ntou('=E2=84=ABngstr=C3=B6m') + self.getPage('/headers/ifmatch', + [('If-Match', ntou('=?utf-8?q?%s?=') % c)]) + # The body should be utf-8 encoded. + self.assertBody(b'\xe2\x84\xabngstr\xc3\xb6m') + # But the Etag header should be RFC-2047 encoded (binary) + self.assertHeader('ETag', ntou('=?utf-8?b?4oSrbmdzdHLDtm0=?=')) + + # Test a *LONG* RFC-2047-encoded request and response header value + self.getPage('/headers/ifmatch', + [('If-Match', ntou('=?utf-8?q?%s?=') % (c * 10))]) + self.assertBody(b'\xe2\x84\xabngstr\xc3\xb6m' * 10) + # Note: this is different output for Python3, but it decodes fine. + etag = self.assertHeader( + 'ETag', + '=?utf-8?b?4oSrbmdzdHLDtm3ihKtuZ3N0csO2beKEq25nc3Ryw7Zt' + '4oSrbmdzdHLDtm3ihKtuZ3N0csO2beKEq25nc3Ryw7Zt' + '4oSrbmdzdHLDtm3ihKtuZ3N0csO2beKEq25nc3Ryw7Zt' + '4oSrbmdzdHLDtm0=?=') + self.assertEqual(httputil.decode_TEXT(etag), u * 10) + + def test_header_presence(self): + # If we don't pass a Content-Type header, it should not be present + # in cherrypy.request.headers + self.getPage('/headers/Content-Type', + headers=[]) + self.assertStatus(500) + + # If Content-Type is present in the request, it should be present in + # cherrypy.request.headers + self.getPage('/headers/Content-Type', + headers=[('Content-type', 'application/json')]) + self.assertBody('application/json') + + def test_basic_HTTPMethods(self): + helper.webtest.methods_with_bodies = ('POST', 'PUT', 'PROPFIND', + 'PATCH') + + # Test that all defined HTTP methods work. + for m in defined_http_methods: + self.getPage('/method/', method=m) + + # HEAD requests should not return any body. + if m == 'HEAD': + self.assertBody('') + elif m == 'TRACE': + # Some HTTP servers (like modpy) have their own TRACE support + self.assertEqual(self.body[:5], b'TRACE') + else: + self.assertBody(m) + + # test of PATCH requests + # Request a PATCH method with a form-urlencoded body + self.getPage('/method/parameterized', method='PATCH', + body='data=on+top+of+other+things') + self.assertBody('on top of other things') + + # Request a PATCH method with a file body + b = 'one thing on top of another' + h = [('Content-Type', 'text/plain'), + ('Content-Length', str(len(b)))] + self.getPage('/method/request_body', headers=h, method='PATCH', body=b) + self.assertStatus(200) + self.assertBody(b) + + # Request a PATCH method with a file body but no Content-Type. + # See https://github.com/cherrypy/cherrypy/issues/790. + b = b'one thing on top of another' + self.persistent = True + try: + conn = self.HTTP_CONN + conn.putrequest('PATCH', '/method/request_body', skip_host=True) + conn.putheader('Host', self.HOST) + conn.putheader('Content-Length', str(len(b))) + conn.endheaders() + conn.send(b) + response = conn.response_class(conn.sock, method='PATCH') + response.begin() + self.assertEqual(response.status, 200) + self.body = response.read() + self.assertBody(b) + finally: + self.persistent = False + + # Request a PATCH method with no body whatsoever (not an empty one). + # See https://github.com/cherrypy/cherrypy/issues/650. + # Provide a C-T or webtest will provide one (and a C-L) for us. + h = [('Content-Type', 'text/plain')] + self.getPage('/method/reachable', headers=h, method='PATCH') + self.assertStatus(411) + + # HTTP PUT tests + # Request a PUT method with a form-urlencoded body + self.getPage('/method/parameterized', method='PUT', + body='data=on+top+of+other+things') + self.assertBody('on top of other things') + + # Request a PUT method with a file body + b = 'one thing on top of another' + h = [('Content-Type', 'text/plain'), + ('Content-Length', str(len(b)))] + self.getPage('/method/request_body', headers=h, method='PUT', body=b) + self.assertStatus(200) + self.assertBody(b) + + # Request a PUT method with a file body but no Content-Type. + # See https://github.com/cherrypy/cherrypy/issues/790. + b = b'one thing on top of another' + self.persistent = True + try: + conn = self.HTTP_CONN + conn.putrequest('PUT', '/method/request_body', skip_host=True) + conn.putheader('Host', self.HOST) + conn.putheader('Content-Length', str(len(b))) + conn.endheaders() + conn.send(b) + response = conn.response_class(conn.sock, method='PUT') + response.begin() + self.assertEqual(response.status, 200) + self.body = response.read() + self.assertBody(b) + finally: + self.persistent = False + + # Request a PUT method with no body whatsoever (not an empty one). + # See https://github.com/cherrypy/cherrypy/issues/650. + # Provide a C-T or webtest will provide one (and a C-L) for us. + h = [('Content-Type', 'text/plain')] + self.getPage('/method/reachable', headers=h, method='PUT') + self.assertStatus(411) + + # Request a custom method with a request body + b = ('<?xml version="1.0" encoding="utf-8" ?>\n\n' + '<propfind xmlns="DAV:"><prop><getlastmodified/>' + '</prop></propfind>') + h = [('Content-Type', 'text/xml'), + ('Content-Length', str(len(b)))] + self.getPage('/method/request_body', headers=h, + method='PROPFIND', body=b) + self.assertStatus(200) + self.assertBody(b) + + # Request a disallowed method + self.getPage('/method/', method='LINK') + self.assertStatus(405) + + # Request an unknown method + self.getPage('/method/', method='SEARCH') + self.assertStatus(501) + + # For method dispatchers: make sure that an HTTP method doesn't + # collide with a virtual path atom. If you build HTTP-method + # dispatching into the core, rewrite these handlers to use + # your dispatch idioms. + self.getPage('/divorce/get?ID=13') + self.assertBody('Divorce document 13: empty') + self.assertStatus(200) + self.getPage('/divorce/', method='GET') + self.assertBody('<h1>Choose your document</h1>\n<ul>\n</ul>') + self.assertStatus(200) + + def test_CONNECT_method(self): + self.persistent = True + try: + conn = self.HTTP_CONN + conn.request('CONNECT', 'created.example.com:3128') + response = conn.response_class(conn.sock, method='CONNECT') + response.begin() + self.assertEqual(response.status, 204) + finally: + self.persistent = False + + self.persistent = True + try: + conn = self.HTTP_CONN + conn.request('CONNECT', 'body.example.com:3128') + response = conn.response_class(conn.sock, method='CONNECT') + response.begin() + self.assertEqual(response.status, 200) + self.body = response.read() + self.assertBody(b'CONNECTed to /body.example.com:3128') + finally: + self.persistent = False + + def test_CONNECT_method_invalid_authority(self): + for request_target in ['example.com', 'http://example.com:33', + '/path/', 'path/', '/?q=f', '#f']: + self.persistent = True + try: + conn = self.HTTP_CONN + conn.request('CONNECT', request_target) + response = conn.response_class(conn.sock, method='CONNECT') + response.begin() + self.assertEqual(response.status, 400) + self.body = response.read() + self.assertBody(b'Invalid path in Request-URI: request-target ' + b'must match authority-form.') + finally: + self.persistent = False + + def testEmptyThreadlocals(self): + results = [] + for x in range(20): + self.getPage('/threadlocal/') + results.append(self.body) + self.assertEqual(results, [b'None'] * 20) diff --git a/libraries/cherrypy/test/test_routes.py b/libraries/cherrypy/test/test_routes.py new file mode 100644 index 00000000..cc714765 --- /dev/null +++ b/libraries/cherrypy/test/test_routes.py @@ -0,0 +1,80 @@ +"""Test Routes dispatcher.""" +import os +import importlib + +import pytest + +import cherrypy +from cherrypy.test import helper + +curdir = os.path.join(os.getcwd(), os.path.dirname(__file__)) + + +class RoutesDispatchTest(helper.CPWebCase): + """Routes dispatcher test suite.""" + + @staticmethod + def setup_server(): + """Set up cherrypy test instance.""" + try: + importlib.import_module('routes') + except ImportError: + pytest.skip('Install routes to test RoutesDispatcher code') + + class Dummy: + + def index(self): + return 'I said good day!' + + class City: + + def __init__(self, name): + self.name = name + self.population = 10000 + + @cherrypy.config(**{ + 'tools.response_headers.on': True, + 'tools.response_headers.headers': [ + ('Content-Language', 'en-GB'), + ], + }) + def index(self, **kwargs): + return 'Welcome to %s, pop. %s' % (self.name, self.population) + + def update(self, **kwargs): + self.population = kwargs['pop'] + return 'OK' + + d = cherrypy.dispatch.RoutesDispatcher() + d.connect(action='index', name='hounslow', route='/hounslow', + controller=City('Hounslow')) + d.connect( + name='surbiton', route='/surbiton', controller=City('Surbiton'), + action='index', conditions=dict(method=['GET'])) + d.mapper.connect('/surbiton', controller='surbiton', + action='update', conditions=dict(method=['POST'])) + d.connect('main', ':action', controller=Dummy()) + + conf = {'/': {'request.dispatch': d}} + cherrypy.tree.mount(root=None, config=conf) + + def test_Routes_Dispatch(self): + """Check that routes package based URI dispatching works correctly.""" + self.getPage('/hounslow') + self.assertStatus('200 OK') + self.assertBody('Welcome to Hounslow, pop. 10000') + + self.getPage('/foo') + self.assertStatus('404 Not Found') + + self.getPage('/surbiton') + self.assertStatus('200 OK') + self.assertBody('Welcome to Surbiton, pop. 10000') + + self.getPage('/surbiton', method='POST', body='pop=1327') + self.assertStatus('200 OK') + self.assertBody('OK') + self.getPage('/surbiton') + self.assertStatus('200 OK') + self.assertHeader('Content-Language', 'en-GB') + self.assertBody('Welcome to Surbiton, pop. 1327') diff --git a/libraries/cherrypy/test/test_session.py b/libraries/cherrypy/test/test_session.py new file mode 100644 index 00000000..0083c97c --- /dev/null +++ b/libraries/cherrypy/test/test_session.py @@ -0,0 +1,512 @@ +import os +import threading +import time +import socket +import importlib + +from six.moves.http_client import HTTPConnection + +import pytest +from path import Path + +import cherrypy +from cherrypy._cpcompat import ( + json_decode, + HTTPSConnection, +) +from cherrypy.lib import sessions +from cherrypy.lib import reprconf +from cherrypy.lib.httputil import response_codes +from cherrypy.test import helper + +localDir = os.path.dirname(__file__) + + +def http_methods_allowed(methods=['GET', 'HEAD']): + method = cherrypy.request.method.upper() + if method not in methods: + cherrypy.response.headers['Allow'] = ', '.join(methods) + raise cherrypy.HTTPError(405) + + +cherrypy.tools.allow = cherrypy.Tool('on_start_resource', http_methods_allowed) + + +def setup_server(): + + @cherrypy.config(**{ + 'tools.sessions.on': True, + 'tools.sessions.storage_class': sessions.RamSession, + 'tools.sessions.storage_path': localDir, + 'tools.sessions.timeout': (1.0 / 60), + 'tools.sessions.clean_freq': (1.0 / 60), + }) + class Root: + + @cherrypy.expose + def clear(self): + cherrypy.session.cache.clear() + + @cherrypy.expose + def data(self): + cherrypy.session['aha'] = 'foo' + return repr(cherrypy.session._data) + + @cherrypy.expose + def testGen(self): + counter = cherrypy.session.get('counter', 0) + 1 + cherrypy.session['counter'] = counter + yield str(counter) + + @cherrypy.expose + def testStr(self): + counter = cherrypy.session.get('counter', 0) + 1 + cherrypy.session['counter'] = counter + return str(counter) + + @cherrypy.expose + @cherrypy.config(**{'tools.sessions.on': False}) + def set_session_cls(self, new_cls_name): + new_cls = reprconf.attributes(new_cls_name) + cfg = {'tools.sessions.storage_class': new_cls} + self.__class__._cp_config.update(cfg) + if hasattr(cherrypy, 'session'): + del cherrypy.session + if new_cls.clean_thread: + new_cls.clean_thread.stop() + new_cls.clean_thread.unsubscribe() + del new_cls.clean_thread + + @cherrypy.expose + def index(self): + sess = cherrypy.session + c = sess.get('counter', 0) + 1 + time.sleep(0.01) + sess['counter'] = c + return str(c) + + @cherrypy.expose + def keyin(self, key): + return str(key in cherrypy.session) + + @cherrypy.expose + def delete(self): + cherrypy.session.delete() + sessions.expire() + return 'done' + + @cherrypy.expose + def delkey(self, key): + del cherrypy.session[key] + return 'OK' + + @cherrypy.expose + def redir_target(self): + return self._cp_config['tools.sessions.storage_class'].__name__ + + @cherrypy.expose + def iredir(self): + raise cherrypy.InternalRedirect('/redir_target') + + @cherrypy.expose + @cherrypy.config(**{ + 'tools.allow.on': True, + 'tools.allow.methods': ['GET'], + }) + def restricted(self): + return cherrypy.request.method + + @cherrypy.expose + def regen(self): + cherrypy.tools.sessions.regenerate() + return 'logged in' + + @cherrypy.expose + def length(self): + return str(len(cherrypy.session)) + + @cherrypy.expose + @cherrypy.config(**{ + 'tools.sessions.path': '/session_cookie', + 'tools.sessions.name': 'temp', + 'tools.sessions.persistent': False, + }) + def session_cookie(self): + # Must load() to start the clean thread. + cherrypy.session.load() + return cherrypy.session.id + + cherrypy.tree.mount(Root()) + + +class SessionTest(helper.CPWebCase): + setup_server = staticmethod(setup_server) + + def tearDown(self): + # Clean up sessions. + for fname in os.listdir(localDir): + if fname.startswith(sessions.FileSession.SESSION_PREFIX): + path = Path(localDir) / fname + path.remove_p() + + @pytest.mark.xfail(reason='#1534') + def test_0_Session(self): + self.getPage('/set_session_cls/cherrypy.lib.sessions.RamSession') + self.getPage('/clear') + + # Test that a normal request gets the same id in the cookies. + # Note: this wouldn't work if /data didn't load the session. + self.getPage('/data') + self.assertBody("{'aha': 'foo'}") + c = self.cookies[0] + self.getPage('/data', self.cookies) + self.assertEqual(self.cookies[0], c) + + self.getPage('/testStr') + self.assertBody('1') + cookie_parts = dict([p.strip().split('=') + for p in self.cookies[0][1].split(';')]) + # Assert there is an 'expires' param + self.assertEqual(set(cookie_parts.keys()), + set(['session_id', 'expires', 'Path'])) + self.getPage('/testGen', self.cookies) + self.assertBody('2') + self.getPage('/testStr', self.cookies) + self.assertBody('3') + self.getPage('/data', self.cookies) + self.assertDictEqual(json_decode(self.body), + {'counter': 3, 'aha': 'foo'}) + self.getPage('/length', self.cookies) + self.assertBody('2') + self.getPage('/delkey?key=counter', self.cookies) + self.assertStatus(200) + + self.getPage('/set_session_cls/cherrypy.lib.sessions.FileSession') + self.getPage('/testStr') + self.assertBody('1') + self.getPage('/testGen', self.cookies) + self.assertBody('2') + self.getPage('/testStr', self.cookies) + self.assertBody('3') + self.getPage('/delkey?key=counter', self.cookies) + self.assertStatus(200) + + # Wait for the session.timeout (1 second) + time.sleep(2) + self.getPage('/') + self.assertBody('1') + self.getPage('/length', self.cookies) + self.assertBody('1') + + # Test session __contains__ + self.getPage('/keyin?key=counter', self.cookies) + self.assertBody('True') + cookieset1 = self.cookies + + # Make a new session and test __len__ again + self.getPage('/') + self.getPage('/length', self.cookies) + self.assertBody('2') + + # Test session delete + self.getPage('/delete', self.cookies) + self.assertBody('done') + self.getPage('/delete', cookieset1) + self.assertBody('done') + + def f(): + return [ + x + for x in os.listdir(localDir) + if x.startswith('session-') + ] + self.assertEqual(f(), []) + + # Wait for the cleanup thread to delete remaining session files + self.getPage('/') + self.assertNotEqual(f(), []) + time.sleep(2) + self.assertEqual(f(), []) + + def test_1_Ram_Concurrency(self): + self.getPage('/set_session_cls/cherrypy.lib.sessions.RamSession') + self._test_Concurrency() + + @pytest.mark.xfail(reason='#1306') + def test_2_File_Concurrency(self): + self.getPage('/set_session_cls/cherrypy.lib.sessions.FileSession') + self._test_Concurrency() + + def _test_Concurrency(self): + client_thread_count = 5 + request_count = 30 + + # Get initial cookie + self.getPage('/') + self.assertBody('1') + cookies = self.cookies + + data_dict = {} + errors = [] + + def request(index): + if self.scheme == 'https': + c = HTTPSConnection('%s:%s' % (self.interface(), self.PORT)) + else: + c = HTTPConnection('%s:%s' % (self.interface(), self.PORT)) + for i in range(request_count): + c.putrequest('GET', '/') + for k, v in cookies: + c.putheader(k, v) + c.endheaders() + response = c.getresponse() + body = response.read() + if response.status != 200 or not body.isdigit(): + errors.append((response.status, body)) + else: + data_dict[index] = max(data_dict[index], int(body)) + # Uncomment the following line to prove threads overlap. + # sys.stdout.write("%d " % index) + + # Start <request_count> requests from each of + # <client_thread_count> concurrent clients + ts = [] + for c in range(client_thread_count): + data_dict[c] = 0 + t = threading.Thread(target=request, args=(c,)) + ts.append(t) + t.start() + + for t in ts: + t.join() + + hitcount = max(data_dict.values()) + expected = 1 + (client_thread_count * request_count) + + for e in errors: + print(e) + self.assertEqual(hitcount, expected) + + def test_3_Redirect(self): + # Start a new session + self.getPage('/testStr') + self.getPage('/iredir', self.cookies) + self.assertBody('FileSession') + + def test_4_File_deletion(self): + # Start a new session + self.getPage('/testStr') + # Delete the session file manually and retry. + id = self.cookies[0][1].split(';', 1)[0].split('=', 1)[1] + path = os.path.join(localDir, 'session-' + id) + os.unlink(path) + self.getPage('/testStr', self.cookies) + + def test_5_Error_paths(self): + self.getPage('/unknown/page') + self.assertErrorPage(404, "The path '/unknown/page' was not found.") + + # Note: this path is *not* the same as above. The above + # takes a normal route through the session code; this one + # skips the session code's before_handler and only calls + # before_finalize (save) and on_end (close). So the session + # code has to survive calling save/close without init. + self.getPage('/restricted', self.cookies, method='POST') + self.assertErrorPage(405, response_codes[405][1]) + + def test_6_regenerate(self): + self.getPage('/testStr') + # grab the cookie ID + id1 = self.cookies[0][1].split(';', 1)[0].split('=', 1)[1] + self.getPage('/regen') + self.assertBody('logged in') + id2 = self.cookies[0][1].split(';', 1)[0].split('=', 1)[1] + self.assertNotEqual(id1, id2) + + self.getPage('/testStr') + # grab the cookie ID + id1 = self.cookies[0][1].split(';', 1)[0].split('=', 1)[1] + self.getPage('/testStr', + headers=[ + ('Cookie', + 'session_id=maliciousid; ' + 'expires=Sat, 27 Oct 2017 04:18:28 GMT; Path=/;')]) + id2 = self.cookies[0][1].split(';', 1)[0].split('=', 1)[1] + self.assertNotEqual(id1, id2) + self.assertNotEqual(id2, 'maliciousid') + + def test_7_session_cookies(self): + self.getPage('/set_session_cls/cherrypy.lib.sessions.RamSession') + self.getPage('/clear') + self.getPage('/session_cookie') + # grab the cookie ID + cookie_parts = dict([p.strip().split('=') + for p in self.cookies[0][1].split(';')]) + # Assert there is no 'expires' param + self.assertEqual(set(cookie_parts.keys()), set(['temp', 'Path'])) + id1 = cookie_parts['temp'] + self.assertEqual(list(sessions.RamSession.cache), [id1]) + + # Send another request in the same "browser session". + self.getPage('/session_cookie', self.cookies) + cookie_parts = dict([p.strip().split('=') + for p in self.cookies[0][1].split(';')]) + # Assert there is no 'expires' param + self.assertEqual(set(cookie_parts.keys()), set(['temp', 'Path'])) + self.assertBody(id1) + self.assertEqual(list(sessions.RamSession.cache), [id1]) + + # Simulate a browser close by just not sending the cookies + self.getPage('/session_cookie') + # grab the cookie ID + cookie_parts = dict([p.strip().split('=') + for p in self.cookies[0][1].split(';')]) + # Assert there is no 'expires' param + self.assertEqual(set(cookie_parts.keys()), set(['temp', 'Path'])) + # Assert a new id has been generated... + id2 = cookie_parts['temp'] + self.assertNotEqual(id1, id2) + self.assertEqual(set(sessions.RamSession.cache.keys()), + set([id1, id2])) + + # Wait for the session.timeout on both sessions + time.sleep(2.5) + cache = list(sessions.RamSession.cache) + if cache: + if cache == [id2]: + self.fail('The second session did not time out.') + else: + self.fail('Unknown session id in cache: %r', cache) + + def test_8_Ram_Cleanup(self): + def lock(): + s1 = sessions.RamSession() + s1.acquire_lock() + time.sleep(1) + s1.release_lock() + + t = threading.Thread(target=lock) + t.start() + start = time.time() + while not sessions.RamSession.locks and time.time() - start < 5: + time.sleep(0.01) + assert len(sessions.RamSession.locks) == 1, 'Lock not acquired' + s2 = sessions.RamSession() + s2.clean_up() + msg = 'Clean up should not remove active lock' + assert len(sessions.RamSession.locks) == 1, msg + t.join() + + +try: + importlib.import_module('memcache') + + host, port = '127.0.0.1', 11211 + for res in socket.getaddrinfo(host, port, socket.AF_UNSPEC, + socket.SOCK_STREAM): + af, socktype, proto, canonname, sa = res + s = None + try: + s = socket.socket(af, socktype, proto) + # See http://groups.google.com/group/cherrypy-users/ + # browse_frm/thread/bbfe5eb39c904fe0 + s.settimeout(1.0) + s.connect((host, port)) + s.close() + except socket.error: + if s: + s.close() + raise + break +except (ImportError, socket.error): + class MemcachedSessionTest(helper.CPWebCase): + setup_server = staticmethod(setup_server) + + def test(self): + return self.skip('memcached not reachable ') +else: + class MemcachedSessionTest(helper.CPWebCase): + setup_server = staticmethod(setup_server) + + def test_0_Session(self): + self.getPage('/set_session_cls/cherrypy.Sessions.MemcachedSession') + + self.getPage('/testStr') + self.assertBody('1') + self.getPage('/testGen', self.cookies) + self.assertBody('2') + self.getPage('/testStr', self.cookies) + self.assertBody('3') + self.getPage('/length', self.cookies) + self.assertErrorPage(500) + self.assertInBody('NotImplementedError') + self.getPage('/delkey?key=counter', self.cookies) + self.assertStatus(200) + + # Wait for the session.timeout (1 second) + time.sleep(1.25) + self.getPage('/') + self.assertBody('1') + + # Test session __contains__ + self.getPage('/keyin?key=counter', self.cookies) + self.assertBody('True') + + # Test session delete + self.getPage('/delete', self.cookies) + self.assertBody('done') + + def test_1_Concurrency(self): + client_thread_count = 5 + request_count = 30 + + # Get initial cookie + self.getPage('/') + self.assertBody('1') + cookies = self.cookies + + data_dict = {} + + def request(index): + for i in range(request_count): + self.getPage('/', cookies) + # Uncomment the following line to prove threads overlap. + # sys.stdout.write("%d " % index) + if not self.body.isdigit(): + self.fail(self.body) + data_dict[index] = int(self.body) + + # Start <request_count> concurrent requests from + # each of <client_thread_count> clients + ts = [] + for c in range(client_thread_count): + data_dict[c] = 0 + t = threading.Thread(target=request, args=(c,)) + ts.append(t) + t.start() + + for t in ts: + t.join() + + hitcount = max(data_dict.values()) + expected = 1 + (client_thread_count * request_count) + self.assertEqual(hitcount, expected) + + def test_3_Redirect(self): + # Start a new session + self.getPage('/testStr') + self.getPage('/iredir', self.cookies) + self.assertBody('memcached') + + def test_5_Error_paths(self): + self.getPage('/unknown/page') + self.assertErrorPage( + 404, "The path '/unknown/page' was not found.") + + # Note: this path is *not* the same as above. The above + # takes a normal route through the session code; this one + # skips the session code's before_handler and only calls + # before_finalize (save) and on_end (close). So the session + # code has to survive calling save/close without init. + self.getPage('/restricted', self.cookies, method='POST') + self.assertErrorPage(405, response_codes[405][1]) diff --git a/libraries/cherrypy/test/test_sessionauthenticate.py b/libraries/cherrypy/test/test_sessionauthenticate.py new file mode 100644 index 00000000..63053fcb --- /dev/null +++ b/libraries/cherrypy/test/test_sessionauthenticate.py @@ -0,0 +1,61 @@ +import cherrypy +from cherrypy.test import helper + + +class SessionAuthenticateTest(helper.CPWebCase): + + @staticmethod + def setup_server(): + + def check(username, password): + # Dummy check_username_and_password function + if username != 'test' or password != 'password': + return 'Wrong login/password' + + def augment_params(): + # A simple tool to add some things to request.params + # This is to check to make sure that session_auth can handle + # request params (ticket #780) + cherrypy.request.params['test'] = 'test' + + cherrypy.tools.augment_params = cherrypy.Tool( + 'before_handler', augment_params, None, priority=30) + + class Test: + + _cp_config = { + 'tools.sessions.on': True, + 'tools.session_auth.on': True, + 'tools.session_auth.check_username_and_password': check, + 'tools.augment_params.on': True, + } + + @cherrypy.expose + def index(self, **kwargs): + return 'Hi %s, you are logged in' % cherrypy.request.login + + cherrypy.tree.mount(Test()) + + def testSessionAuthenticate(self): + # request a page and check for login form + self.getPage('/') + self.assertInBody('<form method="post" action="do_login">') + + # setup credentials + login_body = 'username=test&password=password&from_page=/' + + # attempt a login + self.getPage('/do_login', method='POST', body=login_body) + self.assertStatus((302, 303)) + + # get the page now that we are logged in + self.getPage('/', self.cookies) + self.assertBody('Hi test, you are logged in') + + # do a logout + self.getPage('/do_logout', self.cookies, method='POST') + self.assertStatus((302, 303)) + + # verify we are logged out + self.getPage('/', self.cookies) + self.assertInBody('<form method="post" action="do_login">') diff --git a/libraries/cherrypy/test/test_states.py b/libraries/cherrypy/test/test_states.py new file mode 100644 index 00000000..606ca4f6 --- /dev/null +++ b/libraries/cherrypy/test/test_states.py @@ -0,0 +1,473 @@ +import os +import signal +import time +import unittest +import warnings + +from six.moves.http_client import BadStatusLine + +import pytest +import portend + +import cherrypy +import cherrypy.process.servers +from cherrypy.test import helper + +engine = cherrypy.engine +thisdir = os.path.join(os.getcwd(), os.path.dirname(__file__)) + + +class Dependency: + + def __init__(self, bus): + self.bus = bus + self.running = False + self.startcount = 0 + self.gracecount = 0 + self.threads = {} + + def subscribe(self): + self.bus.subscribe('start', self.start) + self.bus.subscribe('stop', self.stop) + self.bus.subscribe('graceful', self.graceful) + self.bus.subscribe('start_thread', self.startthread) + self.bus.subscribe('stop_thread', self.stopthread) + + def start(self): + self.running = True + self.startcount += 1 + + def stop(self): + self.running = False + + def graceful(self): + self.gracecount += 1 + + def startthread(self, thread_id): + self.threads[thread_id] = None + + def stopthread(self, thread_id): + del self.threads[thread_id] + + +db_connection = Dependency(engine) + + +def setup_server(): + class Root: + + @cherrypy.expose + def index(self): + return 'Hello World' + + @cherrypy.expose + def ctrlc(self): + raise KeyboardInterrupt() + + @cherrypy.expose + def graceful(self): + engine.graceful() + return 'app was (gracefully) restarted succesfully' + + cherrypy.tree.mount(Root()) + cherrypy.config.update({ + 'environment': 'test_suite', + }) + + db_connection.subscribe() + +# ------------ Enough helpers. Time for real live test cases. ------------ # + + +class ServerStateTests(helper.CPWebCase): + setup_server = staticmethod(setup_server) + + def setUp(self): + cherrypy.server.socket_timeout = 0.1 + self.do_gc_test = False + + def test_0_NormalStateFlow(self): + engine.stop() + # Our db_connection should not be running + self.assertEqual(db_connection.running, False) + self.assertEqual(db_connection.startcount, 1) + self.assertEqual(len(db_connection.threads), 0) + + # Test server start + engine.start() + self.assertEqual(engine.state, engine.states.STARTED) + + host = cherrypy.server.socket_host + port = cherrypy.server.socket_port + portend.occupied(host, port, timeout=0.1) + + # The db_connection should be running now + self.assertEqual(db_connection.running, True) + self.assertEqual(db_connection.startcount, 2) + self.assertEqual(len(db_connection.threads), 0) + + self.getPage('/') + self.assertBody('Hello World') + self.assertEqual(len(db_connection.threads), 1) + + # Test engine stop. This will also stop the HTTP server. + engine.stop() + self.assertEqual(engine.state, engine.states.STOPPED) + + # Verify that our custom stop function was called + self.assertEqual(db_connection.running, False) + self.assertEqual(len(db_connection.threads), 0) + + # Block the main thread now and verify that exit() works. + def exittest(): + self.getPage('/') + self.assertBody('Hello World') + engine.exit() + cherrypy.server.start() + engine.start_with_callback(exittest) + engine.block() + self.assertEqual(engine.state, engine.states.EXITING) + + def test_1_Restart(self): + cherrypy.server.start() + engine.start() + + # The db_connection should be running now + self.assertEqual(db_connection.running, True) + grace = db_connection.gracecount + + self.getPage('/') + self.assertBody('Hello World') + self.assertEqual(len(db_connection.threads), 1) + + # Test server restart from this thread + engine.graceful() + self.assertEqual(engine.state, engine.states.STARTED) + self.getPage('/') + self.assertBody('Hello World') + self.assertEqual(db_connection.running, True) + self.assertEqual(db_connection.gracecount, grace + 1) + self.assertEqual(len(db_connection.threads), 1) + + # Test server restart from inside a page handler + self.getPage('/graceful') + self.assertEqual(engine.state, engine.states.STARTED) + self.assertBody('app was (gracefully) restarted succesfully') + self.assertEqual(db_connection.running, True) + self.assertEqual(db_connection.gracecount, grace + 2) + # Since we are requesting synchronously, is only one thread used? + # Note that the "/graceful" request has been flushed. + self.assertEqual(len(db_connection.threads), 0) + + engine.stop() + self.assertEqual(engine.state, engine.states.STOPPED) + self.assertEqual(db_connection.running, False) + self.assertEqual(len(db_connection.threads), 0) + + def test_2_KeyboardInterrupt(self): + # Raise a keyboard interrupt in the HTTP server's main thread. + # We must start the server in this, the main thread + engine.start() + cherrypy.server.start() + + self.persistent = True + try: + # Make the first request and assert there's no "Connection: close". + self.getPage('/') + self.assertStatus('200 OK') + self.assertBody('Hello World') + self.assertNoHeader('Connection') + + cherrypy.server.httpserver.interrupt = KeyboardInterrupt + engine.block() + + self.assertEqual(db_connection.running, False) + self.assertEqual(len(db_connection.threads), 0) + self.assertEqual(engine.state, engine.states.EXITING) + finally: + self.persistent = False + + # Raise a keyboard interrupt in a page handler; on multithreaded + # servers, this should occur in one of the worker threads. + # This should raise a BadStatusLine error, since the worker + # thread will just die without writing a response. + engine.start() + cherrypy.server.start() + # From python3.5 a new exception is retuned when the connection + # ends abruptly: + # http.client.RemoteDisconnected + # RemoteDisconnected is a subclass of: + # (ConnectionResetError, http.client.BadStatusLine) + # and ConnectionResetError is an indirect subclass of: + # OSError + # From python 3.3 an up socket.error is an alias to OSError + # following PEP-3151, therefore http.client.RemoteDisconnected + # is considered a socket.error. + # + # raise_subcls specifies the classes that are not going + # to be considered as a socket.error for the retries. + # Given that RemoteDisconnected is part BadStatusLine + # we can use the same call for all py3 versions without + # sideffects. python < 3.5 will raise directly BadStatusLine + # which is not a subclass for socket.error/OSError. + try: + self.getPage('/ctrlc', raise_subcls=BadStatusLine) + except BadStatusLine: + pass + else: + print(self.body) + self.fail('AssertionError: BadStatusLine not raised') + + engine.block() + self.assertEqual(db_connection.running, False) + self.assertEqual(len(db_connection.threads), 0) + + @pytest.mark.xfail( + 'sys.platform == "Darwin" ' + 'and sys.version_info > (3, 7) ' + 'and os.environ["TRAVIS"]', + reason='https://github.com/cherrypy/cherrypy/issues/1693', + ) + def test_4_Autoreload(self): + # If test_3 has not been executed, the server won't be stopped, + # so we'll have to do it. + if engine.state != engine.states.EXITING: + engine.exit() + + # Start the demo script in a new process + p = helper.CPProcess(ssl=(self.scheme.lower() == 'https')) + p.write_conf(extra='test_case_name: "test_4_Autoreload"') + p.start(imports='cherrypy.test._test_states_demo') + try: + self.getPage('/start') + start = float(self.body) + + # Give the autoreloader time to cache the file time. + time.sleep(2) + + # Touch the file + os.utime(os.path.join(thisdir, '_test_states_demo.py'), None) + + # Give the autoreloader time to re-exec the process + time.sleep(2) + host = cherrypy.server.socket_host + port = cherrypy.server.socket_port + portend.occupied(host, port, timeout=5) + + self.getPage('/start') + if not (float(self.body) > start): + raise AssertionError('start time %s not greater than %s' % + (float(self.body), start)) + finally: + # Shut down the spawned process + self.getPage('/exit') + p.join() + + def test_5_Start_Error(self): + # If test_3 has not been executed, the server won't be stopped, + # so we'll have to do it. + if engine.state != engine.states.EXITING: + engine.exit() + + # If a process errors during start, it should stop the engine + # and exit with a non-zero exit code. + p = helper.CPProcess(ssl=(self.scheme.lower() == 'https'), + wait=True) + p.write_conf( + extra="""starterror: True +test_case_name: "test_5_Start_Error" +""" + ) + p.start(imports='cherrypy.test._test_states_demo') + if p.exit_code == 0: + self.fail('Process failed to return nonzero exit code.') + + +class PluginTests(helper.CPWebCase): + + def test_daemonize(self): + if os.name not in ['posix']: + return self.skip('skipped (not on posix) ') + self.HOST = '127.0.0.1' + self.PORT = 8081 + # Spawn the process and wait, when this returns, the original process + # is finished. If it daemonized properly, we should still be able + # to access pages. + p = helper.CPProcess(ssl=(self.scheme.lower() == 'https'), + wait=True, daemonize=True, + socket_host='127.0.0.1', + socket_port=8081) + p.write_conf( + extra='test_case_name: "test_daemonize"') + p.start(imports='cherrypy.test._test_states_demo') + try: + # Just get the pid of the daemonization process. + self.getPage('/pid') + self.assertStatus(200) + page_pid = int(self.body) + self.assertEqual(page_pid, p.get_pid()) + finally: + # Shut down the spawned process + self.getPage('/exit') + p.join() + + # Wait until here to test the exit code because we want to ensure + # that we wait for the daemon to finish running before we fail. + if p.exit_code != 0: + self.fail('Daemonized parent process failed to exit cleanly.') + + +class SignalHandlingTests(helper.CPWebCase): + + def test_SIGHUP_tty(self): + # When not daemonized, SIGHUP should shut down the server. + try: + from signal import SIGHUP + except ImportError: + return self.skip('skipped (no SIGHUP) ') + + # Spawn the process. + p = helper.CPProcess(ssl=(self.scheme.lower() == 'https')) + p.write_conf( + extra='test_case_name: "test_SIGHUP_tty"') + p.start(imports='cherrypy.test._test_states_demo') + # Send a SIGHUP + os.kill(p.get_pid(), SIGHUP) + # This might hang if things aren't working right, but meh. + p.join() + + def test_SIGHUP_daemonized(self): + # When daemonized, SIGHUP should restart the server. + try: + from signal import SIGHUP + except ImportError: + return self.skip('skipped (no SIGHUP) ') + + if os.name not in ['posix']: + return self.skip('skipped (not on posix) ') + + # Spawn the process and wait, when this returns, the original process + # is finished. If it daemonized properly, we should still be able + # to access pages. + p = helper.CPProcess(ssl=(self.scheme.lower() == 'https'), + wait=True, daemonize=True) + p.write_conf( + extra='test_case_name: "test_SIGHUP_daemonized"') + p.start(imports='cherrypy.test._test_states_demo') + + pid = p.get_pid() + try: + # Send a SIGHUP + os.kill(pid, SIGHUP) + # Give the server some time to restart + time.sleep(2) + self.getPage('/pid') + self.assertStatus(200) + new_pid = int(self.body) + self.assertNotEqual(new_pid, pid) + finally: + # Shut down the spawned process + self.getPage('/exit') + p.join() + + def _require_signal_and_kill(self, signal_name): + if not hasattr(signal, signal_name): + self.skip('skipped (no %(signal_name)s)' % vars()) + + if not hasattr(os, 'kill'): + self.skip('skipped (no os.kill)') + + def test_SIGTERM(self): + 'SIGTERM should shut down the server whether daemonized or not.' + self._require_signal_and_kill('SIGTERM') + + # Spawn a normal, undaemonized process. + p = helper.CPProcess(ssl=(self.scheme.lower() == 'https')) + p.write_conf( + extra='test_case_name: "test_SIGTERM"') + p.start(imports='cherrypy.test._test_states_demo') + # Send a SIGTERM + os.kill(p.get_pid(), signal.SIGTERM) + # This might hang if things aren't working right, but meh. + p.join() + + if os.name in ['posix']: + # Spawn a daemonized process and test again. + p = helper.CPProcess(ssl=(self.scheme.lower() == 'https'), + wait=True, daemonize=True) + p.write_conf( + extra='test_case_name: "test_SIGTERM_2"') + p.start(imports='cherrypy.test._test_states_demo') + # Send a SIGTERM + os.kill(p.get_pid(), signal.SIGTERM) + # This might hang if things aren't working right, but meh. + p.join() + + def test_signal_handler_unsubscribe(self): + self._require_signal_and_kill('SIGTERM') + + # Although Windows has `os.kill` and SIGTERM is defined, the + # platform does not implement signals and sending SIGTERM + # will result in a forced termination of the process. + # Therefore, this test is not suitable for Windows. + if os.name == 'nt': + self.skip('SIGTERM not available') + + # Spawn a normal, undaemonized process. + p = helper.CPProcess(ssl=(self.scheme.lower() == 'https')) + p.write_conf( + extra="""unsubsig: True +test_case_name: "test_signal_handler_unsubscribe" +""") + p.start(imports='cherrypy.test._test_states_demo') + # Ask the process to quit + os.kill(p.get_pid(), signal.SIGTERM) + # This might hang if things aren't working right, but meh. + p.join() + + # Assert the old handler ran. + log_lines = list(open(p.error_log, 'rb')) + assert any( + line.endswith(b'I am an old SIGTERM handler.\n') + for line in log_lines + ) + + +class WaitTests(unittest.TestCase): + + def test_safe_wait_INADDR_ANY(self): + """ + Wait on INADDR_ANY should not raise IOError + + In cases where the loopback interface does not exist, CherryPy cannot + effectively determine if a port binding to INADDR_ANY was effected. + In this situation, CherryPy should assume that it failed to detect + the binding (not that the binding failed) and only warn that it could + not verify it. + """ + # At such a time that CherryPy can reliably determine one or more + # viable IP addresses of the host, this test may be removed. + + # Simulate the behavior we observe when no loopback interface is + # present by: finding a port that's not occupied, then wait on it. + + free_port = portend.find_available_local_port() + + servers = cherrypy.process.servers + + inaddr_any = '0.0.0.0' + + # Wait on the free port that's unbound + with warnings.catch_warnings(record=True) as w: + with servers._safe_wait(inaddr_any, free_port): + portend.occupied(inaddr_any, free_port, timeout=1) + self.assertEqual(len(w), 1) + self.assertTrue(isinstance(w[0], warnings.WarningMessage)) + self.assertTrue( + 'Unable to verify that the server is bound on ' in str(w[0])) + + # The wait should still raise an IO error if INADDR_ANY was + # not supplied. + with pytest.raises(IOError): + with servers._safe_wait('127.0.0.1', free_port): + portend.occupied('127.0.0.1', free_port, timeout=1) diff --git a/libraries/cherrypy/test/test_static.py b/libraries/cherrypy/test/test_static.py new file mode 100644 index 00000000..5dc5a144 --- /dev/null +++ b/libraries/cherrypy/test/test_static.py @@ -0,0 +1,434 @@ +# -*- coding: utf-8 -*- +import contextlib +import io +import os +import sys +import platform +import tempfile + +from six import text_type as str +from six.moves import urllib +from six.moves.http_client import HTTPConnection + +import pytest +import py.path + +import cherrypy +from cherrypy.lib import static +from cherrypy._cpcompat import HTTPSConnection, ntou, tonative +from cherrypy.test import helper + + +@pytest.fixture +def unicode_filesystem(tmpdir): + filename = tmpdir / ntou('☃', 'utf-8') + tmpl = 'File system encoding ({encoding}) cannot support unicode filenames' + msg = tmpl.format(encoding=sys.getfilesystemencoding()) + try: + io.open(str(filename), 'w').close() + except UnicodeEncodeError: + pytest.skip(msg) + + +def ensure_unicode_filesystem(): + """ + TODO: replace with simply pytest fixtures once webtest.TestCase + no longer implies unittest. + """ + tmpdir = py.path.local(tempfile.mkdtemp()) + try: + unicode_filesystem(tmpdir) + finally: + tmpdir.remove() + + +curdir = os.path.join(os.getcwd(), os.path.dirname(__file__)) +has_space_filepath = os.path.join(curdir, 'static', 'has space.html') +bigfile_filepath = os.path.join(curdir, 'static', 'bigfile.log') + +# The file size needs to be big enough such that half the size of it +# won't be socket-buffered (or server-buffered) all in one go. See +# test_file_stream. +MB = 2 ** 20 +BIGFILE_SIZE = 32 * MB + + +class StaticTest(helper.CPWebCase): + + @staticmethod + def setup_server(): + if not os.path.exists(has_space_filepath): + with open(has_space_filepath, 'wb') as f: + f.write(b'Hello, world\r\n') + needs_bigfile = ( + not os.path.exists(bigfile_filepath) or + os.path.getsize(bigfile_filepath) != BIGFILE_SIZE + ) + if needs_bigfile: + with open(bigfile_filepath, 'wb') as f: + f.write(b'x' * BIGFILE_SIZE) + + class Root: + + @cherrypy.expose + @cherrypy.config(**{'response.stream': True}) + def bigfile(self): + self.f = static.serve_file(bigfile_filepath) + return self.f + + @cherrypy.expose + def tell(self): + if self.f.input.closed: + return '' + return repr(self.f.input.tell()).rstrip('L') + + @cherrypy.expose + def fileobj(self): + f = open(os.path.join(curdir, 'style.css'), 'rb') + return static.serve_fileobj(f, content_type='text/css') + + @cherrypy.expose + def bytesio(self): + f = io.BytesIO(b'Fee\nfie\nfo\nfum') + return static.serve_fileobj(f, content_type='text/plain') + + class Static: + + @cherrypy.expose + def index(self): + return 'You want the Baron? You can have the Baron!' + + @cherrypy.expose + def dynamic(self): + return 'This is a DYNAMIC page' + + root = Root() + root.static = Static() + + rootconf = { + '/static': { + 'tools.staticdir.on': True, + 'tools.staticdir.dir': 'static', + 'tools.staticdir.root': curdir, + }, + '/static-long': { + 'tools.staticdir.on': True, + 'tools.staticdir.dir': r'\\?\%s' % curdir, + }, + '/style.css': { + 'tools.staticfile.on': True, + 'tools.staticfile.filename': os.path.join(curdir, 'style.css'), + }, + '/docroot': { + 'tools.staticdir.on': True, + 'tools.staticdir.root': curdir, + 'tools.staticdir.dir': 'static', + 'tools.staticdir.index': 'index.html', + }, + '/error': { + 'tools.staticdir.on': True, + 'request.show_tracebacks': True, + }, + '/404test': { + 'tools.staticdir.on': True, + 'tools.staticdir.root': curdir, + 'tools.staticdir.dir': 'static', + 'error_page.404': error_page_404, + } + } + rootApp = cherrypy.Application(root) + rootApp.merge(rootconf) + + test_app_conf = { + '/test': { + 'tools.staticdir.index': 'index.html', + 'tools.staticdir.on': True, + 'tools.staticdir.root': curdir, + 'tools.staticdir.dir': 'static', + }, + } + testApp = cherrypy.Application(Static()) + testApp.merge(test_app_conf) + + vhost = cherrypy._cpwsgi.VirtualHost(rootApp, {'virt.net': testApp}) + cherrypy.tree.graft(vhost) + + @staticmethod + def teardown_server(): + for f in (has_space_filepath, bigfile_filepath): + if os.path.exists(f): + try: + os.unlink(f) + except Exception: + pass + + def test_static(self): + self.getPage('/static/index.html') + self.assertStatus('200 OK') + self.assertHeader('Content-Type', 'text/html') + self.assertBody('Hello, world\r\n') + + # Using a staticdir.root value in a subdir... + self.getPage('/docroot/index.html') + self.assertStatus('200 OK') + self.assertHeader('Content-Type', 'text/html') + self.assertBody('Hello, world\r\n') + + # Check a filename with spaces in it + self.getPage('/static/has%20space.html') + self.assertStatus('200 OK') + self.assertHeader('Content-Type', 'text/html') + self.assertBody('Hello, world\r\n') + + self.getPage('/style.css') + self.assertStatus('200 OK') + self.assertHeader('Content-Type', 'text/css') + # Note: The body should be exactly 'Dummy stylesheet\n', but + # unfortunately some tools such as WinZip sometimes turn \n + # into \r\n on Windows when extracting the CherryPy tarball so + # we just check the content + self.assertMatchesBody('^Dummy stylesheet') + + @pytest.mark.skipif(platform.system() != 'Windows', reason='Windows only') + def test_static_longpath(self): + """Test serving of a file in subdir of a Windows long-path + staticdir.""" + self.getPage('/static-long/static/index.html') + self.assertStatus('200 OK') + self.assertHeader('Content-Type', 'text/html') + self.assertBody('Hello, world\r\n') + + def test_fallthrough(self): + # Test that NotFound will then try dynamic handlers (see [878]). + self.getPage('/static/dynamic') + self.assertBody('This is a DYNAMIC page') + + # Check a directory via fall-through to dynamic handler. + self.getPage('/static/') + self.assertStatus('200 OK') + self.assertHeader('Content-Type', 'text/html;charset=utf-8') + self.assertBody('You want the Baron? You can have the Baron!') + + def test_index(self): + # Check a directory via "staticdir.index". + self.getPage('/docroot/') + self.assertStatus('200 OK') + self.assertHeader('Content-Type', 'text/html') + self.assertBody('Hello, world\r\n') + # The same page should be returned even if redirected. + self.getPage('/docroot') + self.assertStatus(301) + self.assertHeader('Location', '%s/docroot/' % self.base()) + self.assertMatchesBody( + "This resource .* <a href=(['\"])%s/docroot/\\1>" + '%s/docroot/</a>.' + % (self.base(), self.base()) + ) + + def test_config_errors(self): + # Check that we get an error if no .file or .dir + self.getPage('/error/thing.html') + self.assertErrorPage(500) + if sys.version_info >= (3, 3): + errmsg = ( + r'TypeError: staticdir\(\) missing 2 ' + 'required positional arguments' + ) + else: + errmsg = ( + r'TypeError: staticdir\(\) takes at least 2 ' + r'(positional )?arguments \(0 given\)' + ) + self.assertMatchesBody(errmsg.encode('ascii')) + + def test_security(self): + # Test up-level security + self.getPage('/static/../../test/style.css') + self.assertStatus((400, 403)) + + def test_modif(self): + # Test modified-since on a reasonably-large file + self.getPage('/static/dirback.jpg') + self.assertStatus('200 OK') + lastmod = '' + for k, v in self.headers: + if k == 'Last-Modified': + lastmod = v + ims = ('If-Modified-Since', lastmod) + self.getPage('/static/dirback.jpg', headers=[ims]) + self.assertStatus(304) + self.assertNoHeader('Content-Type') + self.assertNoHeader('Content-Length') + self.assertNoHeader('Content-Disposition') + self.assertBody('') + + def test_755_vhost(self): + self.getPage('/test/', [('Host', 'virt.net')]) + self.assertStatus(200) + self.getPage('/test', [('Host', 'virt.net')]) + self.assertStatus(301) + self.assertHeader('Location', self.scheme + '://virt.net/test/') + + def test_serve_fileobj(self): + self.getPage('/fileobj') + self.assertStatus('200 OK') + self.assertHeader('Content-Type', 'text/css;charset=utf-8') + self.assertMatchesBody('^Dummy stylesheet') + + def test_serve_bytesio(self): + self.getPage('/bytesio') + self.assertStatus('200 OK') + self.assertHeader('Content-Type', 'text/plain;charset=utf-8') + self.assertHeader('Content-Length', 14) + self.assertMatchesBody('Fee\nfie\nfo\nfum') + + @pytest.mark.xfail(reason='#1475') + def test_file_stream(self): + if cherrypy.server.protocol_version != 'HTTP/1.1': + return self.skip() + + self.PROTOCOL = 'HTTP/1.1' + + # Make an initial request + self.persistent = True + conn = self.HTTP_CONN + conn.putrequest('GET', '/bigfile', skip_host=True) + conn.putheader('Host', self.HOST) + conn.endheaders() + response = conn.response_class(conn.sock, method='GET') + response.begin() + self.assertEqual(response.status, 200) + + body = b'' + remaining = BIGFILE_SIZE + while remaining > 0: + data = response.fp.read(65536) + if not data: + break + body += data + remaining -= len(data) + + if self.scheme == 'https': + newconn = HTTPSConnection + else: + newconn = HTTPConnection + s, h, b = helper.webtest.openURL( + b'/tell', headers=[], host=self.HOST, port=self.PORT, + http_conn=newconn) + if not b: + # The file was closed on the server. + tell_position = BIGFILE_SIZE + else: + tell_position = int(b) + + read_so_far = len(body) + + # It is difficult for us to force the server to only read + # the bytes that we ask for - there are going to be buffers + # inbetween. + # + # CherryPy will attempt to write as much data as it can to + # the socket, and we don't have a way to determine what that + # size will be. So we make the following assumption - by + # the time we have read in the entire file on the server, + # we will have at least received half of it. If this is not + # the case, then this is an indicator that either: + # - machines that are running this test are using buffer + # sizes greater than half of BIGFILE_SIZE; or + # - streaming is broken. + # + # At the time of writing, we seem to have encountered + # buffer sizes bigger than 512K, so we've increased + # BIGFILE_SIZE to 4MB and in 2016 to 20MB and then 32MB. + # This test is going to keep failing according to the + # improvements in hardware and OS buffers. + if tell_position >= BIGFILE_SIZE: + if read_so_far < (BIGFILE_SIZE / 2): + self.fail( + 'The file should have advanced to position %r, but ' + 'has already advanced to the end of the file. It ' + 'may not be streamed as intended, or at the wrong ' + 'chunk size (64k)' % read_so_far) + elif tell_position < read_so_far: + self.fail( + 'The file should have advanced to position %r, but has ' + 'only advanced to position %r. It may not be streamed ' + 'as intended, or at the wrong chunk size (64k)' % + (read_so_far, tell_position)) + + if body != b'x' * BIGFILE_SIZE: + self.fail("Body != 'x' * %d. Got %r instead (%d bytes)." % + (BIGFILE_SIZE, body[:50], len(body))) + conn.close() + + def test_file_stream_deadlock(self): + if cherrypy.server.protocol_version != 'HTTP/1.1': + return self.skip() + + self.PROTOCOL = 'HTTP/1.1' + + # Make an initial request but abort early. + self.persistent = True + conn = self.HTTP_CONN + conn.putrequest('GET', '/bigfile', skip_host=True) + conn.putheader('Host', self.HOST) + conn.endheaders() + response = conn.response_class(conn.sock, method='GET') + response.begin() + self.assertEqual(response.status, 200) + body = response.fp.read(65536) + if body != b'x' * len(body): + self.fail("Body != 'x' * %d. Got %r instead (%d bytes)." % + (65536, body[:50], len(body))) + response.close() + conn.close() + + # Make a second request, which should fetch the whole file. + self.persistent = False + self.getPage('/bigfile') + if self.body != b'x' * BIGFILE_SIZE: + self.fail("Body != 'x' * %d. Got %r instead (%d bytes)." % + (BIGFILE_SIZE, self.body[:50], len(body))) + + def test_error_page_with_serve_file(self): + self.getPage('/404test/yunyeen') + self.assertStatus(404) + self.assertInBody("I couldn't find that thing") + + def test_null_bytes(self): + self.getPage('/static/\x00') + self.assertStatus('404 Not Found') + + @staticmethod + @contextlib.contextmanager + def unicode_file(): + filename = ntou('Слава Україні.html', 'utf-8') + filepath = os.path.join(curdir, 'static', filename) + with io.open(filepath, 'w', encoding='utf-8') as strm: + strm.write(ntou('Героям Слава!', 'utf-8')) + try: + yield + finally: + os.remove(filepath) + + py27_on_windows = ( + platform.system() == 'Windows' and + sys.version_info < (3,) + ) + @pytest.mark.xfail(py27_on_windows, reason='#1544') # noqa: E301 + def test_unicode(self): + ensure_unicode_filesystem() + with self.unicode_file(): + url = ntou('/static/Слава Україні.html', 'utf-8') + # quote function requires str + url = tonative(url, 'utf-8') + url = urllib.parse.quote(url) + self.getPage(url) + + expected = ntou('Героям Слава!', 'utf-8') + self.assertInBody(expected) + + +def error_page_404(status, message, traceback, version): + path = os.path.join(curdir, 'static', '404.html') + return static.serve_file(path, content_type='text/html') diff --git a/libraries/cherrypy/test/test_tools.py b/libraries/cherrypy/test/test_tools.py new file mode 100644 index 00000000..a73a3898 --- /dev/null +++ b/libraries/cherrypy/test/test_tools.py @@ -0,0 +1,468 @@ +"""Test the various means of instantiating and invoking tools.""" + +import gzip +import io +import sys +import time +import types +import unittest +import operator + +import six +from six.moves import range, map +from six.moves.http_client import IncompleteRead + +import cherrypy +from cherrypy import tools +from cherrypy._cpcompat import ntou +from cherrypy.test import helper, _test_decorators + + +timeout = 0.2 +europoundUnicode = ntou('\x80\xa3') + + +# Client-side code # + + +class ToolTests(helper.CPWebCase): + + @staticmethod + def setup_server(): + + # Put check_access in a custom toolbox with its own namespace + myauthtools = cherrypy._cptools.Toolbox('myauth') + + def check_access(default=False): + if not getattr(cherrypy.request, 'userid', default): + raise cherrypy.HTTPError(401) + myauthtools.check_access = cherrypy.Tool( + 'before_request_body', check_access) + + def numerify(): + def number_it(body): + for chunk in body: + for k, v in cherrypy.request.numerify_map: + chunk = chunk.replace(k, v) + yield chunk + cherrypy.response.body = number_it(cherrypy.response.body) + + class NumTool(cherrypy.Tool): + + def _setup(self): + def makemap(): + m = self._merged_args().get('map', {}) + cherrypy.request.numerify_map = list(six.iteritems(m)) + cherrypy.request.hooks.attach('on_start_resource', makemap) + + def critical(): + cherrypy.request.error_response = cherrypy.HTTPError( + 502).set_response + critical.failsafe = True + + cherrypy.request.hooks.attach('on_start_resource', critical) + cherrypy.request.hooks.attach(self._point, self.callable) + + tools.numerify = NumTool('before_finalize', numerify) + + # It's not mandatory to inherit from cherrypy.Tool. + class NadsatTool: + + def __init__(self): + self.ended = {} + self._name = 'nadsat' + + def nadsat(self): + def nadsat_it_up(body): + for chunk in body: + chunk = chunk.replace(b'good', b'horrorshow') + chunk = chunk.replace(b'piece', b'lomtick') + yield chunk + cherrypy.response.body = nadsat_it_up(cherrypy.response.body) + nadsat.priority = 0 + + def cleanup(self): + # This runs after the request has been completely written out. + cherrypy.response.body = [b'razdrez'] + id = cherrypy.request.params.get('id') + if id: + self.ended[id] = True + cleanup.failsafe = True + + def _setup(self): + cherrypy.request.hooks.attach('before_finalize', self.nadsat) + cherrypy.request.hooks.attach('on_end_request', self.cleanup) + tools.nadsat = NadsatTool() + + def pipe_body(): + cherrypy.request.process_request_body = False + clen = int(cherrypy.request.headers['Content-Length']) + cherrypy.request.body = cherrypy.request.rfile.read(clen) + + # Assert that we can use a callable object instead of a function. + class Rotator(object): + + def __call__(self, scale): + r = cherrypy.response + r.collapse_body() + if six.PY3: + r.body = [bytes([(x + scale) % 256 for x in r.body[0]])] + else: + r.body = [chr((ord(x) + scale) % 256) for x in r.body[0]] + cherrypy.tools.rotator = cherrypy.Tool('before_finalize', Rotator()) + + def stream_handler(next_handler, *args, **kwargs): + actual = cherrypy.request.config.get('tools.streamer.arg') + assert actual == 'arg value' + cherrypy.response.output = o = io.BytesIO() + try: + next_handler(*args, **kwargs) + # Ignore the response and return our accumulated output + # instead. + return o.getvalue() + finally: + o.close() + cherrypy.tools.streamer = cherrypy._cptools.HandlerWrapperTool( + stream_handler) + + class Root: + + @cherrypy.expose + def index(self): + return 'Howdy earth!' + + @cherrypy.expose + @cherrypy.config(**{ + 'tools.streamer.on': True, + 'tools.streamer.arg': 'arg value', + }) + def tarfile(self): + actual = cherrypy.request.config.get('tools.streamer.arg') + assert actual == 'arg value' + cherrypy.response.output.write(b'I am ') + cherrypy.response.output.write(b'a tarfile') + + @cherrypy.expose + def euro(self): + hooks = list(cherrypy.request.hooks['before_finalize']) + hooks.sort() + cbnames = [x.callback.__name__ for x in hooks] + assert cbnames == ['gzip'], cbnames + priorities = [x.priority for x in hooks] + assert priorities == [80], priorities + yield ntou('Hello,') + yield ntou('world') + yield europoundUnicode + + # Bare hooks + @cherrypy.expose + @cherrypy.config(**{'hooks.before_request_body': pipe_body}) + def pipe(self): + return cherrypy.request.body + + # Multiple decorators; include kwargs just for fun. + # Note that rotator must run before gzip. + @cherrypy.expose + def decorated_euro(self, *vpath): + yield ntou('Hello,') + yield ntou('world') + yield europoundUnicode + decorated_euro = tools.gzip(compress_level=6)(decorated_euro) + decorated_euro = tools.rotator(scale=3)(decorated_euro) + + root = Root() + + class TestType(type): + """Metaclass which automatically exposes all functions in each + subclass, and adds an instance of the subclass as an attribute + of root. + """ + def __init__(cls, name, bases, dct): + type.__init__(cls, name, bases, dct) + for value in six.itervalues(dct): + if isinstance(value, types.FunctionType): + cherrypy.expose(value) + setattr(root, name.lower(), cls()) + Test = TestType('Test', (object,), {}) + + # METHOD ONE: + # Declare Tools in _cp_config + @cherrypy.config(**{'tools.nadsat.on': True}) + class Demo(Test): + + def index(self, id=None): + return 'A good piece of cherry pie' + + def ended(self, id): + return repr(tools.nadsat.ended[id]) + + def err(self, id=None): + raise ValueError() + + def errinstream(self, id=None): + yield 'nonconfidential' + raise ValueError() + yield 'confidential' + + # METHOD TWO: decorator using Tool() + # We support Python 2.3, but the @-deco syntax would look like + # this: + # @tools.check_access() + def restricted(self): + return 'Welcome!' + restricted = myauthtools.check_access()(restricted) + userid = restricted + + def err_in_onstart(self): + return 'success!' + + @cherrypy.config(**{'response.stream': True}) + def stream(self, id=None): + for x in range(100000000): + yield str(x) + + conf = { + # METHOD THREE: + # Declare Tools in detached config + '/demo': { + 'tools.numerify.on': True, + 'tools.numerify.map': {b'pie': b'3.14159'}, + }, + '/demo/restricted': { + 'request.show_tracebacks': False, + }, + '/demo/userid': { + 'request.show_tracebacks': False, + 'myauth.check_access.default': True, + }, + '/demo/errinstream': { + 'response.stream': True, + }, + '/demo/err_in_onstart': { + # Because this isn't a dict, on_start_resource will error. + 'tools.numerify.map': 'pie->3.14159' + }, + # Combined tools + '/euro': { + 'tools.gzip.on': True, + 'tools.encode.on': True, + }, + # Priority specified in config + '/decorated_euro/subpath': { + 'tools.gzip.priority': 10, + }, + # Handler wrappers + '/tarfile': {'tools.streamer.on': True} + } + app = cherrypy.tree.mount(root, config=conf) + app.request_class.namespaces['myauth'] = myauthtools + + root.tooldecs = _test_decorators.ToolExamples() + + def testHookErrors(self): + self.getPage('/demo/?id=1') + # If body is "razdrez", then on_end_request is being called too early. + self.assertBody('A horrorshow lomtick of cherry 3.14159') + # If this fails, then on_end_request isn't being called at all. + time.sleep(0.1) + self.getPage('/demo/ended/1') + self.assertBody('True') + + valerr = '\n raise ValueError()\nValueError' + self.getPage('/demo/err?id=3') + # If body is "razdrez", then on_end_request is being called too early. + self.assertErrorPage(502, pattern=valerr) + # If this fails, then on_end_request isn't being called at all. + time.sleep(0.1) + self.getPage('/demo/ended/3') + self.assertBody('True') + + # If body is "razdrez", then on_end_request is being called too early. + if (cherrypy.server.protocol_version == 'HTTP/1.0' or + getattr(cherrypy.server, 'using_apache', False)): + self.getPage('/demo/errinstream?id=5') + # Because this error is raised after the response body has + # started, the status should not change to an error status. + self.assertStatus('200 OK') + self.assertBody('nonconfidential') + else: + # Because this error is raised after the response body has + # started, and because it's chunked output, an error is raised by + # the HTTP client when it encounters incomplete output. + self.assertRaises((ValueError, IncompleteRead), self.getPage, + '/demo/errinstream?id=5') + # If this fails, then on_end_request isn't being called at all. + time.sleep(0.1) + self.getPage('/demo/ended/5') + self.assertBody('True') + + # Test the "__call__" technique (compile-time decorator). + self.getPage('/demo/restricted') + self.assertErrorPage(401) + + # Test compile-time decorator with kwargs from config. + self.getPage('/demo/userid') + self.assertBody('Welcome!') + + def testEndRequestOnDrop(self): + old_timeout = None + try: + httpserver = cherrypy.server.httpserver + old_timeout = httpserver.timeout + except (AttributeError, IndexError): + return self.skip() + + try: + httpserver.timeout = timeout + + # Test that on_end_request is called even if the client drops. + self.persistent = True + try: + conn = self.HTTP_CONN + conn.putrequest('GET', '/demo/stream?id=9', skip_host=True) + conn.putheader('Host', self.HOST) + conn.endheaders() + # Skip the rest of the request and close the conn. This will + # cause the server's active socket to error, which *should* + # result in the request being aborted, and request.close being + # called all the way up the stack (including WSGI middleware), + # eventually calling our on_end_request hook. + finally: + self.persistent = False + time.sleep(timeout * 2) + # Test that the on_end_request hook was called. + self.getPage('/demo/ended/9') + self.assertBody('True') + finally: + if old_timeout is not None: + httpserver.timeout = old_timeout + + def testGuaranteedHooks(self): + # The 'critical' on_start_resource hook is 'failsafe' (guaranteed + # to run even if there are failures in other on_start methods). + # This is NOT true of the other hooks. + # Here, we have set up a failure in NumerifyTool.numerify_map, + # but our 'critical' hook should run and set the error to 502. + self.getPage('/demo/err_in_onstart') + self.assertErrorPage(502) + tmpl = "AttributeError: 'str' object has no attribute '{attr}'" + expected_msg = tmpl.format(attr='items' if six.PY3 else 'iteritems') + self.assertInBody(expected_msg) + + def testCombinedTools(self): + expectedResult = (ntou('Hello,world') + + europoundUnicode).encode('utf-8') + zbuf = io.BytesIO() + zfile = gzip.GzipFile(mode='wb', fileobj=zbuf, compresslevel=9) + zfile.write(expectedResult) + zfile.close() + + self.getPage('/euro', + headers=[ + ('Accept-Encoding', 'gzip'), + ('Accept-Charset', 'ISO-8859-1,utf-8;q=0.7,*;q=0.7')]) + self.assertInBody(zbuf.getvalue()[:3]) + + zbuf = io.BytesIO() + zfile = gzip.GzipFile(mode='wb', fileobj=zbuf, compresslevel=6) + zfile.write(expectedResult) + zfile.close() + + self.getPage('/decorated_euro', headers=[('Accept-Encoding', 'gzip')]) + self.assertInBody(zbuf.getvalue()[:3]) + + # This returns a different value because gzip's priority was + # lowered in conf, allowing the rotator to run after gzip. + # Of course, we don't want breakage in production apps, + # but it proves the priority was changed. + self.getPage('/decorated_euro/subpath', + headers=[('Accept-Encoding', 'gzip')]) + if six.PY3: + self.assertInBody(bytes([(x + 3) % 256 for x in zbuf.getvalue()])) + else: + self.assertInBody(''.join([chr((ord(x) + 3) % 256) + for x in zbuf.getvalue()])) + + def testBareHooks(self): + content = 'bit of a pain in me gulliver' + self.getPage('/pipe', + headers=[('Content-Length', str(len(content))), + ('Content-Type', 'text/plain')], + method='POST', body=content) + self.assertBody(content) + + def testHandlerWrapperTool(self): + self.getPage('/tarfile') + self.assertBody('I am a tarfile') + + def testToolWithConfig(self): + if not sys.version_info >= (2, 5): + return self.skip('skipped (Python 2.5+ only)') + + self.getPage('/tooldecs/blah') + self.assertHeader('Content-Type', 'application/data') + + def testWarnToolOn(self): + # get + try: + cherrypy.tools.numerify.on + except AttributeError: + pass + else: + raise AssertionError('Tool.on did not error as it should have.') + + # set + try: + cherrypy.tools.numerify.on = True + except AttributeError: + pass + else: + raise AssertionError('Tool.on did not error as it should have.') + + def testDecorator(self): + @cherrypy.tools.register('on_start_resource') + def example(): + pass + self.assertTrue(isinstance(cherrypy.tools.example, cherrypy.Tool)) + self.assertEqual(cherrypy.tools.example._point, 'on_start_resource') + + @cherrypy.tools.register( # noqa: F811 + 'before_finalize', name='renamed', priority=60, + ) + def example(): + pass + self.assertTrue(isinstance(cherrypy.tools.renamed, cherrypy.Tool)) + self.assertEqual(cherrypy.tools.renamed._point, 'before_finalize') + self.assertEqual(cherrypy.tools.renamed._name, 'renamed') + self.assertEqual(cherrypy.tools.renamed._priority, 60) + + +class SessionAuthTest(unittest.TestCase): + + def test_login_screen_returns_bytes(self): + """ + login_screen must return bytes even if unicode parameters are passed. + Issue 1132 revealed that login_screen would return unicode if the + username and password were unicode. + """ + sa = cherrypy.lib.cptools.SessionAuth() + res = sa.login_screen(None, username=six.text_type('nobody'), + password=six.text_type('anypass')) + self.assertTrue(isinstance(res, bytes)) + + +class TestHooks: + def test_priorities(self): + """ + Hooks should sort by priority order. + """ + Hook = cherrypy._cprequest.Hook + hooks = [ + Hook(None, priority=48), + Hook(None), + Hook(None, priority=49), + ] + hooks.sort() + by_priority = operator.attrgetter('priority') + priorities = list(map(by_priority, hooks)) + assert priorities == [48, 49, 50] diff --git a/libraries/cherrypy/test/test_tutorials.py b/libraries/cherrypy/test/test_tutorials.py new file mode 100644 index 00000000..efa35b99 --- /dev/null +++ b/libraries/cherrypy/test/test_tutorials.py @@ -0,0 +1,210 @@ +import sys +import imp +import types +import importlib + +import six + +import cherrypy +from cherrypy.test import helper + + +class TutorialTest(helper.CPWebCase): + + @classmethod + def setup_server(cls): + """ + Mount something so the engine starts. + """ + class Dummy: + pass + cherrypy.tree.mount(Dummy()) + + @staticmethod + def load_module(name): + """ + Import or reload tutorial module as needed. + """ + target = 'cherrypy.tutorial.' + name + if target in sys.modules: + module = imp.reload(sys.modules[target]) + else: + module = importlib.import_module(target) + return module + + @classmethod + def setup_tutorial(cls, name, root_name, config={}): + cherrypy.config.reset() + module = cls.load_module(name) + root = getattr(module, root_name) + conf = getattr(module, 'tutconf') + class_types = type, + if six.PY2: + class_types += types.ClassType, + if isinstance(root, class_types): + root = root() + cherrypy.tree.mount(root, config=conf) + cherrypy.config.update(config) + + def test01HelloWorld(self): + self.setup_tutorial('tut01_helloworld', 'HelloWorld') + self.getPage('/') + self.assertBody('Hello world!') + + def test02ExposeMethods(self): + self.setup_tutorial('tut02_expose_methods', 'HelloWorld') + self.getPage('/show_msg') + self.assertBody('Hello world!') + + def test03GetAndPost(self): + self.setup_tutorial('tut03_get_and_post', 'WelcomePage') + + # Try different GET queries + self.getPage('/greetUser?name=Bob') + self.assertBody("Hey Bob, what's up?") + + self.getPage('/greetUser') + self.assertBody('Please enter your name <a href="./">here</a>.') + + self.getPage('/greetUser?name=') + self.assertBody('No, really, enter your name <a href="./">here</a>.') + + # Try the same with POST + self.getPage('/greetUser', method='POST', body='name=Bob') + self.assertBody("Hey Bob, what's up?") + + self.getPage('/greetUser', method='POST', body='name=') + self.assertBody('No, really, enter your name <a href="./">here</a>.') + + def test04ComplexSite(self): + self.setup_tutorial('tut04_complex_site', 'root') + + msg = ''' + <p>Here are some extra useful links:</p> + + <ul> + <li><a href="http://del.icio.us">del.icio.us</a></li> + <li><a href="http://www.cherrypy.org">CherryPy</a></li> + </ul> + + <p>[<a href="../">Return to links page</a>]</p>''' + self.getPage('/links/extra/') + self.assertBody(msg) + + def test05DerivedObjects(self): + self.setup_tutorial('tut05_derived_objects', 'HomePage') + msg = ''' + <html> + <head> + <title>Another Page</title> + <head> + <body> + <h2>Another Page</h2> + + <p> + And this is the amazing second page! + </p> + + </body> + </html> + ''' + # the tutorial has some annoying spaces in otherwise blank lines + msg = msg.replace('</h2>\n\n', '</h2>\n \n') + msg = msg.replace('</p>\n\n', '</p>\n \n') + self.getPage('/another/') + self.assertBody(msg) + + def test06DefaultMethod(self): + self.setup_tutorial('tut06_default_method', 'UsersPage') + self.getPage('/hendrik') + self.assertBody('Hendrik Mans, CherryPy co-developer & crazy German ' + '(<a href="./">back</a>)') + + def test07Sessions(self): + self.setup_tutorial('tut07_sessions', 'HitCounter') + + self.getPage('/') + self.assertBody( + "\n During your current session, you've viewed this" + '\n page 1 times! Your life is a patio of fun!' + '\n ') + + self.getPage('/', self.cookies) + self.assertBody( + "\n During your current session, you've viewed this" + '\n page 2 times! Your life is a patio of fun!' + '\n ') + + def test08GeneratorsAndYield(self): + self.setup_tutorial('tut08_generators_and_yield', 'GeneratorDemo') + self.getPage('/') + self.assertBody('<html><body><h2>Generators rule!</h2>' + '<h3>List of users:</h3>' + 'Remi<br/>Carlos<br/>Hendrik<br/>Lorenzo Lamas<br/>' + '</body></html>') + + def test09Files(self): + self.setup_tutorial('tut09_files', 'FileDemo') + + # Test upload + filesize = 5 + h = [('Content-type', 'multipart/form-data; boundary=x'), + ('Content-Length', str(105 + filesize))] + b = ('--x\n' + 'Content-Disposition: form-data; name="myFile"; ' + 'filename="hello.txt"\r\n' + 'Content-Type: text/plain\r\n' + '\r\n') + b += 'a' * filesize + '\n' + '--x--\n' + self.getPage('/upload', h, 'POST', b) + self.assertBody('''<html> + <body> + myFile length: %d<br /> + myFile filename: hello.txt<br /> + myFile mime-type: text/plain + </body> + </html>''' % filesize) + + # Test download + self.getPage('/download') + self.assertStatus('200 OK') + self.assertHeader('Content-Type', 'application/x-download') + self.assertHeader('Content-Disposition', + # Make sure the filename is quoted. + 'attachment; filename="pdf_file.pdf"') + self.assertEqual(len(self.body), 85698) + + def test10HTTPErrors(self): + self.setup_tutorial('tut10_http_errors', 'HTTPErrorDemo') + + @cherrypy.expose + def traceback_setting(): + return repr(cherrypy.request.show_tracebacks) + cherrypy.tree.mount(traceback_setting, '/traceback_setting') + + self.getPage('/') + self.assertInBody("""<a href="toggleTracebacks">""") + self.assertInBody("""<a href="/doesNotExist">""") + self.assertInBody("""<a href="/error?code=403">""") + self.assertInBody("""<a href="/error?code=500">""") + self.assertInBody("""<a href="/messageArg">""") + + self.getPage('/traceback_setting') + setting = self.body + self.getPage('/toggleTracebacks') + self.assertStatus((302, 303)) + self.getPage('/traceback_setting') + self.assertBody(str(not eval(setting))) + + self.getPage('/error?code=500') + self.assertStatus(500) + self.assertInBody('The server encountered an unexpected condition ' + 'which prevented it from fulfilling the request.') + + self.getPage('/error?code=403') + self.assertStatus(403) + self.assertInBody("<h2>You can't do that!</h2>") + + self.getPage('/messageArg') + self.assertStatus(500) + self.assertInBody("If you construct an HTTPError with a 'message'") diff --git a/libraries/cherrypy/test/test_virtualhost.py b/libraries/cherrypy/test/test_virtualhost.py new file mode 100644 index 00000000..de88f927 --- /dev/null +++ b/libraries/cherrypy/test/test_virtualhost.py @@ -0,0 +1,113 @@ +import os + +import cherrypy +from cherrypy.test import helper + +curdir = os.path.join(os.getcwd(), os.path.dirname(__file__)) + + +class VirtualHostTest(helper.CPWebCase): + + @staticmethod + def setup_server(): + class Root: + + @cherrypy.expose + def index(self): + return 'Hello, world' + + @cherrypy.expose + def dom4(self): + return 'Under construction' + + @cherrypy.expose + def method(self, value): + return 'You sent %s' % value + + class VHost: + + def __init__(self, sitename): + self.sitename = sitename + + @cherrypy.expose + def index(self): + return 'Welcome to %s' % self.sitename + + @cherrypy.expose + def vmethod(self, value): + return 'You sent %s' % value + + @cherrypy.expose + def url(self): + return cherrypy.url('nextpage') + + # Test static as a handler (section must NOT include vhost prefix) + static = cherrypy.tools.staticdir.handler( + section='/static', dir=curdir) + + root = Root() + root.mydom2 = VHost('Domain 2') + root.mydom3 = VHost('Domain 3') + hostmap = {'www.mydom2.com': '/mydom2', + 'www.mydom3.com': '/mydom3', + 'www.mydom4.com': '/dom4', + } + cherrypy.tree.mount(root, config={ + '/': { + 'request.dispatch': cherrypy.dispatch.VirtualHost(**hostmap) + }, + # Test static in config (section must include vhost prefix) + '/mydom2/static2': { + 'tools.staticdir.on': True, + 'tools.staticdir.root': curdir, + 'tools.staticdir.dir': 'static', + 'tools.staticdir.index': 'index.html', + }, + }) + + def testVirtualHost(self): + self.getPage('/', [('Host', 'www.mydom1.com')]) + self.assertBody('Hello, world') + self.getPage('/mydom2/', [('Host', 'www.mydom1.com')]) + self.assertBody('Welcome to Domain 2') + + self.getPage('/', [('Host', 'www.mydom2.com')]) + self.assertBody('Welcome to Domain 2') + self.getPage('/', [('Host', 'www.mydom3.com')]) + self.assertBody('Welcome to Domain 3') + self.getPage('/', [('Host', 'www.mydom4.com')]) + self.assertBody('Under construction') + + # Test GET, POST, and positional params + self.getPage('/method?value=root') + self.assertBody('You sent root') + self.getPage('/vmethod?value=dom2+GET', [('Host', 'www.mydom2.com')]) + self.assertBody('You sent dom2 GET') + self.getPage('/vmethod', [('Host', 'www.mydom3.com')], method='POST', + body='value=dom3+POST') + self.assertBody('You sent dom3 POST') + self.getPage('/vmethod/pos', [('Host', 'www.mydom3.com')]) + self.assertBody('You sent pos') + + # Test that cherrypy.url uses the browser url, not the virtual url + self.getPage('/url', [('Host', 'www.mydom2.com')]) + self.assertBody('%s://www.mydom2.com/nextpage' % self.scheme) + + def test_VHost_plus_Static(self): + # Test static as a handler + self.getPage('/static/style.css', [('Host', 'www.mydom2.com')]) + self.assertStatus('200 OK') + self.assertHeader('Content-Type', 'text/css;charset=utf-8') + + # Test static in config + self.getPage('/static2/dirback.jpg', [('Host', 'www.mydom2.com')]) + self.assertStatus('200 OK') + self.assertHeaderIn('Content-Type', ['image/jpeg', 'image/pjpeg']) + + # Test static config with "index" arg + self.getPage('/static2/', [('Host', 'www.mydom2.com')]) + self.assertStatus('200 OK') + self.assertBody('Hello, world\r\n') + # Since tools.trailing_slash is on by default, this should redirect + self.getPage('/static2', [('Host', 'www.mydom2.com')]) + self.assertStatus(301) diff --git a/libraries/cherrypy/test/test_wsgi_ns.py b/libraries/cherrypy/test/test_wsgi_ns.py new file mode 100644 index 00000000..3545724c --- /dev/null +++ b/libraries/cherrypy/test/test_wsgi_ns.py @@ -0,0 +1,93 @@ +import cherrypy +from cherrypy.test import helper + + +class WSGI_Namespace_Test(helper.CPWebCase): + + @staticmethod + def setup_server(): + + class WSGIResponse(object): + + def __init__(self, appresults): + self.appresults = appresults + self.iter = iter(appresults) + + def __iter__(self): + return self + + def next(self): + return self.iter.next() + + def __next__(self): + return next(self.iter) + + def close(self): + if hasattr(self.appresults, 'close'): + self.appresults.close() + + class ChangeCase(object): + + def __init__(self, app, to=None): + self.app = app + self.to = to + + def __call__(self, environ, start_response): + res = self.app(environ, start_response) + + class CaseResults(WSGIResponse): + + def next(this): + return getattr(this.iter.next(), self.to)() + + def __next__(this): + return getattr(next(this.iter), self.to)() + return CaseResults(res) + + class Replacer(object): + + def __init__(self, app, map={}): + self.app = app + self.map = map + + def __call__(self, environ, start_response): + res = self.app(environ, start_response) + + class ReplaceResults(WSGIResponse): + + def next(this): + line = this.iter.next() + for k, v in self.map.iteritems(): + line = line.replace(k, v) + return line + + def __next__(this): + line = next(this.iter) + for k, v in self.map.items(): + line = line.replace(k, v) + return line + return ReplaceResults(res) + + class Root(object): + + @cherrypy.expose + def index(self): + return 'HellO WoRlD!' + + root_conf = {'wsgi.pipeline': [('replace', Replacer)], + 'wsgi.replace.map': {b'L': b'X', + b'l': b'r'}, + } + + app = cherrypy.Application(Root()) + app.wsgiapp.pipeline.append(('changecase', ChangeCase)) + app.wsgiapp.config['changecase'] = {'to': 'upper'} + cherrypy.tree.mount(app, config={'/': root_conf}) + + def test_pipeline(self): + if not cherrypy.server.httpserver: + return self.skip() + + self.getPage('/') + # If body is "HEXXO WORXD!", the middleware was applied out of order. + self.assertBody('HERRO WORRD!') diff --git a/libraries/cherrypy/test/test_wsgi_unix_socket.py b/libraries/cherrypy/test/test_wsgi_unix_socket.py new file mode 100644 index 00000000..8f1cc00b --- /dev/null +++ b/libraries/cherrypy/test/test_wsgi_unix_socket.py @@ -0,0 +1,93 @@ +import os +import socket +import atexit +import tempfile + +from six.moves.http_client import HTTPConnection + +import pytest + +import cherrypy +from cherrypy.test import helper + + +def usocket_path(): + fd, path = tempfile.mkstemp('cp_test.sock') + os.close(fd) + os.remove(path) + return path + + +USOCKET_PATH = usocket_path() + + +class USocketHTTPConnection(HTTPConnection): + """ + HTTPConnection over a unix socket. + """ + + def __init__(self, path): + HTTPConnection.__init__(self, 'localhost') + self.path = path + + def __call__(self, *args, **kwargs): + """ + Catch-all method just to present itself as a constructor for the + HTTPConnection. + """ + return self + + def connect(self): + """ + Override the connect method and assign a unix socket as a transport. + """ + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + sock.connect(self.path) + self.sock = sock + atexit.register(lambda: os.remove(self.path)) + + +@pytest.mark.skipif("sys.platform == 'win32'") +class WSGI_UnixSocket_Test(helper.CPWebCase): + """ + Test basic behavior on a cherrypy wsgi server listening + on a unix socket. + + It exercises the config option `server.socket_file`. + """ + HTTP_CONN = USocketHTTPConnection(USOCKET_PATH) + + @staticmethod + def setup_server(): + class Root(object): + + @cherrypy.expose + def index(self): + return 'Test OK' + + @cherrypy.expose + def error(self): + raise Exception('Invalid page') + + config = { + 'server.socket_file': USOCKET_PATH + } + cherrypy.config.update(config) + cherrypy.tree.mount(Root()) + + def tearDown(self): + cherrypy.config.update({'server.socket_file': None}) + + def test_simple_request(self): + self.getPage('/') + self.assertStatus('200 OK') + self.assertInBody('Test OK') + + def test_not_found(self): + self.getPage('/invalid_path') + self.assertStatus('404 Not Found') + + def test_internal_error(self): + self.getPage('/error') + self.assertStatus('500 Internal Server Error') + self.assertInBody('Invalid page') diff --git a/libraries/cherrypy/test/test_wsgi_vhost.py b/libraries/cherrypy/test/test_wsgi_vhost.py new file mode 100644 index 00000000..2b6e5ba9 --- /dev/null +++ b/libraries/cherrypy/test/test_wsgi_vhost.py @@ -0,0 +1,35 @@ +import cherrypy +from cherrypy.test import helper + + +class WSGI_VirtualHost_Test(helper.CPWebCase): + + @staticmethod + def setup_server(): + + class ClassOfRoot(object): + + def __init__(self, name): + self.name = name + + @cherrypy.expose + def index(self): + return 'Welcome to the %s website!' % self.name + + default = cherrypy.Application(None) + + domains = {} + for year in range(1997, 2008): + app = cherrypy.Application(ClassOfRoot('Class of %s' % year)) + domains['www.classof%s.example' % year] = app + + cherrypy.tree.graft(cherrypy._cpwsgi.VirtualHost(default, domains)) + + def test_welcome(self): + if not cherrypy.server.using_wsgi: + return self.skip('skipped (not using WSGI)... ') + + for year in range(1997, 2008): + self.getPage( + '/', headers=[('Host', 'www.classof%s.example' % year)]) + self.assertBody('Welcome to the Class of %s website!' % year) diff --git a/libraries/cherrypy/test/test_wsgiapps.py b/libraries/cherrypy/test/test_wsgiapps.py new file mode 100644 index 00000000..1b3bf28f --- /dev/null +++ b/libraries/cherrypy/test/test_wsgiapps.py @@ -0,0 +1,120 @@ +import sys + +import cherrypy +from cherrypy._cpcompat import ntob +from cherrypy.test import helper + + +class WSGIGraftTests(helper.CPWebCase): + + @staticmethod + def setup_server(): + + def test_app(environ, start_response): + status = '200 OK' + response_headers = [('Content-type', 'text/plain')] + start_response(status, response_headers) + output = ['Hello, world!\n', + 'This is a wsgi app running within CherryPy!\n\n'] + keys = list(environ.keys()) + keys.sort() + for k in keys: + output.append('%s: %s\n' % (k, environ[k])) + return [ntob(x, 'utf-8') for x in output] + + def test_empty_string_app(environ, start_response): + status = '200 OK' + response_headers = [('Content-type', 'text/plain')] + start_response(status, response_headers) + return [ + b'Hello', b'', b' ', b'', b'world', + ] + + class WSGIResponse(object): + + def __init__(self, appresults): + self.appresults = appresults + self.iter = iter(appresults) + + def __iter__(self): + return self + + if sys.version_info >= (3, 0): + def __next__(self): + return next(self.iter) + else: + def next(self): + return self.iter.next() + + def close(self): + if hasattr(self.appresults, 'close'): + self.appresults.close() + + class ReversingMiddleware(object): + + def __init__(self, app): + self.app = app + + def __call__(self, environ, start_response): + results = app(environ, start_response) + + class Reverser(WSGIResponse): + + if sys.version_info >= (3, 0): + def __next__(this): + line = list(next(this.iter)) + line.reverse() + return bytes(line) + else: + def next(this): + line = list(this.iter.next()) + line.reverse() + return ''.join(line) + + return Reverser(results) + + class Root: + + @cherrypy.expose + def index(self): + return ntob("I'm a regular CherryPy page handler!") + + cherrypy.tree.mount(Root()) + + cherrypy.tree.graft(test_app, '/hosted/app1') + cherrypy.tree.graft(test_empty_string_app, '/hosted/app3') + + # Set script_name explicitly to None to signal CP that it should + # be pulled from the WSGI environ each time. + app = cherrypy.Application(Root(), script_name=None) + cherrypy.tree.graft(ReversingMiddleware(app), '/hosted/app2') + + wsgi_output = '''Hello, world! +This is a wsgi app running within CherryPy!''' + + def test_01_standard_app(self): + self.getPage('/') + self.assertBody("I'm a regular CherryPy page handler!") + + def test_04_pure_wsgi(self): + if not cherrypy.server.using_wsgi: + return self.skip('skipped (not using WSGI)... ') + self.getPage('/hosted/app1') + self.assertHeader('Content-Type', 'text/plain') + self.assertInBody(self.wsgi_output) + + def test_05_wrapped_cp_app(self): + if not cherrypy.server.using_wsgi: + return self.skip('skipped (not using WSGI)... ') + self.getPage('/hosted/app2/') + body = list("I'm a regular CherryPy page handler!") + body.reverse() + body = ''.join(body) + self.assertInBody(body) + + def test_06_empty_string_app(self): + if not cherrypy.server.using_wsgi: + return self.skip('skipped (not using WSGI)... ') + self.getPage('/hosted/app3') + self.assertHeader('Content-Type', 'text/plain') + self.assertInBody('Hello world') diff --git a/libraries/cherrypy/test/test_xmlrpc.py b/libraries/cherrypy/test/test_xmlrpc.py new file mode 100644 index 00000000..ad93b821 --- /dev/null +++ b/libraries/cherrypy/test/test_xmlrpc.py @@ -0,0 +1,183 @@ +import sys + +import six + +from six.moves.xmlrpc_client import ( + DateTime, Fault, + ProtocolError, ServerProxy, SafeTransport +) + +import cherrypy +from cherrypy import _cptools +from cherrypy.test import helper + +if six.PY3: + HTTPSTransport = SafeTransport + + # Python 3.0's SafeTransport still mistakenly checks for socket.ssl + import socket + if not hasattr(socket, 'ssl'): + socket.ssl = True +else: + class HTTPSTransport(SafeTransport): + + """Subclass of SafeTransport to fix sock.recv errors (by using file). + """ + + def request(self, host, handler, request_body, verbose=0): + # issue XML-RPC request + h = self.make_connection(host) + if verbose: + h.set_debuglevel(1) + + self.send_request(h, handler, request_body) + self.send_host(h, host) + self.send_user_agent(h) + self.send_content(h, request_body) + + errcode, errmsg, headers = h.getreply() + if errcode != 200: + raise ProtocolError(host + handler, errcode, errmsg, headers) + + self.verbose = verbose + + # Here's where we differ from the superclass. It says: + # try: + # sock = h._conn.sock + # except AttributeError: + # sock = None + # return self._parse_response(h.getfile(), sock) + + return self.parse_response(h.getfile()) + + +def setup_server(): + + class Root: + + @cherrypy.expose + def index(self): + return "I'm a standard index!" + + class XmlRpc(_cptools.XMLRPCController): + + @cherrypy.expose + def foo(self): + return 'Hello world!' + + @cherrypy.expose + def return_single_item_list(self): + return [42] + + @cherrypy.expose + def return_string(self): + return 'here is a string' + + @cherrypy.expose + def return_tuple(self): + return ('here', 'is', 1, 'tuple') + + @cherrypy.expose + def return_dict(self): + return dict(a=1, b=2, c=3) + + @cherrypy.expose + def return_composite(self): + return dict(a=1, z=26), 'hi', ['welcome', 'friend'] + + @cherrypy.expose + def return_int(self): + return 42 + + @cherrypy.expose + def return_float(self): + return 3.14 + + @cherrypy.expose + def return_datetime(self): + return DateTime((2003, 10, 7, 8, 1, 0, 1, 280, -1)) + + @cherrypy.expose + def return_boolean(self): + return True + + @cherrypy.expose + def test_argument_passing(self, num): + return num * 2 + + @cherrypy.expose + def test_returning_Fault(self): + return Fault(1, 'custom Fault response') + + root = Root() + root.xmlrpc = XmlRpc() + cherrypy.tree.mount(root, config={'/': { + 'request.dispatch': cherrypy.dispatch.XMLRPCDispatcher(), + 'tools.xmlrpc.allow_none': 0, + }}) + + +class XmlRpcTest(helper.CPWebCase): + setup_server = staticmethod(setup_server) + + def testXmlRpc(self): + + scheme = self.scheme + if scheme == 'https': + url = 'https://%s:%s/xmlrpc/' % (self.interface(), self.PORT) + proxy = ServerProxy(url, transport=HTTPSTransport()) + else: + url = 'http://%s:%s/xmlrpc/' % (self.interface(), self.PORT) + proxy = ServerProxy(url) + + # begin the tests ... + self.getPage('/xmlrpc/foo') + self.assertBody('Hello world!') + + self.assertEqual(proxy.return_single_item_list(), [42]) + self.assertNotEqual(proxy.return_single_item_list(), 'one bazillion') + self.assertEqual(proxy.return_string(), 'here is a string') + self.assertEqual(proxy.return_tuple(), + list(('here', 'is', 1, 'tuple'))) + self.assertEqual(proxy.return_dict(), {'a': 1, 'c': 3, 'b': 2}) + self.assertEqual(proxy.return_composite(), + [{'a': 1, 'z': 26}, 'hi', ['welcome', 'friend']]) + self.assertEqual(proxy.return_int(), 42) + self.assertEqual(proxy.return_float(), 3.14) + self.assertEqual(proxy.return_datetime(), + DateTime((2003, 10, 7, 8, 1, 0, 1, 280, -1))) + self.assertEqual(proxy.return_boolean(), True) + self.assertEqual(proxy.test_argument_passing(22), 22 * 2) + + # Test an error in the page handler (should raise an xmlrpclib.Fault) + try: + proxy.test_argument_passing({}) + except Exception: + x = sys.exc_info()[1] + self.assertEqual(x.__class__, Fault) + self.assertEqual(x.faultString, ('unsupported operand type(s) ' + "for *: 'dict' and 'int'")) + else: + self.fail('Expected xmlrpclib.Fault') + + # https://github.com/cherrypy/cherrypy/issues/533 + # if a method is not found, an xmlrpclib.Fault should be raised + try: + proxy.non_method() + except Exception: + x = sys.exc_info()[1] + self.assertEqual(x.__class__, Fault) + self.assertEqual(x.faultString, + 'method "non_method" is not supported') + else: + self.fail('Expected xmlrpclib.Fault') + + # Test returning a Fault from the page handler. + try: + proxy.test_returning_Fault() + except Exception: + x = sys.exc_info()[1] + self.assertEqual(x.__class__, Fault) + self.assertEqual(x.faultString, ('custom Fault response')) + else: + self.fail('Expected xmlrpclib.Fault') diff --git a/libraries/cherrypy/test/webtest.py b/libraries/cherrypy/test/webtest.py new file mode 100644 index 00000000..9fb6ce62 --- /dev/null +++ b/libraries/cherrypy/test/webtest.py @@ -0,0 +1,11 @@ +# for compatibility, expose cheroot webtest here +import warnings + +from cheroot.test.webtest import ( # noqa + interface, + WebCase, cleanHeaders, shb, openURL, + ServerError, server_error, +) + + +warnings.warn('Use cheroot.test.webtest', DeprecationWarning) diff --git a/libraries/cherrypy/tutorial/README.rst b/libraries/cherrypy/tutorial/README.rst new file mode 100644 index 00000000..c47e7d32 --- /dev/null +++ b/libraries/cherrypy/tutorial/README.rst @@ -0,0 +1,16 @@ +CherryPy Tutorials +------------------ + +This is a series of tutorials explaining how to develop dynamic web +applications using CherryPy. A couple of notes: + + +- Each of these tutorials builds on the ones before it. If you're + new to CherryPy, we recommend you start with 01_helloworld.py and + work your way upwards. :) + +- In most of these tutorials, you will notice that all output is done + by returning normal Python strings, often using simple Python + variable substitution. In most real-world applications, you will + probably want to use a separate template package (like Cheetah, + CherryTemplate or XML/XSL). diff --git a/libraries/cherrypy/tutorial/__init__.py b/libraries/cherrypy/tutorial/__init__.py new file mode 100644 index 00000000..08c142c5 --- /dev/null +++ b/libraries/cherrypy/tutorial/__init__.py @@ -0,0 +1,3 @@ + +# This is used in test_config to test unrepr of "from A import B" +thing2 = object() diff --git a/libraries/cherrypy/tutorial/custom_error.html b/libraries/cherrypy/tutorial/custom_error.html new file mode 100644 index 00000000..d0f30c8a --- /dev/null +++ b/libraries/cherrypy/tutorial/custom_error.html @@ -0,0 +1,14 @@ +<?xml version="1.0" encoding="UTF-8"?> +<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Strict//EN" + "http://www.w3.org/TR/xhtml1/DTD/xhtml1-strict.dtd"> +<html> +<head> + <title>403 Unauthorized</title> +</head> + <body> + <h2>You can't do that!</h2> + <p>%(message)s</p> + <p>This is a custom error page that is read from a file.<p> + <pre>%(traceback)s</pre> + </body> +</html> diff --git a/libraries/cherrypy/tutorial/pdf_file.pdf b/libraries/cherrypy/tutorial/pdf_file.pdf new file mode 100644 index 00000000..38b4f15e Binary files /dev/null and b/libraries/cherrypy/tutorial/pdf_file.pdf differ diff --git a/libraries/cherrypy/tutorial/tut01_helloworld.py b/libraries/cherrypy/tutorial/tut01_helloworld.py new file mode 100644 index 00000000..e86793c8 --- /dev/null +++ b/libraries/cherrypy/tutorial/tut01_helloworld.py @@ -0,0 +1,34 @@ +""" +Tutorial - Hello World + +The most basic (working) CherryPy application possible. +""" + +import os.path + +# Import CherryPy global namespace +import cherrypy + + +class HelloWorld: + + """ Sample request handler class. """ + + # Expose the index method through the web. CherryPy will never + # publish methods that don't have the exposed attribute set to True. + @cherrypy.expose + def index(self): + # CherryPy will call this method for the root URI ("/") and send + # its return value to the client. Because this is tutorial + # lesson number 01, we'll just send something really simple. + # How about... + return 'Hello world!' + + +tutconf = os.path.join(os.path.dirname(__file__), 'tutorial.conf') + +if __name__ == '__main__': + # CherryPy always starts with app.root when trying to map request URIs + # to objects, so we need to mount a request handler root. A request + # to '/' will be mapped to HelloWorld().index(). + cherrypy.quickstart(HelloWorld(), config=tutconf) diff --git a/libraries/cherrypy/tutorial/tut02_expose_methods.py b/libraries/cherrypy/tutorial/tut02_expose_methods.py new file mode 100644 index 00000000..8afbf7d8 --- /dev/null +++ b/libraries/cherrypy/tutorial/tut02_expose_methods.py @@ -0,0 +1,32 @@ +""" +Tutorial - Multiple methods + +This tutorial shows you how to link to other methods of your request +handler. +""" + +import os.path + +import cherrypy + + +class HelloWorld: + + @cherrypy.expose + def index(self): + # Let's link to another method here. + return 'We have an <a href="show_msg">important message</a> for you!' + + @cherrypy.expose + def show_msg(self): + # Here's the important message! + return 'Hello world!' + + +tutconf = os.path.join(os.path.dirname(__file__), 'tutorial.conf') + +if __name__ == '__main__': + # CherryPy always starts with app.root when trying to map request URIs + # to objects, so we need to mount a request handler root. A request + # to '/' will be mapped to HelloWorld().index(). + cherrypy.quickstart(HelloWorld(), config=tutconf) diff --git a/libraries/cherrypy/tutorial/tut03_get_and_post.py b/libraries/cherrypy/tutorial/tut03_get_and_post.py new file mode 100644 index 00000000..0b3d4613 --- /dev/null +++ b/libraries/cherrypy/tutorial/tut03_get_and_post.py @@ -0,0 +1,51 @@ +""" +Tutorial - Passing variables + +This tutorial shows you how to pass GET/POST variables to methods. +""" + +import os.path + +import cherrypy + + +class WelcomePage: + + @cherrypy.expose + def index(self): + # Ask for the user's name. + return ''' + <form action="greetUser" method="GET"> + What is your name? + <input type="text" name="name" /> + <input type="submit" /> + </form>''' + + @cherrypy.expose + def greetUser(self, name=None): + # CherryPy passes all GET and POST variables as method parameters. + # It doesn't make a difference where the variables come from, how + # large their contents are, and so on. + # + # You can define default parameter values as usual. In this + # example, the "name" parameter defaults to None so we can check + # if a name was actually specified. + + if name: + # Greet the user! + return "Hey %s, what's up?" % name + else: + if name is None: + # No name was specified + return 'Please enter your name <a href="./">here</a>.' + else: + return 'No, really, enter your name <a href="./">here</a>.' + + +tutconf = os.path.join(os.path.dirname(__file__), 'tutorial.conf') + +if __name__ == '__main__': + # CherryPy always starts with app.root when trying to map request URIs + # to objects, so we need to mount a request handler root. A request + # to '/' will be mapped to HelloWorld().index(). + cherrypy.quickstart(WelcomePage(), config=tutconf) diff --git a/libraries/cherrypy/tutorial/tut04_complex_site.py b/libraries/cherrypy/tutorial/tut04_complex_site.py new file mode 100644 index 00000000..3caa1775 --- /dev/null +++ b/libraries/cherrypy/tutorial/tut04_complex_site.py @@ -0,0 +1,103 @@ +""" +Tutorial - Multiple objects + +This tutorial shows you how to create a site structure through multiple +possibly nested request handler objects. +""" + +import os.path + +import cherrypy + + +class HomePage: + + @cherrypy.expose + def index(self): + return ''' + <p>Hi, this is the home page! Check out the other + fun stuff on this site:</p> + + <ul> + <li><a href="/joke/">A silly joke</a></li> + <li><a href="/links/">Useful links</a></li> + </ul>''' + + +class JokePage: + + @cherrypy.expose + def index(self): + return ''' + <p>"In Python, how do you create a string of random + characters?" -- "Read a Perl file!"</p> + <p>[<a href="../">Return</a>]</p>''' + + +class LinksPage: + + def __init__(self): + # Request handler objects can create their own nested request + # handler objects. Simply create them inside their __init__ + # methods! + self.extra = ExtraLinksPage() + + @cherrypy.expose + def index(self): + # Note the way we link to the extra links page (and back). + # As you can see, this object doesn't really care about its + # absolute position in the site tree, since we use relative + # links exclusively. + return ''' + <p>Here are some useful links:</p> + + <ul> + <li> + <a href="http://www.cherrypy.org">The CherryPy Homepage</a> + </li> + <li> + <a href="http://www.python.org">The Python Homepage</a> + </li> + </ul> + + <p>You can check out some extra useful + links <a href="./extra/">here</a>.</p> + + <p>[<a href="../">Return</a>]</p> + ''' + + +class ExtraLinksPage: + + @cherrypy.expose + def index(self): + # Note the relative link back to the Links page! + return ''' + <p>Here are some extra useful links:</p> + + <ul> + <li><a href="http://del.icio.us">del.icio.us</a></li> + <li><a href="http://www.cherrypy.org">CherryPy</a></li> + </ul> + + <p>[<a href="../">Return to links page</a>]</p>''' + + +# Of course we can also mount request handler objects right here! +root = HomePage() +root.joke = JokePage() +root.links = LinksPage() + +# Remember, we don't need to mount ExtraLinksPage here, because +# LinksPage does that itself on initialization. In fact, there is +# no reason why you shouldn't let your root object take care of +# creating all contained request handler objects. + + +tutconf = os.path.join(os.path.dirname(__file__), 'tutorial.conf') + +if __name__ == '__main__': + # CherryPy always starts with app.root when trying to map request URIs + # to objects, so we need to mount a request handler root. A request + # to '/' will be mapped to HelloWorld().index(). + cherrypy.quickstart(root, config=tutconf) diff --git a/libraries/cherrypy/tutorial/tut05_derived_objects.py b/libraries/cherrypy/tutorial/tut05_derived_objects.py new file mode 100644 index 00000000..f626e03f --- /dev/null +++ b/libraries/cherrypy/tutorial/tut05_derived_objects.py @@ -0,0 +1,80 @@ +""" +Tutorial - Object inheritance + +You are free to derive your request handler classes from any base +class you wish. In most real-world applications, you will probably +want to create a central base class used for all your pages, which takes +care of things like printing a common page header and footer. +""" + +import os.path + +import cherrypy + + +class Page: + # Store the page title in a class attribute + title = 'Untitled Page' + + def header(self): + return ''' + <html> + <head> + <title>%s</title> + <head> + <body> + <h2>%s</h2> + ''' % (self.title, self.title) + + def footer(self): + return ''' + </body> + </html> + ''' + + # Note that header and footer don't get their exposed attributes + # set to True. This isn't necessary since the user isn't supposed + # to call header or footer directly; instead, we'll call them from + # within the actually exposed handler methods defined in this + # class' subclasses. + + +class HomePage(Page): + # Different title for this page + title = 'Tutorial 5' + + def __init__(self): + # create a subpage + self.another = AnotherPage() + + @cherrypy.expose + def index(self): + # Note that we call the header and footer methods inherited + # from the Page class! + return self.header() + ''' + <p> + Isn't this exciting? There's + <a href="./another/">another page</a>, too! + </p> + ''' + self.footer() + + +class AnotherPage(Page): + title = 'Another Page' + + @cherrypy.expose + def index(self): + return self.header() + ''' + <p> + And this is the amazing second page! + </p> + ''' + self.footer() + + +tutconf = os.path.join(os.path.dirname(__file__), 'tutorial.conf') + +if __name__ == '__main__': + # CherryPy always starts with app.root when trying to map request URIs + # to objects, so we need to mount a request handler root. A request + # to '/' will be mapped to HelloWorld().index(). + cherrypy.quickstart(HomePage(), config=tutconf) diff --git a/libraries/cherrypy/tutorial/tut06_default_method.py b/libraries/cherrypy/tutorial/tut06_default_method.py new file mode 100644 index 00000000..0ce4cabe --- /dev/null +++ b/libraries/cherrypy/tutorial/tut06_default_method.py @@ -0,0 +1,61 @@ +""" +Tutorial - The default method + +Request handler objects can implement a method called "default" that +is called when no other suitable method/object could be found. +Essentially, if CherryPy2 can't find a matching request handler object +for the given request URI, it will use the default method of the object +located deepest on the URI path. + +Using this mechanism you can easily simulate virtual URI structures +by parsing the extra URI string, which you can access through +cherrypy.request.virtualPath. + +The application in this tutorial simulates an URI structure looking +like /users/<username>. Since the <username> bit will not be found (as +there are no matching methods), it is handled by the default method. +""" + +import os.path + +import cherrypy + + +class UsersPage: + + @cherrypy.expose + def index(self): + # Since this is just a stupid little example, we'll simply + # display a list of links to random, made-up users. In a real + # application, this could be generated from a database result set. + return ''' + <a href="./remi">Remi Delon</a><br/> + <a href="./hendrik">Hendrik Mans</a><br/> + <a href="./lorenzo">Lorenzo Lamas</a><br/> + ''' + + @cherrypy.expose + def default(self, user): + # Here we react depending on the virtualPath -- the part of the + # path that could not be mapped to an object method. In a real + # application, we would probably do some database lookups here + # instead of the silly if/elif/else construct. + if user == 'remi': + out = 'Remi Delon, CherryPy lead developer' + elif user == 'hendrik': + out = 'Hendrik Mans, CherryPy co-developer & crazy German' + elif user == 'lorenzo': + out = 'Lorenzo Lamas, famous actor and singer!' + else: + out = 'Unknown user. :-(' + + return '%s (<a href="./">back</a>)' % out + + +tutconf = os.path.join(os.path.dirname(__file__), 'tutorial.conf') + +if __name__ == '__main__': + # CherryPy always starts with app.root when trying to map request URIs + # to objects, so we need to mount a request handler root. A request + # to '/' will be mapped to HelloWorld().index(). + cherrypy.quickstart(UsersPage(), config=tutconf) diff --git a/libraries/cherrypy/tutorial/tut07_sessions.py b/libraries/cherrypy/tutorial/tut07_sessions.py new file mode 100644 index 00000000..204322b5 --- /dev/null +++ b/libraries/cherrypy/tutorial/tut07_sessions.py @@ -0,0 +1,41 @@ +""" +Tutorial - Sessions + +Storing session data in CherryPy applications is very easy: cherrypy +provides a dictionary called "session" that represents the session +data for the current user. If you use RAM based sessions, you can store +any kind of object into that dictionary; otherwise, you are limited to +objects that can be pickled. +""" + +import os.path + +import cherrypy + + +class HitCounter: + + _cp_config = {'tools.sessions.on': True} + + @cherrypy.expose + def index(self): + # Increase the silly hit counter + count = cherrypy.session.get('count', 0) + 1 + + # Store the new value in the session dictionary + cherrypy.session['count'] = count + + # And display a silly hit count message! + return ''' + During your current session, you've viewed this + page %s times! Your life is a patio of fun! + ''' % count + + +tutconf = os.path.join(os.path.dirname(__file__), 'tutorial.conf') + +if __name__ == '__main__': + # CherryPy always starts with app.root when trying to map request URIs + # to objects, so we need to mount a request handler root. A request + # to '/' will be mapped to HelloWorld().index(). + cherrypy.quickstart(HitCounter(), config=tutconf) diff --git a/libraries/cherrypy/tutorial/tut08_generators_and_yield.py b/libraries/cherrypy/tutorial/tut08_generators_and_yield.py new file mode 100644 index 00000000..18f42f93 --- /dev/null +++ b/libraries/cherrypy/tutorial/tut08_generators_and_yield.py @@ -0,0 +1,44 @@ +""" +Bonus Tutorial: Using generators to return result bodies + +Instead of returning a complete result string, you can use the yield +statement to return one result part after another. This may be convenient +in situations where using a template package like CherryPy or Cheetah +would be overkill, and messy string concatenation too uncool. ;-) +""" + +import os.path + +import cherrypy + + +class GeneratorDemo: + + def header(self): + return '<html><body><h2>Generators rule!</h2>' + + def footer(self): + return '</body></html>' + + @cherrypy.expose + def index(self): + # Let's make up a list of users for presentation purposes + users = ['Remi', 'Carlos', 'Hendrik', 'Lorenzo Lamas'] + + # Every yield line adds one part to the total result body. + yield self.header() + yield '<h3>List of users:</h3>' + + for user in users: + yield '%s<br/>' % user + + yield self.footer() + + +tutconf = os.path.join(os.path.dirname(__file__), 'tutorial.conf') + +if __name__ == '__main__': + # CherryPy always starts with app.root when trying to map request URIs + # to objects, so we need to mount a request handler root. A request + # to '/' will be mapped to HelloWorld().index(). + cherrypy.quickstart(GeneratorDemo(), config=tutconf) diff --git a/libraries/cherrypy/tutorial/tut09_files.py b/libraries/cherrypy/tutorial/tut09_files.py new file mode 100644 index 00000000..48585cbe --- /dev/null +++ b/libraries/cherrypy/tutorial/tut09_files.py @@ -0,0 +1,105 @@ +""" + +Tutorial: File upload and download + +Uploads +------- + +When a client uploads a file to a CherryPy application, it's placed +on disk immediately. CherryPy will pass it to your exposed method +as an argument (see "myFile" below); that arg will have a "file" +attribute, which is a handle to the temporary uploaded file. +If you wish to permanently save the file, you need to read() +from myFile.file and write() somewhere else. + +Note the use of 'enctype="multipart/form-data"' and 'input type="file"' +in the HTML which the client uses to upload the file. + + +Downloads +--------- + +If you wish to send a file to the client, you have two options: +First, you can simply return a file-like object from your page handler. +CherryPy will read the file and serve it as the content (HTTP body) +of the response. However, that doesn't tell the client that +the response is a file to be saved, rather than displayed. +Use cherrypy.lib.static.serve_file for that; it takes four +arguments: + +serve_file(path, content_type=None, disposition=None, name=None) + +Set "name" to the filename that you expect clients to use when they save +your file. Note that the "name" argument is ignored if you don't also +provide a "disposition" (usually "attachement"). You can manually set +"content_type", but be aware that if you also use the encoding tool, it +may choke if the file extension is not recognized as belonging to a known +Content-Type. Setting the content_type to "application/x-download" works +in most cases, and should prompt the user with an Open/Save dialog in +popular browsers. + +""" + +import os +import os.path + +import cherrypy +from cherrypy.lib import static + +localDir = os.path.dirname(__file__) +absDir = os.path.join(os.getcwd(), localDir) + + +class FileDemo(object): + + @cherrypy.expose + def index(self): + return """ + <html><body> + <h2>Upload a file</h2> + <form action="upload" method="post" enctype="multipart/form-data"> + filename: <input type="file" name="myFile" /><br /> + <input type="submit" /> + </form> + <h2>Download a file</h2> + <a href='download'>This one</a> + </body></html> + """ + + @cherrypy.expose + def upload(self, myFile): + out = """<html> + <body> + myFile length: %s<br /> + myFile filename: %s<br /> + myFile mime-type: %s + </body> + </html>""" + + # Although this just counts the file length, it demonstrates + # how to read large files in chunks instead of all at once. + # CherryPy reads the uploaded file into a temporary file; + # myFile.file.read reads from that. + size = 0 + while True: + data = myFile.file.read(8192) + if not data: + break + size += len(data) + + return out % (size, myFile.filename, myFile.content_type) + + @cherrypy.expose + def download(self): + path = os.path.join(absDir, 'pdf_file.pdf') + return static.serve_file(path, 'application/x-download', + 'attachment', os.path.basename(path)) + + +tutconf = os.path.join(os.path.dirname(__file__), 'tutorial.conf') + +if __name__ == '__main__': + # CherryPy always starts with app.root when trying to map request URIs + # to objects, so we need to mount a request handler root. A request + # to '/' will be mapped to HelloWorld().index(). + cherrypy.quickstart(FileDemo(), config=tutconf) diff --git a/libraries/cherrypy/tutorial/tut10_http_errors.py b/libraries/cherrypy/tutorial/tut10_http_errors.py new file mode 100644 index 00000000..18f02fd0 --- /dev/null +++ b/libraries/cherrypy/tutorial/tut10_http_errors.py @@ -0,0 +1,84 @@ +""" + +Tutorial: HTTP errors + +HTTPError is used to return an error response to the client. +CherryPy has lots of options regarding how such errors are +logged, displayed, and formatted. + +""" + +import os +import os.path + +import cherrypy + +localDir = os.path.dirname(__file__) +curpath = os.path.normpath(os.path.join(os.getcwd(), localDir)) + + +class HTTPErrorDemo(object): + + # Set a custom response for 403 errors. + _cp_config = {'error_page.403': + os.path.join(curpath, 'custom_error.html')} + + @cherrypy.expose + def index(self): + # display some links that will result in errors + tracebacks = cherrypy.request.show_tracebacks + if tracebacks: + trace = 'off' + else: + trace = 'on' + + return """ + <html><body> + <p>Toggle tracebacks <a href="toggleTracebacks">%s</a></p> + <p><a href="/doesNotExist">Click me; I'm a broken link!</a></p> + <p> + <a href="/error?code=403"> + Use a custom error page from a file. + </a> + </p> + <p>These errors are explicitly raised by the application:</p> + <ul> + <li><a href="/error?code=400">400</a></li> + <li><a href="/error?code=401">401</a></li> + <li><a href="/error?code=402">402</a></li> + <li><a href="/error?code=500">500</a></li> + </ul> + <p><a href="/messageArg">You can also set the response body + when you raise an error.</a></p> + </body></html> + """ % trace + + @cherrypy.expose + def toggleTracebacks(self): + # simple function to toggle tracebacks on and off + tracebacks = cherrypy.request.show_tracebacks + cherrypy.config.update({'request.show_tracebacks': not tracebacks}) + + # redirect back to the index + raise cherrypy.HTTPRedirect('/') + + @cherrypy.expose + def error(self, code): + # raise an error based on the get query + raise cherrypy.HTTPError(status=code) + + @cherrypy.expose + def messageArg(self): + message = ("If you construct an HTTPError with a 'message' " + 'argument, it wil be placed on the error page ' + '(underneath the status line by default).') + raise cherrypy.HTTPError(500, message=message) + + +tutconf = os.path.join(os.path.dirname(__file__), 'tutorial.conf') + +if __name__ == '__main__': + # CherryPy always starts with app.root when trying to map request URIs + # to objects, so we need to mount a request handler root. A request + # to '/' will be mapped to HelloWorld().index(). + cherrypy.quickstart(HTTPErrorDemo(), config=tutconf) diff --git a/libraries/cherrypy/tutorial/tutorial.conf b/libraries/cherrypy/tutorial/tutorial.conf new file mode 100644 index 00000000..43dfa60f --- /dev/null +++ b/libraries/cherrypy/tutorial/tutorial.conf @@ -0,0 +1,4 @@ +[global] +server.socket_host = "127.0.0.1" +server.socket_port = 8080 +server.thread_pool = 10 diff --git a/libraries/contextlib2.py b/libraries/contextlib2.py new file mode 100644 index 00000000..f08df14c --- /dev/null +++ b/libraries/contextlib2.py @@ -0,0 +1,436 @@ +"""contextlib2 - backports and enhancements to the contextlib module""" + +import sys +import warnings +from collections import deque +from functools import wraps + +__all__ = ["contextmanager", "closing", "ContextDecorator", "ExitStack", + "redirect_stdout", "redirect_stderr", "suppress"] + +# Backwards compatibility +__all__ += ["ContextStack"] + +class ContextDecorator(object): + "A base class or mixin that enables context managers to work as decorators." + + def refresh_cm(self): + """Returns the context manager used to actually wrap the call to the + decorated function. + + The default implementation just returns *self*. + + Overriding this method allows otherwise one-shot context managers + like _GeneratorContextManager to support use as decorators via + implicit recreation. + + DEPRECATED: refresh_cm was never added to the standard library's + ContextDecorator API + """ + warnings.warn("refresh_cm was never added to the standard library", + DeprecationWarning) + return self._recreate_cm() + + def _recreate_cm(self): + """Return a recreated instance of self. + + Allows an otherwise one-shot context manager like + _GeneratorContextManager to support use as + a decorator via implicit recreation. + + This is a private interface just for _GeneratorContextManager. + See issue #11647 for details. + """ + return self + + def __call__(self, func): + @wraps(func) + def inner(*args, **kwds): + with self._recreate_cm(): + return func(*args, **kwds) + return inner + + +class _GeneratorContextManager(ContextDecorator): + """Helper for @contextmanager decorator.""" + + def __init__(self, func, args, kwds): + self.gen = func(*args, **kwds) + self.func, self.args, self.kwds = func, args, kwds + # Issue 19330: ensure context manager instances have good docstrings + doc = getattr(func, "__doc__", None) + if doc is None: + doc = type(self).__doc__ + self.__doc__ = doc + # Unfortunately, this still doesn't provide good help output when + # inspecting the created context manager instances, since pydoc + # currently bypasses the instance docstring and shows the docstring + # for the class instead. + # See http://bugs.python.org/issue19404 for more details. + + def _recreate_cm(self): + # _GCM instances are one-shot context managers, so the + # CM must be recreated each time a decorated function is + # called + return self.__class__(self.func, self.args, self.kwds) + + def __enter__(self): + try: + return next(self.gen) + except StopIteration: + raise RuntimeError("generator didn't yield") + + def __exit__(self, type, value, traceback): + if type is None: + try: + next(self.gen) + except StopIteration: + return + else: + raise RuntimeError("generator didn't stop") + else: + if value is None: + # Need to force instantiation so we can reliably + # tell if we get the same exception back + value = type() + try: + self.gen.throw(type, value, traceback) + raise RuntimeError("generator didn't stop after throw()") + except StopIteration as exc: + # Suppress StopIteration *unless* it's the same exception that + # was passed to throw(). This prevents a StopIteration + # raised inside the "with" statement from being suppressed. + return exc is not value + except RuntimeError as exc: + # Don't re-raise the passed in exception + if exc is value: + return False + # Likewise, avoid suppressing if a StopIteration exception + # was passed to throw() and later wrapped into a RuntimeError + # (see PEP 479). + if _HAVE_EXCEPTION_CHAINING and exc.__cause__ is value: + return False + raise + except: + # only re-raise if it's *not* the exception that was + # passed to throw(), because __exit__() must not raise + # an exception unless __exit__() itself failed. But throw() + # has to raise the exception to signal propagation, so this + # fixes the impedance mismatch between the throw() protocol + # and the __exit__() protocol. + # + if sys.exc_info()[1] is not value: + raise + + +def contextmanager(func): + """@contextmanager decorator. + + Typical usage: + + @contextmanager + def some_generator(<arguments>): + <setup> + try: + yield <value> + finally: + <cleanup> + + This makes this: + + with some_generator(<arguments>) as <variable>: + <body> + + equivalent to this: + + <setup> + try: + <variable> = <value> + <body> + finally: + <cleanup> + + """ + @wraps(func) + def helper(*args, **kwds): + return _GeneratorContextManager(func, args, kwds) + return helper + + +class closing(object): + """Context to automatically close something at the end of a block. + + Code like this: + + with closing(<module>.open(<arguments>)) as f: + <block> + + is equivalent to this: + + f = <module>.open(<arguments>) + try: + <block> + finally: + f.close() + + """ + def __init__(self, thing): + self.thing = thing + def __enter__(self): + return self.thing + def __exit__(self, *exc_info): + self.thing.close() + + +class _RedirectStream(object): + + _stream = None + + def __init__(self, new_target): + self._new_target = new_target + # We use a list of old targets to make this CM re-entrant + self._old_targets = [] + + def __enter__(self): + self._old_targets.append(getattr(sys, self._stream)) + setattr(sys, self._stream, self._new_target) + return self._new_target + + def __exit__(self, exctype, excinst, exctb): + setattr(sys, self._stream, self._old_targets.pop()) + + +class redirect_stdout(_RedirectStream): + """Context manager for temporarily redirecting stdout to another file. + + # How to send help() to stderr + with redirect_stdout(sys.stderr): + help(dir) + + # How to write help() to a file + with open('help.txt', 'w') as f: + with redirect_stdout(f): + help(pow) + """ + + _stream = "stdout" + + +class redirect_stderr(_RedirectStream): + """Context manager for temporarily redirecting stderr to another file.""" + + _stream = "stderr" + + +class suppress(object): + """Context manager to suppress specified exceptions + + After the exception is suppressed, execution proceeds with the next + statement following the with statement. + + with suppress(FileNotFoundError): + os.remove(somefile) + # Execution still resumes here if the file was already removed + """ + + def __init__(self, *exceptions): + self._exceptions = exceptions + + def __enter__(self): + pass + + def __exit__(self, exctype, excinst, exctb): + # Unlike isinstance and issubclass, CPython exception handling + # currently only looks at the concrete type hierarchy (ignoring + # the instance and subclass checking hooks). While Guido considers + # that a bug rather than a feature, it's a fairly hard one to fix + # due to various internal implementation details. suppress provides + # the simpler issubclass based semantics, rather than trying to + # exactly reproduce the limitations of the CPython interpreter. + # + # See http://bugs.python.org/issue12029 for more details + return exctype is not None and issubclass(exctype, self._exceptions) + + +# Context manipulation is Python 3 only +_HAVE_EXCEPTION_CHAINING = sys.version_info[0] >= 3 +if _HAVE_EXCEPTION_CHAINING: + def _make_context_fixer(frame_exc): + def _fix_exception_context(new_exc, old_exc): + # Context may not be correct, so find the end of the chain + while 1: + exc_context = new_exc.__context__ + if exc_context is old_exc: + # Context is already set correctly (see issue 20317) + return + if exc_context is None or exc_context is frame_exc: + break + new_exc = exc_context + # Change the end of the chain to point to the exception + # we expect it to reference + new_exc.__context__ = old_exc + return _fix_exception_context + + def _reraise_with_existing_context(exc_details): + try: + # bare "raise exc_details[1]" replaces our carefully + # set-up context + fixed_ctx = exc_details[1].__context__ + raise exc_details[1] + except BaseException: + exc_details[1].__context__ = fixed_ctx + raise +else: + # No exception context in Python 2 + def _make_context_fixer(frame_exc): + return lambda new_exc, old_exc: None + + # Use 3 argument raise in Python 2, + # but use exec to avoid SyntaxError in Python 3 + def _reraise_with_existing_context(exc_details): + exc_type, exc_value, exc_tb = exc_details + exec ("raise exc_type, exc_value, exc_tb") + +# Handle old-style classes if they exist +try: + from types import InstanceType +except ImportError: + # Python 3 doesn't have old-style classes + _get_type = type +else: + # Need to handle old-style context managers on Python 2 + def _get_type(obj): + obj_type = type(obj) + if obj_type is InstanceType: + return obj.__class__ # Old-style class + return obj_type # New-style class + +# Inspired by discussions on http://bugs.python.org/issue13585 +class ExitStack(object): + """Context manager for dynamic management of a stack of exit callbacks + + For example: + + with ExitStack() as stack: + files = [stack.enter_context(open(fname)) for fname in filenames] + # All opened files will automatically be closed at the end of + # the with statement, even if attempts to open files later + # in the list raise an exception + + """ + def __init__(self): + self._exit_callbacks = deque() + + def pop_all(self): + """Preserve the context stack by transferring it to a new instance""" + new_stack = type(self)() + new_stack._exit_callbacks = self._exit_callbacks + self._exit_callbacks = deque() + return new_stack + + def _push_cm_exit(self, cm, cm_exit): + """Helper to correctly register callbacks to __exit__ methods""" + def _exit_wrapper(*exc_details): + return cm_exit(cm, *exc_details) + _exit_wrapper.__self__ = cm + self.push(_exit_wrapper) + + def push(self, exit): + """Registers a callback with the standard __exit__ method signature + + Can suppress exceptions the same way __exit__ methods can. + + Also accepts any object with an __exit__ method (registering a call + to the method instead of the object itself) + """ + # We use an unbound method rather than a bound method to follow + # the standard lookup behaviour for special methods + _cb_type = _get_type(exit) + try: + exit_method = _cb_type.__exit__ + except AttributeError: + # Not a context manager, so assume its a callable + self._exit_callbacks.append(exit) + else: + self._push_cm_exit(exit, exit_method) + return exit # Allow use as a decorator + + def callback(self, callback, *args, **kwds): + """Registers an arbitrary callback and arguments. + + Cannot suppress exceptions. + """ + def _exit_wrapper(exc_type, exc, tb): + callback(*args, **kwds) + # We changed the signature, so using @wraps is not appropriate, but + # setting __wrapped__ may still help with introspection + _exit_wrapper.__wrapped__ = callback + self.push(_exit_wrapper) + return callback # Allow use as a decorator + + def enter_context(self, cm): + """Enters the supplied context manager + + If successful, also pushes its __exit__ method as a callback and + returns the result of the __enter__ method. + """ + # We look up the special methods on the type to match the with statement + _cm_type = _get_type(cm) + _exit = _cm_type.__exit__ + result = _cm_type.__enter__(cm) + self._push_cm_exit(cm, _exit) + return result + + def close(self): + """Immediately unwind the context stack""" + self.__exit__(None, None, None) + + def __enter__(self): + return self + + def __exit__(self, *exc_details): + received_exc = exc_details[0] is not None + + # We manipulate the exception state so it behaves as though + # we were actually nesting multiple with statements + frame_exc = sys.exc_info()[1] + _fix_exception_context = _make_context_fixer(frame_exc) + + # Callbacks are invoked in LIFO order to match the behaviour of + # nested context managers + suppressed_exc = False + pending_raise = False + while self._exit_callbacks: + cb = self._exit_callbacks.pop() + try: + if cb(*exc_details): + suppressed_exc = True + pending_raise = False + exc_details = (None, None, None) + except: + new_exc_details = sys.exc_info() + # simulate the stack of exceptions by setting the context + _fix_exception_context(new_exc_details[1], exc_details[1]) + pending_raise = True + exc_details = new_exc_details + if pending_raise: + _reraise_with_existing_context(exc_details) + return received_exc and suppressed_exc + +# Preserve backwards compatibility +class ContextStack(ExitStack): + """Backwards compatibility alias for ExitStack""" + + def __init__(self): + warnings.warn("ContextStack has been renamed to ExitStack", + DeprecationWarning) + super(ContextStack, self).__init__() + + def register_exit(self, callback): + return self.push(callback) + + def register(self, callback, *args, **kwds): + return self.callback(callback, *args, **kwds) + + def preserve(self): + return self.pop_all() diff --git a/libraries/more_itertools/__init__.py b/libraries/more_itertools/__init__.py new file mode 100644 index 00000000..bba462c3 --- /dev/null +++ b/libraries/more_itertools/__init__.py @@ -0,0 +1,2 @@ +from more_itertools.more import * # noqa +from more_itertools.recipes import * # noqa diff --git a/libraries/more_itertools/more.py b/libraries/more_itertools/more.py new file mode 100644 index 00000000..05e851ee --- /dev/null +++ b/libraries/more_itertools/more.py @@ -0,0 +1,2211 @@ +from __future__ import print_function + +from collections import Counter, defaultdict, deque +from functools import partial, wraps +from heapq import merge +from itertools import ( + chain, + compress, + count, + cycle, + dropwhile, + groupby, + islice, + repeat, + starmap, + takewhile, + tee +) +from operator import itemgetter, lt, gt, sub +from sys import maxsize, version_info +try: + from collections.abc import Sequence +except ImportError: + from collections import Sequence + +from six import binary_type, string_types, text_type +from six.moves import filter, map, range, zip, zip_longest + +from .recipes import consume, flatten, take + +__all__ = [ + 'adjacent', + 'always_iterable', + 'always_reversible', + 'bucket', + 'chunked', + 'circular_shifts', + 'collapse', + 'collate', + 'consecutive_groups', + 'consumer', + 'count_cycle', + 'difference', + 'distinct_permutations', + 'distribute', + 'divide', + 'exactly_n', + 'first', + 'groupby_transform', + 'ilen', + 'interleave_longest', + 'interleave', + 'intersperse', + 'islice_extended', + 'iterate', + 'last', + 'locate', + 'lstrip', + 'make_decorator', + 'map_reduce', + 'numeric_range', + 'one', + 'padded', + 'peekable', + 'replace', + 'rlocate', + 'rstrip', + 'run_length', + 'seekable', + 'SequenceView', + 'side_effect', + 'sliced', + 'sort_together', + 'split_at', + 'split_after', + 'split_before', + 'spy', + 'stagger', + 'strip', + 'unique_to_each', + 'windowed', + 'with_iter', + 'zip_offset', +] + +_marker = object() + + +def chunked(iterable, n): + """Break *iterable* into lists of length *n*: + + >>> list(chunked([1, 2, 3, 4, 5, 6], 3)) + [[1, 2, 3], [4, 5, 6]] + + If the length of *iterable* is not evenly divisible by *n*, the last + returned list will be shorter: + + >>> list(chunked([1, 2, 3, 4, 5, 6, 7, 8], 3)) + [[1, 2, 3], [4, 5, 6], [7, 8]] + + To use a fill-in value instead, see the :func:`grouper` recipe. + + :func:`chunked` is useful for splitting up a computation on a large number + of keys into batches, to be pickled and sent off to worker processes. One + example is operations on rows in MySQL, which does not implement + server-side cursors properly and would otherwise load the entire dataset + into RAM on the client. + + """ + return iter(partial(take, n, iter(iterable)), []) + + +def first(iterable, default=_marker): + """Return the first item of *iterable*, or *default* if *iterable* is + empty. + + >>> first([0, 1, 2, 3]) + 0 + >>> first([], 'some default') + 'some default' + + If *default* is not provided and there are no items in the iterable, + raise ``ValueError``. + + :func:`first` is useful when you have a generator of expensive-to-retrieve + values and want any arbitrary one. It is marginally shorter than + ``next(iter(iterable), default)``. + + """ + try: + return next(iter(iterable)) + except StopIteration: + # I'm on the edge about raising ValueError instead of StopIteration. At + # the moment, ValueError wins, because the caller could conceivably + # want to do something different with flow control when I raise the + # exception, and it's weird to explicitly catch StopIteration. + if default is _marker: + raise ValueError('first() was called on an empty iterable, and no ' + 'default value was provided.') + return default + + +def last(iterable, default=_marker): + """Return the last item of *iterable*, or *default* if *iterable* is + empty. + + >>> last([0, 1, 2, 3]) + 3 + >>> last([], 'some default') + 'some default' + + If *default* is not provided and there are no items in the iterable, + raise ``ValueError``. + """ + try: + try: + # Try to access the last item directly + return iterable[-1] + except (TypeError, AttributeError, KeyError): + # If not slice-able, iterate entirely using length-1 deque + return deque(iterable, maxlen=1)[0] + except IndexError: # If the iterable was empty + if default is _marker: + raise ValueError('last() was called on an empty iterable, and no ' + 'default value was provided.') + return default + + +class peekable(object): + """Wrap an iterator to allow lookahead and prepending elements. + + Call :meth:`peek` on the result to get the value that will be returned + by :func:`next`. This won't advance the iterator: + + >>> p = peekable(['a', 'b']) + >>> p.peek() + 'a' + >>> next(p) + 'a' + + Pass :meth:`peek` a default value to return that instead of raising + ``StopIteration`` when the iterator is exhausted. + + >>> p = peekable([]) + >>> p.peek('hi') + 'hi' + + peekables also offer a :meth:`prepend` method, which "inserts" items + at the head of the iterable: + + >>> p = peekable([1, 2, 3]) + >>> p.prepend(10, 11, 12) + >>> next(p) + 10 + >>> p.peek() + 11 + >>> list(p) + [11, 12, 1, 2, 3] + + peekables can be indexed. Index 0 is the item that will be returned by + :func:`next`, index 1 is the item after that, and so on: + The values up to the given index will be cached. + + >>> p = peekable(['a', 'b', 'c', 'd']) + >>> p[0] + 'a' + >>> p[1] + 'b' + >>> next(p) + 'a' + + Negative indexes are supported, but be aware that they will cache the + remaining items in the source iterator, which may require significant + storage. + + To check whether a peekable is exhausted, check its truth value: + + >>> p = peekable(['a', 'b']) + >>> if p: # peekable has items + ... list(p) + ['a', 'b'] + >>> if not p: # peekable is exhaused + ... list(p) + [] + + """ + def __init__(self, iterable): + self._it = iter(iterable) + self._cache = deque() + + def __iter__(self): + return self + + def __bool__(self): + try: + self.peek() + except StopIteration: + return False + return True + + def __nonzero__(self): + # For Python 2 compatibility + return self.__bool__() + + def peek(self, default=_marker): + """Return the item that will be next returned from ``next()``. + + Return ``default`` if there are no items left. If ``default`` is not + provided, raise ``StopIteration``. + + """ + if not self._cache: + try: + self._cache.append(next(self._it)) + except StopIteration: + if default is _marker: + raise + return default + return self._cache[0] + + def prepend(self, *items): + """Stack up items to be the next ones returned from ``next()`` or + ``self.peek()``. The items will be returned in + first in, first out order:: + + >>> p = peekable([1, 2, 3]) + >>> p.prepend(10, 11, 12) + >>> next(p) + 10 + >>> list(p) + [11, 12, 1, 2, 3] + + It is possible, by prepending items, to "resurrect" a peekable that + previously raised ``StopIteration``. + + >>> p = peekable([]) + >>> next(p) + Traceback (most recent call last): + ... + StopIteration + >>> p.prepend(1) + >>> next(p) + 1 + >>> next(p) + Traceback (most recent call last): + ... + StopIteration + + """ + self._cache.extendleft(reversed(items)) + + def __next__(self): + if self._cache: + return self._cache.popleft() + + return next(self._it) + + next = __next__ # For Python 2 compatibility + + def _get_slice(self, index): + # Normalize the slice's arguments + step = 1 if (index.step is None) else index.step + if step > 0: + start = 0 if (index.start is None) else index.start + stop = maxsize if (index.stop is None) else index.stop + elif step < 0: + start = -1 if (index.start is None) else index.start + stop = (-maxsize - 1) if (index.stop is None) else index.stop + else: + raise ValueError('slice step cannot be zero') + + # If either the start or stop index is negative, we'll need to cache + # the rest of the iterable in order to slice from the right side. + if (start < 0) or (stop < 0): + self._cache.extend(self._it) + # Otherwise we'll need to find the rightmost index and cache to that + # point. + else: + n = min(max(start, stop) + 1, maxsize) + cache_len = len(self._cache) + if n >= cache_len: + self._cache.extend(islice(self._it, n - cache_len)) + + return list(self._cache)[index] + + def __getitem__(self, index): + if isinstance(index, slice): + return self._get_slice(index) + + cache_len = len(self._cache) + if index < 0: + self._cache.extend(self._it) + elif index >= cache_len: + self._cache.extend(islice(self._it, index + 1 - cache_len)) + + return self._cache[index] + + +def _collate(*iterables, **kwargs): + """Helper for ``collate()``, called when the user is using the ``reverse`` + or ``key`` keyword arguments on Python versions below 3.5. + + """ + key = kwargs.pop('key', lambda a: a) + reverse = kwargs.pop('reverse', False) + + min_or_max = partial(max if reverse else min, key=itemgetter(0)) + peekables = [peekable(it) for it in iterables] + peekables = [p for p in peekables if p] # Kill empties. + while peekables: + _, p = min_or_max((key(p.peek()), p) for p in peekables) + yield next(p) + peekables = [x for x in peekables if x] + + +def collate(*iterables, **kwargs): + """Return a sorted merge of the items from each of several already-sorted + *iterables*. + + >>> list(collate('ACDZ', 'AZ', 'JKL')) + ['A', 'A', 'C', 'D', 'J', 'K', 'L', 'Z', 'Z'] + + Works lazily, keeping only the next value from each iterable in memory. Use + :func:`collate` to, for example, perform a n-way mergesort of items that + don't fit in memory. + + If a *key* function is specified, the iterables will be sorted according + to its result: + + >>> key = lambda s: int(s) # Sort by numeric value, not by string + >>> list(collate(['1', '10'], ['2', '11'], key=key)) + ['1', '2', '10', '11'] + + + If the *iterables* are sorted in descending order, set *reverse* to + ``True``: + + >>> list(collate([5, 3, 1], [4, 2, 0], reverse=True)) + [5, 4, 3, 2, 1, 0] + + If the elements of the passed-in iterables are out of order, you might get + unexpected results. + + On Python 2.7, this function delegates to :func:`heapq.merge` if neither + of the keyword arguments are specified. On Python 3.5+, this function + is an alias for :func:`heapq.merge`. + + """ + if not kwargs: + return merge(*iterables) + + return _collate(*iterables, **kwargs) + + +# If using Python version 3.5 or greater, heapq.merge() will be faster than +# collate - use that instead. +if version_info >= (3, 5, 0): + _collate_docstring = collate.__doc__ + collate = partial(merge) + collate.__doc__ = _collate_docstring + + +def consumer(func): + """Decorator that automatically advances a PEP-342-style "reverse iterator" + to its first yield point so you don't have to call ``next()`` on it + manually. + + >>> @consumer + ... def tally(): + ... i = 0 + ... while True: + ... print('Thing number %s is %s.' % (i, (yield))) + ... i += 1 + ... + >>> t = tally() + >>> t.send('red') + Thing number 0 is red. + >>> t.send('fish') + Thing number 1 is fish. + + Without the decorator, you would have to call ``next(t)`` before + ``t.send()`` could be used. + + """ + @wraps(func) + def wrapper(*args, **kwargs): + gen = func(*args, **kwargs) + next(gen) + return gen + return wrapper + + +def ilen(iterable): + """Return the number of items in *iterable*. + + >>> ilen(x for x in range(1000000) if x % 3 == 0) + 333334 + + This consumes the iterable, so handle with care. + + """ + # maxlen=1 only stores the last item in the deque + d = deque(enumerate(iterable, 1), maxlen=1) + # since we started enumerate at 1, + # the first item of the last pair will be the length of the iterable + # (assuming there were items) + return d[0][0] if d else 0 + + +def iterate(func, start): + """Return ``start``, ``func(start)``, ``func(func(start))``, ... + + >>> from itertools import islice + >>> list(islice(iterate(lambda x: 2*x, 1), 10)) + [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] + + """ + while True: + yield start + start = func(start) + + +def with_iter(context_manager): + """Wrap an iterable in a ``with`` statement, so it closes once exhausted. + + For example, this will close the file when the iterator is exhausted:: + + upper_lines = (line.upper() for line in with_iter(open('foo'))) + + Any context manager which returns an iterable is a candidate for + ``with_iter``. + + """ + with context_manager as iterable: + for item in iterable: + yield item + + +def one(iterable, too_short=None, too_long=None): + """Return the first item from *iterable*, which is expected to contain only + that item. Raise an exception if *iterable* is empty or has more than one + item. + + :func:`one` is useful for ensuring that an iterable contains only one item. + For example, it can be used to retrieve the result of a database query + that is expected to return a single row. + + If *iterable* is empty, ``ValueError`` will be raised. You may specify a + different exception with the *too_short* keyword: + + >>> it = [] + >>> one(it) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError: too many items in iterable (expected 1)' + >>> too_short = IndexError('too few items') + >>> one(it, too_short=too_short) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + IndexError: too few items + + Similarly, if *iterable* contains more than one item, ``ValueError`` will + be raised. You may specify a different exception with the *too_long* + keyword: + + >>> it = ['too', 'many'] + >>> one(it) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError: too many items in iterable (expected 1)' + >>> too_long = RuntimeError + >>> one(it, too_long=too_long) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + RuntimeError + + Note that :func:`one` attempts to advance *iterable* twice to ensure there + is only one item. If there is more than one, both items will be discarded. + See :func:`spy` or :func:`peekable` to check iterable contents less + destructively. + + """ + it = iter(iterable) + + try: + value = next(it) + except StopIteration: + raise too_short or ValueError('too few items in iterable (expected 1)') + + try: + next(it) + except StopIteration: + pass + else: + raise too_long or ValueError('too many items in iterable (expected 1)') + + return value + + +def distinct_permutations(iterable): + """Yield successive distinct permutations of the elements in *iterable*. + + >>> sorted(distinct_permutations([1, 0, 1])) + [(0, 1, 1), (1, 0, 1), (1, 1, 0)] + + Equivalent to ``set(permutations(iterable))``, except duplicates are not + generated and thrown away. For larger input sequences this is much more + efficient. + + Duplicate permutations arise when there are duplicated elements in the + input iterable. The number of items returned is + `n! / (x_1! * x_2! * ... * x_n!)`, where `n` is the total number of + items input, and each `x_i` is the count of a distinct item in the input + sequence. + + """ + def perm_unique_helper(item_counts, perm, i): + """Internal helper function + + :arg item_counts: Stores the unique items in ``iterable`` and how many + times they are repeated + :arg perm: The permutation that is being built for output + :arg i: The index of the permutation being modified + + The output permutations are built up recursively; the distinct items + are placed until their repetitions are exhausted. + """ + if i < 0: + yield tuple(perm) + else: + for item in item_counts: + if item_counts[item] <= 0: + continue + perm[i] = item + item_counts[item] -= 1 + for x in perm_unique_helper(item_counts, perm, i - 1): + yield x + item_counts[item] += 1 + + item_counts = Counter(iterable) + length = sum(item_counts.values()) + + return perm_unique_helper(item_counts, [None] * length, length - 1) + + +def intersperse(e, iterable, n=1): + """Intersperse filler element *e* among the items in *iterable*, leaving + *n* items between each filler element. + + >>> list(intersperse('!', [1, 2, 3, 4, 5])) + [1, '!', 2, '!', 3, '!', 4, '!', 5] + + >>> list(intersperse(None, [1, 2, 3, 4, 5], n=2)) + [1, 2, None, 3, 4, None, 5] + + """ + if n == 0: + raise ValueError('n must be > 0') + elif n == 1: + # interleave(repeat(e), iterable) -> e, x_0, e, e, x_1, e, x_2... + # islice(..., 1, None) -> x_0, e, e, x_1, e, x_2... + return islice(interleave(repeat(e), iterable), 1, None) + else: + # interleave(filler, chunks) -> [e], [x_0, x_1], [e], [x_2, x_3]... + # islice(..., 1, None) -> [x_0, x_1], [e], [x_2, x_3]... + # flatten(...) -> x_0, x_1, e, x_2, x_3... + filler = repeat([e]) + chunks = chunked(iterable, n) + return flatten(islice(interleave(filler, chunks), 1, None)) + + +def unique_to_each(*iterables): + """Return the elements from each of the input iterables that aren't in the + other input iterables. + + For example, suppose you have a set of packages, each with a set of + dependencies:: + + {'pkg_1': {'A', 'B'}, 'pkg_2': {'B', 'C'}, 'pkg_3': {'B', 'D'}} + + If you remove one package, which dependencies can also be removed? + + If ``pkg_1`` is removed, then ``A`` is no longer necessary - it is not + associated with ``pkg_2`` or ``pkg_3``. Similarly, ``C`` is only needed for + ``pkg_2``, and ``D`` is only needed for ``pkg_3``:: + + >>> unique_to_each({'A', 'B'}, {'B', 'C'}, {'B', 'D'}) + [['A'], ['C'], ['D']] + + If there are duplicates in one input iterable that aren't in the others + they will be duplicated in the output. Input order is preserved:: + + >>> unique_to_each("mississippi", "missouri") + [['p', 'p'], ['o', 'u', 'r']] + + It is assumed that the elements of each iterable are hashable. + + """ + pool = [list(it) for it in iterables] + counts = Counter(chain.from_iterable(map(set, pool))) + uniques = {element for element in counts if counts[element] == 1} + return [list(filter(uniques.__contains__, it)) for it in pool] + + +def windowed(seq, n, fillvalue=None, step=1): + """Return a sliding window of width *n* over the given iterable. + + >>> all_windows = windowed([1, 2, 3, 4, 5], 3) + >>> list(all_windows) + [(1, 2, 3), (2, 3, 4), (3, 4, 5)] + + When the window is larger than the iterable, *fillvalue* is used in place + of missing values:: + + >>> list(windowed([1, 2, 3], 4)) + [(1, 2, 3, None)] + + Each window will advance in increments of *step*: + + >>> list(windowed([1, 2, 3, 4, 5, 6], 3, fillvalue='!', step=2)) + [(1, 2, 3), (3, 4, 5), (5, 6, '!')] + + """ + if n < 0: + raise ValueError('n must be >= 0') + if n == 0: + yield tuple() + return + if step < 1: + raise ValueError('step must be >= 1') + + it = iter(seq) + window = deque([], n) + append = window.append + + # Initial deque fill + for _ in range(n): + append(next(it, fillvalue)) + yield tuple(window) + + # Appending new items to the right causes old items to fall off the left + i = 0 + for item in it: + append(item) + i = (i + 1) % step + if i % step == 0: + yield tuple(window) + + # If there are items from the iterable in the window, pad with the given + # value and emit them. + if (i % step) and (step - i < n): + for _ in range(step - i): + append(fillvalue) + yield tuple(window) + + +class bucket(object): + """Wrap *iterable* and return an object that buckets it iterable into + child iterables based on a *key* function. + + >>> iterable = ['a1', 'b1', 'c1', 'a2', 'b2', 'c2', 'b3'] + >>> s = bucket(iterable, key=lambda x: x[0]) + >>> a_iterable = s['a'] + >>> next(a_iterable) + 'a1' + >>> next(a_iterable) + 'a2' + >>> list(s['b']) + ['b1', 'b2', 'b3'] + + The original iterable will be advanced and its items will be cached until + they are used by the child iterables. This may require significant storage. + + By default, attempting to select a bucket to which no items belong will + exhaust the iterable and cache all values. + If you specify a *validator* function, selected buckets will instead be + checked against it. + + >>> from itertools import count + >>> it = count(1, 2) # Infinite sequence of odd numbers + >>> key = lambda x: x % 10 # Bucket by last digit + >>> validator = lambda x: x in {1, 3, 5, 7, 9} # Odd digits only + >>> s = bucket(it, key=key, validator=validator) + >>> 2 in s + False + >>> list(s[2]) + [] + + """ + def __init__(self, iterable, key, validator=None): + self._it = iter(iterable) + self._key = key + self._cache = defaultdict(deque) + self._validator = validator or (lambda x: True) + + def __contains__(self, value): + if not self._validator(value): + return False + + try: + item = next(self[value]) + except StopIteration: + return False + else: + self._cache[value].appendleft(item) + + return True + + def _get_values(self, value): + """ + Helper to yield items from the parent iterator that match *value*. + Items that don't match are stored in the local cache as they + are encountered. + """ + while True: + # If we've cached some items that match the target value, emit + # the first one and evict it from the cache. + if self._cache[value]: + yield self._cache[value].popleft() + # Otherwise we need to advance the parent iterator to search for + # a matching item, caching the rest. + else: + while True: + try: + item = next(self._it) + except StopIteration: + return + item_value = self._key(item) + if item_value == value: + yield item + break + elif self._validator(item_value): + self._cache[item_value].append(item) + + def __getitem__(self, value): + if not self._validator(value): + return iter(()) + + return self._get_values(value) + + +def spy(iterable, n=1): + """Return a 2-tuple with a list containing the first *n* elements of + *iterable*, and an iterator with the same items as *iterable*. + This allows you to "look ahead" at the items in the iterable without + advancing it. + + There is one item in the list by default: + + >>> iterable = 'abcdefg' + >>> head, iterable = spy(iterable) + >>> head + ['a'] + >>> list(iterable) + ['a', 'b', 'c', 'd', 'e', 'f', 'g'] + + You may use unpacking to retrieve items instead of lists: + + >>> (head,), iterable = spy('abcdefg') + >>> head + 'a' + >>> (first, second), iterable = spy('abcdefg', 2) + >>> first + 'a' + >>> second + 'b' + + The number of items requested can be larger than the number of items in + the iterable: + + >>> iterable = [1, 2, 3, 4, 5] + >>> head, iterable = spy(iterable, 10) + >>> head + [1, 2, 3, 4, 5] + >>> list(iterable) + [1, 2, 3, 4, 5] + + """ + it = iter(iterable) + head = take(n, it) + + return head, chain(head, it) + + +def interleave(*iterables): + """Return a new iterable yielding from each iterable in turn, + until the shortest is exhausted. + + >>> list(interleave([1, 2, 3], [4, 5], [6, 7, 8])) + [1, 4, 6, 2, 5, 7] + + For a version that doesn't terminate after the shortest iterable is + exhausted, see :func:`interleave_longest`. + + """ + return chain.from_iterable(zip(*iterables)) + + +def interleave_longest(*iterables): + """Return a new iterable yielding from each iterable in turn, + skipping any that are exhausted. + + >>> list(interleave_longest([1, 2, 3], [4, 5], [6, 7, 8])) + [1, 4, 6, 2, 5, 7, 3, 8] + + This function produces the same output as :func:`roundrobin`, but may + perform better for some inputs (in particular when the number of iterables + is large). + + """ + i = chain.from_iterable(zip_longest(*iterables, fillvalue=_marker)) + return (x for x in i if x is not _marker) + + +def collapse(iterable, base_type=None, levels=None): + """Flatten an iterable with multiple levels of nesting (e.g., a list of + lists of tuples) into non-iterable types. + + >>> iterable = [(1, 2), ([3, 4], [[5], [6]])] + >>> list(collapse(iterable)) + [1, 2, 3, 4, 5, 6] + + String types are not considered iterable and will not be collapsed. + To avoid collapsing other types, specify *base_type*: + + >>> iterable = ['ab', ('cd', 'ef'), ['gh', 'ij']] + >>> list(collapse(iterable, base_type=tuple)) + ['ab', ('cd', 'ef'), 'gh', 'ij'] + + Specify *levels* to stop flattening after a certain level: + + >>> iterable = [('a', ['b']), ('c', ['d'])] + >>> list(collapse(iterable)) # Fully flattened + ['a', 'b', 'c', 'd'] + >>> list(collapse(iterable, levels=1)) # Only one level flattened + ['a', ['b'], 'c', ['d']] + + """ + def walk(node, level): + if ( + ((levels is not None) and (level > levels)) or + isinstance(node, string_types) or + ((base_type is not None) and isinstance(node, base_type)) + ): + yield node + return + + try: + tree = iter(node) + except TypeError: + yield node + return + else: + for child in tree: + for x in walk(child, level + 1): + yield x + + for x in walk(iterable, 0): + yield x + + +def side_effect(func, iterable, chunk_size=None, before=None, after=None): + """Invoke *func* on each item in *iterable* (or on each *chunk_size* group + of items) before yielding the item. + + `func` must be a function that takes a single argument. Its return value + will be discarded. + + *before* and *after* are optional functions that take no arguments. They + will be executed before iteration starts and after it ends, respectively. + + `side_effect` can be used for logging, updating progress bars, or anything + that is not functionally "pure." + + Emitting a status message: + + >>> from more_itertools import consume + >>> func = lambda item: print('Received {}'.format(item)) + >>> consume(side_effect(func, range(2))) + Received 0 + Received 1 + + Operating on chunks of items: + + >>> pair_sums = [] + >>> func = lambda chunk: pair_sums.append(sum(chunk)) + >>> list(side_effect(func, [0, 1, 2, 3, 4, 5], 2)) + [0, 1, 2, 3, 4, 5] + >>> list(pair_sums) + [1, 5, 9] + + Writing to a file-like object: + + >>> from io import StringIO + >>> from more_itertools import consume + >>> f = StringIO() + >>> func = lambda x: print(x, file=f) + >>> before = lambda: print(u'HEADER', file=f) + >>> after = f.close + >>> it = [u'a', u'b', u'c'] + >>> consume(side_effect(func, it, before=before, after=after)) + >>> f.closed + True + + """ + try: + if before is not None: + before() + + if chunk_size is None: + for item in iterable: + func(item) + yield item + else: + for chunk in chunked(iterable, chunk_size): + func(chunk) + for item in chunk: + yield item + finally: + if after is not None: + after() + + +def sliced(seq, n): + """Yield slices of length *n* from the sequence *seq*. + + >>> list(sliced((1, 2, 3, 4, 5, 6), 3)) + [(1, 2, 3), (4, 5, 6)] + + If the length of the sequence is not divisible by the requested slice + length, the last slice will be shorter. + + >>> list(sliced((1, 2, 3, 4, 5, 6, 7, 8), 3)) + [(1, 2, 3), (4, 5, 6), (7, 8)] + + This function will only work for iterables that support slicing. + For non-sliceable iterables, see :func:`chunked`. + + """ + return takewhile(bool, (seq[i: i + n] for i in count(0, n))) + + +def split_at(iterable, pred): + """Yield lists of items from *iterable*, where each list is delimited by + an item where callable *pred* returns ``True``. The lists do not include + the delimiting items. + + >>> list(split_at('abcdcba', lambda x: x == 'b')) + [['a'], ['c', 'd', 'c'], ['a']] + + >>> list(split_at(range(10), lambda n: n % 2 == 1)) + [[0], [2], [4], [6], [8], []] + """ + buf = [] + for item in iterable: + if pred(item): + yield buf + buf = [] + else: + buf.append(item) + yield buf + + +def split_before(iterable, pred): + """Yield lists of items from *iterable*, where each list starts with an + item where callable *pred* returns ``True``: + + >>> list(split_before('OneTwo', lambda s: s.isupper())) + [['O', 'n', 'e'], ['T', 'w', 'o']] + + >>> list(split_before(range(10), lambda n: n % 3 == 0)) + [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] + + """ + buf = [] + for item in iterable: + if pred(item) and buf: + yield buf + buf = [] + buf.append(item) + yield buf + + +def split_after(iterable, pred): + """Yield lists of items from *iterable*, where each list ends with an + item where callable *pred* returns ``True``: + + >>> list(split_after('one1two2', lambda s: s.isdigit())) + [['o', 'n', 'e', '1'], ['t', 'w', 'o', '2']] + + >>> list(split_after(range(10), lambda n: n % 3 == 0)) + [[0], [1, 2, 3], [4, 5, 6], [7, 8, 9]] + + """ + buf = [] + for item in iterable: + buf.append(item) + if pred(item) and buf: + yield buf + buf = [] + if buf: + yield buf + + +def padded(iterable, fillvalue=None, n=None, next_multiple=False): + """Yield the elements from *iterable*, followed by *fillvalue*, such that + at least *n* items are emitted. + + >>> list(padded([1, 2, 3], '?', 5)) + [1, 2, 3, '?', '?'] + + If *next_multiple* is ``True``, *fillvalue* will be emitted until the + number of items emitted is a multiple of *n*:: + + >>> list(padded([1, 2, 3, 4], n=3, next_multiple=True)) + [1, 2, 3, 4, None, None] + + If *n* is ``None``, *fillvalue* will be emitted indefinitely. + + """ + it = iter(iterable) + if n is None: + for item in chain(it, repeat(fillvalue)): + yield item + elif n < 1: + raise ValueError('n must be at least 1') + else: + item_count = 0 + for item in it: + yield item + item_count += 1 + + remaining = (n - item_count) % n if next_multiple else n - item_count + for _ in range(remaining): + yield fillvalue + + +def distribute(n, iterable): + """Distribute the items from *iterable* among *n* smaller iterables. + + >>> group_1, group_2 = distribute(2, [1, 2, 3, 4, 5, 6]) + >>> list(group_1) + [1, 3, 5] + >>> list(group_2) + [2, 4, 6] + + If the length of *iterable* is not evenly divisible by *n*, then the + length of the returned iterables will not be identical: + + >>> children = distribute(3, [1, 2, 3, 4, 5, 6, 7]) + >>> [list(c) for c in children] + [[1, 4, 7], [2, 5], [3, 6]] + + If the length of *iterable* is smaller than *n*, then the last returned + iterables will be empty: + + >>> children = distribute(5, [1, 2, 3]) + >>> [list(c) for c in children] + [[1], [2], [3], [], []] + + This function uses :func:`itertools.tee` and may require significant + storage. If you need the order items in the smaller iterables to match the + original iterable, see :func:`divide`. + + """ + if n < 1: + raise ValueError('n must be at least 1') + + children = tee(iterable, n) + return [islice(it, index, None, n) for index, it in enumerate(children)] + + +def stagger(iterable, offsets=(-1, 0, 1), longest=False, fillvalue=None): + """Yield tuples whose elements are offset from *iterable*. + The amount by which the `i`-th item in each tuple is offset is given by + the `i`-th item in *offsets*. + + >>> list(stagger([0, 1, 2, 3])) + [(None, 0, 1), (0, 1, 2), (1, 2, 3)] + >>> list(stagger(range(8), offsets=(0, 2, 4))) + [(0, 2, 4), (1, 3, 5), (2, 4, 6), (3, 5, 7)] + + By default, the sequence will end when the final element of a tuple is the + last item in the iterable. To continue until the first element of a tuple + is the last item in the iterable, set *longest* to ``True``:: + + >>> list(stagger([0, 1, 2, 3], longest=True)) + [(None, 0, 1), (0, 1, 2), (1, 2, 3), (2, 3, None), (3, None, None)] + + By default, ``None`` will be used to replace offsets beyond the end of the + sequence. Specify *fillvalue* to use some other value. + + """ + children = tee(iterable, len(offsets)) + + return zip_offset( + *children, offsets=offsets, longest=longest, fillvalue=fillvalue + ) + + +def zip_offset(*iterables, **kwargs): + """``zip`` the input *iterables* together, but offset the `i`-th iterable + by the `i`-th item in *offsets*. + + >>> list(zip_offset('0123', 'abcdef', offsets=(0, 1))) + [('0', 'b'), ('1', 'c'), ('2', 'd'), ('3', 'e')] + + This can be used as a lightweight alternative to SciPy or pandas to analyze + data sets in which somes series have a lead or lag relationship. + + By default, the sequence will end when the shortest iterable is exhausted. + To continue until the longest iterable is exhausted, set *longest* to + ``True``. + + >>> list(zip_offset('0123', 'abcdef', offsets=(0, 1), longest=True)) + [('0', 'b'), ('1', 'c'), ('2', 'd'), ('3', 'e'), (None, 'f')] + + By default, ``None`` will be used to replace offsets beyond the end of the + sequence. Specify *fillvalue* to use some other value. + + """ + offsets = kwargs['offsets'] + longest = kwargs.get('longest', False) + fillvalue = kwargs.get('fillvalue', None) + + if len(iterables) != len(offsets): + raise ValueError("Number of iterables and offsets didn't match") + + staggered = [] + for it, n in zip(iterables, offsets): + if n < 0: + staggered.append(chain(repeat(fillvalue, -n), it)) + elif n > 0: + staggered.append(islice(it, n, None)) + else: + staggered.append(it) + + if longest: + return zip_longest(*staggered, fillvalue=fillvalue) + + return zip(*staggered) + + +def sort_together(iterables, key_list=(0,), reverse=False): + """Return the input iterables sorted together, with *key_list* as the + priority for sorting. All iterables are trimmed to the length of the + shortest one. + + This can be used like the sorting function in a spreadsheet. If each + iterable represents a column of data, the key list determines which + columns are used for sorting. + + By default, all iterables are sorted using the ``0``-th iterable:: + + >>> iterables = [(4, 3, 2, 1), ('a', 'b', 'c', 'd')] + >>> sort_together(iterables) + [(1, 2, 3, 4), ('d', 'c', 'b', 'a')] + + Set a different key list to sort according to another iterable. + Specifying mutliple keys dictates how ties are broken:: + + >>> iterables = [(3, 1, 2), (0, 1, 0), ('c', 'b', 'a')] + >>> sort_together(iterables, key_list=(1, 2)) + [(2, 3, 1), (0, 0, 1), ('a', 'c', 'b')] + + Set *reverse* to ``True`` to sort in descending order. + + >>> sort_together([(1, 2, 3), ('c', 'b', 'a')], reverse=True) + [(3, 2, 1), ('a', 'b', 'c')] + + """ + return list(zip(*sorted(zip(*iterables), + key=itemgetter(*key_list), + reverse=reverse))) + + +def divide(n, iterable): + """Divide the elements from *iterable* into *n* parts, maintaining + order. + + >>> group_1, group_2 = divide(2, [1, 2, 3, 4, 5, 6]) + >>> list(group_1) + [1, 2, 3] + >>> list(group_2) + [4, 5, 6] + + If the length of *iterable* is not evenly divisible by *n*, then the + length of the returned iterables will not be identical: + + >>> children = divide(3, [1, 2, 3, 4, 5, 6, 7]) + >>> [list(c) for c in children] + [[1, 2, 3], [4, 5], [6, 7]] + + If the length of the iterable is smaller than n, then the last returned + iterables will be empty: + + >>> children = divide(5, [1, 2, 3]) + >>> [list(c) for c in children] + [[1], [2], [3], [], []] + + This function will exhaust the iterable before returning and may require + significant storage. If order is not important, see :func:`distribute`, + which does not first pull the iterable into memory. + + """ + if n < 1: + raise ValueError('n must be at least 1') + + seq = tuple(iterable) + q, r = divmod(len(seq), n) + + ret = [] + for i in range(n): + start = (i * q) + (i if i < r else r) + stop = ((i + 1) * q) + (i + 1 if i + 1 < r else r) + ret.append(iter(seq[start:stop])) + + return ret + + +def always_iterable(obj, base_type=(text_type, binary_type)): + """If *obj* is iterable, return an iterator over its items:: + + >>> obj = (1, 2, 3) + >>> list(always_iterable(obj)) + [1, 2, 3] + + If *obj* is not iterable, return a one-item iterable containing *obj*:: + + >>> obj = 1 + >>> list(always_iterable(obj)) + [1] + + If *obj* is ``None``, return an empty iterable: + + >>> obj = None + >>> list(always_iterable(None)) + [] + + By default, binary and text strings are not considered iterable:: + + >>> obj = 'foo' + >>> list(always_iterable(obj)) + ['foo'] + + If *base_type* is set, objects for which ``isinstance(obj, base_type)`` + returns ``True`` won't be considered iterable. + + >>> obj = {'a': 1} + >>> list(always_iterable(obj)) # Iterate over the dict's keys + ['a'] + >>> list(always_iterable(obj, base_type=dict)) # Treat dicts as a unit + [{'a': 1}] + + Set *base_type* to ``None`` to avoid any special handling and treat objects + Python considers iterable as iterable: + + >>> obj = 'foo' + >>> list(always_iterable(obj, base_type=None)) + ['f', 'o', 'o'] + """ + if obj is None: + return iter(()) + + if (base_type is not None) and isinstance(obj, base_type): + return iter((obj,)) + + try: + return iter(obj) + except TypeError: + return iter((obj,)) + + +def adjacent(predicate, iterable, distance=1): + """Return an iterable over `(bool, item)` tuples where the `item` is + drawn from *iterable* and the `bool` indicates whether + that item satisfies the *predicate* or is adjacent to an item that does. + + For example, to find whether items are adjacent to a ``3``:: + + >>> list(adjacent(lambda x: x == 3, range(6))) + [(False, 0), (False, 1), (True, 2), (True, 3), (True, 4), (False, 5)] + + Set *distance* to change what counts as adjacent. For example, to find + whether items are two places away from a ``3``: + + >>> list(adjacent(lambda x: x == 3, range(6), distance=2)) + [(False, 0), (True, 1), (True, 2), (True, 3), (True, 4), (True, 5)] + + This is useful for contextualizing the results of a search function. + For example, a code comparison tool might want to identify lines that + have changed, but also surrounding lines to give the viewer of the diff + context. + + The predicate function will only be called once for each item in the + iterable. + + See also :func:`groupby_transform`, which can be used with this function + to group ranges of items with the same `bool` value. + + """ + # Allow distance=0 mainly for testing that it reproduces results with map() + if distance < 0: + raise ValueError('distance must be at least 0') + + i1, i2 = tee(iterable) + padding = [False] * distance + selected = chain(padding, map(predicate, i1), padding) + adjacent_to_selected = map(any, windowed(selected, 2 * distance + 1)) + return zip(adjacent_to_selected, i2) + + +def groupby_transform(iterable, keyfunc=None, valuefunc=None): + """An extension of :func:`itertools.groupby` that transforms the values of + *iterable* after grouping them. + *keyfunc* is a function used to compute a grouping key for each item. + *valuefunc* is a function for transforming the items after grouping. + + >>> iterable = 'AaaABbBCcA' + >>> keyfunc = lambda x: x.upper() + >>> valuefunc = lambda x: x.lower() + >>> grouper = groupby_transform(iterable, keyfunc, valuefunc) + >>> [(k, ''.join(g)) for k, g in grouper] + [('A', 'aaaa'), ('B', 'bbb'), ('C', 'cc'), ('A', 'a')] + + *keyfunc* and *valuefunc* default to identity functions if they are not + specified. + + :func:`groupby_transform` is useful when grouping elements of an iterable + using a separate iterable as the key. To do this, :func:`zip` the iterables + and pass a *keyfunc* that extracts the first element and a *valuefunc* + that extracts the second element:: + + >>> from operator import itemgetter + >>> keys = [0, 0, 1, 1, 1, 2, 2, 2, 3] + >>> values = 'abcdefghi' + >>> iterable = zip(keys, values) + >>> grouper = groupby_transform(iterable, itemgetter(0), itemgetter(1)) + >>> [(k, ''.join(g)) for k, g in grouper] + [(0, 'ab'), (1, 'cde'), (2, 'fgh'), (3, 'i')] + + Note that the order of items in the iterable is significant. + Only adjacent items are grouped together, so if you don't want any + duplicate groups, you should sort the iterable by the key function. + + """ + valuefunc = (lambda x: x) if valuefunc is None else valuefunc + return ((k, map(valuefunc, g)) for k, g in groupby(iterable, keyfunc)) + + +def numeric_range(*args): + """An extension of the built-in ``range()`` function whose arguments can + be any orderable numeric type. + + With only *stop* specified, *start* defaults to ``0`` and *step* + defaults to ``1``. The output items will match the type of *stop*: + + >>> list(numeric_range(3.5)) + [0.0, 1.0, 2.0, 3.0] + + With only *start* and *stop* specified, *step* defaults to ``1``. The + output items will match the type of *start*: + + >>> from decimal import Decimal + >>> start = Decimal('2.1') + >>> stop = Decimal('5.1') + >>> list(numeric_range(start, stop)) + [Decimal('2.1'), Decimal('3.1'), Decimal('4.1')] + + With *start*, *stop*, and *step* specified the output items will match + the type of ``start + step``: + + >>> from fractions import Fraction + >>> start = Fraction(1, 2) # Start at 1/2 + >>> stop = Fraction(5, 2) # End at 5/2 + >>> step = Fraction(1, 2) # Count by 1/2 + >>> list(numeric_range(start, stop, step)) + [Fraction(1, 2), Fraction(1, 1), Fraction(3, 2), Fraction(2, 1)] + + If *step* is zero, ``ValueError`` is raised. Negative steps are supported: + + >>> list(numeric_range(3, -1, -1.0)) + [3.0, 2.0, 1.0, 0.0] + + Be aware of the limitations of floating point numbers; the representation + of the yielded numbers may be surprising. + + """ + argc = len(args) + if argc == 1: + stop, = args + start = type(stop)(0) + step = 1 + elif argc == 2: + start, stop = args + step = 1 + elif argc == 3: + start, stop, step = args + else: + err_msg = 'numeric_range takes at most 3 arguments, got {}' + raise TypeError(err_msg.format(argc)) + + values = (start + (step * n) for n in count()) + if step > 0: + return takewhile(partial(gt, stop), values) + elif step < 0: + return takewhile(partial(lt, stop), values) + else: + raise ValueError('numeric_range arg 3 must not be zero') + + +def count_cycle(iterable, n=None): + """Cycle through the items from *iterable* up to *n* times, yielding + the number of completed cycles along with each item. If *n* is omitted the + process repeats indefinitely. + + >>> list(count_cycle('AB', 3)) + [(0, 'A'), (0, 'B'), (1, 'A'), (1, 'B'), (2, 'A'), (2, 'B')] + + """ + iterable = tuple(iterable) + if not iterable: + return iter(()) + counter = count() if n is None else range(n) + return ((i, item) for i in counter for item in iterable) + + +def locate(iterable, pred=bool, window_size=None): + """Yield the index of each item in *iterable* for which *pred* returns + ``True``. + + *pred* defaults to :func:`bool`, which will select truthy items: + + >>> list(locate([0, 1, 1, 0, 1, 0, 0])) + [1, 2, 4] + + Set *pred* to a custom function to, e.g., find the indexes for a particular + item. + + >>> list(locate(['a', 'b', 'c', 'b'], lambda x: x == 'b')) + [1, 3] + + If *window_size* is given, then the *pred* function will be called with + that many items. This enables searching for sub-sequences: + + >>> iterable = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3] + >>> pred = lambda *args: args == (1, 2, 3) + >>> list(locate(iterable, pred=pred, window_size=3)) + [1, 5, 9] + + Use with :func:`seekable` to find indexes and then retrieve the associated + items: + + >>> from itertools import count + >>> from more_itertools import seekable + >>> source = (3 * n + 1 if (n % 2) else n // 2 for n in count()) + >>> it = seekable(source) + >>> pred = lambda x: x > 100 + >>> indexes = locate(it, pred=pred) + >>> i = next(indexes) + >>> it.seek(i) + >>> next(it) + 106 + + """ + if window_size is None: + return compress(count(), map(pred, iterable)) + + if window_size < 1: + raise ValueError('window size must be at least 1') + + it = windowed(iterable, window_size, fillvalue=_marker) + return compress(count(), starmap(pred, it)) + + +def lstrip(iterable, pred): + """Yield the items from *iterable*, but strip any from the beginning + for which *pred* returns ``True``. + + For example, to remove a set of items from the start of an iterable: + + >>> iterable = (None, False, None, 1, 2, None, 3, False, None) + >>> pred = lambda x: x in {None, False, ''} + >>> list(lstrip(iterable, pred)) + [1, 2, None, 3, False, None] + + This function is analogous to to :func:`str.lstrip`, and is essentially + an wrapper for :func:`itertools.dropwhile`. + + """ + return dropwhile(pred, iterable) + + +def rstrip(iterable, pred): + """Yield the items from *iterable*, but strip any from the end + for which *pred* returns ``True``. + + For example, to remove a set of items from the end of an iterable: + + >>> iterable = (None, False, None, 1, 2, None, 3, False, None) + >>> pred = lambda x: x in {None, False, ''} + >>> list(rstrip(iterable, pred)) + [None, False, None, 1, 2, None, 3] + + This function is analogous to :func:`str.rstrip`. + + """ + cache = [] + cache_append = cache.append + for x in iterable: + if pred(x): + cache_append(x) + else: + for y in cache: + yield y + del cache[:] + yield x + + +def strip(iterable, pred): + """Yield the items from *iterable*, but strip any from the + beginning and end for which *pred* returns ``True``. + + For example, to remove a set of items from both ends of an iterable: + + >>> iterable = (None, False, None, 1, 2, None, 3, False, None) + >>> pred = lambda x: x in {None, False, ''} + >>> list(strip(iterable, pred)) + [1, 2, None, 3] + + This function is analogous to :func:`str.strip`. + + """ + return rstrip(lstrip(iterable, pred), pred) + + +def islice_extended(iterable, *args): + """An extension of :func:`itertools.islice` that supports negative values + for *stop*, *start*, and *step*. + + >>> iterable = iter('abcdefgh') + >>> list(islice_extended(iterable, -4, -1)) + ['e', 'f', 'g'] + + Slices with negative values require some caching of *iterable*, but this + function takes care to minimize the amount of memory required. + + For example, you can use a negative step with an infinite iterator: + + >>> from itertools import count + >>> list(islice_extended(count(), 110, 99, -2)) + [110, 108, 106, 104, 102, 100] + + """ + s = slice(*args) + start = s.start + stop = s.stop + if s.step == 0: + raise ValueError('step argument must be a non-zero integer or None.') + step = s.step or 1 + + it = iter(iterable) + + if step > 0: + start = 0 if (start is None) else start + + if (start < 0): + # Consume all but the last -start items + cache = deque(enumerate(it, 1), maxlen=-start) + len_iter = cache[-1][0] if cache else 0 + + # Adjust start to be positive + i = max(len_iter + start, 0) + + # Adjust stop to be positive + if stop is None: + j = len_iter + elif stop >= 0: + j = min(stop, len_iter) + else: + j = max(len_iter + stop, 0) + + # Slice the cache + n = j - i + if n <= 0: + return + + for index, item in islice(cache, 0, n, step): + yield item + elif (stop is not None) and (stop < 0): + # Advance to the start position + next(islice(it, start, start), None) + + # When stop is negative, we have to carry -stop items while + # iterating + cache = deque(islice(it, -stop), maxlen=-stop) + + for index, item in enumerate(it): + cached_item = cache.popleft() + if index % step == 0: + yield cached_item + cache.append(item) + else: + # When both start and stop are positive we have the normal case + for item in islice(it, start, stop, step): + yield item + else: + start = -1 if (start is None) else start + + if (stop is not None) and (stop < 0): + # Consume all but the last items + n = -stop - 1 + cache = deque(enumerate(it, 1), maxlen=n) + len_iter = cache[-1][0] if cache else 0 + + # If start and stop are both negative they are comparable and + # we can just slice. Otherwise we can adjust start to be negative + # and then slice. + if start < 0: + i, j = start, stop + else: + i, j = min(start - len_iter, -1), None + + for index, item in list(cache)[i:j:step]: + yield item + else: + # Advance to the stop position + if stop is not None: + m = stop + 1 + next(islice(it, m, m), None) + + # stop is positive, so if start is negative they are not comparable + # and we need the rest of the items. + if start < 0: + i = start + n = None + # stop is None and start is positive, so we just need items up to + # the start index. + elif stop is None: + i = None + n = start + 1 + # Both stop and start are positive, so they are comparable. + else: + i = None + n = start - stop + if n <= 0: + return + + cache = list(islice(it, n)) + + for item in cache[i::step]: + yield item + + +def always_reversible(iterable): + """An extension of :func:`reversed` that supports all iterables, not + just those which implement the ``Reversible`` or ``Sequence`` protocols. + + >>> print(*always_reversible(x for x in range(3))) + 2 1 0 + + If the iterable is already reversible, this function returns the + result of :func:`reversed()`. If the iterable is not reversible, + this function will cache the remaining items in the iterable and + yield them in reverse order, which may require significant storage. + """ + try: + return reversed(iterable) + except TypeError: + return reversed(list(iterable)) + + +def consecutive_groups(iterable, ordering=lambda x: x): + """Yield groups of consecutive items using :func:`itertools.groupby`. + The *ordering* function determines whether two items are adjacent by + returning their position. + + By default, the ordering function is the identity function. This is + suitable for finding runs of numbers: + + >>> iterable = [1, 10, 11, 12, 20, 30, 31, 32, 33, 40] + >>> for group in consecutive_groups(iterable): + ... print(list(group)) + [1] + [10, 11, 12] + [20] + [30, 31, 32, 33] + [40] + + For finding runs of adjacent letters, try using the :meth:`index` method + of a string of letters: + + >>> from string import ascii_lowercase + >>> iterable = 'abcdfgilmnop' + >>> ordering = ascii_lowercase.index + >>> for group in consecutive_groups(iterable, ordering): + ... print(list(group)) + ['a', 'b', 'c', 'd'] + ['f', 'g'] + ['i'] + ['l', 'm', 'n', 'o', 'p'] + + """ + for k, g in groupby( + enumerate(iterable), key=lambda x: x[0] - ordering(x[1]) + ): + yield map(itemgetter(1), g) + + +def difference(iterable, func=sub): + """By default, compute the first difference of *iterable* using + :func:`operator.sub`. + + >>> iterable = [0, 1, 3, 6, 10] + >>> list(difference(iterable)) + [0, 1, 2, 3, 4] + + This is the opposite of :func:`accumulate`'s default behavior: + + >>> from more_itertools import accumulate + >>> iterable = [0, 1, 2, 3, 4] + >>> list(accumulate(iterable)) + [0, 1, 3, 6, 10] + >>> list(difference(accumulate(iterable))) + [0, 1, 2, 3, 4] + + By default *func* is :func:`operator.sub`, but other functions can be + specified. They will be applied as follows:: + + A, B, C, D, ... --> A, func(B, A), func(C, B), func(D, C), ... + + For example, to do progressive division: + + >>> iterable = [1, 2, 6, 24, 120] # Factorial sequence + >>> func = lambda x, y: x // y + >>> list(difference(iterable, func)) + [1, 2, 3, 4, 5] + + """ + a, b = tee(iterable) + try: + item = next(b) + except StopIteration: + return iter([]) + return chain([item], map(lambda x: func(x[1], x[0]), zip(a, b))) + + +class SequenceView(Sequence): + """Return a read-only view of the sequence object *target*. + + :class:`SequenceView` objects are analagous to Python's built-in + "dictionary view" types. They provide a dynamic view of a sequence's items, + meaning that when the sequence updates, so does the view. + + >>> seq = ['0', '1', '2'] + >>> view = SequenceView(seq) + >>> view + SequenceView(['0', '1', '2']) + >>> seq.append('3') + >>> view + SequenceView(['0', '1', '2', '3']) + + Sequence views support indexing, slicing, and length queries. They act + like the underlying sequence, except they don't allow assignment: + + >>> view[1] + '1' + >>> view[1:-1] + ['1', '2'] + >>> len(view) + 4 + + Sequence views are useful as an alternative to copying, as they don't + require (much) extra storage. + + """ + def __init__(self, target): + if not isinstance(target, Sequence): + raise TypeError + self._target = target + + def __getitem__(self, index): + return self._target[index] + + def __len__(self): + return len(self._target) + + def __repr__(self): + return '{}({})'.format(self.__class__.__name__, repr(self._target)) + + +class seekable(object): + """Wrap an iterator to allow for seeking backward and forward. This + progressively caches the items in the source iterable so they can be + re-visited. + + Call :meth:`seek` with an index to seek to that position in the source + iterable. + + To "reset" an iterator, seek to ``0``: + + >>> from itertools import count + >>> it = seekable((str(n) for n in count())) + >>> next(it), next(it), next(it) + ('0', '1', '2') + >>> it.seek(0) + >>> next(it), next(it), next(it) + ('0', '1', '2') + >>> next(it) + '3' + + You can also seek forward: + + >>> it = seekable((str(n) for n in range(20))) + >>> it.seek(10) + >>> next(it) + '10' + >>> it.seek(20) # Seeking past the end of the source isn't a problem + >>> list(it) + [] + >>> it.seek(0) # Resetting works even after hitting the end + >>> next(it), next(it), next(it) + ('0', '1', '2') + + The cache grows as the source iterable progresses, so beware of wrapping + very large or infinite iterables. + + You may view the contents of the cache with the :meth:`elements` method. + That returns a :class:`SequenceView`, a view that updates automatically: + + >>> it = seekable((str(n) for n in range(10))) + >>> next(it), next(it), next(it) + ('0', '1', '2') + >>> elements = it.elements() + >>> elements + SequenceView(['0', '1', '2']) + >>> next(it) + '3' + >>> elements + SequenceView(['0', '1', '2', '3']) + + """ + + def __init__(self, iterable): + self._source = iter(iterable) + self._cache = [] + self._index = None + + def __iter__(self): + return self + + def __next__(self): + if self._index is not None: + try: + item = self._cache[self._index] + except IndexError: + self._index = None + else: + self._index += 1 + return item + + item = next(self._source) + self._cache.append(item) + return item + + next = __next__ + + def elements(self): + return SequenceView(self._cache) + + def seek(self, index): + self._index = index + remainder = index - len(self._cache) + if remainder > 0: + consume(self, remainder) + + +class run_length(object): + """ + :func:`run_length.encode` compresses an iterable with run-length encoding. + It yields groups of repeated items with the count of how many times they + were repeated: + + >>> uncompressed = 'abbcccdddd' + >>> list(run_length.encode(uncompressed)) + [('a', 1), ('b', 2), ('c', 3), ('d', 4)] + + :func:`run_length.decode` decompresses an iterable that was previously + compressed with run-length encoding. It yields the items of the + decompressed iterable: + + >>> compressed = [('a', 1), ('b', 2), ('c', 3), ('d', 4)] + >>> list(run_length.decode(compressed)) + ['a', 'b', 'b', 'c', 'c', 'c', 'd', 'd', 'd', 'd'] + + """ + + @staticmethod + def encode(iterable): + return ((k, ilen(g)) for k, g in groupby(iterable)) + + @staticmethod + def decode(iterable): + return chain.from_iterable(repeat(k, n) for k, n in iterable) + + +def exactly_n(iterable, n, predicate=bool): + """Return ``True`` if exactly ``n`` items in the iterable are ``True`` + according to the *predicate* function. + + >>> exactly_n([True, True, False], 2) + True + >>> exactly_n([True, True, False], 1) + False + >>> exactly_n([0, 1, 2, 3, 4, 5], 3, lambda x: x < 3) + True + + The iterable will be advanced until ``n + 1`` truthy items are encountered, + so avoid calling it on infinite iterables. + + """ + return len(take(n + 1, filter(predicate, iterable))) == n + + +def circular_shifts(iterable): + """Return a list of circular shifts of *iterable*. + + >>> circular_shifts(range(4)) + [(0, 1, 2, 3), (1, 2, 3, 0), (2, 3, 0, 1), (3, 0, 1, 2)] + """ + lst = list(iterable) + return take(len(lst), windowed(cycle(lst), len(lst))) + + +def make_decorator(wrapping_func, result_index=0): + """Return a decorator version of *wrapping_func*, which is a function that + modifies an iterable. *result_index* is the position in that function's + signature where the iterable goes. + + This lets you use itertools on the "production end," i.e. at function + definition. This can augment what the function returns without changing the + function's code. + + For example, to produce a decorator version of :func:`chunked`: + + >>> from more_itertools import chunked + >>> chunker = make_decorator(chunked, result_index=0) + >>> @chunker(3) + ... def iter_range(n): + ... return iter(range(n)) + ... + >>> list(iter_range(9)) + [[0, 1, 2], [3, 4, 5], [6, 7, 8]] + + To only allow truthy items to be returned: + + >>> truth_serum = make_decorator(filter, result_index=1) + >>> @truth_serum(bool) + ... def boolean_test(): + ... return [0, 1, '', ' ', False, True] + ... + >>> list(boolean_test()) + [1, ' ', True] + + The :func:`peekable` and :func:`seekable` wrappers make for practical + decorators: + + >>> from more_itertools import peekable + >>> peekable_function = make_decorator(peekable) + >>> @peekable_function() + ... def str_range(*args): + ... return (str(x) for x in range(*args)) + ... + >>> it = str_range(1, 20, 2) + >>> next(it), next(it), next(it) + ('1', '3', '5') + >>> it.peek() + '7' + >>> next(it) + '7' + + """ + # See https://sites.google.com/site/bbayles/index/decorator_factory for + # notes on how this works. + def decorator(*wrapping_args, **wrapping_kwargs): + def outer_wrapper(f): + def inner_wrapper(*args, **kwargs): + result = f(*args, **kwargs) + wrapping_args_ = list(wrapping_args) + wrapping_args_.insert(result_index, result) + return wrapping_func(*wrapping_args_, **wrapping_kwargs) + + return inner_wrapper + + return outer_wrapper + + return decorator + + +def map_reduce(iterable, keyfunc, valuefunc=None, reducefunc=None): + """Return a dictionary that maps the items in *iterable* to categories + defined by *keyfunc*, transforms them with *valuefunc*, and + then summarizes them by category with *reducefunc*. + + *valuefunc* defaults to the identity function if it is unspecified. + If *reducefunc* is unspecified, no summarization takes place: + + >>> keyfunc = lambda x: x.upper() + >>> result = map_reduce('abbccc', keyfunc) + >>> sorted(result.items()) + [('A', ['a']), ('B', ['b', 'b']), ('C', ['c', 'c', 'c'])] + + Specifying *valuefunc* transforms the categorized items: + + >>> keyfunc = lambda x: x.upper() + >>> valuefunc = lambda x: 1 + >>> result = map_reduce('abbccc', keyfunc, valuefunc) + >>> sorted(result.items()) + [('A', [1]), ('B', [1, 1]), ('C', [1, 1, 1])] + + Specifying *reducefunc* summarizes the categorized items: + + >>> keyfunc = lambda x: x.upper() + >>> valuefunc = lambda x: 1 + >>> reducefunc = sum + >>> result = map_reduce('abbccc', keyfunc, valuefunc, reducefunc) + >>> sorted(result.items()) + [('A', 1), ('B', 2), ('C', 3)] + + You may want to filter the input iterable before applying the map/reduce + procedure: + + >>> all_items = range(30) + >>> items = [x for x in all_items if 10 <= x <= 20] # Filter + >>> keyfunc = lambda x: x % 2 # Evens map to 0; odds to 1 + >>> categories = map_reduce(items, keyfunc=keyfunc) + >>> sorted(categories.items()) + [(0, [10, 12, 14, 16, 18, 20]), (1, [11, 13, 15, 17, 19])] + >>> summaries = map_reduce(items, keyfunc=keyfunc, reducefunc=sum) + >>> sorted(summaries.items()) + [(0, 90), (1, 75)] + + Note that all items in the iterable are gathered into a list before the + summarization step, which may require significant storage. + + The returned object is a :obj:`collections.defaultdict` with the + ``default_factory`` set to ``None``, such that it behaves like a normal + dictionary. + + """ + valuefunc = (lambda x: x) if (valuefunc is None) else valuefunc + + ret = defaultdict(list) + for item in iterable: + key = keyfunc(item) + value = valuefunc(item) + ret[key].append(value) + + if reducefunc is not None: + for key, value_list in ret.items(): + ret[key] = reducefunc(value_list) + + ret.default_factory = None + return ret + + +def rlocate(iterable, pred=bool, window_size=None): + """Yield the index of each item in *iterable* for which *pred* returns + ``True``, starting from the right and moving left. + + *pred* defaults to :func:`bool`, which will select truthy items: + + >>> list(rlocate([0, 1, 1, 0, 1, 0, 0])) # Truthy at 1, 2, and 4 + [4, 2, 1] + + Set *pred* to a custom function to, e.g., find the indexes for a particular + item: + + >>> iterable = iter('abcb') + >>> pred = lambda x: x == 'b' + >>> list(rlocate(iterable, pred)) + [3, 1] + + If *window_size* is given, then the *pred* function will be called with + that many items. This enables searching for sub-sequences: + + >>> iterable = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3] + >>> pred = lambda *args: args == (1, 2, 3) + >>> list(rlocate(iterable, pred=pred, window_size=3)) + [9, 5, 1] + + Beware, this function won't return anything for infinite iterables. + If *iterable* is reversible, ``rlocate`` will reverse it and search from + the right. Otherwise, it will search from the left and return the results + in reverse order. + + See :func:`locate` to for other example applications. + + """ + if window_size is None: + try: + len_iter = len(iterable) + return ( + len_iter - i - 1 for i in locate(reversed(iterable), pred) + ) + except TypeError: + pass + + return reversed(list(locate(iterable, pred, window_size))) + + +def replace(iterable, pred, substitutes, count=None, window_size=1): + """Yield the items from *iterable*, replacing the items for which *pred* + returns ``True`` with the items from the iterable *substitutes*. + + >>> iterable = [1, 1, 0, 1, 1, 0, 1, 1] + >>> pred = lambda x: x == 0 + >>> substitutes = (2, 3) + >>> list(replace(iterable, pred, substitutes)) + [1, 1, 2, 3, 1, 1, 2, 3, 1, 1] + + If *count* is given, the number of replacements will be limited: + + >>> iterable = [1, 1, 0, 1, 1, 0, 1, 1, 0] + >>> pred = lambda x: x == 0 + >>> substitutes = [None] + >>> list(replace(iterable, pred, substitutes, count=2)) + [1, 1, None, 1, 1, None, 1, 1, 0] + + Use *window_size* to control the number of items passed as arguments to + *pred*. This allows for locating and replacing subsequences. + + >>> iterable = [0, 1, 2, 5, 0, 1, 2, 5] + >>> window_size = 3 + >>> pred = lambda *args: args == (0, 1, 2) # 3 items passed to pred + >>> substitutes = [3, 4] # Splice in these items + >>> list(replace(iterable, pred, substitutes, window_size=window_size)) + [3, 4, 5, 3, 4, 5] + + """ + if window_size < 1: + raise ValueError('window_size must be at least 1') + + # Save the substitutes iterable, since it's used more than once + substitutes = tuple(substitutes) + + # Add padding such that the number of windows matches the length of the + # iterable + it = chain(iterable, [_marker] * (window_size - 1)) + windows = windowed(it, window_size) + + n = 0 + for w in windows: + # If the current window matches our predicate (and we haven't hit + # our maximum number of replacements), splice in the substitutes + # and then consume the following windows that overlap with this one. + # For example, if the iterable is (0, 1, 2, 3, 4...) + # and the window size is 2, we have (0, 1), (1, 2), (2, 3)... + # If the predicate matches on (0, 1), we need to zap (0, 1) and (1, 2) + if pred(*w): + if (count is None) or (n < count): + n += 1 + for s in substitutes: + yield s + consume(windows, window_size - 1) + continue + + # If there was no match (or we've reached the replacement limit), + # yield the first item from the window. + if w and (w[0] is not _marker): + yield w[0] diff --git a/libraries/more_itertools/recipes.py b/libraries/more_itertools/recipes.py new file mode 100644 index 00000000..3a7706cb --- /dev/null +++ b/libraries/more_itertools/recipes.py @@ -0,0 +1,565 @@ +"""Imported from the recipes section of the itertools documentation. + +All functions taken from the recipes section of the itertools library docs +[1]_. +Some backward-compatible usability improvements have been made. + +.. [1] http://docs.python.org/library/itertools.html#recipes + +""" +from collections import deque +from itertools import ( + chain, combinations, count, cycle, groupby, islice, repeat, starmap, tee +) +import operator +from random import randrange, sample, choice + +from six import PY2 +from six.moves import filter, filterfalse, map, range, zip, zip_longest + +__all__ = [ + 'accumulate', + 'all_equal', + 'consume', + 'dotproduct', + 'first_true', + 'flatten', + 'grouper', + 'iter_except', + 'ncycles', + 'nth', + 'nth_combination', + 'padnone', + 'pairwise', + 'partition', + 'powerset', + 'prepend', + 'quantify', + 'random_combination_with_replacement', + 'random_combination', + 'random_permutation', + 'random_product', + 'repeatfunc', + 'roundrobin', + 'tabulate', + 'tail', + 'take', + 'unique_everseen', + 'unique_justseen', +] + + +def accumulate(iterable, func=operator.add): + """ + Return an iterator whose items are the accumulated results of a function + (specified by the optional *func* argument) that takes two arguments. + By default, returns accumulated sums with :func:`operator.add`. + + >>> list(accumulate([1, 2, 3, 4, 5])) # Running sum + [1, 3, 6, 10, 15] + >>> list(accumulate([1, 2, 3], func=operator.mul)) # Running product + [1, 2, 6] + >>> list(accumulate([0, 1, -1, 2, 3, 2], func=max)) # Running maximum + [0, 1, 1, 2, 3, 3] + + This function is available in the ``itertools`` module for Python 3.2 and + greater. + + """ + it = iter(iterable) + try: + total = next(it) + except StopIteration: + return + else: + yield total + + for element in it: + total = func(total, element) + yield total + + +def take(n, iterable): + """Return first *n* items of the iterable as a list. + + >>> take(3, range(10)) + [0, 1, 2] + >>> take(5, range(3)) + [0, 1, 2] + + Effectively a short replacement for ``next`` based iterator consumption + when you want more than one item, but less than the whole iterator. + + """ + return list(islice(iterable, n)) + + +def tabulate(function, start=0): + """Return an iterator over the results of ``func(start)``, + ``func(start + 1)``, ``func(start + 2)``... + + *func* should be a function that accepts one integer argument. + + If *start* is not specified it defaults to 0. It will be incremented each + time the iterator is advanced. + + >>> square = lambda x: x ** 2 + >>> iterator = tabulate(square, -3) + >>> take(4, iterator) + [9, 4, 1, 0] + + """ + return map(function, count(start)) + + +def tail(n, iterable): + """Return an iterator over the last *n* items of *iterable*. + + >>> t = tail(3, 'ABCDEFG') + >>> list(t) + ['E', 'F', 'G'] + + """ + return iter(deque(iterable, maxlen=n)) + + +def consume(iterator, n=None): + """Advance *iterable* by *n* steps. If *n* is ``None``, consume it + entirely. + + Efficiently exhausts an iterator without returning values. Defaults to + consuming the whole iterator, but an optional second argument may be + provided to limit consumption. + + >>> i = (x for x in range(10)) + >>> next(i) + 0 + >>> consume(i, 3) + >>> next(i) + 4 + >>> consume(i) + >>> next(i) + Traceback (most recent call last): + File "<stdin>", line 1, in <module> + StopIteration + + If the iterator has fewer items remaining than the provided limit, the + whole iterator will be consumed. + + >>> i = (x for x in range(3)) + >>> consume(i, 5) + >>> next(i) + Traceback (most recent call last): + File "<stdin>", line 1, in <module> + StopIteration + + """ + # Use functions that consume iterators at C speed. + if n is None: + # feed the entire iterator into a zero-length deque + deque(iterator, maxlen=0) + else: + # advance to the empty slice starting at position n + next(islice(iterator, n, n), None) + + +def nth(iterable, n, default=None): + """Returns the nth item or a default value. + + >>> l = range(10) + >>> nth(l, 3) + 3 + >>> nth(l, 20, "zebra") + 'zebra' + + """ + return next(islice(iterable, n, None), default) + + +def all_equal(iterable): + """ + Returns ``True`` if all the elements are equal to each other. + + >>> all_equal('aaaa') + True + >>> all_equal('aaab') + False + + """ + g = groupby(iterable) + return next(g, True) and not next(g, False) + + +def quantify(iterable, pred=bool): + """Return the how many times the predicate is true. + + >>> quantify([True, False, True]) + 2 + + """ + return sum(map(pred, iterable)) + + +def padnone(iterable): + """Returns the sequence of elements and then returns ``None`` indefinitely. + + >>> take(5, padnone(range(3))) + [0, 1, 2, None, None] + + Useful for emulating the behavior of the built-in :func:`map` function. + + See also :func:`padded`. + + """ + return chain(iterable, repeat(None)) + + +def ncycles(iterable, n): + """Returns the sequence elements *n* times + + >>> list(ncycles(["a", "b"], 3)) + ['a', 'b', 'a', 'b', 'a', 'b'] + + """ + return chain.from_iterable(repeat(tuple(iterable), n)) + + +def dotproduct(vec1, vec2): + """Returns the dot product of the two iterables. + + >>> dotproduct([10, 10], [20, 20]) + 400 + + """ + return sum(map(operator.mul, vec1, vec2)) + + +def flatten(listOfLists): + """Return an iterator flattening one level of nesting in a list of lists. + + >>> list(flatten([[0, 1], [2, 3]])) + [0, 1, 2, 3] + + See also :func:`collapse`, which can flatten multiple levels of nesting. + + """ + return chain.from_iterable(listOfLists) + + +def repeatfunc(func, times=None, *args): + """Call *func* with *args* repeatedly, returning an iterable over the + results. + + If *times* is specified, the iterable will terminate after that many + repetitions: + + >>> from operator import add + >>> times = 4 + >>> args = 3, 5 + >>> list(repeatfunc(add, times, *args)) + [8, 8, 8, 8] + + If *times* is ``None`` the iterable will not terminate: + + >>> from random import randrange + >>> times = None + >>> args = 1, 11 + >>> take(6, repeatfunc(randrange, times, *args)) # doctest:+SKIP + [2, 4, 8, 1, 8, 4] + + """ + if times is None: + return starmap(func, repeat(args)) + return starmap(func, repeat(args, times)) + + +def pairwise(iterable): + """Returns an iterator of paired items, overlapping, from the original + + >>> take(4, pairwise(count())) + [(0, 1), (1, 2), (2, 3), (3, 4)] + + """ + a, b = tee(iterable) + next(b, None) + return zip(a, b) + + +def grouper(n, iterable, fillvalue=None): + """Collect data into fixed-length chunks or blocks. + + >>> list(grouper(3, 'ABCDEFG', 'x')) + [('A', 'B', 'C'), ('D', 'E', 'F'), ('G', 'x', 'x')] + + """ + args = [iter(iterable)] * n + return zip_longest(fillvalue=fillvalue, *args) + + +def roundrobin(*iterables): + """Yields an item from each iterable, alternating between them. + + >>> list(roundrobin('ABC', 'D', 'EF')) + ['A', 'D', 'E', 'B', 'F', 'C'] + + This function produces the same output as :func:`interleave_longest`, but + may perform better for some inputs (in particular when the number of + iterables is small). + + """ + # Recipe credited to George Sakkis + pending = len(iterables) + if PY2: + nexts = cycle(iter(it).next for it in iterables) + else: + nexts = cycle(iter(it).__next__ for it in iterables) + while pending: + try: + for next in nexts: + yield next() + except StopIteration: + pending -= 1 + nexts = cycle(islice(nexts, pending)) + + +def partition(pred, iterable): + """ + Returns a 2-tuple of iterables derived from the input iterable. + The first yields the items that have ``pred(item) == False``. + The second yields the items that have ``pred(item) == True``. + + >>> is_odd = lambda x: x % 2 != 0 + >>> iterable = range(10) + >>> even_items, odd_items = partition(is_odd, iterable) + >>> list(even_items), list(odd_items) + ([0, 2, 4, 6, 8], [1, 3, 5, 7, 9]) + + """ + # partition(is_odd, range(10)) --> 0 2 4 6 8 and 1 3 5 7 9 + t1, t2 = tee(iterable) + return filterfalse(pred, t1), filter(pred, t2) + + +def powerset(iterable): + """Yields all possible subsets of the iterable. + + >>> list(powerset([1,2,3])) + [(), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)] + + """ + s = list(iterable) + return chain.from_iterable(combinations(s, r) for r in range(len(s) + 1)) + + +def unique_everseen(iterable, key=None): + """ + Yield unique elements, preserving order. + + >>> list(unique_everseen('AAAABBBCCDAABBB')) + ['A', 'B', 'C', 'D'] + >>> list(unique_everseen('ABBCcAD', str.lower)) + ['A', 'B', 'C', 'D'] + + Sequences with a mix of hashable and unhashable items can be used. + The function will be slower (i.e., `O(n^2)`) for unhashable items. + + """ + seenset = set() + seenset_add = seenset.add + seenlist = [] + seenlist_add = seenlist.append + if key is None: + for element in iterable: + try: + if element not in seenset: + seenset_add(element) + yield element + except TypeError: + if element not in seenlist: + seenlist_add(element) + yield element + else: + for element in iterable: + k = key(element) + try: + if k not in seenset: + seenset_add(k) + yield element + except TypeError: + if k not in seenlist: + seenlist_add(k) + yield element + + +def unique_justseen(iterable, key=None): + """Yields elements in order, ignoring serial duplicates + + >>> list(unique_justseen('AAAABBBCCDAABBB')) + ['A', 'B', 'C', 'D', 'A', 'B'] + >>> list(unique_justseen('ABBCcAD', str.lower)) + ['A', 'B', 'C', 'A', 'D'] + + """ + return map(next, map(operator.itemgetter(1), groupby(iterable, key))) + + +def iter_except(func, exception, first=None): + """Yields results from a function repeatedly until an exception is raised. + + Converts a call-until-exception interface to an iterator interface. + Like ``iter(func, sentinel)``, but uses an exception instead of a sentinel + to end the loop. + + >>> l = [0, 1, 2] + >>> list(iter_except(l.pop, IndexError)) + [2, 1, 0] + + """ + try: + if first is not None: + yield first() + while 1: + yield func() + except exception: + pass + + +def first_true(iterable, default=False, pred=None): + """ + Returns the first true value in the iterable. + + If no true value is found, returns *default* + + If *pred* is not None, returns the first item for which + ``pred(item) == True`` . + + >>> first_true(range(10)) + 1 + >>> first_true(range(10), pred=lambda x: x > 5) + 6 + >>> first_true(range(10), default='missing', pred=lambda x: x > 9) + 'missing' + + """ + return next(filter(pred, iterable), default) + + +def random_product(*args, **kwds): + """Draw an item at random from each of the input iterables. + + >>> random_product('abc', range(4), 'XYZ') # doctest:+SKIP + ('c', 3, 'Z') + + If *repeat* is provided as a keyword argument, that many items will be + drawn from each iterable. + + >>> random_product('abcd', range(4), repeat=2) # doctest:+SKIP + ('a', 2, 'd', 3) + + This equivalent to taking a random selection from + ``itertools.product(*args, **kwarg)``. + + """ + pools = [tuple(pool) for pool in args] * kwds.get('repeat', 1) + return tuple(choice(pool) for pool in pools) + + +def random_permutation(iterable, r=None): + """Return a random *r* length permutation of the elements in *iterable*. + + If *r* is not specified or is ``None``, then *r* defaults to the length of + *iterable*. + + >>> random_permutation(range(5)) # doctest:+SKIP + (3, 4, 0, 1, 2) + + This equivalent to taking a random selection from + ``itertools.permutations(iterable, r)``. + + """ + pool = tuple(iterable) + r = len(pool) if r is None else r + return tuple(sample(pool, r)) + + +def random_combination(iterable, r): + """Return a random *r* length subsequence of the elements in *iterable*. + + >>> random_combination(range(5), 3) # doctest:+SKIP + (2, 3, 4) + + This equivalent to taking a random selection from + ``itertools.combinations(iterable, r)``. + + """ + pool = tuple(iterable) + n = len(pool) + indices = sorted(sample(range(n), r)) + return tuple(pool[i] for i in indices) + + +def random_combination_with_replacement(iterable, r): + """Return a random *r* length subsequence of elements in *iterable*, + allowing individual elements to be repeated. + + >>> random_combination_with_replacement(range(3), 5) # doctest:+SKIP + (0, 0, 1, 2, 2) + + This equivalent to taking a random selection from + ``itertools.combinations_with_replacement(iterable, r)``. + + """ + pool = tuple(iterable) + n = len(pool) + indices = sorted(randrange(n) for i in range(r)) + return tuple(pool[i] for i in indices) + + +def nth_combination(iterable, r, index): + """Equivalent to ``list(combinations(iterable, r))[index]``. + + The subsequences of *iterable* that are of length *r* can be ordered + lexicographically. :func:`nth_combination` computes the subsequence at + sort position *index* directly, without computing the previous + subsequences. + + """ + pool = tuple(iterable) + n = len(pool) + if (r < 0) or (r > n): + raise ValueError + + c = 1 + k = min(r, n - r) + for i in range(1, k + 1): + c = c * (n - k + i) // i + + if index < 0: + index += c + + if (index < 0) or (index >= c): + raise IndexError + + result = [] + while r: + c, n, r = c * r // n, n - 1, r - 1 + while index >= c: + index -= c + c, n = c * (n - r) // n, n - 1 + result.append(pool[-1 - n]) + + return tuple(result) + + +def prepend(value, iterator): + """Yield *value*, followed by the elements in *iterator*. + + >>> value = '0' + >>> iterator = ['1', '2', '3'] + >>> list(prepend(value, iterator)) + ['0', '1', '2', '3'] + + To prepend multiple values, see :func:`itertools.chain`. + + """ + return chain([value], iterator) diff --git a/libraries/more_itertools/tests/__init__.py b/libraries/more_itertools/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libraries/more_itertools/tests/test_more.py b/libraries/more_itertools/tests/test_more.py new file mode 100644 index 00000000..a1b1e431 --- /dev/null +++ b/libraries/more_itertools/tests/test_more.py @@ -0,0 +1,2074 @@ +from __future__ import division, print_function, unicode_literals + +from collections import OrderedDict +from decimal import Decimal +from doctest import DocTestSuite +from fractions import Fraction +from functools import partial, reduce +from heapq import merge +from io import StringIO +from itertools import ( + chain, + count, + groupby, + islice, + permutations, + product, + repeat, +) +from operator import add, mul, itemgetter +from unittest import TestCase + +from six.moves import filter, map, range, zip + +import more_itertools as mi + + +def load_tests(loader, tests, ignore): + # Add the doctests + tests.addTests(DocTestSuite('more_itertools.more')) + return tests + + +class CollateTests(TestCase): + """Unit tests for ``collate()``""" + # Also accidentally tests peekable, though that could use its own tests + + def test_default(self): + """Test with the default `key` function.""" + iterables = [range(4), range(7), range(3, 6)] + self.assertEqual( + sorted(reduce(list.__add__, [list(it) for it in iterables])), + list(mi.collate(*iterables)) + ) + + def test_key(self): + """Test using a custom `key` function.""" + iterables = [range(5, 0, -1), range(4, 0, -1)] + actual = sorted( + reduce(list.__add__, [list(it) for it in iterables]), reverse=True + ) + expected = list(mi.collate(*iterables, key=lambda x: -x)) + self.assertEqual(actual, expected) + + def test_empty(self): + """Be nice if passed an empty list of iterables.""" + self.assertEqual([], list(mi.collate())) + + def test_one(self): + """Work when only 1 iterable is passed.""" + self.assertEqual([0, 1], list(mi.collate(range(2)))) + + def test_reverse(self): + """Test the `reverse` kwarg.""" + iterables = [range(4, 0, -1), range(7, 0, -1), range(3, 6, -1)] + + actual = sorted( + reduce(list.__add__, [list(it) for it in iterables]), reverse=True + ) + expected = list(mi.collate(*iterables, reverse=True)) + self.assertEqual(actual, expected) + + def test_alias(self): + self.assertNotEqual(merge.__doc__, mi.collate.__doc__) + self.assertNotEqual(partial.__doc__, mi.collate.__doc__) + + +class ChunkedTests(TestCase): + """Tests for ``chunked()``""" + + def test_even(self): + """Test when ``n`` divides evenly into the length of the iterable.""" + self.assertEqual( + list(mi.chunked('ABCDEF', 3)), [['A', 'B', 'C'], ['D', 'E', 'F']] + ) + + def test_odd(self): + """Test when ``n`` does not divide evenly into the length of the + iterable. + + """ + self.assertEqual( + list(mi.chunked('ABCDE', 3)), [['A', 'B', 'C'], ['D', 'E']] + ) + + +class FirstTests(TestCase): + """Tests for ``first()``""" + + def test_many(self): + """Test that it works on many-item iterables.""" + # Also try it on a generator expression to make sure it works on + # whatever those return, across Python versions. + self.assertEqual(mi.first(x for x in range(4)), 0) + + def test_one(self): + """Test that it doesn't raise StopIteration prematurely.""" + self.assertEqual(mi.first([3]), 3) + + def test_empty_stop_iteration(self): + """It should raise StopIteration for empty iterables.""" + self.assertRaises(ValueError, lambda: mi.first([])) + + def test_default(self): + """It should return the provided default arg for empty iterables.""" + self.assertEqual(mi.first([], 'boo'), 'boo') + + +class IterOnlyRange: + """User-defined iterable class which only support __iter__. + + It is not specified to inherit ``object``, so indexing on a instance will + raise an ``AttributeError`` rather than ``TypeError`` in Python 2. + + >>> r = IterOnlyRange(5) + >>> r[0] + AttributeError: IterOnlyRange instance has no attribute '__getitem__' + + Note: In Python 3, ``TypeError`` will be raised because ``object`` is + inherited implicitly by default. + + >>> r[0] + TypeError: 'IterOnlyRange' object does not support indexing + """ + def __init__(self, n): + """Set the length of the range.""" + self.n = n + + def __iter__(self): + """Works same as range().""" + return iter(range(self.n)) + + +class LastTests(TestCase): + """Tests for ``last()``""" + + def test_many_nonsliceable(self): + """Test that it works on many-item non-slice-able iterables.""" + # Also try it on a generator expression to make sure it works on + # whatever those return, across Python versions. + self.assertEqual(mi.last(x for x in range(4)), 3) + + def test_one_nonsliceable(self): + """Test that it doesn't raise StopIteration prematurely.""" + self.assertEqual(mi.last(x for x in range(1)), 0) + + def test_empty_stop_iteration_nonsliceable(self): + """It should raise ValueError for empty non-slice-able iterables.""" + self.assertRaises(ValueError, lambda: mi.last(x for x in range(0))) + + def test_default_nonsliceable(self): + """It should return the provided default arg for empty non-slice-able + iterables. + """ + self.assertEqual(mi.last((x for x in range(0)), 'boo'), 'boo') + + def test_many_sliceable(self): + """Test that it works on many-item slice-able iterables.""" + self.assertEqual(mi.last([0, 1, 2, 3]), 3) + + def test_one_sliceable(self): + """Test that it doesn't raise StopIteration prematurely.""" + self.assertEqual(mi.last([3]), 3) + + def test_empty_stop_iteration_sliceable(self): + """It should raise ValueError for empty slice-able iterables.""" + self.assertRaises(ValueError, lambda: mi.last([])) + + def test_default_sliceable(self): + """It should return the provided default arg for empty slice-able + iterables. + """ + self.assertEqual(mi.last([], 'boo'), 'boo') + + def test_dict(self): + """last(dic) and last(dic.keys()) should return same result.""" + dic = {'a': 1, 'b': 2, 'c': 3} + self.assertEqual(mi.last(dic), mi.last(dic.keys())) + + def test_ordereddict(self): + """last(dic) should return the last key.""" + od = OrderedDict() + od['a'] = 1 + od['b'] = 2 + od['c'] = 3 + self.assertEqual(mi.last(od), 'c') + + def test_customrange(self): + """It should work on custom class where [] raises AttributeError.""" + self.assertEqual(mi.last(IterOnlyRange(5)), 4) + + +class PeekableTests(TestCase): + """Tests for ``peekable()`` behavor not incidentally covered by testing + ``collate()`` + + """ + def test_peek_default(self): + """Make sure passing a default into ``peek()`` works.""" + p = mi.peekable([]) + self.assertEqual(p.peek(7), 7) + + def test_truthiness(self): + """Make sure a ``peekable`` tests true iff there are items remaining in + the iterable. + + """ + p = mi.peekable([]) + self.assertFalse(p) + + p = mi.peekable(range(3)) + self.assertTrue(p) + + def test_simple_peeking(self): + """Make sure ``next`` and ``peek`` advance and don't advance the + iterator, respectively. + + """ + p = mi.peekable(range(10)) + self.assertEqual(next(p), 0) + self.assertEqual(p.peek(), 1) + self.assertEqual(next(p), 1) + + def test_indexing(self): + """ + Indexing into the peekable shouldn't advance the iterator. + """ + p = mi.peekable('abcdefghijkl') + + # The 0th index is what ``next()`` will return + self.assertEqual(p[0], 'a') + self.assertEqual(next(p), 'a') + + # Indexing further into the peekable shouldn't advance the itertor + self.assertEqual(p[2], 'd') + self.assertEqual(next(p), 'b') + + # The 0th index moves up with the iterator; the last index follows + self.assertEqual(p[0], 'c') + self.assertEqual(p[9], 'l') + + self.assertEqual(next(p), 'c') + self.assertEqual(p[8], 'l') + + # Negative indexing should work too + self.assertEqual(p[-2], 'k') + self.assertEqual(p[-9], 'd') + self.assertRaises(IndexError, lambda: p[-10]) + + def test_slicing(self): + """Slicing the peekable shouldn't advance the iterator.""" + seq = list('abcdefghijkl') + p = mi.peekable(seq) + + # Slicing the peekable should just be like slicing a re-iterable + self.assertEqual(p[1:4], seq[1:4]) + + # Advancing the iterator moves the slices up also + self.assertEqual(next(p), 'a') + self.assertEqual(p[1:4], seq[1:][1:4]) + + # Implicit starts and stop should work + self.assertEqual(p[:5], seq[1:][:5]) + self.assertEqual(p[:], seq[1:][:]) + + # Indexing past the end should work + self.assertEqual(p[:100], seq[1:][:100]) + + # Steps should work, including negative + self.assertEqual(p[::2], seq[1:][::2]) + self.assertEqual(p[::-1], seq[1:][::-1]) + + def test_slicing_reset(self): + """Test slicing on a fresh iterable each time""" + iterable = ['0', '1', '2', '3', '4', '5'] + indexes = list(range(-4, len(iterable) + 4)) + [None] + steps = [1, 2, 3, 4, -1, -2, -3, 4] + for slice_args in product(indexes, indexes, steps): + it = iter(iterable) + p = mi.peekable(it) + next(p) + index = slice(*slice_args) + actual = p[index] + expected = iterable[1:][index] + self.assertEqual(actual, expected, slice_args) + + def test_slicing_error(self): + iterable = '01234567' + p = mi.peekable(iter(iterable)) + + # Prime the cache + p.peek() + old_cache = list(p._cache) + + # Illegal slice + with self.assertRaises(ValueError): + p[1:-1:0] + + # Neither the cache nor the iteration should be affected + self.assertEqual(old_cache, list(p._cache)) + self.assertEqual(list(p), list(iterable)) + + def test_passthrough(self): + """Iterating a peekable without using ``peek()`` or ``prepend()`` + should just give the underlying iterable's elements (a trivial test but + useful to set a baseline in case something goes wrong)""" + expected = [1, 2, 3, 4, 5] + actual = list(mi.peekable(expected)) + self.assertEqual(actual, expected) + + # prepend() behavior tests + + def test_prepend(self): + """Tests intersperesed ``prepend()`` and ``next()`` calls""" + it = mi.peekable(range(2)) + actual = [] + + # Test prepend() before next() + it.prepend(10) + actual += [next(it), next(it)] + + # Test prepend() between next()s + it.prepend(11) + actual += [next(it), next(it)] + + # Test prepend() after source iterable is consumed + it.prepend(12) + actual += [next(it)] + + expected = [10, 0, 11, 1, 12] + self.assertEqual(actual, expected) + + def test_multi_prepend(self): + """Tests prepending multiple items and getting them in proper order""" + it = mi.peekable(range(5)) + actual = [next(it), next(it)] + it.prepend(10, 11, 12) + it.prepend(20, 21) + actual += list(it) + expected = [0, 1, 20, 21, 10, 11, 12, 2, 3, 4] + self.assertEqual(actual, expected) + + def test_empty(self): + """Tests prepending in front of an empty iterable""" + it = mi.peekable([]) + it.prepend(10) + actual = list(it) + expected = [10] + self.assertEqual(actual, expected) + + def test_prepend_truthiness(self): + """Tests that ``__bool__()`` or ``__nonzero__()`` works properly + with ``prepend()``""" + it = mi.peekable(range(5)) + self.assertTrue(it) + actual = list(it) + self.assertFalse(it) + it.prepend(10) + self.assertTrue(it) + actual += [next(it)] + self.assertFalse(it) + expected = [0, 1, 2, 3, 4, 10] + self.assertEqual(actual, expected) + + def test_multi_prepend_peek(self): + """Tests prepending multiple elements and getting them in reverse order + while peeking""" + it = mi.peekable(range(5)) + actual = [next(it), next(it)] + self.assertEqual(it.peek(), 2) + it.prepend(10, 11, 12) + self.assertEqual(it.peek(), 10) + it.prepend(20, 21) + self.assertEqual(it.peek(), 20) + actual += list(it) + self.assertFalse(it) + expected = [0, 1, 20, 21, 10, 11, 12, 2, 3, 4] + self.assertEqual(actual, expected) + + def test_prepend_after_stop(self): + """Test resuming iteration after a previous exhaustion""" + it = mi.peekable(range(3)) + self.assertEqual(list(it), [0, 1, 2]) + self.assertRaises(StopIteration, lambda: next(it)) + it.prepend(10) + self.assertEqual(next(it), 10) + self.assertRaises(StopIteration, lambda: next(it)) + + def test_prepend_slicing(self): + """Tests interaction between prepending and slicing""" + seq = list(range(20)) + p = mi.peekable(seq) + + p.prepend(30, 40, 50) + pseq = [30, 40, 50] + seq # pseq for prepended_seq + + # adapt the specific tests from test_slicing + self.assertEqual(p[0], 30) + self.assertEqual(p[1:8], pseq[1:8]) + self.assertEqual(p[1:], pseq[1:]) + self.assertEqual(p[:5], pseq[:5]) + self.assertEqual(p[:], pseq[:]) + self.assertEqual(p[:100], pseq[:100]) + self.assertEqual(p[::2], pseq[::2]) + self.assertEqual(p[::-1], pseq[::-1]) + + def test_prepend_indexing(self): + """Tests interaction between prepending and indexing""" + seq = list(range(20)) + p = mi.peekable(seq) + + p.prepend(30, 40, 50) + + self.assertEqual(p[0], 30) + self.assertEqual(next(p), 30) + self.assertEqual(p[2], 0) + self.assertEqual(next(p), 40) + self.assertEqual(p[0], 50) + self.assertEqual(p[9], 8) + self.assertEqual(next(p), 50) + self.assertEqual(p[8], 8) + self.assertEqual(p[-2], 18) + self.assertEqual(p[-9], 11) + self.assertRaises(IndexError, lambda: p[-21]) + + def test_prepend_iterable(self): + """Tests prepending from an iterable""" + it = mi.peekable(range(5)) + # Don't directly use the range() object to avoid any range-specific + # optimizations + it.prepend(*(x for x in range(5))) + actual = list(it) + expected = list(chain(range(5), range(5))) + self.assertEqual(actual, expected) + + def test_prepend_many(self): + """Tests that prepending a huge number of elements works""" + it = mi.peekable(range(5)) + # Don't directly use the range() object to avoid any range-specific + # optimizations + it.prepend(*(x for x in range(20000))) + actual = list(it) + expected = list(chain(range(20000), range(5))) + self.assertEqual(actual, expected) + + def test_prepend_reversed(self): + """Tests prepending from a reversed iterable""" + it = mi.peekable(range(3)) + it.prepend(*reversed((10, 11, 12))) + actual = list(it) + expected = [12, 11, 10, 0, 1, 2] + self.assertEqual(actual, expected) + + +class ConsumerTests(TestCase): + """Tests for ``consumer()``""" + + def test_consumer(self): + @mi.consumer + def eater(): + while True: + x = yield # noqa + + e = eater() + e.send('hi') # without @consumer, would raise TypeError + + +class DistinctPermutationsTests(TestCase): + def test_distinct_permutations(self): + """Make sure the output for ``distinct_permutations()`` is the same as + set(permutations(it)). + + """ + iterable = ['z', 'a', 'a', 'q', 'q', 'q', 'y'] + test_output = sorted(mi.distinct_permutations(iterable)) + ref_output = sorted(set(permutations(iterable))) + self.assertEqual(test_output, ref_output) + + def test_other_iterables(self): + """Make sure ``distinct_permutations()`` accepts a different type of + iterables. + + """ + # a generator + iterable = (c for c in ['z', 'a', 'a', 'q', 'q', 'q', 'y']) + test_output = sorted(mi.distinct_permutations(iterable)) + # "reload" it + iterable = (c for c in ['z', 'a', 'a', 'q', 'q', 'q', 'y']) + ref_output = sorted(set(permutations(iterable))) + self.assertEqual(test_output, ref_output) + + # an iterator + iterable = iter(['z', 'a', 'a', 'q', 'q', 'q', 'y']) + test_output = sorted(mi.distinct_permutations(iterable)) + # "reload" it + iterable = iter(['z', 'a', 'a', 'q', 'q', 'q', 'y']) + ref_output = sorted(set(permutations(iterable))) + self.assertEqual(test_output, ref_output) + + +class IlenTests(TestCase): + def test_ilen(self): + """Sanity-checks for ``ilen()``.""" + # Non-empty + self.assertEqual( + mi.ilen(filter(lambda x: x % 10 == 0, range(101))), 11 + ) + + # Empty + self.assertEqual(mi.ilen((x for x in range(0))), 0) + + # Iterable with __len__ + self.assertEqual(mi.ilen(list(range(6))), 6) + + +class WithIterTests(TestCase): + def test_with_iter(self): + s = StringIO('One fish\nTwo fish') + initial_words = [line.split()[0] for line in mi.with_iter(s)] + + # Iterable's items should be faithfully represented + self.assertEqual(initial_words, ['One', 'Two']) + # The file object should be closed + self.assertEqual(s.closed, True) + + +class OneTests(TestCase): + def test_basic(self): + it = iter(['item']) + self.assertEqual(mi.one(it), 'item') + + def test_too_short(self): + it = iter([]) + self.assertRaises(ValueError, lambda: mi.one(it)) + self.assertRaises(IndexError, lambda: mi.one(it, too_short=IndexError)) + + def test_too_long(self): + it = count() + self.assertRaises(ValueError, lambda: mi.one(it)) # burn 0 and 1 + self.assertEqual(next(it), 2) + self.assertRaises( + OverflowError, lambda: mi.one(it, too_long=OverflowError) + ) + + +class IntersperseTest(TestCase): + """ Tests for intersperse() """ + + def test_even(self): + iterable = (x for x in '01') + self.assertEqual( + list(mi.intersperse(None, iterable)), ['0', None, '1'] + ) + + def test_odd(self): + iterable = (x for x in '012') + self.assertEqual( + list(mi.intersperse(None, iterable)), ['0', None, '1', None, '2'] + ) + + def test_nested(self): + element = ('a', 'b') + iterable = (x for x in '012') + actual = list(mi.intersperse(element, iterable)) + expected = ['0', ('a', 'b'), '1', ('a', 'b'), '2'] + self.assertEqual(actual, expected) + + def test_not_iterable(self): + self.assertRaises(TypeError, lambda: mi.intersperse('x', 1)) + + def test_n(self): + for n, element, expected in [ + (1, '_', ['0', '_', '1', '_', '2', '_', '3', '_', '4', '_', '5']), + (2, '_', ['0', '1', '_', '2', '3', '_', '4', '5']), + (3, '_', ['0', '1', '2', '_', '3', '4', '5']), + (4, '_', ['0', '1', '2', '3', '_', '4', '5']), + (5, '_', ['0', '1', '2', '3', '4', '_', '5']), + (6, '_', ['0', '1', '2', '3', '4', '5']), + (7, '_', ['0', '1', '2', '3', '4', '5']), + (3, ['a', 'b'], ['0', '1', '2', ['a', 'b'], '3', '4', '5']), + ]: + iterable = (x for x in '012345') + actual = list(mi.intersperse(element, iterable, n=n)) + self.assertEqual(actual, expected) + + def test_n_zero(self): + self.assertRaises( + ValueError, lambda: list(mi.intersperse('x', '012', n=0)) + ) + + +class UniqueToEachTests(TestCase): + """Tests for ``unique_to_each()``""" + + def test_all_unique(self): + """When all the input iterables are unique the output should match + the input.""" + iterables = [[1, 2], [3, 4, 5], [6, 7, 8]] + self.assertEqual(mi.unique_to_each(*iterables), iterables) + + def test_duplicates(self): + """When there are duplicates in any of the input iterables that aren't + in the rest, those duplicates should be emitted.""" + iterables = ["mississippi", "missouri"] + self.assertEqual( + mi.unique_to_each(*iterables), [['p', 'p'], ['o', 'u', 'r']] + ) + + def test_mixed(self): + """When the input iterables contain different types the function should + still behave properly""" + iterables = ['x', (i for i in range(3)), [1, 2, 3], tuple()] + self.assertEqual(mi.unique_to_each(*iterables), [['x'], [0], [3], []]) + + +class WindowedTests(TestCase): + """Tests for ``windowed()``""" + + def test_basic(self): + actual = list(mi.windowed([1, 2, 3, 4, 5], 3)) + expected = [(1, 2, 3), (2, 3, 4), (3, 4, 5)] + self.assertEqual(actual, expected) + + def test_large_size(self): + """ + When the window size is larger than the iterable, and no fill value is + given,``None`` should be filled in. + """ + actual = list(mi.windowed([1, 2, 3, 4, 5], 6)) + expected = [(1, 2, 3, 4, 5, None)] + self.assertEqual(actual, expected) + + def test_fillvalue(self): + """ + When sizes don't match evenly, the given fill value should be used. + """ + iterable = [1, 2, 3, 4, 5] + + for n, kwargs, expected in [ + (6, {}, [(1, 2, 3, 4, 5, '!')]), # n > len(iterable) + (3, {'step': 3}, [(1, 2, 3), (4, 5, '!')]), # using ``step`` + ]: + actual = list(mi.windowed(iterable, n, fillvalue='!', **kwargs)) + self.assertEqual(actual, expected) + + def test_zero(self): + """When the window size is zero, an empty tuple should be emitted.""" + actual = list(mi.windowed([1, 2, 3, 4, 5], 0)) + expected = [tuple()] + self.assertEqual(actual, expected) + + def test_negative(self): + """When the window size is negative, ValueError should be raised.""" + with self.assertRaises(ValueError): + list(mi.windowed([1, 2, 3, 4, 5], -1)) + + def test_step(self): + """The window should advance by the number of steps provided""" + iterable = [1, 2, 3, 4, 5, 6, 7] + for n, step, expected in [ + (3, 2, [(1, 2, 3), (3, 4, 5), (5, 6, 7)]), # n > step + (3, 3, [(1, 2, 3), (4, 5, 6), (7, None, None)]), # n == step + (3, 4, [(1, 2, 3), (5, 6, 7)]), # line up nicely + (3, 5, [(1, 2, 3), (6, 7, None)]), # off by one + (3, 6, [(1, 2, 3), (7, None, None)]), # off by two + (3, 7, [(1, 2, 3)]), # step past the end + (7, 8, [(1, 2, 3, 4, 5, 6, 7)]), # step > len(iterable) + ]: + actual = list(mi.windowed(iterable, n, step=step)) + self.assertEqual(actual, expected) + + # Step must be greater than or equal to 1 + with self.assertRaises(ValueError): + list(mi.windowed(iterable, 3, step=0)) + + +class BucketTests(TestCase): + """Tests for ``bucket()``""" + + def test_basic(self): + iterable = [10, 20, 30, 11, 21, 31, 12, 22, 23, 33] + D = mi.bucket(iterable, key=lambda x: 10 * (x // 10)) + + # In-order access + self.assertEqual(list(D[10]), [10, 11, 12]) + + # Out of order access + self.assertEqual(list(D[30]), [30, 31, 33]) + self.assertEqual(list(D[20]), [20, 21, 22, 23]) + + self.assertEqual(list(D[40]), []) # Nothing in here! + + def test_in(self): + iterable = [10, 20, 30, 11, 21, 31, 12, 22, 23, 33] + D = mi.bucket(iterable, key=lambda x: 10 * (x // 10)) + + self.assertTrue(10 in D) + self.assertFalse(40 in D) + self.assertTrue(20 in D) + self.assertFalse(21 in D) + + # Checking in-ness shouldn't advance the iterator + self.assertEqual(next(D[10]), 10) + + def test_validator(self): + iterable = count(0) + key = lambda x: int(str(x)[0]) # First digit of each number + validator = lambda x: 0 < x < 10 # No leading zeros + D = mi.bucket(iterable, key, validator=validator) + self.assertEqual(mi.take(3, D[1]), [1, 10, 11]) + self.assertNotIn(0, D) # Non-valid entries don't return True + self.assertNotIn(0, D._cache) # Don't store non-valid entries + self.assertEqual(list(D[0]), []) + + +class SpyTests(TestCase): + """Tests for ``spy()``""" + + def test_basic(self): + original_iterable = iter('abcdefg') + head, new_iterable = mi.spy(original_iterable) + self.assertEqual(head, ['a']) + self.assertEqual( + list(new_iterable), ['a', 'b', 'c', 'd', 'e', 'f', 'g'] + ) + + def test_unpacking(self): + original_iterable = iter('abcdefg') + (first, second, third), new_iterable = mi.spy(original_iterable, 3) + self.assertEqual(first, 'a') + self.assertEqual(second, 'b') + self.assertEqual(third, 'c') + self.assertEqual( + list(new_iterable), ['a', 'b', 'c', 'd', 'e', 'f', 'g'] + ) + + def test_too_many(self): + original_iterable = iter('abc') + head, new_iterable = mi.spy(original_iterable, 4) + self.assertEqual(head, ['a', 'b', 'c']) + self.assertEqual(list(new_iterable), ['a', 'b', 'c']) + + def test_zero(self): + original_iterable = iter('abc') + head, new_iterable = mi.spy(original_iterable, 0) + self.assertEqual(head, []) + self.assertEqual(list(new_iterable), ['a', 'b', 'c']) + + +class InterleaveTests(TestCase): + def test_even(self): + actual = list(mi.interleave([1, 4, 7], [2, 5, 8], [3, 6, 9])) + expected = [1, 2, 3, 4, 5, 6, 7, 8, 9] + self.assertEqual(actual, expected) + + def test_short(self): + actual = list(mi.interleave([1, 4], [2, 5, 7], [3, 6, 8])) + expected = [1, 2, 3, 4, 5, 6] + self.assertEqual(actual, expected) + + def test_mixed_types(self): + it_list = ['a', 'b', 'c', 'd'] + it_str = '12345' + it_inf = count() + actual = list(mi.interleave(it_list, it_str, it_inf)) + expected = ['a', '1', 0, 'b', '2', 1, 'c', '3', 2, 'd', '4', 3] + self.assertEqual(actual, expected) + + +class InterleaveLongestTests(TestCase): + def test_even(self): + actual = list(mi.interleave_longest([1, 4, 7], [2, 5, 8], [3, 6, 9])) + expected = [1, 2, 3, 4, 5, 6, 7, 8, 9] + self.assertEqual(actual, expected) + + def test_short(self): + actual = list(mi.interleave_longest([1, 4], [2, 5, 7], [3, 6, 8])) + expected = [1, 2, 3, 4, 5, 6, 7, 8] + self.assertEqual(actual, expected) + + def test_mixed_types(self): + it_list = ['a', 'b', 'c', 'd'] + it_str = '12345' + it_gen = (x for x in range(3)) + actual = list(mi.interleave_longest(it_list, it_str, it_gen)) + expected = ['a', '1', 0, 'b', '2', 1, 'c', '3', 2, 'd', '4', '5'] + self.assertEqual(actual, expected) + + +class TestCollapse(TestCase): + """Tests for ``collapse()``""" + + def test_collapse(self): + l = [[1], 2, [[3], 4], [[[5]]]] + self.assertEqual(list(mi.collapse(l)), [1, 2, 3, 4, 5]) + + def test_collapse_to_string(self): + l = [["s1"], "s2", [["s3"], "s4"], [[["s5"]]]] + self.assertEqual(list(mi.collapse(l)), ["s1", "s2", "s3", "s4", "s5"]) + + def test_collapse_flatten(self): + l = [[1], [2], [[3], 4], [[[5]]]] + self.assertEqual(list(mi.collapse(l, levels=1)), list(mi.flatten(l))) + + def test_collapse_to_level(self): + l = [[1], 2, [[3], 4], [[[5]]]] + self.assertEqual(list(mi.collapse(l, levels=2)), [1, 2, 3, 4, [5]]) + self.assertEqual( + list(mi.collapse(mi.collapse(l, levels=1), levels=1)), + list(mi.collapse(l, levels=2)) + ) + + def test_collapse_to_list(self): + l = (1, [2], (3, [4, (5,)], 'ab')) + actual = list(mi.collapse(l, base_type=list)) + expected = [1, [2], 3, [4, (5,)], 'ab'] + self.assertEqual(actual, expected) + + +class SideEffectTests(TestCase): + """Tests for ``side_effect()``""" + + def test_individual(self): + # The function increments the counter for each call + counter = [0] + + def func(arg): + counter[0] += 1 + + result = list(mi.side_effect(func, range(10))) + self.assertEqual(result, list(range(10))) + self.assertEqual(counter[0], 10) + + def test_chunked(self): + # The function increments the counter for each call + counter = [0] + + def func(arg): + counter[0] += 1 + + result = list(mi.side_effect(func, range(10), 2)) + self.assertEqual(result, list(range(10))) + self.assertEqual(counter[0], 5) + + def test_before_after(self): + f = StringIO() + collector = [] + + def func(item): + print(item, file=f) + collector.append(f.getvalue()) + + def it(): + yield u'a' + yield u'b' + raise RuntimeError('kaboom') + + before = lambda: print('HEADER', file=f) + after = f.close + + try: + mi.consume(mi.side_effect(func, it(), before=before, after=after)) + except RuntimeError: + pass + + # The iterable should have been written to the file + self.assertEqual(collector, [u'HEADER\na\n', u'HEADER\na\nb\n']) + + # The file should be closed even though something bad happened + self.assertTrue(f.closed) + + def test_before_fails(self): + f = StringIO() + func = lambda x: print(x, file=f) + + def before(): + raise RuntimeError('ouch') + + try: + mi.consume( + mi.side_effect(func, u'abc', before=before, after=f.close) + ) + except RuntimeError: + pass + + # The file should be closed even though something bad happened in the + # before function + self.assertTrue(f.closed) + + +class SlicedTests(TestCase): + """Tests for ``sliced()``""" + + def test_even(self): + """Test when the length of the sequence is divisible by *n*""" + seq = 'ABCDEFGHI' + self.assertEqual(list(mi.sliced(seq, 3)), ['ABC', 'DEF', 'GHI']) + + def test_odd(self): + """Test when the length of the sequence is not divisible by *n*""" + seq = 'ABCDEFGHI' + self.assertEqual(list(mi.sliced(seq, 4)), ['ABCD', 'EFGH', 'I']) + + def test_not_sliceable(self): + seq = (x for x in 'ABCDEFGHI') + + with self.assertRaises(TypeError): + list(mi.sliced(seq, 3)) + + +class SplitAtTests(TestCase): + """Tests for ``split()``""" + + def comp_with_str_split(self, str_to_split, delim): + pred = lambda c: c == delim + actual = list(map(''.join, mi.split_at(str_to_split, pred))) + expected = str_to_split.split(delim) + self.assertEqual(actual, expected) + + def test_seperators(self): + test_strs = ['', 'abcba', 'aaabbbcccddd', 'e'] + for s, delim in product(test_strs, 'abcd'): + self.comp_with_str_split(s, delim) + + +class SplitBeforeTest(TestCase): + """Tests for ``split_before()``""" + + def test_starts_with_sep(self): + actual = list(mi.split_before('xooxoo', lambda c: c == 'x')) + expected = [['x', 'o', 'o'], ['x', 'o', 'o']] + self.assertEqual(actual, expected) + + def test_ends_with_sep(self): + actual = list(mi.split_before('ooxoox', lambda c: c == 'x')) + expected = [['o', 'o'], ['x', 'o', 'o'], ['x']] + self.assertEqual(actual, expected) + + def test_no_sep(self): + actual = list(mi.split_before('ooo', lambda c: c == 'x')) + expected = [['o', 'o', 'o']] + self.assertEqual(actual, expected) + + +class SplitAfterTest(TestCase): + """Tests for ``split_after()``""" + + def test_starts_with_sep(self): + actual = list(mi.split_after('xooxoo', lambda c: c == 'x')) + expected = [['x'], ['o', 'o', 'x'], ['o', 'o']] + self.assertEqual(actual, expected) + + def test_ends_with_sep(self): + actual = list(mi.split_after('ooxoox', lambda c: c == 'x')) + expected = [['o', 'o', 'x'], ['o', 'o', 'x']] + self.assertEqual(actual, expected) + + def test_no_sep(self): + actual = list(mi.split_after('ooo', lambda c: c == 'x')) + expected = [['o', 'o', 'o']] + self.assertEqual(actual, expected) + + +class PaddedTest(TestCase): + """Tests for ``padded()``""" + + def test_no_n(self): + seq = [1, 2, 3] + + # No fillvalue + self.assertEqual(mi.take(5, mi.padded(seq)), [1, 2, 3, None, None]) + + # With fillvalue + self.assertEqual( + mi.take(5, mi.padded(seq, fillvalue='')), [1, 2, 3, '', ''] + ) + + def test_invalid_n(self): + self.assertRaises(ValueError, lambda: list(mi.padded([1, 2, 3], n=-1))) + self.assertRaises(ValueError, lambda: list(mi.padded([1, 2, 3], n=0))) + + def test_valid_n(self): + seq = [1, 2, 3, 4, 5] + + # No need for padding: len(seq) <= n + self.assertEqual(list(mi.padded(seq, n=4)), [1, 2, 3, 4, 5]) + self.assertEqual(list(mi.padded(seq, n=5)), [1, 2, 3, 4, 5]) + + # No fillvalue + self.assertEqual( + list(mi.padded(seq, n=7)), [1, 2, 3, 4, 5, None, None] + ) + + # With fillvalue + self.assertEqual( + list(mi.padded(seq, fillvalue='', n=7)), [1, 2, 3, 4, 5, '', ''] + ) + + def test_next_multiple(self): + seq = [1, 2, 3, 4, 5, 6] + + # No need for padding: len(seq) % n == 0 + self.assertEqual( + list(mi.padded(seq, n=3, next_multiple=True)), [1, 2, 3, 4, 5, 6] + ) + + # Padding needed: len(seq) < n + self.assertEqual( + list(mi.padded(seq, n=8, next_multiple=True)), + [1, 2, 3, 4, 5, 6, None, None] + ) + + # No padding needed: len(seq) == n + self.assertEqual( + list(mi.padded(seq, n=6, next_multiple=True)), [1, 2, 3, 4, 5, 6] + ) + + # Padding needed: len(seq) > n + self.assertEqual( + list(mi.padded(seq, n=4, next_multiple=True)), + [1, 2, 3, 4, 5, 6, None, None] + ) + + # With fillvalue + self.assertEqual( + list(mi.padded(seq, fillvalue='', n=4, next_multiple=True)), + [1, 2, 3, 4, 5, 6, '', ''] + ) + + +class DistributeTest(TestCase): + """Tests for distribute()""" + + def test_invalid_n(self): + self.assertRaises(ValueError, lambda: mi.distribute(-1, [1, 2, 3])) + self.assertRaises(ValueError, lambda: mi.distribute(0, [1, 2, 3])) + + def test_basic(self): + iterable = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + + for n, expected in [ + (1, [iterable]), + (2, [[1, 3, 5, 7, 9], [2, 4, 6, 8, 10]]), + (3, [[1, 4, 7, 10], [2, 5, 8], [3, 6, 9]]), + (10, [[n] for n in range(1, 10 + 1)]), + ]: + self.assertEqual( + [list(x) for x in mi.distribute(n, iterable)], expected + ) + + def test_large_n(self): + iterable = [1, 2, 3, 4] + self.assertEqual( + [list(x) for x in mi.distribute(6, iterable)], + [[1], [2], [3], [4], [], []] + ) + + +class StaggerTest(TestCase): + """Tests for ``stagger()``""" + + def test_default(self): + iterable = [0, 1, 2, 3] + actual = list(mi.stagger(iterable)) + expected = [(None, 0, 1), (0, 1, 2), (1, 2, 3)] + self.assertEqual(actual, expected) + + def test_offsets(self): + iterable = [0, 1, 2, 3] + for offsets, expected in [ + ((-2, 0, 2), [('', 0, 2), ('', 1, 3)]), + ((-2, -1), [('', ''), ('', 0), (0, 1), (1, 2), (2, 3)]), + ((1, 2), [(1, 2), (2, 3)]), + ]: + all_groups = mi.stagger(iterable, offsets=offsets, fillvalue='') + self.assertEqual(list(all_groups), expected) + + def test_longest(self): + iterable = [0, 1, 2, 3] + for offsets, expected in [ + ( + (-1, 0, 1), + [('', 0, 1), (0, 1, 2), (1, 2, 3), (2, 3, ''), (3, '', '')] + ), + ((-2, -1), [('', ''), ('', 0), (0, 1), (1, 2), (2, 3), (3, '')]), + ((1, 2), [(1, 2), (2, 3), (3, '')]), + ]: + all_groups = mi.stagger( + iterable, offsets=offsets, fillvalue='', longest=True + ) + self.assertEqual(list(all_groups), expected) + + +class ZipOffsetTest(TestCase): + """Tests for ``zip_offset()``""" + + def test_shortest(self): + a_1 = [0, 1, 2, 3] + a_2 = [0, 1, 2, 3, 4, 5] + a_3 = [0, 1, 2, 3, 4, 5, 6, 7] + actual = list( + mi.zip_offset(a_1, a_2, a_3, offsets=(-1, 0, 1), fillvalue='') + ) + expected = [('', 0, 1), (0, 1, 2), (1, 2, 3), (2, 3, 4), (3, 4, 5)] + self.assertEqual(actual, expected) + + def test_longest(self): + a_1 = [0, 1, 2, 3] + a_2 = [0, 1, 2, 3, 4, 5] + a_3 = [0, 1, 2, 3, 4, 5, 6, 7] + actual = list( + mi.zip_offset(a_1, a_2, a_3, offsets=(-1, 0, 1), longest=True) + ) + expected = [ + (None, 0, 1), + (0, 1, 2), + (1, 2, 3), + (2, 3, 4), + (3, 4, 5), + (None, 5, 6), + (None, None, 7), + ] + self.assertEqual(actual, expected) + + def test_mismatch(self): + iterables = [0, 1, 2], [2, 3, 4] + offsets = (-1, 0, 1) + self.assertRaises( + ValueError, + lambda: list(mi.zip_offset(*iterables, offsets=offsets)) + ) + + +class SortTogetherTest(TestCase): + """Tests for sort_together()""" + + def test_key_list(self): + """tests `key_list` including default, iterables include duplicates""" + iterables = [ + ['GA', 'GA', 'GA', 'CT', 'CT', 'CT'], + ['May', 'Aug.', 'May', 'June', 'July', 'July'], + [97, 20, 100, 70, 100, 20] + ] + + self.assertEqual( + mi.sort_together(iterables), + [ + ('CT', 'CT', 'CT', 'GA', 'GA', 'GA'), + ('June', 'July', 'July', 'May', 'Aug.', 'May'), + (70, 100, 20, 97, 20, 100) + ] + ) + + self.assertEqual( + mi.sort_together(iterables, key_list=(0, 1)), + [ + ('CT', 'CT', 'CT', 'GA', 'GA', 'GA'), + ('July', 'July', 'June', 'Aug.', 'May', 'May'), + (100, 20, 70, 20, 97, 100) + ] + ) + + self.assertEqual( + mi.sort_together(iterables, key_list=(0, 1, 2)), + [ + ('CT', 'CT', 'CT', 'GA', 'GA', 'GA'), + ('July', 'July', 'June', 'Aug.', 'May', 'May'), + (20, 100, 70, 20, 97, 100) + ] + ) + + self.assertEqual( + mi.sort_together(iterables, key_list=(2,)), + [ + ('GA', 'CT', 'CT', 'GA', 'GA', 'CT'), + ('Aug.', 'July', 'June', 'May', 'May', 'July'), + (20, 20, 70, 97, 100, 100) + ] + ) + + def test_invalid_key_list(self): + """tests `key_list` for indexes not available in `iterables`""" + iterables = [ + ['GA', 'GA', 'GA', 'CT', 'CT', 'CT'], + ['May', 'Aug.', 'May', 'June', 'July', 'July'], + [97, 20, 100, 70, 100, 20] + ] + + self.assertRaises( + IndexError, lambda: mi.sort_together(iterables, key_list=(5,)) + ) + + def test_reverse(self): + """tests `reverse` to ensure a reverse sort for `key_list` iterables""" + iterables = [ + ['GA', 'GA', 'GA', 'CT', 'CT', 'CT'], + ['May', 'Aug.', 'May', 'June', 'July', 'July'], + [97, 20, 100, 70, 100, 20] + ] + + self.assertEqual( + mi.sort_together(iterables, key_list=(0, 1, 2), reverse=True), + [('GA', 'GA', 'GA', 'CT', 'CT', 'CT'), + ('May', 'May', 'Aug.', 'June', 'July', 'July'), + (100, 97, 20, 70, 100, 20)] + ) + + def test_uneven_iterables(self): + """tests trimming of iterables to the shortest length before sorting""" + iterables = [['GA', 'GA', 'GA', 'CT', 'CT', 'CT', 'MA'], + ['May', 'Aug.', 'May', 'June', 'July', 'July'], + [97, 20, 100, 70, 100, 20, 0]] + + self.assertEqual( + mi.sort_together(iterables), + [ + ('CT', 'CT', 'CT', 'GA', 'GA', 'GA'), + ('June', 'July', 'July', 'May', 'Aug.', 'May'), + (70, 100, 20, 97, 20, 100) + ] + ) + + +class DivideTest(TestCase): + """Tests for divide()""" + + def test_invalid_n(self): + self.assertRaises(ValueError, lambda: mi.divide(-1, [1, 2, 3])) + self.assertRaises(ValueError, lambda: mi.divide(0, [1, 2, 3])) + + def test_basic(self): + iterable = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + + for n, expected in [ + (1, [iterable]), + (2, [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]), + (3, [[1, 2, 3, 4], [5, 6, 7], [8, 9, 10]]), + (10, [[n] for n in range(1, 10 + 1)]), + ]: + self.assertEqual( + [list(x) for x in mi.divide(n, iterable)], expected + ) + + def test_large_n(self): + iterable = [1, 2, 3, 4] + self.assertEqual( + [list(x) for x in mi.divide(6, iterable)], + [[1], [2], [3], [4], [], []] + ) + + +class TestAlwaysIterable(TestCase): + """Tests for always_iterable()""" + def test_single(self): + self.assertEqual(list(mi.always_iterable(1)), [1]) + + def test_strings(self): + for obj in ['foo', b'bar', u'baz']: + actual = list(mi.always_iterable(obj)) + expected = [obj] + self.assertEqual(actual, expected) + + def test_base_type(self): + dict_obj = {'a': 1, 'b': 2} + str_obj = '123' + + # Default: dicts are iterable like they normally are + default_actual = list(mi.always_iterable(dict_obj)) + default_expected = list(dict_obj) + self.assertEqual(default_actual, default_expected) + + # Unitary types set: dicts are not iterable + custom_actual = list(mi.always_iterable(dict_obj, base_type=dict)) + custom_expected = [dict_obj] + self.assertEqual(custom_actual, custom_expected) + + # With unitary types set, strings are iterable + str_actual = list(mi.always_iterable(str_obj, base_type=None)) + str_expected = list(str_obj) + self.assertEqual(str_actual, str_expected) + + def test_iterables(self): + self.assertEqual(list(mi.always_iterable([0, 1])), [0, 1]) + self.assertEqual( + list(mi.always_iterable([0, 1], base_type=list)), [[0, 1]] + ) + self.assertEqual( + list(mi.always_iterable(iter('foo'))), ['f', 'o', 'o'] + ) + self.assertEqual(list(mi.always_iterable([])), []) + + def test_none(self): + self.assertEqual(list(mi.always_iterable(None)), []) + + def test_generator(self): + def _gen(): + yield 0 + yield 1 + + self.assertEqual(list(mi.always_iterable(_gen())), [0, 1]) + + +class AdjacentTests(TestCase): + def test_typical(self): + actual = list(mi.adjacent(lambda x: x % 5 == 0, range(10))) + expected = [(True, 0), (True, 1), (False, 2), (False, 3), (True, 4), + (True, 5), (True, 6), (False, 7), (False, 8), (False, 9)] + self.assertEqual(actual, expected) + + def test_empty_iterable(self): + actual = list(mi.adjacent(lambda x: x % 5 == 0, [])) + expected = [] + self.assertEqual(actual, expected) + + def test_length_one(self): + actual = list(mi.adjacent(lambda x: x % 5 == 0, [0])) + expected = [(True, 0)] + self.assertEqual(actual, expected) + + actual = list(mi.adjacent(lambda x: x % 5 == 0, [1])) + expected = [(False, 1)] + self.assertEqual(actual, expected) + + def test_consecutive_true(self): + """Test that when the predicate matches multiple consecutive elements + it doesn't repeat elements in the output""" + actual = list(mi.adjacent(lambda x: x % 5 < 2, range(10))) + expected = [(True, 0), (True, 1), (True, 2), (False, 3), (True, 4), + (True, 5), (True, 6), (True, 7), (False, 8), (False, 9)] + self.assertEqual(actual, expected) + + def test_distance(self): + actual = list(mi.adjacent(lambda x: x % 5 == 0, range(10), distance=2)) + expected = [(True, 0), (True, 1), (True, 2), (True, 3), (True, 4), + (True, 5), (True, 6), (True, 7), (False, 8), (False, 9)] + self.assertEqual(actual, expected) + + actual = list(mi.adjacent(lambda x: x % 5 == 0, range(10), distance=3)) + expected = [(True, 0), (True, 1), (True, 2), (True, 3), (True, 4), + (True, 5), (True, 6), (True, 7), (True, 8), (False, 9)] + self.assertEqual(actual, expected) + + def test_large_distance(self): + """Test distance larger than the length of the iterable""" + iterable = range(10) + actual = list(mi.adjacent(lambda x: x % 5 == 4, iterable, distance=20)) + expected = list(zip(repeat(True), iterable)) + self.assertEqual(actual, expected) + + actual = list(mi.adjacent(lambda x: False, iterable, distance=20)) + expected = list(zip(repeat(False), iterable)) + self.assertEqual(actual, expected) + + def test_zero_distance(self): + """Test that adjacent() reduces to zip+map when distance is 0""" + iterable = range(1000) + predicate = lambda x: x % 4 == 2 + actual = mi.adjacent(predicate, iterable, 0) + expected = zip(map(predicate, iterable), iterable) + self.assertTrue(all(a == e for a, e in zip(actual, expected))) + + def test_negative_distance(self): + """Test that adjacent() raises an error with negative distance""" + pred = lambda x: x + self.assertRaises( + ValueError, lambda: mi.adjacent(pred, range(1000), -1) + ) + self.assertRaises( + ValueError, lambda: mi.adjacent(pred, range(10), -10) + ) + + def test_grouping(self): + """Test interaction of adjacent() with groupby_transform()""" + iterable = mi.adjacent(lambda x: x % 5 == 0, range(10)) + grouper = mi.groupby_transform(iterable, itemgetter(0), itemgetter(1)) + actual = [(k, list(g)) for k, g in grouper] + expected = [ + (True, [0, 1]), + (False, [2, 3]), + (True, [4, 5, 6]), + (False, [7, 8, 9]), + ] + self.assertEqual(actual, expected) + + def test_call_once(self): + """Test that the predicate is only called once per item.""" + already_seen = set() + iterable = range(10) + + def predicate(item): + self.assertNotIn(item, already_seen) + already_seen.add(item) + return True + + actual = list(mi.adjacent(predicate, iterable)) + expected = [(True, x) for x in iterable] + self.assertEqual(actual, expected) + + +class GroupByTransformTests(TestCase): + def assertAllGroupsEqual(self, groupby1, groupby2): + """Compare two groupby objects for equality, both keys and groups.""" + for a, b in zip(groupby1, groupby2): + key1, group1 = a + key2, group2 = b + self.assertEqual(key1, key2) + self.assertListEqual(list(group1), list(group2)) + self.assertRaises(StopIteration, lambda: next(groupby1)) + self.assertRaises(StopIteration, lambda: next(groupby2)) + + def test_default_funcs(self): + """Test that groupby_transform() with default args mimics groupby()""" + iterable = [(x // 5, x) for x in range(1000)] + actual = mi.groupby_transform(iterable) + expected = groupby(iterable) + self.assertAllGroupsEqual(actual, expected) + + def test_valuefunc(self): + iterable = [(int(x / 5), int(x / 3), x) for x in range(10)] + + # Test the standard usage of grouping one iterable using another's keys + grouper = mi.groupby_transform( + iterable, keyfunc=itemgetter(0), valuefunc=itemgetter(-1) + ) + actual = [(k, list(g)) for k, g in grouper] + expected = [(0, [0, 1, 2, 3, 4]), (1, [5, 6, 7, 8, 9])] + self.assertEqual(actual, expected) + + grouper = mi.groupby_transform( + iterable, keyfunc=itemgetter(1), valuefunc=itemgetter(-1) + ) + actual = [(k, list(g)) for k, g in grouper] + expected = [(0, [0, 1, 2]), (1, [3, 4, 5]), (2, [6, 7, 8]), (3, [9])] + self.assertEqual(actual, expected) + + # and now for something a little different + d = dict(zip(range(10), 'abcdefghij')) + grouper = mi.groupby_transform( + range(10), keyfunc=lambda x: x // 5, valuefunc=d.get + ) + actual = [(k, ''.join(g)) for k, g in grouper] + expected = [(0, 'abcde'), (1, 'fghij')] + self.assertEqual(actual, expected) + + def test_no_valuefunc(self): + iterable = range(1000) + + def key(x): + return x // 5 + + actual = mi.groupby_transform(iterable, key, valuefunc=None) + expected = groupby(iterable, key) + self.assertAllGroupsEqual(actual, expected) + + actual = mi.groupby_transform(iterable, key) # default valuefunc + expected = groupby(iterable, key) + self.assertAllGroupsEqual(actual, expected) + + +class NumericRangeTests(TestCase): + def test_basic(self): + for args, expected in [ + ((4,), [0, 1, 2, 3]), + ((4.0,), [0.0, 1.0, 2.0, 3.0]), + ((1.0, 4), [1.0, 2.0, 3.0]), + ((1, 4.0), [1, 2, 3]), + ((1.0, 5), [1.0, 2.0, 3.0, 4.0]), + ((0, 20, 5), [0, 5, 10, 15]), + ((0, 20, 5.0), [0.0, 5.0, 10.0, 15.0]), + ((0, 10, 3), [0, 3, 6, 9]), + ((0, 10, 3.0), [0.0, 3.0, 6.0, 9.0]), + ((0, -5, -1), [0, -1, -2, -3, -4]), + ((0.0, -5, -1), [0.0, -1.0, -2.0, -3.0, -4.0]), + ((1, 2, Fraction(1, 2)), [Fraction(1, 1), Fraction(3, 2)]), + ((0,), []), + ((0.0,), []), + ((1, 0), []), + ((1.0, 0.0), []), + ((Fraction(2, 1),), [Fraction(0, 1), Fraction(1, 1)]), + ((Decimal('2.0'),), [Decimal('0.0'), Decimal('1.0')]), + ]: + actual = list(mi.numeric_range(*args)) + self.assertEqual(actual, expected) + self.assertTrue( + all(type(a) == type(e) for a, e in zip(actual, expected)) + ) + + def test_arg_count(self): + self.assertRaises(TypeError, lambda: list(mi.numeric_range())) + self.assertRaises( + TypeError, lambda: list(mi.numeric_range(0, 1, 2, 3)) + ) + + def test_zero_step(self): + self.assertRaises( + ValueError, lambda: list(mi.numeric_range(1, 2, 0)) + ) + + +class CountCycleTests(TestCase): + def test_basic(self): + expected = [ + (0, 'a'), (0, 'b'), (0, 'c'), + (1, 'a'), (1, 'b'), (1, 'c'), + (2, 'a'), (2, 'b'), (2, 'c'), + ] + for actual in [ + mi.take(9, mi.count_cycle('abc')), # n=None + list(mi.count_cycle('abc', 3)), # n=3 + ]: + self.assertEqual(actual, expected) + + def test_empty(self): + self.assertEqual(list(mi.count_cycle('')), []) + self.assertEqual(list(mi.count_cycle('', 2)), []) + + def test_negative(self): + self.assertEqual(list(mi.count_cycle('abc', -3)), []) + + +class LocateTests(TestCase): + def test_default_pred(self): + iterable = [0, 1, 1, 0, 1, 0, 0] + actual = list(mi.locate(iterable)) + expected = [1, 2, 4] + self.assertEqual(actual, expected) + + def test_no_matches(self): + iterable = [0, 0, 0] + actual = list(mi.locate(iterable)) + expected = [] + self.assertEqual(actual, expected) + + def test_custom_pred(self): + iterable = ['0', 1, 1, '0', 1, '0', '0'] + pred = lambda x: x == '0' + actual = list(mi.locate(iterable, pred)) + expected = [0, 3, 5, 6] + self.assertEqual(actual, expected) + + def test_window_size(self): + iterable = ['0', 1, 1, '0', 1, '0', '0'] + pred = lambda *args: args == ('0', 1) + actual = list(mi.locate(iterable, pred, window_size=2)) + expected = [0, 3] + self.assertEqual(actual, expected) + + def test_window_size_large(self): + iterable = [1, 2, 3, 4] + pred = lambda a, b, c, d, e: True + actual = list(mi.locate(iterable, pred, window_size=5)) + expected = [0] + self.assertEqual(actual, expected) + + def test_window_size_zero(self): + iterable = [1, 2, 3, 4] + pred = lambda: True + with self.assertRaises(ValueError): + list(mi.locate(iterable, pred, window_size=0)) + + +class StripFunctionTests(TestCase): + def test_hashable(self): + iterable = list('www.example.com') + pred = lambda x: x in set('cmowz.') + + self.assertEqual(list(mi.lstrip(iterable, pred)), list('example.com')) + self.assertEqual(list(mi.rstrip(iterable, pred)), list('www.example')) + self.assertEqual(list(mi.strip(iterable, pred)), list('example')) + + def test_not_hashable(self): + iterable = [ + list('http://'), list('www'), list('.example'), list('.com') + ] + pred = lambda x: x in [list('http://'), list('www'), list('.com')] + + self.assertEqual(list(mi.lstrip(iterable, pred)), iterable[2:]) + self.assertEqual(list(mi.rstrip(iterable, pred)), iterable[:3]) + self.assertEqual(list(mi.strip(iterable, pred)), iterable[2: 3]) + + def test_math(self): + iterable = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2] + pred = lambda x: x <= 2 + + self.assertEqual(list(mi.lstrip(iterable, pred)), iterable[3:]) + self.assertEqual(list(mi.rstrip(iterable, pred)), iterable[:-3]) + self.assertEqual(list(mi.strip(iterable, pred)), iterable[3:-3]) + + +class IsliceExtendedTests(TestCase): + def test_all(self): + iterable = ['0', '1', '2', '3', '4', '5'] + indexes = list(range(-4, len(iterable) + 4)) + [None] + steps = [1, 2, 3, 4, -1, -2, -3, 4] + for slice_args in product(indexes, indexes, steps): + try: + actual = list(mi.islice_extended(iterable, *slice_args)) + except Exception as e: + self.fail((slice_args, e)) + + expected = iterable[slice(*slice_args)] + self.assertEqual(actual, expected, slice_args) + + def test_zero_step(self): + with self.assertRaises(ValueError): + list(mi.islice_extended([1, 2, 3], 0, 1, 0)) + + +class ConsecutiveGroupsTest(TestCase): + def test_numbers(self): + iterable = [-10, -8, -7, -6, 1, 2, 4, 5, -1, 7] + actual = [list(g) for g in mi.consecutive_groups(iterable)] + expected = [[-10], [-8, -7, -6], [1, 2], [4, 5], [-1], [7]] + self.assertEqual(actual, expected) + + def test_custom_ordering(self): + iterable = ['1', '10', '11', '20', '21', '22', '30', '31'] + ordering = lambda x: int(x) + actual = [list(g) for g in mi.consecutive_groups(iterable, ordering)] + expected = [['1'], ['10', '11'], ['20', '21', '22'], ['30', '31']] + self.assertEqual(actual, expected) + + def test_exotic_ordering(self): + iterable = [ + ('a', 'b', 'c', 'd'), + ('a', 'c', 'b', 'd'), + ('a', 'c', 'd', 'b'), + ('a', 'd', 'b', 'c'), + ('d', 'b', 'c', 'a'), + ('d', 'c', 'a', 'b'), + ] + ordering = list(permutations('abcd')).index + actual = [list(g) for g in mi.consecutive_groups(iterable, ordering)] + expected = [ + [('a', 'b', 'c', 'd')], + [('a', 'c', 'b', 'd'), ('a', 'c', 'd', 'b'), ('a', 'd', 'b', 'c')], + [('d', 'b', 'c', 'a'), ('d', 'c', 'a', 'b')], + ] + self.assertEqual(actual, expected) + + +class DifferenceTest(TestCase): + def test_normal(self): + iterable = [10, 20, 30, 40, 50] + actual = list(mi.difference(iterable)) + expected = [10, 10, 10, 10, 10] + self.assertEqual(actual, expected) + + def test_custom(self): + iterable = [10, 20, 30, 40, 50] + actual = list(mi.difference(iterable, add)) + expected = [10, 30, 50, 70, 90] + self.assertEqual(actual, expected) + + def test_roundtrip(self): + original = list(range(100)) + accumulated = mi.accumulate(original) + actual = list(mi.difference(accumulated)) + self.assertEqual(actual, original) + + def test_one(self): + self.assertEqual(list(mi.difference([0])), [0]) + + def test_empty(self): + self.assertEqual(list(mi.difference([])), []) + + +class SeekableTest(TestCase): + def test_exhaustion_reset(self): + iterable = [str(n) for n in range(10)] + + s = mi.seekable(iterable) + self.assertEqual(list(s), iterable) # Normal iteration + self.assertEqual(list(s), []) # Iterable is exhausted + + s.seek(0) + self.assertEqual(list(s), iterable) # Back in action + + def test_partial_reset(self): + iterable = [str(n) for n in range(10)] + + s = mi.seekable(iterable) + self.assertEqual(mi.take(5, s), iterable[:5]) # Normal iteration + + s.seek(1) + self.assertEqual(list(s), iterable[1:]) # Get the rest of the iterable + + def test_forward(self): + iterable = [str(n) for n in range(10)] + + s = mi.seekable(iterable) + self.assertEqual(mi.take(1, s), iterable[:1]) # Normal iteration + + s.seek(3) # Skip over index 2 + self.assertEqual(list(s), iterable[3:]) # Result is similar to slicing + + s.seek(0) # Back to 0 + self.assertEqual(list(s), iterable) # No difference in result + + def test_past_end(self): + iterable = [str(n) for n in range(10)] + + s = mi.seekable(iterable) + self.assertEqual(mi.take(1, s), iterable[:1]) # Normal iteration + + s.seek(20) + self.assertEqual(list(s), []) # Iterable is exhausted + + s.seek(0) # Back to 0 + self.assertEqual(list(s), iterable) # No difference in result + + def test_elements(self): + iterable = map(str, count()) + + s = mi.seekable(iterable) + mi.take(10, s) + + elements = s.elements() + self.assertEqual( + [elements[i] for i in range(10)], [str(n) for n in range(10)] + ) + self.assertEqual(len(elements), 10) + + mi.take(10, s) + self.assertEqual(list(elements), [str(n) for n in range(20)]) + + +class SequenceViewTests(TestCase): + def test_init(self): + view = mi.SequenceView((1, 2, 3)) + self.assertEqual(repr(view), "SequenceView((1, 2, 3))") + self.assertRaises(TypeError, lambda: mi.SequenceView({})) + + def test_update(self): + seq = [1, 2, 3] + view = mi.SequenceView(seq) + self.assertEqual(len(view), 3) + self.assertEqual(repr(view), "SequenceView([1, 2, 3])") + + seq.pop() + self.assertEqual(len(view), 2) + self.assertEqual(repr(view), "SequenceView([1, 2])") + + def test_indexing(self): + seq = ('a', 'b', 'c', 'd', 'e', 'f') + view = mi.SequenceView(seq) + for i in range(-len(seq), len(seq)): + self.assertEqual(view[i], seq[i]) + + def test_slicing(self): + seq = ('a', 'b', 'c', 'd', 'e', 'f') + view = mi.SequenceView(seq) + n = len(seq) + indexes = list(range(-n - 1, n + 1)) + [None] + steps = list(range(-n, n + 1)) + steps.remove(0) + for slice_args in product(indexes, indexes, steps): + i = slice(*slice_args) + self.assertEqual(view[i], seq[i]) + + def test_abc_methods(self): + # collections.Sequence should provide all of this functionality + seq = ('a', 'b', 'c', 'd', 'e', 'f', 'f') + view = mi.SequenceView(seq) + + # __contains__ + self.assertIn('b', view) + self.assertNotIn('g', view) + + # __iter__ + self.assertEqual(list(iter(view)), list(seq)) + + # __reversed__ + self.assertEqual(list(reversed(view)), list(reversed(seq))) + + # index + self.assertEqual(view.index('b'), 1) + + # count + self.assertEqual(seq.count('f'), 2) + + +class RunLengthTest(TestCase): + def test_encode(self): + iterable = (int(str(n)[0]) for n in count(800)) + actual = mi.take(4, mi.run_length.encode(iterable)) + expected = [(8, 100), (9, 100), (1, 1000), (2, 1000)] + self.assertEqual(actual, expected) + + def test_decode(self): + iterable = [('d', 4), ('c', 3), ('b', 2), ('a', 1)] + actual = ''.join(mi.run_length.decode(iterable)) + expected = 'ddddcccbba' + self.assertEqual(actual, expected) + + +class ExactlyNTests(TestCase): + """Tests for ``exactly_n()``""" + + def test_true(self): + """Iterable has ``n`` ``True`` elements""" + self.assertTrue(mi.exactly_n([True, False, True], 2)) + self.assertTrue(mi.exactly_n([1, 1, 1, 0], 3)) + self.assertTrue(mi.exactly_n([False, False], 0)) + self.assertTrue(mi.exactly_n(range(100), 10, lambda x: x < 10)) + + def test_false(self): + """Iterable does not have ``n`` ``True`` elements""" + self.assertFalse(mi.exactly_n([True, False, False], 2)) + self.assertFalse(mi.exactly_n([True, True, False], 1)) + self.assertFalse(mi.exactly_n([False], 1)) + self.assertFalse(mi.exactly_n([True], -1)) + self.assertFalse(mi.exactly_n(repeat(True), 100)) + + def test_empty(self): + """Return ``True`` if the iterable is empty and ``n`` is 0""" + self.assertTrue(mi.exactly_n([], 0)) + self.assertFalse(mi.exactly_n([], 1)) + + +class AlwaysReversibleTests(TestCase): + """Tests for ``always_reversible()``""" + + def test_regular_reversed(self): + self.assertEqual(list(reversed(range(10))), + list(mi.always_reversible(range(10)))) + self.assertEqual(list(reversed([1, 2, 3])), + list(mi.always_reversible([1, 2, 3]))) + self.assertEqual(reversed([1, 2, 3]).__class__, + mi.always_reversible([1, 2, 3]).__class__) + + def test_nonseq_reversed(self): + # Create a non-reversible generator from a sequence + with self.assertRaises(TypeError): + reversed(x for x in range(10)) + + self.assertEqual(list(reversed(range(10))), + list(mi.always_reversible(x for x in range(10)))) + self.assertEqual(list(reversed([1, 2, 3])), + list(mi.always_reversible(x for x in [1, 2, 3]))) + self.assertNotEqual(reversed((1, 2)).__class__, + mi.always_reversible(x for x in (1, 2)).__class__) + + +class CircularShiftsTests(TestCase): + def test_empty(self): + # empty iterable -> empty list + self.assertEqual(list(mi.circular_shifts([])), []) + + def test_simple_circular_shifts(self): + # test the a simple iterator case + self.assertEqual( + mi.circular_shifts(range(4)), + [(0, 1, 2, 3), (1, 2, 3, 0), (2, 3, 0, 1), (3, 0, 1, 2)] + ) + + def test_duplicates(self): + # test non-distinct entries + self.assertEqual( + mi.circular_shifts([0, 1, 0, 1]), + [(0, 1, 0, 1), (1, 0, 1, 0), (0, 1, 0, 1), (1, 0, 1, 0)] + ) + + +class MakeDecoratorTests(TestCase): + def test_basic(self): + slicer = mi.make_decorator(islice) + + @slicer(1, 10, 2) + def user_function(arg_1, arg_2, kwarg_1=None): + self.assertEqual(arg_1, 'arg_1') + self.assertEqual(arg_2, 'arg_2') + self.assertEqual(kwarg_1, 'kwarg_1') + return map(str, count()) + + it = user_function('arg_1', 'arg_2', kwarg_1='kwarg_1') + actual = list(it) + expected = ['1', '3', '5', '7', '9'] + self.assertEqual(actual, expected) + + def test_result_index(self): + def stringify(*args, **kwargs): + self.assertEqual(args[0], 'arg_0') + iterable = args[1] + self.assertEqual(args[2], 'arg_2') + self.assertEqual(kwargs['kwarg_1'], 'kwarg_1') + return map(str, iterable) + + stringifier = mi.make_decorator(stringify, result_index=1) + + @stringifier('arg_0', 'arg_2', kwarg_1='kwarg_1') + def user_function(n): + return count(n) + + it = user_function(1) + actual = mi.take(5, it) + expected = ['1', '2', '3', '4', '5'] + self.assertEqual(actual, expected) + + def test_wrap_class(self): + seeker = mi.make_decorator(mi.seekable) + + @seeker() + def user_function(n): + return map(str, range(n)) + + it = user_function(5) + self.assertEqual(list(it), ['0', '1', '2', '3', '4']) + + it.seek(0) + self.assertEqual(list(it), ['0', '1', '2', '3', '4']) + + +class MapReduceTests(TestCase): + def test_default(self): + iterable = (str(x) for x in range(5)) + keyfunc = lambda x: int(x) // 2 + actual = sorted(mi.map_reduce(iterable, keyfunc).items()) + expected = [(0, ['0', '1']), (1, ['2', '3']), (2, ['4'])] + self.assertEqual(actual, expected) + + def test_valuefunc(self): + iterable = (str(x) for x in range(5)) + keyfunc = lambda x: int(x) // 2 + valuefunc = int + actual = sorted(mi.map_reduce(iterable, keyfunc, valuefunc).items()) + expected = [(0, [0, 1]), (1, [2, 3]), (2, [4])] + self.assertEqual(actual, expected) + + def test_reducefunc(self): + iterable = (str(x) for x in range(5)) + keyfunc = lambda x: int(x) // 2 + valuefunc = int + reducefunc = lambda value_list: reduce(mul, value_list, 1) + actual = sorted( + mi.map_reduce(iterable, keyfunc, valuefunc, reducefunc).items() + ) + expected = [(0, 0), (1, 6), (2, 4)] + self.assertEqual(actual, expected) + + def test_ret(self): + d = mi.map_reduce([1, 0, 2, 0, 1, 0], bool) + self.assertEqual(d, {False: [0, 0, 0], True: [1, 2, 1]}) + self.assertRaises(KeyError, lambda: d[None].append(1)) + + +class RlocateTests(TestCase): + def test_default_pred(self): + iterable = [0, 1, 1, 0, 1, 0, 0] + for it in (iterable[:], iter(iterable)): + actual = list(mi.rlocate(it)) + expected = [4, 2, 1] + self.assertEqual(actual, expected) + + def test_no_matches(self): + iterable = [0, 0, 0] + for it in (iterable[:], iter(iterable)): + actual = list(mi.rlocate(it)) + expected = [] + self.assertEqual(actual, expected) + + def test_custom_pred(self): + iterable = ['0', 1, 1, '0', 1, '0', '0'] + pred = lambda x: x == '0' + for it in (iterable[:], iter(iterable)): + actual = list(mi.rlocate(it, pred)) + expected = [6, 5, 3, 0] + self.assertEqual(actual, expected) + + def test_efficient_reversal(self): + iterable = range(10 ** 10) # Is efficiently reversible + target = 10 ** 10 - 2 + pred = lambda x: x == target # Find-able from the right + actual = next(mi.rlocate(iterable, pred)) + self.assertEqual(actual, target) + + def test_window_size(self): + iterable = ['0', 1, 1, '0', 1, '0', '0'] + pred = lambda *args: args == ('0', 1) + for it in (iterable, iter(iterable)): + actual = list(mi.rlocate(it, pred, window_size=2)) + expected = [3, 0] + self.assertEqual(actual, expected) + + def test_window_size_large(self): + iterable = [1, 2, 3, 4] + pred = lambda a, b, c, d, e: True + for it in (iterable, iter(iterable)): + actual = list(mi.rlocate(iterable, pred, window_size=5)) + expected = [0] + self.assertEqual(actual, expected) + + def test_window_size_zero(self): + iterable = [1, 2, 3, 4] + pred = lambda: True + for it in (iterable, iter(iterable)): + with self.assertRaises(ValueError): + list(mi.locate(iterable, pred, window_size=0)) + + +class ReplaceTests(TestCase): + def test_basic(self): + iterable = range(10) + pred = lambda x: x % 2 == 0 + substitutes = [] + actual = list(mi.replace(iterable, pred, substitutes)) + expected = [1, 3, 5, 7, 9] + self.assertEqual(actual, expected) + + def test_count(self): + iterable = range(10) + pred = lambda x: x % 2 == 0 + substitutes = [] + actual = list(mi.replace(iterable, pred, substitutes, count=4)) + expected = [1, 3, 5, 7, 8, 9] + self.assertEqual(actual, expected) + + def test_window_size(self): + iterable = range(10) + pred = lambda *args: args == (0, 1, 2) + substitutes = [] + actual = list(mi.replace(iterable, pred, substitutes, window_size=3)) + expected = [3, 4, 5, 6, 7, 8, 9] + self.assertEqual(actual, expected) + + def test_window_size_end(self): + iterable = range(10) + pred = lambda *args: args == (7, 8, 9) + substitutes = [] + actual = list(mi.replace(iterable, pred, substitutes, window_size=3)) + expected = [0, 1, 2, 3, 4, 5, 6] + self.assertEqual(actual, expected) + + def test_window_size_count(self): + iterable = range(10) + pred = lambda *args: (args == (0, 1, 2)) or (args == (7, 8, 9)) + substitutes = [] + actual = list( + mi.replace(iterable, pred, substitutes, count=1, window_size=3) + ) + expected = [3, 4, 5, 6, 7, 8, 9] + self.assertEqual(actual, expected) + + def test_window_size_large(self): + iterable = range(4) + pred = lambda a, b, c, d, e: True + substitutes = [5, 6, 7] + actual = list(mi.replace(iterable, pred, substitutes, window_size=5)) + expected = [5, 6, 7] + self.assertEqual(actual, expected) + + def test_window_size_zero(self): + iterable = range(10) + pred = lambda *args: True + substitutes = [] + with self.assertRaises(ValueError): + list(mi.replace(iterable, pred, substitutes, window_size=0)) + + def test_iterable_substitutes(self): + iterable = range(5) + pred = lambda x: x % 2 == 0 + substitutes = iter('__') + actual = list(mi.replace(iterable, pred, substitutes)) + expected = ['_', '_', 1, '_', '_', 3, '_', '_'] + self.assertEqual(actual, expected) diff --git a/libraries/more_itertools/tests/test_recipes.py b/libraries/more_itertools/tests/test_recipes.py new file mode 100644 index 00000000..98981fe8 --- /dev/null +++ b/libraries/more_itertools/tests/test_recipes.py @@ -0,0 +1,616 @@ +from doctest import DocTestSuite +from unittest import TestCase + +from itertools import combinations +from six.moves import range + +import more_itertools as mi + + +def load_tests(loader, tests, ignore): + # Add the doctests + tests.addTests(DocTestSuite('more_itertools.recipes')) + return tests + + +class AccumulateTests(TestCase): + """Tests for ``accumulate()``""" + + def test_empty(self): + """Test that an empty input returns an empty output""" + self.assertEqual(list(mi.accumulate([])), []) + + def test_default(self): + """Test accumulate with the default function (addition)""" + self.assertEqual(list(mi.accumulate([1, 2, 3])), [1, 3, 6]) + + def test_bogus_function(self): + """Test accumulate with an invalid function""" + with self.assertRaises(TypeError): + list(mi.accumulate([1, 2, 3], func=lambda x: x)) + + def test_custom_function(self): + """Test accumulate with a custom function""" + self.assertEqual( + list(mi.accumulate((1, 2, 3, 2, 1), func=max)), [1, 2, 3, 3, 3] + ) + + +class TakeTests(TestCase): + """Tests for ``take()``""" + + def test_simple_take(self): + """Test basic usage""" + t = mi.take(5, range(10)) + self.assertEqual(t, [0, 1, 2, 3, 4]) + + def test_null_take(self): + """Check the null case""" + t = mi.take(0, range(10)) + self.assertEqual(t, []) + + def test_negative_take(self): + """Make sure taking negative items results in a ValueError""" + self.assertRaises(ValueError, lambda: mi.take(-3, range(10))) + + def test_take_too_much(self): + """Taking more than an iterator has remaining should return what the + iterator has remaining. + + """ + t = mi.take(10, range(5)) + self.assertEqual(t, [0, 1, 2, 3, 4]) + + +class TabulateTests(TestCase): + """Tests for ``tabulate()``""" + + def test_simple_tabulate(self): + """Test the happy path""" + t = mi.tabulate(lambda x: x) + f = tuple([next(t) for _ in range(3)]) + self.assertEqual(f, (0, 1, 2)) + + def test_count(self): + """Ensure tabulate accepts specific count""" + t = mi.tabulate(lambda x: 2 * x, -1) + f = (next(t), next(t), next(t)) + self.assertEqual(f, (-2, 0, 2)) + + +class TailTests(TestCase): + """Tests for ``tail()``""" + + def test_greater(self): + """Length of iterable is greather than requested tail""" + self.assertEqual(list(mi.tail(3, 'ABCDEFG')), ['E', 'F', 'G']) + + def test_equal(self): + """Length of iterable is equal to the requested tail""" + self.assertEqual( + list(mi.tail(7, 'ABCDEFG')), ['A', 'B', 'C', 'D', 'E', 'F', 'G'] + ) + + def test_less(self): + """Length of iterable is less than requested tail""" + self.assertEqual( + list(mi.tail(8, 'ABCDEFG')), ['A', 'B', 'C', 'D', 'E', 'F', 'G'] + ) + + +class ConsumeTests(TestCase): + """Tests for ``consume()``""" + + def test_sanity(self): + """Test basic functionality""" + r = (x for x in range(10)) + mi.consume(r, 3) + self.assertEqual(3, next(r)) + + def test_null_consume(self): + """Check the null case""" + r = (x for x in range(10)) + mi.consume(r, 0) + self.assertEqual(0, next(r)) + + def test_negative_consume(self): + """Check that negative consumsion throws an error""" + r = (x for x in range(10)) + self.assertRaises(ValueError, lambda: mi.consume(r, -1)) + + def test_total_consume(self): + """Check that iterator is totally consumed by default""" + r = (x for x in range(10)) + mi.consume(r) + self.assertRaises(StopIteration, lambda: next(r)) + + +class NthTests(TestCase): + """Tests for ``nth()``""" + + def test_basic(self): + """Make sure the nth item is returned""" + l = range(10) + for i, v in enumerate(l): + self.assertEqual(mi.nth(l, i), v) + + def test_default(self): + """Ensure a default value is returned when nth item not found""" + l = range(3) + self.assertEqual(mi.nth(l, 100, "zebra"), "zebra") + + def test_negative_item_raises(self): + """Ensure asking for a negative item raises an exception""" + self.assertRaises(ValueError, lambda: mi.nth(range(10), -3)) + + +class AllEqualTests(TestCase): + """Tests for ``all_equal()``""" + + def test_true(self): + """Everything is equal""" + self.assertTrue(mi.all_equal('aaaaaa')) + self.assertTrue(mi.all_equal([0, 0, 0, 0])) + + def test_false(self): + """Not everything is equal""" + self.assertFalse(mi.all_equal('aaaaab')) + self.assertFalse(mi.all_equal([0, 0, 0, 1])) + + def test_tricky(self): + """Not everything is identical, but everything is equal""" + items = [1, complex(1, 0), 1.0] + self.assertTrue(mi.all_equal(items)) + + def test_empty(self): + """Return True if the iterable is empty""" + self.assertTrue(mi.all_equal('')) + self.assertTrue(mi.all_equal([])) + + def test_one(self): + """Return True if the iterable is singular""" + self.assertTrue(mi.all_equal('0')) + self.assertTrue(mi.all_equal([0])) + + +class QuantifyTests(TestCase): + """Tests for ``quantify()``""" + + def test_happy_path(self): + """Make sure True count is returned""" + q = [True, False, True] + self.assertEqual(mi.quantify(q), 2) + + def test_custom_predicate(self): + """Ensure non-default predicates return as expected""" + q = range(10) + self.assertEqual(mi.quantify(q, lambda x: x % 2 == 0), 5) + + +class PadnoneTests(TestCase): + """Tests for ``padnone()``""" + + def test_happy_path(self): + """wrapper iterator should return None indefinitely""" + r = range(2) + p = mi.padnone(r) + self.assertEqual([0, 1, None, None], [next(p) for _ in range(4)]) + + +class NcyclesTests(TestCase): + """Tests for ``nyclces()``""" + + def test_happy_path(self): + """cycle a sequence three times""" + r = ["a", "b", "c"] + n = mi.ncycles(r, 3) + self.assertEqual( + ["a", "b", "c", "a", "b", "c", "a", "b", "c"], + list(n) + ) + + def test_null_case(self): + """asking for 0 cycles should return an empty iterator""" + n = mi.ncycles(range(100), 0) + self.assertRaises(StopIteration, lambda: next(n)) + + def test_pathalogical_case(self): + """asking for negative cycles should return an empty iterator""" + n = mi.ncycles(range(100), -10) + self.assertRaises(StopIteration, lambda: next(n)) + + +class DotproductTests(TestCase): + """Tests for ``dotproduct()``'""" + + def test_happy_path(self): + """simple dotproduct example""" + self.assertEqual(400, mi.dotproduct([10, 10], [20, 20])) + + +class FlattenTests(TestCase): + """Tests for ``flatten()``""" + + def test_basic_usage(self): + """ensure list of lists is flattened one level""" + f = [[0, 1, 2], [3, 4, 5]] + self.assertEqual(list(range(6)), list(mi.flatten(f))) + + def test_single_level(self): + """ensure list of lists is flattened only one level""" + f = [[0, [1, 2]], [[3, 4], 5]] + self.assertEqual([0, [1, 2], [3, 4], 5], list(mi.flatten(f))) + + +class RepeatfuncTests(TestCase): + """Tests for ``repeatfunc()``""" + + def test_simple_repeat(self): + """test simple repeated functions""" + r = mi.repeatfunc(lambda: 5) + self.assertEqual([5, 5, 5, 5, 5], [next(r) for _ in range(5)]) + + def test_finite_repeat(self): + """ensure limited repeat when times is provided""" + r = mi.repeatfunc(lambda: 5, times=5) + self.assertEqual([5, 5, 5, 5, 5], list(r)) + + def test_added_arguments(self): + """ensure arguments are applied to the function""" + r = mi.repeatfunc(lambda x: x, 2, 3) + self.assertEqual([3, 3], list(r)) + + def test_null_times(self): + """repeat 0 should return an empty iterator""" + r = mi.repeatfunc(range, 0, 3) + self.assertRaises(StopIteration, lambda: next(r)) + + +class PairwiseTests(TestCase): + """Tests for ``pairwise()``""" + + def test_base_case(self): + """ensure an iterable will return pairwise""" + p = mi.pairwise([1, 2, 3]) + self.assertEqual([(1, 2), (2, 3)], list(p)) + + def test_short_case(self): + """ensure an empty iterator if there's not enough values to pair""" + p = mi.pairwise("a") + self.assertRaises(StopIteration, lambda: next(p)) + + +class GrouperTests(TestCase): + """Tests for ``grouper()``""" + + def test_even(self): + """Test when group size divides evenly into the length of + the iterable. + + """ + self.assertEqual( + list(mi.grouper(3, 'ABCDEF')), [('A', 'B', 'C'), ('D', 'E', 'F')] + ) + + def test_odd(self): + """Test when group size does not divide evenly into the length of the + iterable. + + """ + self.assertEqual( + list(mi.grouper(3, 'ABCDE')), [('A', 'B', 'C'), ('D', 'E', None)] + ) + + def test_fill_value(self): + """Test that the fill value is used to pad the final group""" + self.assertEqual( + list(mi.grouper(3, 'ABCDE', 'x')), + [('A', 'B', 'C'), ('D', 'E', 'x')] + ) + + +class RoundrobinTests(TestCase): + """Tests for ``roundrobin()``""" + + def test_even_groups(self): + """Ensure ordered output from evenly populated iterables""" + self.assertEqual( + list(mi.roundrobin('ABC', [1, 2, 3], range(3))), + ['A', 1, 0, 'B', 2, 1, 'C', 3, 2] + ) + + def test_uneven_groups(self): + """Ensure ordered output from unevenly populated iterables""" + self.assertEqual( + list(mi.roundrobin('ABCD', [1, 2], range(0))), + ['A', 1, 'B', 2, 'C', 'D'] + ) + + +class PartitionTests(TestCase): + """Tests for ``partition()``""" + + def test_bool(self): + """Test when pred() returns a boolean""" + lesser, greater = mi.partition(lambda x: x > 5, range(10)) + self.assertEqual(list(lesser), [0, 1, 2, 3, 4, 5]) + self.assertEqual(list(greater), [6, 7, 8, 9]) + + def test_arbitrary(self): + """Test when pred() returns an integer""" + divisibles, remainders = mi.partition(lambda x: x % 3, range(10)) + self.assertEqual(list(divisibles), [0, 3, 6, 9]) + self.assertEqual(list(remainders), [1, 2, 4, 5, 7, 8]) + + +class PowersetTests(TestCase): + """Tests for ``powerset()``""" + + def test_combinatorics(self): + """Ensure a proper enumeration""" + p = mi.powerset([1, 2, 3]) + self.assertEqual( + list(p), + [(), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)] + ) + + +class UniqueEverseenTests(TestCase): + """Tests for ``unique_everseen()``""" + + def test_everseen(self): + """ensure duplicate elements are ignored""" + u = mi.unique_everseen('AAAABBBBCCDAABBB') + self.assertEqual( + ['A', 'B', 'C', 'D'], + list(u) + ) + + def test_custom_key(self): + """ensure the custom key comparison works""" + u = mi.unique_everseen('aAbACCc', key=str.lower) + self.assertEqual(list('abC'), list(u)) + + def test_unhashable(self): + """ensure things work for unhashable items""" + iterable = ['a', [1, 2, 3], [1, 2, 3], 'a'] + u = mi.unique_everseen(iterable) + self.assertEqual(list(u), ['a', [1, 2, 3]]) + + def test_unhashable_key(self): + """ensure things work for unhashable items with a custom key""" + iterable = ['a', [1, 2, 3], [1, 2, 3], 'a'] + u = mi.unique_everseen(iterable, key=lambda x: x) + self.assertEqual(list(u), ['a', [1, 2, 3]]) + + +class UniqueJustseenTests(TestCase): + """Tests for ``unique_justseen()``""" + + def test_justseen(self): + """ensure only last item is remembered""" + u = mi.unique_justseen('AAAABBBCCDABB') + self.assertEqual(list('ABCDAB'), list(u)) + + def test_custom_key(self): + """ensure the custom key comparison works""" + u = mi.unique_justseen('AABCcAD', str.lower) + self.assertEqual(list('ABCAD'), list(u)) + + +class IterExceptTests(TestCase): + """Tests for ``iter_except()``""" + + def test_exact_exception(self): + """ensure the exact specified exception is caught""" + l = [1, 2, 3] + i = mi.iter_except(l.pop, IndexError) + self.assertEqual(list(i), [3, 2, 1]) + + def test_generic_exception(self): + """ensure the generic exception can be caught""" + l = [1, 2] + i = mi.iter_except(l.pop, Exception) + self.assertEqual(list(i), [2, 1]) + + def test_uncaught_exception_is_raised(self): + """ensure a non-specified exception is raised""" + l = [1, 2, 3] + i = mi.iter_except(l.pop, KeyError) + self.assertRaises(IndexError, lambda: list(i)) + + def test_first(self): + """ensure first is run before the function""" + l = [1, 2, 3] + f = lambda: 25 + i = mi.iter_except(l.pop, IndexError, f) + self.assertEqual(list(i), [25, 3, 2, 1]) + + +class FirstTrueTests(TestCase): + """Tests for ``first_true()``""" + + def test_something_true(self): + """Test with no keywords""" + self.assertEqual(mi.first_true(range(10)), 1) + + def test_nothing_true(self): + """Test default return value.""" + self.assertEqual(mi.first_true([0, 0, 0]), False) + + def test_default(self): + """Test with a default keyword""" + self.assertEqual(mi.first_true([0, 0, 0], default='!'), '!') + + def test_pred(self): + """Test with a custom predicate""" + self.assertEqual( + mi.first_true([2, 4, 6], pred=lambda x: x % 3 == 0), 6 + ) + + +class RandomProductTests(TestCase): + """Tests for ``random_product()`` + + Since random.choice() has different results with the same seed across + python versions 2.x and 3.x, these tests use highly probably events to + create predictable outcomes across platforms. + """ + + def test_simple_lists(self): + """Ensure that one item is chosen from each list in each pair. + Also ensure that each item from each list eventually appears in + the chosen combinations. + + Odds are roughly 1 in 7.1 * 10e16 that one item from either list will + not be chosen after 100 samplings of one item from each list. Just to + be safe, better use a known random seed, too. + + """ + nums = [1, 2, 3] + lets = ['a', 'b', 'c'] + n, m = zip(*[mi.random_product(nums, lets) for _ in range(100)]) + n, m = set(n), set(m) + self.assertEqual(n, set(nums)) + self.assertEqual(m, set(lets)) + self.assertEqual(len(n), len(nums)) + self.assertEqual(len(m), len(lets)) + + def test_list_with_repeat(self): + """ensure multiple items are chosen, and that they appear to be chosen + from one list then the next, in proper order. + + """ + nums = [1, 2, 3] + lets = ['a', 'b', 'c'] + r = list(mi.random_product(nums, lets, repeat=100)) + self.assertEqual(2 * 100, len(r)) + n, m = set(r[::2]), set(r[1::2]) + self.assertEqual(n, set(nums)) + self.assertEqual(m, set(lets)) + self.assertEqual(len(n), len(nums)) + self.assertEqual(len(m), len(lets)) + + +class RandomPermutationTests(TestCase): + """Tests for ``random_permutation()``""" + + def test_full_permutation(self): + """ensure every item from the iterable is returned in a new ordering + + 15 elements have a 1 in 1.3 * 10e12 of appearing in sorted order, so + we fix a seed value just to be sure. + + """ + i = range(15) + r = mi.random_permutation(i) + self.assertEqual(set(i), set(r)) + if i == r: + raise AssertionError("Values were not permuted") + + def test_partial_permutation(self): + """ensure all returned items are from the iterable, that the returned + permutation is of the desired length, and that all items eventually + get returned. + + Sampling 100 permutations of length 5 from a set of 15 leaves a + (2/3)^100 chance that an item will not be chosen. Multiplied by 15 + items, there is a 1 in 2.6e16 chance that at least 1 item will not + show up in the resulting output. Using a random seed will fix that. + + """ + items = range(15) + item_set = set(items) + all_items = set() + for _ in range(100): + permutation = mi.random_permutation(items, 5) + self.assertEqual(len(permutation), 5) + permutation_set = set(permutation) + self.assertLessEqual(permutation_set, item_set) + all_items |= permutation_set + self.assertEqual(all_items, item_set) + + +class RandomCombinationTests(TestCase): + """Tests for ``random_combination()``""" + + def test_psuedorandomness(self): + """ensure different subsets of the iterable get returned over many + samplings of random combinations""" + items = range(15) + all_items = set() + for _ in range(50): + combination = mi.random_combination(items, 5) + all_items |= set(combination) + self.assertEqual(all_items, set(items)) + + def test_no_replacement(self): + """ensure that elements are sampled without replacement""" + items = range(15) + for _ in range(50): + combination = mi.random_combination(items, len(items)) + self.assertEqual(len(combination), len(set(combination))) + self.assertRaises( + ValueError, lambda: mi.random_combination(items, len(items) + 1) + ) + + +class RandomCombinationWithReplacementTests(TestCase): + """Tests for ``random_combination_with_replacement()``""" + + def test_replacement(self): + """ensure that elements are sampled with replacement""" + items = range(5) + combo = mi.random_combination_with_replacement(items, len(items) * 2) + self.assertEqual(2 * len(items), len(combo)) + if len(set(combo)) == len(combo): + raise AssertionError("Combination contained no duplicates") + + def test_pseudorandomness(self): + """ensure different subsets of the iterable get returned over many + samplings of random combinations""" + items = range(15) + all_items = set() + for _ in range(50): + combination = mi.random_combination_with_replacement(items, 5) + all_items |= set(combination) + self.assertEqual(all_items, set(items)) + + +class NthCombinationTests(TestCase): + def test_basic(self): + iterable = 'abcdefg' + r = 4 + for index, expected in enumerate(combinations(iterable, r)): + actual = mi.nth_combination(iterable, r, index) + self.assertEqual(actual, expected) + + def test_long(self): + actual = mi.nth_combination(range(180), 4, 2000000) + expected = (2, 12, 35, 126) + self.assertEqual(actual, expected) + + def test_invalid_r(self): + for r in (-1, 3): + with self.assertRaises(ValueError): + mi.nth_combination([], r, 0) + + def test_invalid_index(self): + with self.assertRaises(IndexError): + mi.nth_combination('abcdefg', 3, -36) + + +class PrependTests(TestCase): + def test_basic(self): + value = 'a' + iterator = iter('bcdefg') + actual = list(mi.prepend(value, iterator)) + expected = list('abcdefg') + self.assertEqual(actual, expected) + + def test_multiple(self): + value = 'ab' + iterator = iter('cdefg') + actual = tuple(mi.prepend(value, iterator)) + expected = ('ab',) + tuple('cdefg') + self.assertEqual(actual, expected) diff --git a/libraries/portend.py b/libraries/portend.py new file mode 100644 index 00000000..4c393806 --- /dev/null +++ b/libraries/portend.py @@ -0,0 +1,212 @@ +# -*- coding: utf-8 -*- + +""" +A simple library for managing the availability of ports. +""" + +from __future__ import print_function, division + +import time +import socket +import argparse +import sys +import itertools +import contextlib +import collections +import platform + +from tempora import timing + + +def client_host(server_host): + """Return the host on which a client can connect to the given listener.""" + if server_host == '0.0.0.0': + # 0.0.0.0 is INADDR_ANY, which should answer on localhost. + return '127.0.0.1' + if server_host in ('::', '::0', '::0.0.0.0'): + # :: is IN6ADDR_ANY, which should answer on localhost. + # ::0 and ::0.0.0.0 are non-canonical but common + # ways to write IN6ADDR_ANY. + return '::1' + return server_host + + +class Checker(object): + def __init__(self, timeout=1.0): + self.timeout = timeout + + def assert_free(self, host, port=None): + """ + Assert that the given addr is free + in that all attempts to connect fail within the timeout + or raise a PortNotFree exception. + + >>> free_port = find_available_local_port() + + >>> Checker().assert_free('localhost', free_port) + >>> Checker().assert_free('127.0.0.1', free_port) + >>> Checker().assert_free('::1', free_port) + + Also accepts an addr tuple + + >>> addr = '::1', free_port, 0, 0 + >>> Checker().assert_free(addr) + + Host might refer to a server bind address like '::', which + should use localhost to perform the check. + + >>> Checker().assert_free('::', free_port) + """ + if port is None and isinstance(host, collections.Sequence): + host, port = host[:2] + if platform.system() == 'Windows': + host = client_host(host) + info = socket.getaddrinfo( + host, port, socket.AF_UNSPEC, socket.SOCK_STREAM, + ) + list(itertools.starmap(self._connect, info)) + + def _connect(self, af, socktype, proto, canonname, sa): + s = socket.socket(af, socktype, proto) + # fail fast with a small timeout + s.settimeout(self.timeout) + + with contextlib.closing(s): + try: + s.connect(sa) + except socket.error: + return + + # the connect succeeded, so the port isn't free + port, host = sa[:2] + tmpl = "Port {port} is in use on {host}." + raise PortNotFree(tmpl.format(**locals())) + + +class Timeout(IOError): + pass + + +class PortNotFree(IOError): + pass + + +def free(host, port, timeout=float('Inf')): + """ + Wait for the specified port to become free (dropping or rejecting + requests). Return when the port is free or raise a Timeout if timeout has + elapsed. + + Timeout may be specified in seconds or as a timedelta. + If timeout is None or ∞, the routine will run indefinitely. + + >>> free('localhost', find_available_local_port()) + """ + if not host: + raise ValueError("Host values of '' or None are not allowed.") + + timer = timing.Timer(timeout) + + while not timer.expired(): + try: + # Expect a free port, so use a small timeout + Checker(timeout=0.1).assert_free(host, port) + return + except PortNotFree: + # Politely wait. + time.sleep(0.1) + + raise Timeout("Port {port} not free on {host}.".format(**locals())) +wait_for_free_port = free + + +def occupied(host, port, timeout=float('Inf')): + """ + Wait for the specified port to become occupied (accepting requests). + Return when the port is occupied or raise a Timeout if timeout has + elapsed. + + Timeout may be specified in seconds or as a timedelta. + If timeout is None or ∞, the routine will run indefinitely. + + >>> occupied('localhost', find_available_local_port(), .1) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + Timeout: Port ... not bound on localhost. + """ + if not host: + raise ValueError("Host values of '' or None are not allowed.") + + timer = timing.Timer(timeout) + + while not timer.expired(): + try: + Checker(timeout=.5).assert_free(host, port) + # Politely wait + time.sleep(0.1) + except PortNotFree: + # port is occupied + return + + raise Timeout("Port {port} not bound on {host}.".format(**locals())) +wait_for_occupied_port = occupied + + +def find_available_local_port(): + """ + Find a free port on localhost. + + >>> 0 < find_available_local_port() < 65536 + True + """ + sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) + addr = '', 0 + sock.bind(addr) + addr, port = sock.getsockname()[:2] + sock.close() + return port + + +class HostPort(str): + """ + A simple representation of a host/port pair as a string + + >>> hp = HostPort('localhost:32768') + + >>> hp.host + 'localhost' + + >>> hp.port + 32768 + + >>> len(hp) + 15 + """ + + @property + def host(self): + host, sep, port = self.partition(':') + return host + + @property + def port(self): + host, sep, port = self.partition(':') + return int(port) + + +def _main(): + parser = argparse.ArgumentParser() + global_lookup = lambda key: globals()[key] + parser.add_argument('target', metavar='host:port', type=HostPort) + parser.add_argument('func', metavar='state', type=global_lookup) + parser.add_argument('-t', '--timeout', default=None, type=float) + args = parser.parse_args() + try: + args.func(args.target.host, args.target.port, timeout=args.timeout) + except Timeout as timeout: + print(timeout, file=sys.stderr) + raise SystemExit(1) + + +if __name__ == '__main__': + _main() diff --git a/libraries/tempora/__init__.py b/libraries/tempora/__init__.py new file mode 100644 index 00000000..e0cdead0 --- /dev/null +++ b/libraries/tempora/__init__.py @@ -0,0 +1,505 @@ +# -*- coding: UTF-8 -*- + +"Objects and routines pertaining to date and time (tempora)" + +from __future__ import division, unicode_literals + +import datetime +import time +import re +import numbers +import functools + +import six + +__metaclass__ = type + + +class Parser: + """ + Datetime parser: parses a date-time string using multiple possible + formats. + + >>> p = Parser(('%H%M', '%H:%M')) + >>> tuple(p.parse('1319')) + (1900, 1, 1, 13, 19, 0, 0, 1, -1) + >>> dateParser = Parser(('%m/%d/%Y', '%Y-%m-%d', '%d-%b-%Y')) + >>> tuple(dateParser.parse('2003-12-20')) + (2003, 12, 20, 0, 0, 0, 5, 354, -1) + >>> tuple(dateParser.parse('16-Dec-1994')) + (1994, 12, 16, 0, 0, 0, 4, 350, -1) + >>> tuple(dateParser.parse('5/19/2003')) + (2003, 5, 19, 0, 0, 0, 0, 139, -1) + >>> dtParser = Parser(('%Y-%m-%d %H:%M:%S', '%a %b %d %H:%M:%S %Y')) + >>> tuple(dtParser.parse('2003-12-20 19:13:26')) + (2003, 12, 20, 19, 13, 26, 5, 354, -1) + >>> tuple(dtParser.parse('Tue Jan 20 16:19:33 2004')) + (2004, 1, 20, 16, 19, 33, 1, 20, -1) + + Be forewarned, a ValueError will be raised if more than one format + matches: + + >>> Parser(('%H%M', '%H%M%S')).parse('732') + Traceback (most recent call last): + ... + ValueError: More than one format string matched target 732. + """ + + formats = ('%m/%d/%Y', '%m/%d/%y', '%Y-%m-%d', '%d-%b-%Y', '%d-%b-%y') + "some common default formats" + + def __init__(self, formats=None): + if formats: + self.formats = formats + + def parse(self, target): + self.target = target + results = tuple(filter(None, map(self._parse, self.formats))) + del self.target + if not results: + tmpl = "No format strings matched the target {target}." + raise ValueError(tmpl.format(**locals())) + if not len(results) == 1: + tmpl = "More than one format string matched target {target}." + raise ValueError(tmpl.format(**locals())) + return results[0] + + def _parse(self, format): + try: + result = time.strptime(self.target, format) + except ValueError: + result = False + return result + + +# some useful constants +osc_per_year = 290091329207984000 +""" +mean vernal equinox year expressed in oscillations of atomic cesium at the +year 2000 (see http://webexhibits.org/calendars/timeline.html for more info). +""" +osc_per_second = 9192631770 +seconds_per_second = 1 +seconds_per_year = 31556940 +seconds_per_minute = 60 +minutes_per_hour = 60 +hours_per_day = 24 +seconds_per_hour = seconds_per_minute * minutes_per_hour +seconds_per_day = seconds_per_hour * hours_per_day +days_per_year = seconds_per_year / seconds_per_day +thirty_days = datetime.timedelta(days=30) +# these values provide useful averages +six_months = datetime.timedelta(days=days_per_year / 2) +seconds_per_month = seconds_per_year / 12 +hours_per_month = hours_per_day * days_per_year / 12 + + +def strftime(fmt, t): + """A class to replace the strftime in datetime package or time module. + Identical to strftime behavior in those modules except supports any + year. + Also supports datetime.datetime times. + Also supports milliseconds using %s + Also supports microseconds using %u""" + if isinstance(t, (time.struct_time, tuple)): + t = datetime.datetime(*t[:6]) + assert isinstance(t, (datetime.datetime, datetime.time, datetime.date)) + try: + year = t.year + if year < 1900: + t = t.replace(year=1900) + except AttributeError: + year = 1900 + subs = ( + ('%Y', '%04d' % year), + ('%y', '%02d' % (year % 100)), + ('%s', '%03d' % (t.microsecond // 1000)), + ('%u', '%03d' % (t.microsecond % 1000)) + ) + + def doSub(s, sub): + return s.replace(*sub) + + def doSubs(s): + return functools.reduce(doSub, subs, s) + + fmt = '%%'.join(map(doSubs, fmt.split('%%'))) + return t.strftime(fmt) + + +def strptime(s, fmt, tzinfo=None): + """ + A function to replace strptime in the time module. Should behave + identically to the strptime function except it returns a datetime.datetime + object instead of a time.struct_time object. + Also takes an optional tzinfo parameter which is a time zone info object. + """ + res = time.strptime(s, fmt) + return datetime.datetime(tzinfo=tzinfo, *res[:6]) + + +class DatetimeConstructor: + """ + >>> cd = DatetimeConstructor.construct_datetime + >>> cd(datetime.datetime(2011,1,1)) + datetime.datetime(2011, 1, 1, 0, 0) + """ + @classmethod + def construct_datetime(cls, *args, **kwargs): + """Construct a datetime.datetime from a number of different time + types found in python and pythonwin""" + if len(args) == 1: + arg = args[0] + method = cls.__get_dt_constructor( + type(arg).__module__, + type(arg).__name__, + ) + result = method(arg) + try: + result = result.replace(tzinfo=kwargs.pop('tzinfo')) + except KeyError: + pass + if kwargs: + first_key = kwargs.keys()[0] + tmpl = ( + "{first_key} is an invalid keyword " + "argument for this function." + ) + raise TypeError(tmpl.format(**locals())) + else: + result = datetime.datetime(*args, **kwargs) + return result + + @classmethod + def __get_dt_constructor(cls, moduleName, name): + try: + method_name = '__dt_from_{moduleName}_{name}__'.format(**locals()) + return getattr(cls, method_name) + except AttributeError: + tmpl = ( + "No way to construct datetime.datetime from " + "{moduleName}.{name}" + ) + raise TypeError(tmpl.format(**locals())) + + @staticmethod + def __dt_from_datetime_datetime__(source): + dtattrs = ( + 'year', 'month', 'day', 'hour', 'minute', 'second', + 'microsecond', 'tzinfo', + ) + attrs = map(lambda a: getattr(source, a), dtattrs) + return datetime.datetime(*attrs) + + @staticmethod + def __dt_from___builtin___time__(pyt): + "Construct a datetime.datetime from a pythonwin time" + fmtString = '%Y-%m-%d %H:%M:%S' + result = strptime(pyt.Format(fmtString), fmtString) + # get milliseconds and microseconds. The only way to do this is + # to use the __float__ attribute of the time, which is in days. + microseconds_per_day = seconds_per_day * 1000000 + microseconds = float(pyt) * microseconds_per_day + microsecond = int(microseconds % 1000000) + result = result.replace(microsecond=microsecond) + return result + + @staticmethod + def __dt_from_timestamp__(timestamp): + return datetime.datetime.utcfromtimestamp(timestamp) + __dt_from___builtin___float__ = __dt_from_timestamp__ + __dt_from___builtin___long__ = __dt_from_timestamp__ + __dt_from___builtin___int__ = __dt_from_timestamp__ + + @staticmethod + def __dt_from_time_struct_time__(s): + return datetime.datetime(*s[:6]) + + +def datetime_mod(dt, period, start=None): + """ + Find the time which is the specified date/time truncated to the time delta + relative to the start date/time. + By default, the start time is midnight of the same day as the specified + date/time. + + >>> datetime_mod(datetime.datetime(2004, 1, 2, 3), + ... datetime.timedelta(days = 1.5), + ... start = datetime.datetime(2004, 1, 1)) + datetime.datetime(2004, 1, 1, 0, 0) + >>> datetime_mod(datetime.datetime(2004, 1, 2, 13), + ... datetime.timedelta(days = 1.5), + ... start = datetime.datetime(2004, 1, 1)) + datetime.datetime(2004, 1, 2, 12, 0) + >>> datetime_mod(datetime.datetime(2004, 1, 2, 13), + ... datetime.timedelta(days = 7), + ... start = datetime.datetime(2004, 1, 1)) + datetime.datetime(2004, 1, 1, 0, 0) + >>> datetime_mod(datetime.datetime(2004, 1, 10, 13), + ... datetime.timedelta(days = 7), + ... start = datetime.datetime(2004, 1, 1)) + datetime.datetime(2004, 1, 8, 0, 0) + """ + if start is None: + # use midnight of the same day + start = datetime.datetime.combine(dt.date(), datetime.time()) + # calculate the difference between the specified time and the start date. + delta = dt - start + + # now aggregate the delta and the period into microseconds + # Use microseconds because that's the highest precision of these time + # pieces. Also, using microseconds ensures perfect precision (no floating + # point errors). + def get_time_delta_microseconds(td): + return (td.days * seconds_per_day + td.seconds) * 1000000 + td.microseconds + delta, period = map(get_time_delta_microseconds, (delta, period)) + offset = datetime.timedelta(microseconds=delta % period) + # the result is the original specified time minus the offset + result = dt - offset + return result + + +def datetime_round(dt, period, start=None): + """ + Find the nearest even period for the specified date/time. + + >>> datetime_round(datetime.datetime(2004, 11, 13, 8, 11, 13), + ... datetime.timedelta(hours = 1)) + datetime.datetime(2004, 11, 13, 8, 0) + >>> datetime_round(datetime.datetime(2004, 11, 13, 8, 31, 13), + ... datetime.timedelta(hours = 1)) + datetime.datetime(2004, 11, 13, 9, 0) + >>> datetime_round(datetime.datetime(2004, 11, 13, 8, 30), + ... datetime.timedelta(hours = 1)) + datetime.datetime(2004, 11, 13, 9, 0) + """ + result = datetime_mod(dt, period, start) + if abs(dt - result) >= period // 2: + result += period + return result + + +def get_nearest_year_for_day(day): + """ + Returns the nearest year to now inferred from a Julian date. + """ + now = time.gmtime() + result = now.tm_year + # if the day is far greater than today, it must be from last year + if day - now.tm_yday > 365 // 2: + result -= 1 + # if the day is far less than today, it must be for next year. + if now.tm_yday - day > 365 // 2: + result += 1 + return result + + +def gregorian_date(year, julian_day): + """ + Gregorian Date is defined as a year and a julian day (1-based + index into the days of the year). + + >>> gregorian_date(2007, 15) + datetime.date(2007, 1, 15) + """ + result = datetime.date(year, 1, 1) + result += datetime.timedelta(days=julian_day - 1) + return result + + +def get_period_seconds(period): + """ + return the number of seconds in the specified period + + >>> get_period_seconds('day') + 86400 + >>> get_period_seconds(86400) + 86400 + >>> get_period_seconds(datetime.timedelta(hours=24)) + 86400 + >>> get_period_seconds('day + os.system("rm -Rf *")') + Traceback (most recent call last): + ... + ValueError: period not in (second, minute, hour, day, month, year) + """ + if isinstance(period, six.string_types): + try: + name = 'seconds_per_' + period.lower() + result = globals()[name] + except KeyError: + msg = "period not in (second, minute, hour, day, month, year)" + raise ValueError(msg) + elif isinstance(period, numbers.Number): + result = period + elif isinstance(period, datetime.timedelta): + result = period.days * get_period_seconds('day') + period.seconds + else: + raise TypeError('period must be a string or integer') + return result + + +def get_date_format_string(period): + """ + For a given period (e.g. 'month', 'day', or some numeric interval + such as 3600 (in secs)), return the format string that can be + used with strftime to format that time to specify the times + across that interval, but no more detailed. + For example, + + >>> get_date_format_string('month') + '%Y-%m' + >>> get_date_format_string(3600) + '%Y-%m-%d %H' + >>> get_date_format_string('hour') + '%Y-%m-%d %H' + >>> get_date_format_string(None) + Traceback (most recent call last): + ... + TypeError: period must be a string or integer + >>> get_date_format_string('garbage') + Traceback (most recent call last): + ... + ValueError: period not in (second, minute, hour, day, month, year) + """ + # handle the special case of 'month' which doesn't have + # a static interval in seconds + if isinstance(period, six.string_types) and period.lower() == 'month': + return '%Y-%m' + file_period_secs = get_period_seconds(period) + format_pieces = ('%Y', '-%m-%d', ' %H', '-%M', '-%S') + seconds_per_second = 1 + intervals = ( + seconds_per_year, + seconds_per_day, + seconds_per_hour, + seconds_per_minute, + seconds_per_second, + ) + mods = list(map(lambda interval: file_period_secs % interval, intervals)) + format_pieces = format_pieces[: mods.index(0) + 1] + return ''.join(format_pieces) + + +def divide_timedelta_float(td, divisor): + """ + Divide a timedelta by a float value + + >>> one_day = datetime.timedelta(days=1) + >>> half_day = datetime.timedelta(days=.5) + >>> divide_timedelta_float(one_day, 2.0) == half_day + True + >>> divide_timedelta_float(one_day, 2) == half_day + True + """ + # td is comprised of days, seconds, microseconds + dsm = [getattr(td, attr) for attr in ('days', 'seconds', 'microseconds')] + dsm = map(lambda elem: elem / divisor, dsm) + return datetime.timedelta(*dsm) + + +def calculate_prorated_values(): + """ + A utility function to prompt for a rate (a string in units per + unit time), and return that same rate for various time periods. + """ + rate = six.moves.input("Enter the rate (3/hour, 50/month)> ") + res = re.match('(?P<value>[\d.]+)/(?P<period>\w+)$', rate).groupdict() + value = float(res['value']) + value_per_second = value / get_period_seconds(res['period']) + for period in ('minute', 'hour', 'day', 'month', 'year'): + period_value = value_per_second * get_period_seconds(period) + print("per {period}: {period_value}".format(**locals())) + + +def parse_timedelta(str): + """ + Take a string representing a span of time and parse it to a time delta. + Accepts any string of comma-separated numbers each with a unit indicator. + + >>> parse_timedelta('1 day') + datetime.timedelta(days=1) + + >>> parse_timedelta('1 day, 30 seconds') + datetime.timedelta(days=1, seconds=30) + + >>> parse_timedelta('47.32 days, 20 minutes, 15.4 milliseconds') + datetime.timedelta(days=47, seconds=28848, microseconds=15400) + + Supports weeks, months, years + + >>> parse_timedelta('1 week') + datetime.timedelta(days=7) + + >>> parse_timedelta('1 year, 1 month') + datetime.timedelta(days=395, seconds=58685) + + Note that months and years strict intervals, not aligned + to a calendar: + + >>> now = datetime.datetime.now() + >>> later = now + parse_timedelta('1 year') + >>> later.replace(year=now.year) - now + datetime.timedelta(seconds=20940) + """ + deltas = (_parse_timedelta_part(part.strip()) for part in str.split(',')) + return sum(deltas, datetime.timedelta()) + + +def _parse_timedelta_part(part): + match = re.match('(?P<value>[\d.]+) (?P<unit>\w+)', part) + if not match: + msg = "Unable to parse {part!r} as a time delta".format(**locals()) + raise ValueError(msg) + unit = match.group('unit').lower() + if not unit.endswith('s'): + unit += 's' + value = float(match.group('value')) + if unit == 'months': + unit = 'years' + value = value / 12 + if unit == 'years': + unit = 'days' + value = value * days_per_year + return datetime.timedelta(**{unit: value}) + + +def divide_timedelta(td1, td2): + """ + Get the ratio of two timedeltas + + >>> one_day = datetime.timedelta(days=1) + >>> one_hour = datetime.timedelta(hours=1) + >>> divide_timedelta(one_hour, one_day) == 1 / 24 + True + """ + try: + return td1 / td2 + except TypeError: + # Python 3.2 gets division + # http://bugs.python.org/issue2706 + return td1.total_seconds() / td2.total_seconds() + + +def date_range(start=None, stop=None, step=None): + """ + Much like the built-in function range, but works with dates + + >>> range_items = date_range( + ... datetime.datetime(2005,12,21), + ... datetime.datetime(2005,12,25), + ... ) + >>> my_range = tuple(range_items) + >>> datetime.datetime(2005,12,21) in my_range + True + >>> datetime.datetime(2005,12,22) in my_range + True + >>> datetime.datetime(2005,12,25) in my_range + False + """ + if step is None: + step = datetime.timedelta(days=1) + if start is None: + start = datetime.datetime.now() + while start < stop: + yield start + start += step diff --git a/libraries/tempora/schedule.py b/libraries/tempora/schedule.py new file mode 100644 index 00000000..1ad093b2 --- /dev/null +++ b/libraries/tempora/schedule.py @@ -0,0 +1,202 @@ +# -*- coding: utf-8 -*- + +""" +Classes for calling functions a schedule. +""" + +from __future__ import absolute_import + +import datetime +import numbers +import abc +import bisect + +import pytz + +__metaclass__ = type + + +def now(): + """ + Provide the current timezone-aware datetime. + + A client may override this function to change the default behavior, + such as to use local time or timezone-naïve times. + """ + return datetime.datetime.utcnow().replace(tzinfo=pytz.utc) + + +def from_timestamp(ts): + """ + Convert a numeric timestamp to a timezone-aware datetime. + + A client may override this function to change the default behavior, + such as to use local time or timezone-naïve times. + """ + return datetime.datetime.utcfromtimestamp(ts).replace(tzinfo=pytz.utc) + + +class DelayedCommand(datetime.datetime): + """ + A command to be executed after some delay (seconds or timedelta). + """ + + @classmethod + def from_datetime(cls, other): + return cls( + other.year, other.month, other.day, other.hour, + other.minute, other.second, other.microsecond, + other.tzinfo, + ) + + @classmethod + def after(cls, delay, target): + if not isinstance(delay, datetime.timedelta): + delay = datetime.timedelta(seconds=delay) + due_time = now() + delay + cmd = cls.from_datetime(due_time) + cmd.delay = delay + cmd.target = target + return cmd + + @staticmethod + def _from_timestamp(input): + """ + If input is a real number, interpret it as a Unix timestamp + (seconds sinc Epoch in UTC) and return a timezone-aware + datetime object. Otherwise return input unchanged. + """ + if not isinstance(input, numbers.Real): + return input + return from_timestamp(input) + + @classmethod + def at_time(cls, at, target): + """ + Construct a DelayedCommand to come due at `at`, where `at` may be + a datetime or timestamp. + """ + at = cls._from_timestamp(at) + cmd = cls.from_datetime(at) + cmd.delay = at - now() + cmd.target = target + return cmd + + def due(self): + return now() >= self + + +class PeriodicCommand(DelayedCommand): + """ + Like a delayed command, but expect this command to run every delay + seconds. + """ + def _next_time(self): + """ + Add delay to self, localized + """ + return self._localize(self + self.delay) + + @staticmethod + def _localize(dt): + """ + Rely on pytz.localize to ensure new result honors DST. + """ + try: + tz = dt.tzinfo + return tz.localize(dt.replace(tzinfo=None)) + except AttributeError: + return dt + + def next(self): + cmd = self.__class__.from_datetime(self._next_time()) + cmd.delay = self.delay + cmd.target = self.target + return cmd + + def __setattr__(self, key, value): + if key == 'delay' and not value > datetime.timedelta(): + raise ValueError( + "A PeriodicCommand must have a positive, " + "non-zero delay." + ) + super(PeriodicCommand, self).__setattr__(key, value) + + +class PeriodicCommandFixedDelay(PeriodicCommand): + """ + Like a periodic command, but don't calculate the delay based on + the current time. Instead use a fixed delay following the initial + run. + """ + + @classmethod + def at_time(cls, at, delay, target): + at = cls._from_timestamp(at) + cmd = cls.from_datetime(at) + if isinstance(delay, numbers.Number): + delay = datetime.timedelta(seconds=delay) + cmd.delay = delay + cmd.target = target + return cmd + + @classmethod + def daily_at(cls, at, target): + """ + Schedule a command to run at a specific time each day. + """ + daily = datetime.timedelta(days=1) + # convert when to the next datetime matching this time + when = datetime.datetime.combine(datetime.date.today(), at) + if when < now(): + when += daily + return cls.at_time(cls._localize(when), daily, target) + + +class Scheduler: + """ + A rudimentary abstract scheduler accepting DelayedCommands + and dispatching them on schedule. + """ + def __init__(self): + self.queue = [] + + def add(self, command): + assert isinstance(command, DelayedCommand) + bisect.insort(self.queue, command) + + def run_pending(self): + while self.queue: + command = self.queue[0] + if not command.due(): + break + self.run(command) + if isinstance(command, PeriodicCommand): + self.add(command.next()) + del self.queue[0] + + @abc.abstractmethod + def run(self, command): + """ + Run the command + """ + + +class InvokeScheduler(Scheduler): + """ + Command targets are functions to be invoked on schedule. + """ + def run(self, command): + command.target() + + +class CallbackScheduler(Scheduler): + """ + Command targets are passed to a dispatch callable on schedule. + """ + def __init__(self, dispatch): + super(CallbackScheduler, self).__init__() + self.dispatch = dispatch + + def run(self, command): + self.dispatch(command.target) diff --git a/libraries/tempora/tests/test_schedule.py b/libraries/tempora/tests/test_schedule.py new file mode 100644 index 00000000..38eb8dc9 --- /dev/null +++ b/libraries/tempora/tests/test_schedule.py @@ -0,0 +1,118 @@ +import time +import random +import datetime + +import pytest +import pytz +import freezegun + +from tempora import schedule + +__metaclass__ = type + + +@pytest.fixture +def naive_times(monkeypatch): + monkeypatch.setattr( + 'irc.schedule.from_timestamp', + datetime.datetime.fromtimestamp) + monkeypatch.setattr('irc.schedule.now', datetime.datetime.now) + + +do_nothing = type(None) +try: + do_nothing() +except TypeError: + # Python 2 compat + def do_nothing(): + return None + + +def test_delayed_command_order(): + """ + delayed commands should be sorted by delay time + """ + delays = [random.randint(0, 99) for x in range(5)] + cmds = sorted([ + schedule.DelayedCommand.after(delay, do_nothing) + for delay in delays + ]) + assert [c.delay.seconds for c in cmds] == sorted(delays) + + +def test_periodic_command_delay(): + "A PeriodicCommand must have a positive, non-zero delay." + with pytest.raises(ValueError) as exc_info: + schedule.PeriodicCommand.after(0, None) + assert str(exc_info.value) == test_periodic_command_delay.__doc__ + + +def test_periodic_command_fixed_delay(): + """ + Test that we can construct a periodic command with a fixed initial + delay. + """ + fd = schedule.PeriodicCommandFixedDelay.at_time( + at=schedule.now(), + delay=datetime.timedelta(seconds=2), + target=lambda: None, + ) + assert fd.due() is True + assert fd.next().due() is False + + +class TestCommands: + def test_delayed_command_from_timestamp(self): + """ + Ensure a delayed command can be constructed from a timestamp. + """ + t = time.time() + schedule.DelayedCommand.at_time(t, do_nothing) + + def test_command_at_noon(self): + """ + Create a periodic command that's run at noon every day. + """ + when = datetime.time(12, 0, tzinfo=pytz.utc) + cmd = schedule.PeriodicCommandFixedDelay.daily_at(when, target=None) + assert cmd.due() is False + next_cmd = cmd.next() + daily = datetime.timedelta(days=1) + day_from_now = schedule.now() + daily + two_days_from_now = day_from_now + daily + assert day_from_now < next_cmd < two_days_from_now + + +class TestTimezones: + def test_alternate_timezone_west(self): + target_tz = pytz.timezone('US/Pacific') + target = schedule.now().astimezone(target_tz) + cmd = schedule.DelayedCommand.at_time(target, target=None) + assert cmd.due() + + def test_alternate_timezone_east(self): + target_tz = pytz.timezone('Europe/Amsterdam') + target = schedule.now().astimezone(target_tz) + cmd = schedule.DelayedCommand.at_time(target, target=None) + assert cmd.due() + + def test_daylight_savings(self): + """ + A command at 9am should always be 9am regardless of + a DST boundary. + """ + with freezegun.freeze_time('2018-03-10 08:00:00'): + target_tz = pytz.timezone('US/Eastern') + target_time = datetime.time(9, tzinfo=target_tz) + cmd = schedule.PeriodicCommandFixedDelay.daily_at( + target_time, + target=lambda: None, + ) + + def naive(dt): + return dt.replace(tzinfo=None) + + assert naive(cmd) == datetime.datetime(2018, 3, 10, 9, 0, 0) + next_ = cmd.next() + assert naive(next_) == datetime.datetime(2018, 3, 11, 9, 0, 0) + assert next_ - cmd == datetime.timedelta(hours=23) diff --git a/libraries/tempora/timing.py b/libraries/tempora/timing.py new file mode 100644 index 00000000..03c22454 --- /dev/null +++ b/libraries/tempora/timing.py @@ -0,0 +1,219 @@ +# -*- coding: utf-8 -*- + +from __future__ import unicode_literals, absolute_import + +import datetime +import functools +import numbers +import time + +__metaclass__ = type + + +class Stopwatch: + """ + A simple stopwatch which starts automatically. + + >>> w = Stopwatch() + >>> _1_sec = datetime.timedelta(seconds=1) + >>> w.split() < _1_sec + True + >>> import time + >>> time.sleep(1.0) + >>> w.split() >= _1_sec + True + >>> w.stop() >= _1_sec + True + >>> w.reset() + >>> w.start() + >>> w.split() < _1_sec + True + + It should be possible to launch the Stopwatch in a context: + + >>> with Stopwatch() as watch: + ... assert isinstance(watch.split(), datetime.timedelta) + + In that case, the watch is stopped when the context is exited, + so to read the elapsed time:: + + >>> watch.elapsed + datetime.timedelta(...) + >>> watch.elapsed.seconds + 0 + """ + def __init__(self): + self.reset() + self.start() + + def reset(self): + self.elapsed = datetime.timedelta(0) + if hasattr(self, 'start_time'): + del self.start_time + + def start(self): + self.start_time = datetime.datetime.utcnow() + + def stop(self): + stop_time = datetime.datetime.utcnow() + self.elapsed += stop_time - self.start_time + del self.start_time + return self.elapsed + + def split(self): + local_duration = datetime.datetime.utcnow() - self.start_time + return self.elapsed + local_duration + + # context manager support + def __enter__(self): + self.start() + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.stop() + + +class IntervalGovernor: + """ + Decorate a function to only allow it to be called once per + min_interval. Otherwise, it returns None. + """ + def __init__(self, min_interval): + if isinstance(min_interval, numbers.Number): + min_interval = datetime.timedelta(seconds=min_interval) + self.min_interval = min_interval + self.last_call = None + + def decorate(self, func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + allow = ( + not self.last_call + or self.last_call.split() > self.min_interval + ) + if allow: + self.last_call = Stopwatch() + return func(*args, **kwargs) + return wrapper + + __call__ = decorate + + +class Timer(Stopwatch): + """ + Watch for a target elapsed time. + + >>> t = Timer(0.1) + >>> t.expired() + False + >>> __import__('time').sleep(0.15) + >>> t.expired() + True + """ + def __init__(self, target=float('Inf')): + self.target = self._accept(target) + super(Timer, self).__init__() + + def _accept(self, target): + "Accept None or ∞ or datetime or numeric for target" + if isinstance(target, datetime.timedelta): + target = target.total_seconds() + + if target is None: + # treat None as infinite target + target = float('Inf') + + return target + + def expired(self): + return self.split().total_seconds() > self.target + + +class BackoffDelay: + """ + Exponential backoff delay. + + Useful for defining delays between retries. Consider for use + with ``jaraco.functools.retry_call`` as the cleanup. + + Default behavior has no effect; a delay or jitter must + be supplied for the call to be non-degenerate. + + >>> bd = BackoffDelay() + >>> bd() + >>> bd() + + The following instance will delay 10ms for the first call, + 20ms for the second, etc. + + >>> bd = BackoffDelay(delay=0.01, factor=2) + >>> bd() + >>> bd() + + Inspect and adjust the state of the delay anytime. + + >>> bd.delay + 0.04 + >>> bd.delay = 0.01 + + Set limit to prevent the delay from exceeding bounds. + + >>> bd = BackoffDelay(delay=0.01, factor=2, limit=0.015) + >>> bd() + >>> bd.delay + 0.015 + + Limit may be a callable taking a number and returning + the limited number. + + >>> at_least_one = lambda n: max(n, 1) + >>> bd = BackoffDelay(delay=0.01, factor=2, limit=at_least_one) + >>> bd() + >>> bd.delay + 1 + + Pass a jitter to add or subtract seconds to the delay. + + >>> bd = BackoffDelay(jitter=0.01) + >>> bd() + >>> bd.delay + 0.01 + + Jitter may be a callable. To supply a non-deterministic jitter + between -0.5 and 0.5, consider: + + >>> import random + >>> jitter=functools.partial(random.uniform, -0.5, 0.5) + >>> bd = BackoffDelay(jitter=jitter) + >>> bd() + >>> 0 <= bd.delay <= 0.5 + True + """ + + delay = 0 + + factor = 1 + "Multiplier applied to delay" + + jitter = 0 + "Number or callable returning extra seconds to add to delay" + + def __init__(self, delay=0, factor=1, limit=float('inf'), jitter=0): + self.delay = delay + self.factor = factor + if isinstance(limit, numbers.Number): + limit_ = limit + + def limit(n): + return max(0, min(limit_, n)) + self.limit = limit + if isinstance(jitter, numbers.Number): + jitter_ = jitter + + def jitter(): + return jitter_ + self.jitter = jitter + + def __call__(self): + time.sleep(self.delay) + self.delay = self.limit(self.delay * self.factor + self.jitter()) diff --git a/libraries/tempora/utc.py b/libraries/tempora/utc.py new file mode 100644 index 00000000..35bfdb06 --- /dev/null +++ b/libraries/tempora/utc.py @@ -0,0 +1,36 @@ +""" +Facilities for common time operations in UTC. + +Inspired by the `utc project <https://pypi.org/project/utc>`_. + +>>> dt = now() +>>> dt == fromtimestamp(dt.timestamp()) +True +>>> dt.tzinfo +datetime.timezone.utc + +>>> from time import time as timestamp +>>> now().timestamp() - timestamp() < 0.1 +True + +>>> datetime(2018, 6, 26, 0).tzinfo +datetime.timezone.utc + +>>> time(0, 0).tzinfo +datetime.timezone.utc +""" + +import datetime as std +import functools + + +__all__ = ['now', 'fromtimestamp', 'datetime', 'time'] + + +now = functools.partial(std.datetime.now, std.timezone.utc) +fromtimestamp = functools.partial( + std.datetime.fromtimestamp, + tz=std.timezone.utc, +) +datetime = functools.partial(std.datetime, tzinfo=std.timezone.utc) +time = functools.partial(std.time, tzinfo=std.timezone.utc) diff --git a/libraries/zc/__init__.py b/libraries/zc/__init__.py new file mode 100644 index 00000000..146c3362 --- /dev/null +++ b/libraries/zc/__init__.py @@ -0,0 +1 @@ +__namespace__ = 'zc' \ No newline at end of file diff --git a/libraries/zc/lockfile/README.txt b/libraries/zc/lockfile/README.txt new file mode 100644 index 00000000..89ef33e9 --- /dev/null +++ b/libraries/zc/lockfile/README.txt @@ -0,0 +1,70 @@ +Lock file support +================= + +The ZODB lock_file module provides support for creating file system +locks. These are locks that are implemented with lock files and +OS-provided locking facilities. To create a lock, instantiate a +LockFile object with a file name: + + >>> import zc.lockfile + >>> lock = zc.lockfile.LockFile('lock') + +If we try to lock the same name, we'll get a lock error: + + >>> import zope.testing.loggingsupport + >>> handler = zope.testing.loggingsupport.InstalledHandler('zc.lockfile') + >>> try: + ... zc.lockfile.LockFile('lock') + ... except zc.lockfile.LockError: + ... print("Can't lock file") + Can't lock file + +.. We don't log failure to acquire. + + >>> for record in handler.records: # doctest: +ELLIPSIS + ... print(record.levelname+' '+record.getMessage()) + +To release the lock, use it's close method: + + >>> lock.close() + +The lock file is not removed. It is left behind: + + >>> import os + >>> os.path.exists('lock') + True + +Of course, now that we've released the lock, we can create it again: + + >>> lock = zc.lockfile.LockFile('lock') + >>> lock.close() + +.. Cleanup + + >>> import os + >>> os.remove('lock') + +Hostname in lock file +===================== + +In a container environment (e.g. Docker), the PID is typically always +identical even if multiple containers are running under the same operating +system instance. + +Clearly, inspecting lock files doesn't then help much in debugging. To identify +the container which created the lock file, we need information about the +container in the lock file. Since Docker uses the container identifier or name +as the hostname, this information can be stored in the lock file in addition to +or instead of the PID. + +Use the ``content_template`` keyword argument to ``LockFile`` to specify a +custom lock file content format: + + >>> lock = zc.lockfile.LockFile('lock', content_template='{pid};{hostname}') + >>> lock.close() + +If you now inspected the lock file, you would see e.g.: + + $ cat lock + 123;myhostname + diff --git a/libraries/zc/lockfile/__init__.py b/libraries/zc/lockfile/__init__.py new file mode 100644 index 00000000..a0ac2ff1 --- /dev/null +++ b/libraries/zc/lockfile/__init__.py @@ -0,0 +1,104 @@ +############################################################################## +# +# Copyright (c) 2001, 2002 Zope Foundation and Contributors. +# All Rights Reserved. +# +# This software is subject to the provisions of the Zope Public License, +# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. +# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED +# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS +# FOR A PARTICULAR PURPOSE +# +############################################################################## + +import os +import errno +import logging +logger = logging.getLogger("zc.lockfile") + +class LockError(Exception): + """Couldn't get a lock + """ + +try: + import fcntl +except ImportError: + try: + import msvcrt + except ImportError: + def _lock_file(file): + raise TypeError('No file-locking support on this platform') + def _unlock_file(file): + raise TypeError('No file-locking support on this platform') + + else: + # Windows + def _lock_file(file): + # Lock just the first byte + try: + msvcrt.locking(file.fileno(), msvcrt.LK_NBLCK, 1) + except IOError: + raise LockError("Couldn't lock %r" % file.name) + + def _unlock_file(file): + try: + file.seek(0) + msvcrt.locking(file.fileno(), msvcrt.LK_UNLCK, 1) + except IOError: + raise LockError("Couldn't unlock %r" % file.name) + +else: + # Unix + _flags = fcntl.LOCK_EX | fcntl.LOCK_NB + + def _lock_file(file): + try: + fcntl.flock(file.fileno(), _flags) + except IOError: + raise LockError("Couldn't lock %r" % file.name) + + def _unlock_file(file): + fcntl.flock(file.fileno(), fcntl.LOCK_UN) + +class LazyHostName(object): + """Avoid importing socket and calling gethostname() unnecessarily""" + def __str__(self): + import socket + return socket.gethostname() + + +class LockFile: + + _fp = None + + def __init__(self, path, content_template='{pid}'): + self._path = path + try: + # Try to open for writing without truncation: + fp = open(path, 'r+') + except IOError: + # If the file doesn't exist, we'll get an IO error, try a+ + # Note that there may be a race here. Multiple processes + # could fail on the r+ open and open the file a+, but only + # one will get the the lock and write a pid. + fp = open(path, 'a+') + + try: + _lock_file(fp) + except: + fp.close() + raise + + # We got the lock, record info in the file. + self._fp = fp + fp.write(" %s\n" % content_template.format(pid=os.getpid(), + hostname=LazyHostName())) + fp.truncate() + fp.flush() + + def close(self): + if self._fp is not None: + _unlock_file(self._fp) + self._fp.close() + self._fp = None diff --git a/libraries/zc/lockfile/tests.py b/libraries/zc/lockfile/tests.py new file mode 100644 index 00000000..e9fcbff3 --- /dev/null +++ b/libraries/zc/lockfile/tests.py @@ -0,0 +1,193 @@ +############################################################################## +# +# Copyright (c) 2004 Zope Foundation and Contributors. +# All Rights Reserved. +# +# This software is subject to the provisions of the Zope Public License, +# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. +# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED +# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS +# FOR A PARTICULAR PURPOSE. +# +############################################################################## +import os, re, sys, unittest, doctest +import zc.lockfile, time, threading +from zope.testing import renormalizing, setupstack +import tempfile +try: + from unittest.mock import Mock, patch +except ImportError: + from mock import Mock, patch + +checker = renormalizing.RENormalizing([ + # Python 3 adds module path to error class name. + (re.compile("zc\.lockfile\.LockError:"), + r"LockError:"), + ]) + +def inc(): + while 1: + try: + lock = zc.lockfile.LockFile('f.lock') + except zc.lockfile.LockError: + continue + else: + break + f = open('f', 'r+b') + v = int(f.readline().strip()) + time.sleep(0.01) + v += 1 + f.seek(0) + f.write(('%d\n' % v).encode('ASCII')) + f.close() + lock.close() + +def many_threads_read_and_write(): + r""" + >>> with open('f', 'w+b') as file: + ... _ = file.write(b'0\n') + >>> with open('f.lock', 'w+b') as file: + ... _ = file.write(b'0\n') + + >>> n = 50 + >>> threads = [threading.Thread(target=inc) for i in range(n)] + >>> _ = [thread.start() for thread in threads] + >>> _ = [thread.join() for thread in threads] + >>> with open('f', 'rb') as file: + ... saved = int(file.read().strip()) + >>> saved == n + True + + >>> os.remove('f') + + We should only have one pid in the lock file: + + >>> f = open('f.lock') + >>> len(f.read().strip().split()) + 1 + >>> f.close() + + >>> os.remove('f.lock') + + """ + +def pid_in_lockfile(): + r""" + >>> import os, zc.lockfile + >>> pid = os.getpid() + >>> lock = zc.lockfile.LockFile("f.lock") + >>> f = open("f.lock") + >>> _ = f.seek(1) + >>> f.read().strip() == str(pid) + True + >>> f.close() + + Make sure that locking twice does not overwrite the old pid: + + >>> lock = zc.lockfile.LockFile("f.lock") + Traceback (most recent call last): + ... + LockError: Couldn't lock 'f.lock' + + >>> f = open("f.lock") + >>> _ = f.seek(1) + >>> f.read().strip() == str(pid) + True + >>> f.close() + + >>> lock.close() + """ + + +def hostname_in_lockfile(): + r""" + hostname is correctly written into the lock file when it's included in the + lock file content template + + >>> import zc.lockfile + >>> with patch('socket.gethostname', Mock(return_value='myhostname')): + ... lock = zc.lockfile.LockFile("f.lock", content_template='{hostname}') + >>> f = open("f.lock") + >>> _ = f.seek(1) + >>> f.read().rstrip() + 'myhostname' + >>> f.close() + + Make sure that locking twice does not overwrite the old hostname: + + >>> lock = zc.lockfile.LockFile("f.lock", content_template='{hostname}') + Traceback (most recent call last): + ... + LockError: Couldn't lock 'f.lock' + + >>> f = open("f.lock") + >>> _ = f.seek(1) + >>> f.read().rstrip() + 'myhostname' + >>> f.close() + + >>> lock.close() + """ + + +class TestLogger(object): + def __init__(self): + self.log_entries = [] + + def exception(self, msg, *args): + self.log_entries.append((msg,) + args) + + +class LockFileLogEntryTestCase(unittest.TestCase): + """Tests for logging in case of lock failure""" + def setUp(self): + self.here = os.getcwd() + self.tmp = tempfile.mkdtemp(prefix='zc.lockfile-test-') + os.chdir(self.tmp) + + def tearDown(self): + os.chdir(self.here) + setupstack.rmtree(self.tmp) + + def test_log_formatting(self): + # PID and hostname are parsed and logged from lock file on failure + with patch('os.getpid', Mock(return_value=123)): + with patch('socket.gethostname', Mock(return_value='myhostname')): + lock = zc.lockfile.LockFile('f.lock', + content_template='{pid}/{hostname}') + with open('f.lock') as f: + self.assertEqual(' 123/myhostname\n', f.read()) + + lock.close() + + def test_unlock_and_lock_while_multiprocessing_process_running(self): + import multiprocessing + + lock = zc.lockfile.LockFile('l') + q = multiprocessing.Queue() + p = multiprocessing.Process(target=q.get) + p.daemon = True + p.start() + + # release and re-acquire should work (obviously) + lock.close() + lock = zc.lockfile.LockFile('l') + self.assertTrue(p.is_alive()) + + q.put(0) + lock.close() + p.join() + + +def test_suite(): + suite = unittest.TestSuite() + suite.addTest(doctest.DocFileSuite( + 'README.txt', checker=checker, + setUp=setupstack.setUpDirectory, tearDown=setupstack.tearDown)) + suite.addTest(doctest.DocTestSuite( + setUp=setupstack.setUpDirectory, tearDown=setupstack.tearDown, + checker=checker)) + # Add unittest test cases from this module + suite.addTest(unittest.defaultTestLoader.loadTestsFromName(__name__)) + return suite diff --git a/resources/lib/webservice.py b/resources/lib/webservice.py index 34826af0..7f692846 100644 --- a/resources/lib/webservice.py +++ b/resources/lib/webservice.py @@ -2,17 +2,14 @@ ################################################################################################# -import SimpleHTTPServer -import BaseHTTPServer import logging -import httplib import threading -import urlparse -import urllib import xbmc import xbmcvfs +import cherrypy + ################################################################################################# PORT = 57578 @@ -21,181 +18,59 @@ LOG = logging.getLogger("EMBY."+__name__) ################################################################################################# +class Root(object): + + @cherrypy.expose + def default(self, *args, **kwargs): + + try: + if not kwargs.get('Id').isdigit(): + raise IndexError("Incorrect Id format: %s" % kwargs.get('Id')) + + LOG.info("Webservice called with params: %s", kwargs) + + return ("plugin://plugin.video.emby?mode=play&id=%s&dbid=%s&filename=%s&transcode=%s" + % (kwargs.get('Id'), kwargs.get('KodiId'), kwargs.get('Name'), kwargs.get('transcode') or False)) + + except IndexError as error: + LOG.error(error) + + raise cherrypy.HTTPError(404, error) + + except Exception as error: + LOG.exception(error) + + raise cherrypy.HTTPError(500, "Exception occurred: %s" % error) + class WebService(threading.Thread): - ''' Run a webservice to trigger playback. - Inspired from script.skin.helper.service by marcelveldt. - ''' - stop_thread = False + root = None def __init__(self): + + self.root = Root() + cherrypy.config.update({ + 'engine.autoreload.on' : False, + 'log.screen': False, + 'engine.timeout_monitor.frequency': 5, + 'server.shutdown_timeout': 1, + }) threading.Thread.__init__(self) + def run(self): + + LOG.info("--->[ webservice/%s ]", PORT) + conf = { + 'global': { + 'server.socket_host': '0.0.0.0', + 'server.socket_port': PORT + }, '/': {} + } + cherrypy.quickstart(self.root, '/', conf) + def stop(self): - ''' Called when the thread needs to stop - ''' - try: - conn = httplib.HTTPConnection("127.0.0.1:%d" % PORT) - conn.request("QUIT", "/") - conn.getresponse() - self.stop_thread = True - except Exception as error: - LOG.exception(error) + cherrypy.engine.exit() + self.join(0) - def run(self): - - ''' Called to start the webservice. - ''' - LOG.info("--->[ webservice/%s ]", PORT) - - try: - server = StoppableHttpServer(('127.0.0.1', PORT), StoppableHttpRequestHandler) - server.serve_forever() - except Exception as error: - - if '10053' not in error: # ignore host diconnected errors - LOG.exception(error) - - LOG.info("---<[ webservice ]") - - -class Request(object): - - ''' Attributes from urlsplit that this class also sets - ''' - uri_attrs = ('scheme', 'netloc', 'path', 'query', 'fragment') - - def __init__(self, uri, headers, rfile=None): - - self.uri = uri - self.headers = headers - parsed = urlparse.urlsplit(uri) - - for i, attr in enumerate(self.uri_attrs): - setattr(self, attr, parsed[i]) - - try: - body_len = int(self.headers.get('Content-length', 0)) - except ValueError: - body_len = 0 - - self.body = rfile.read(body_len) if body_len and rfile else None - - -class StoppableHttpServer(BaseHTTPServer.HTTPServer): - - ''' Http server that reacts to self.stop flag. - ''' - def serve_forever(self): - - ''' Handle one request at a time until stopped. - ''' - self.stop = False - - while not self.stop: - - self.handle_request() - xbmc.sleep(100) - - -class StoppableHttpRequestHandler(SimpleHTTPServer.SimpleHTTPRequestHandler): - - ''' http request handler with QUIT stopping the server - ''' - raw_requestline = "" - - def __init__(self, request, client_address, server): - try: - SimpleHTTPServer.SimpleHTTPRequestHandler.__init__(self, request, client_address, server) - except Exception: - pass - - def log_message(self, format, *args): - - ''' Mute the webservice requests. - ''' - pass - - def do_QUIT(self): - - ''' send 200 OK response, and set server.stop to True - ''' - self.send_response(200) - self.end_headers() - self.server.stop = True - - def parse_request(self): - - ''' Modify here to workaround unencoded requests. - ''' - retval = SimpleHTTPServer.SimpleHTTPRequestHandler.parse_request(self) - self.request = Request(self.path, self.headers, self.rfile) - - return retval - - def do_HEAD(self): - - ''' Called on HEAD requests - ''' - self.handle_request(True) - - return - - def get_params(self): - - ''' Get the params - ''' - try: - path = self.path[1:] - - if '?' in path: - path = path.split('?', 1)[1] - - params = dict(urlparse.parse_qsl(path)) - except Exception: - params = {} - - return params - - def handle_request(self, headers_only=False): - - ''' Send headers and reponse - ''' - try: - params = self.get_params() - - if not params.get('Id').isdigit(): - raise IndexError("Incorrect Id format: %s" % params.get('Id')) - - LOG.info("Webservice called with params: %s", params) - - path = ("plugin://plugin.video.emby?mode=play&id=%s&dbid=%s&filename=%s&transcode=%s" - % (params.get('Id'), params.get('KodiId'), params.get('Name'), params.get('transcode') or False)) - - self.send_response(200) - self.send_header('Content-type', 'text/html') - self.send_header('Content-Length', len(path)) - self.end_headers() - - if not headers_only: - self.wfile.write(path) - except IndexError as error: - - LOG.error(error) - self.send_error(403) - - except Exception as error: - - LOG.exception(error) - self.send_error(500, "Exception occurred: %s" % error) - - return - - def do_GET(self): - - ''' Called on GET requests - ''' - self.handle_request() - - return + del self.root