# -*- coding: utf-8 -*- """ A simple library for managing the availability of ports. """ from __future__ import print_function, division import time import socket import argparse import sys import itertools import contextlib import collections import platform from tempora import timing def client_host(server_host): """Return the host on which a client can connect to the given listener.""" if server_host == '0.0.0.0': # 0.0.0.0 is INADDR_ANY, which should answer on localhost. return '127.0.0.1' if server_host in ('::', '::0', '::0.0.0.0'): # :: is IN6ADDR_ANY, which should answer on localhost. # ::0 and ::0.0.0.0 are non-canonical but common # ways to write IN6ADDR_ANY. return '::1' return server_host class Checker(object): def __init__(self, timeout=1.0): self.timeout = timeout def assert_free(self, host, port=None): """ Assert that the given addr is free in that all attempts to connect fail within the timeout or raise a PortNotFree exception. >>> free_port = find_available_local_port() >>> Checker().assert_free('localhost', free_port) >>> Checker().assert_free('127.0.0.1', free_port) >>> Checker().assert_free('::1', free_port) Also accepts an addr tuple >>> addr = '::1', free_port, 0, 0 >>> Checker().assert_free(addr) Host might refer to a server bind address like '::', which should use localhost to perform the check. >>> Checker().assert_free('::', free_port) """ if port is None and isinstance(host, collections.Sequence): host, port = host[:2] if platform.system() == 'Windows': host = client_host(host) info = socket.getaddrinfo( host, port, socket.AF_UNSPEC, socket.SOCK_STREAM, ) list(itertools.starmap(self._connect, info)) def _connect(self, af, socktype, proto, canonname, sa): s = socket.socket(af, socktype, proto) # fail fast with a small timeout s.settimeout(self.timeout) with contextlib.closing(s): try: s.connect(sa) except socket.error: return # the connect succeeded, so the port isn't free port, host = sa[:2] tmpl = "Port {port} is in use on {host}." raise PortNotFree(tmpl.format(**locals())) class Timeout(IOError): pass class PortNotFree(IOError): pass def free(host, port, timeout=float('Inf')): """ Wait for the specified port to become free (dropping or rejecting requests). Return when the port is free or raise a Timeout if timeout has elapsed. Timeout may be specified in seconds or as a timedelta. If timeout is None or ∞, the routine will run indefinitely. >>> free('localhost', find_available_local_port()) """ if not host: raise ValueError("Host values of '' or None are not allowed.") timer = timing.Timer(timeout) while not timer.expired(): try: # Expect a free port, so use a small timeout Checker(timeout=0.1).assert_free(host, port) return except PortNotFree: # Politely wait. time.sleep(0.1) raise Timeout("Port {port} not free on {host}.".format(**locals())) wait_for_free_port = free def occupied(host, port, timeout=float('Inf')): """ Wait for the specified port to become occupied (accepting requests). Return when the port is occupied or raise a Timeout if timeout has elapsed. Timeout may be specified in seconds or as a timedelta. If timeout is None or ∞, the routine will run indefinitely. >>> occupied('localhost', find_available_local_port(), .1) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... Timeout: Port ... not bound on localhost. """ if not host: raise ValueError("Host values of '' or None are not allowed.") timer = timing.Timer(timeout) while not timer.expired(): try: Checker(timeout=.5).assert_free(host, port) # Politely wait time.sleep(0.1) except PortNotFree: # port is occupied return raise Timeout("Port {port} not bound on {host}.".format(**locals())) wait_for_occupied_port = occupied def find_available_local_port(): """ Find a free port on localhost. >>> 0 < find_available_local_port() < 65536 True """ sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) addr = '', 0 sock.bind(addr) addr, port = sock.getsockname()[:2] sock.close() return port class HostPort(str): """ A simple representation of a host/port pair as a string >>> hp = HostPort('localhost:32768') >>> hp.host 'localhost' >>> hp.port 32768 >>> len(hp) 15 """ @property def host(self): host, sep, port = self.partition(':') return host @property def port(self): host, sep, port = self.partition(':') return int(port) def _main(): parser = argparse.ArgumentParser() global_lookup = lambda key: globals()[key] parser.add_argument('target', metavar='host:port', type=HostPort) parser.add_argument('func', metavar='state', type=global_lookup) parser.add_argument('-t', '--timeout', default=None, type=float) args = parser.parse_args() try: args.func(args.target.host, args.target.port, timeout=args.timeout) except Timeout as timeout: print(timeout, file=sys.stderr) raise SystemExit(1) if __name__ == '__main__': _main()