New hybrid method

This commit is contained in:
angelblue05 2018-09-06 03:36:32 -05:00
parent 7f5084c62e
commit ace50b34dc
279 changed files with 39526 additions and 19994 deletions

View file

@ -0,0 +1,118 @@
# -*- coding: utf-8 -*-
#################################################################################################
import logging
from client import EmbyClient
from helpers import has_attribute
#################################################################################################
class NullHandler(logging.Handler):
def emit(self, record):
print(self.format(record))
loghandler = NullHandler
LOG = logging.getLogger('Emby')
#################################################################################################
def config(level=logging.INFO):
logger = logging.getLogger('Emby')
logger.addHandler(Emby.loghandler())
logger.setLevel(level)
def ensure_client():
def decorator(func):
def wrapper(self, *args, **kwargs):
if self.client.get(self.server_id) is None:
self.construct()
return func(self, *args, **kwargs)
return wrapper
return decorator
class Emby(object):
''' This is your Embyclient, you can create more than one. The server_id is only a temporary thing.
from emby import Emby
default_client = Emby()['config/app']
another_client = Emby('123456')['config/app']
'''
# Borg - multiple instances, shared state
_shared_state = {}
client = {}
server_id = "default"
loghandler = loghandler
def __init__(self, server_id=None):
self.__dict__ = self._shared_state
self.server_id = server_id or "default"
@classmethod
def set_loghandler(cls, func=loghandler, level=logging.INFO):
for handler in logging.getLogger('Emby').handlers:
if isinstance(handler, cls.loghandler):
logging.getLogger('Emby').removeHandler(handler)
cls.loghandler = func
config(level)
def close(self):
if self.server_id not in self.client:
return
self.client[self.server_id].stop()
self.client.pop(self.server_id, None)
LOG.info("---[ STOPPED EMBYCLIENT: %s ]---", self.server_id)
@classmethod
def close_all(cls):
for client in cls.client:
cls.client[client].stop()
cls.client = {}
LOG.info("---[ STOPPED ALL EMBYCLIENTS ]---")
@classmethod
def get_active_clients(cls):
return cls.client
@ensure_client()
def __setattr__(self, name, value):
if has_attribute(self, name):
return super(Emby, self).__setattr__(name, value)
setattr(self.client[self.server_id], name, value)
@ensure_client()
def __getattr__(self, name):
return getattr(self.client[self.server_id], name)
@ensure_client()
def __getitem__(self, key):
return self.client[self.server_id][key]
def construct(self):
self.client[self.server_id] = EmbyClient()
if self.server_id == 'default':
LOG.info("---[ START EMBYCLIENT ]---")
else:
LOG.info("---[ START EMBYCLIENT: %s ]---", self.server_id)
config()

View file

@ -0,0 +1,107 @@
# -*- coding: utf-8 -*-
#################################################################################################
import logging
import core.api as api
from core.configuration import Config
from core.http import HTTP
from core.ws_client import WSClient
from core.connection_manager import ConnectionManager, CONNECTION_STATE
#################################################################################################
LOG = logging.getLogger('Emby.'+__name__)
#################################################################################################
def callback(message, data):
''' Callback function should received message, data
message: string
data: json dictionary
'''
pass
class EmbyClient(object):
logged_in = False
def __init__(self):
LOG.debug("EmbyClient initializing...")
self.config = Config()
self.http = HTTP(self)
self.wsc = WSClient(self)
self.auth = ConnectionManager(self)
self.emby = api
self.emby.client = self.http
self.callback_ws = callback
self.callback = callback
def set_credentials(self, credentials=None):
self.auth.credentials.set_credentials(credentials or {})
def get_credentials(self):
return self.auth.credentials.get_credentials()
def authenticate(self, credentials=None, options=None):
self.set_credentials(credentials or {})
state = self.auth.connect(options or {})
if state['State'] == CONNECTION_STATE['SignedIn']:
LOG.info("User is authenticated.")
self.logged_in = True
self.callback("ServerOnline", {'Id': self['auth/server-id']})
state['Credentials'] = self.get_credentials()
return state
def start(self, websocket=False):
if not self.logged_in:
raise ValueError("User is not authenticated.")
self.http.start_session()
if websocket:
self.start_wsc()
def start_wsc(self):
self.wsc.start()
def stop(self):
self.wsc.stop_client()
self.http.stop_session()
def __getitem__(self, key):
if key.startswith('config'):
return self.config[key.replace('config/', "", 1)] if "/" in key else self.config
elif key.startswith('http'):
return self.http.__shortcuts__(key.replace('http/', "", 1))
elif key.startswith('websocket'):
return self.wsc.__shortcuts__(key.replace('websocket/', "", 1))
elif key.startswith('callback'):
return self.callback_ws if 'ws' in key else self.callback
elif key.startswith('auth'):
return self.auth.__shortcuts__(key.replace('auth/', "", 1))
elif key.startswith('api'):
self.emby.client = self.http # Since api is not a class, re-assign global var to correct http adapter
return self.emby
elif key == 'connected':
return self.logged_in
return

View file

@ -0,0 +1 @@

View file

@ -0,0 +1,287 @@
# -*- coding: utf-8 -*-
#################################################################################################
client = None
#################################################################################################
def _http(action, url, request={}):
request.update({'type': action, 'handler': url})
return client.request(request)
def _get(handler, params=None):
return _http("GET", handler, {'params': params})
def _post(handler, json=None, params=None):
return _http("POST", handler, {'params': params, 'json': json})
def _delete(handler, params=None):
return _http("DELETE", handler, {'params': params})
def emby_url(handler):
return "%s/emby/%s" % (client.config['auth.server'], handler)
def basic_info():
return "Etag"
def info():
return (
"Path,Genres,SortName,Studios,Writer,Taglines,"
"OfficialRating,CumulativeRunTimeTicks,"
"Metascore,AirTime,DateCreated,People,Overview,"
"CriticRating,CriticRatingSummary,Etag,ShortOverview,ProductionLocations,"
"Tags,ProviderIds,ParentId,RemoteTrailers,SpecialEpisodeNumbers,"
"MediaSources,VoteCount,RecursiveItemCount,PrimaryImageAspectRatio"
)
def music_info():
return (
"Etag,Genres,SortName,Studios,Writer,"
"OfficialRating,CumulativeRunTimeTicks,Metascore,"
"AirTime,DateCreated,MediaStreams,People,ProviderIds,Overview,ItemCounts"
)
#################################################################################################
# Bigger section of the Emby api
#################################################################################################
def try_server():
return _get("System/Info/Public")
def sessions(handler="", action="GET", params=None, json=None):
if action == "POST":
return _post("Sessions%s" % handler, json, params)
elif action == "DELETE":
return _delete("Sessions%s" % handler, params)
else:
return _get("Sessions%s" % handler, params)
def users(handler="", action="GET", params=None, json=None):
if action == "POST":
return _post("Users/{UserId}%s" % handler, json, params)
elif action == "DELETE":
return _delete("Users/{UserId}%s" % handler, params)
else:
return _get("Users/{UserId}%s" % handler, params)
def items(handler="", action="GET", params=None, json=None):
if action == "POST":
return _post("Items%s" % handler, json, params)
elif action == "DELETE":
return _delete("Items%s" % handler, params)
else:
return _get("Items%s" % handler, params)
def user_items(handler="", params=None):
return users("/Items%s" % handler, params)
def shows(handler, params):
return _get("Shows%s" % handler, params)
def videos(handler):
return _get("Videos%s" % handler)
def artwork(item_id, art, max_width, ext="jpg", index=None):
if index is None:
return emby_url("Items/%s/Images/%s?MaxWidth=%s&format=%s" % (item_id, art, max_width, ext))
return emby_url("Items/%s/Images/%s/%s?MaxWidth=%s&format=%s" % (item_id, art, index, max_width, ext))
#################################################################################################
# More granular api
#################################################################################################
def get_users(disabled=False, hidden=False):
return _get("Users", params={
'IsDisabled': disabled,
'IsHidden': hidden
})
def get_public_users():
return _get("Users/Public")
def get_user(user_id=None):
return users() if user_id is None else _get("Users/%s" % user_id)
def get_views():
return users("/Views")
def get_media_folders():
return users("/Items")
def get_item(item_id):
return users("/Items/%s" % item_id)
def get_items(item_ids):
return users("/Items", params={
'Ids': ','.join(str(x) for x in item_ids),
'Fields': info()
})
def get_sessions():
return sessions(params={'ControllableByUserId': "{UserId}"})
def get_device(device_id):
return sessions(params={'DeviceId': device_id})
def post_session(session_id, url, params=None, data=None):
return sessions("/%s/%s" % (session_id, url), "POST", params, data)
def get_images(item_id):
return items("/%s/Images" % item_id)
def get_suggestion(media="Movie,Episode", limit=1):
return users("/Suggestions", {
'Type': media,
'Limit': limit
})
def get_recently_added(media=None, limit=20):
return user_items("/Latest", {
'Limit': limit,
'UserId': "{UserId}",
'IncludeItemTypes': media
})
def get_next(index=None, limit=1):
return shows("/NextUp", {
'Limit': limit,
'UserId': "{UserId}",
'StartIndex': None if index is None else int(index)
})
def get_intros(item_id):
return user_items("/%s/Intros" % item_id)
def get_additional_parts(item_id):
return videos("/%s/AdditionalParts" % item_id)
def delete_item(item_id):
return items("/%s" % item_id, "DELETE")
def get_local_trailers(item_id):
return user_items("/%s/LocalTrailers" % item_id)
def get_ancestors(item_id):
return items("/%s/Ancestors" % item_id, params={
'UserId': "{UserId}"
})
def get_items_theme_video(parent_id):
return users("/Items", params={
'HasThemeVideo': True,
'ParentId': parent_id
})
def get_themes(item_id):
return items("/%s/ThemeMedia" % item_id, params={
'UserId': "{UserId}",
'InheritFromParent': True
})
def get_items_theme_song(parent_id):
return users("/Items", params={
'HasThemeSong': True,
'ParentId': parent_id
})
def get_plugins():
return _get("Plugins")
def get_seasons(show_id):
return shows("/%s/Seasons" % show_id, params={
'UserId': "{UserId}",
'Fields': basic_info()
})
def get_date_modified(date, parent_id, media=None):
return users("/Items", params={
'ParentId': parent_id,
'Recursive': False,
'IsMissing': False,
'IsVirtualUnaired': False,
'IncludeItemTypes': media or None,
#'MinDateLastSavedForUser': date,
'MinDateLastSaved': date,
'Fields': info()
})
def refresh_item(item_id):
return items("/%s/Refresh" % item_id, "POST", json={
'Recursive': True,
'ImageRefreshMode': "FullRefresh",
'MetadataRefreshMode': "FullRefresh",
'ReplaceAllImages': False,
'ReplaceAllMetadata': True
})
def favorite(item_id, option=True):
return users("/FavoriteItems/%s" % item_id, "POST" if option else "DELETE")
def get_system_info():
return _get("System/Configuration")
def post_capabilities(data):
return sessions("/Capabilities/Full", "POST", json=data)
def session_add_user(session_id, user_id, option=True):
return sessions("/%s/Users/%s" % (session_id, user_id), "POST" if option else "DELETE")
def session_playing(data):
return sessions("/Playing", "POST", json=data)
def session_progress(data):
return sessions("/Playing/Progress", "POST", json=data)
def session_stop(data):
return sessions("/Playing/Stopped", "POST", json=data)
def item_played(item_id, watched):
return users("/PlayedItems/%s" % item_id, "POST" if watched else "DELETE")
def get_sync_queue(date, filters=None):
return _get("Emby.Kodi.SyncQueue/{UserId}/GetItems", params={
'LastUpdateDT': date,
'filter': filters or None
})
def get_server_time():
return _get("Emby.Kodi.SyncQueue/GetServerDateTime")
def get_play_info(item_id, profile):
return items("/%s/PlaybackInfo" % item_id, "POST", json={
'UserId': "{UserId}",
'DeviceProfile': profile
})
def get_live_stream(item_id, play_id, token, profile):
return _post("LiveStreams/Open", json={
'UserId': "{UserId}",
'DeviceProfile': profile,
'OpenToken': token,
'PlaySessionId': play_id,
'ItemId': item_id
})
def close_live_stream(live_id):
return _post("LiveStreams/Close", json={
'LiveStreamId': live_id
})
def close_transcode(device_id):
return _delete("Videos/ActiveEncodings", params={
'DeviceId': device_id
})
def delete_item(item_id):
return items("/%s" % item_id, "DELETE")

View file

@ -0,0 +1,73 @@
# -*- coding: utf-8 -*-
''' This will hold all configs from the client.
Configuration set here will be used for the HTTP client.
'''
#################################################################################################
import logging
#################################################################################################
DEFAULT_HTTP_MAX_RETRIES = 3
DEFAULT_HTTP_TIMEOUT = 30
LOG = logging.getLogger('Emby.'+__name__)
#################################################################################################
class Config(object):
def __init__(self):
LOG.debug("Configuration initializing...")
self.data = {}
self.http()
def __shortcuts__(self, key):
if key == "auth":
return self.auth
elif key == "app":
return self.app
elif key == "http":
return self.http
elif key == "data":
return self
return
def __getitem__(self, key):
return self.data.get(key, self.__shortcuts__(key))
def __setitem__(self, key, value):
self.data[key] = value
def app(self, name, version, device_name, device_id, capabilities=None, device_pixel_ratio=None):
LOG.info("Begin app constructor.")
self.data['app.name'] = name
self.data['app.version'] = version
self.data['app.device_name'] = device_name
self.data['app.device_id'] = device_id
self.data['app.capabilities'] = capabilities
self.data['app.device_pixel_ratio'] = device_pixel_ratio
self.data['app.default'] = False
def auth(self, server, user_id, token=None, ssl=None):
LOG.info("Begin auth constructor.")
self.data['auth.server'] = server
self.data['auth.user_id'] = user_id
self.data['auth.token'] = token
self.data['auth.ssl'] = ssl
def http(self, user_agent=None, max_retries=DEFAULT_HTTP_MAX_RETRIES, timeout=DEFAULT_HTTP_TIMEOUT):
LOG.info("Begin http constructor.")
self.data['http.max_retries'] = max_retries
self.data['http.timeout'] = timeout
self.data['http.user_agent'] = user_agent

View file

@ -0,0 +1,846 @@
# -*- coding: utf-8 -*-
#################################################################################################
import json
import logging
import hashlib
import socket
import time
from datetime import datetime
from credentials import Credentials
from http import HTTP
#################################################################################################
LOG = logging.getLogger('Emby.'+__name__)
CONNECTION_STATE = {
'Unavailable': 0,
'ServerSelection': 1,
'ServerSignIn': 2,
'SignedIn': 3,
'ConnectSignIn': 4,
'ServerUpdateNeeded': 5
}
CONNECTION_MODE = {
'Local': 0,
'Remote': 1,
'Manual': 2
}
#################################################################################################
def get_server_address(server, mode):
modes = {
CONNECTION_MODE['Local']: server.get('LocalAddress'),
CONNECTION_MODE['Remote']: server.get('RemoteAddress'),
CONNECTION_MODE['Manual']: server.get('ManualAddress')
}
return modes.get(mode) or server.get('ManualAddress', server.get('LocalAddress', server.get('RemoteAddress')))
class ConnectionManager(object):
min_server_version = "3.0.5930"
server_version = min_server_version
user = {}
server_id = None
timeout = 10
def __init__(self, client):
LOG.debug("ConnectionManager initializing...")
self.client = client
self.config = client.config
self.credentials = Credentials()
self.http = HTTP(client)
def __shortcuts__(self, key):
if key == "clear":
return self.clear_data
elif key == "servers":
return self.get_available_servers()
elif key in ("reconnect", "refresh"):
return self.connect
elif key == "login":
return self.login
elif key == "login-connect":
return self.login_to_connect
elif key == "connect-user":
return self.connect_user()
elif key == "connect-token":
return self.connect_token()
elif key == "connect-user-id":
return self.connect_user_id()
elif key == "server":
return self.get_server_info(self.server_id)
elif key == "server-id":
return self.server_id
elif key == "server-version":
return self.server_version
elif key == "user-id":
return self.emby_user_id()
elif key == "public-users":
return self.get_public_users()
elif key == "token":
return self.emby_token()
elif key == "manual-server":
return self.connect_to_address
elif key == "connect-to-server":
return self.connect_to_server
elif key == "server-address":
server = self.get_server_info(self.server_id)
return get_server_address(server, server['LastConnectionMode'])
elif key == "revoke-token":
return self.revoke_token()
return
def __getitem__(self, key):
return self.__shortcuts__(key)
def clear_data(self):
LOG.info("connection manager clearing data")
self.user = None
credentials = self.credentials.get_credentials()
credentials['ConnectAccessToken'] = None
credentials['ConnectUserId'] = None
credentials['Servers'] = list()
self.credentials.get_credentials(credentials)
self.config.auth(None, None)
def revoke_token(self):
LOG.info("revoking token")
self['server']['AccessToken'] = None
self.credentials.get_credentials(self.credentials.get_credentials())
self.config['auth.token'] = None
def get_available_servers(self):
LOG.info("Begin getAvailableServers")
# Clone the credentials
credentials = self.credentials.get_credentials()
connect_servers = self._get_connect_servers(credentials)
found_servers = self._find_servers(self._server_discovery())
if not connect_servers and not found_servers and not credentials['Servers']: # back out right away, no point in continuing
LOG.info("Found no servers")
return list()
servers = list(credentials['Servers'])
self._merge_servers(servers, found_servers)
self._merge_servers(servers, connect_servers)
servers = self._filter_servers(servers, connect_servers)
try:
servers.sort(key=lambda x: datetime.strptime(x['DateLastAccessed'], "%Y-%m-%dT%H:%M:%SZ"), reverse=True)
except TypeError:
servers.sort(key=lambda x: datetime(*(time.strptime(x['DateLastAccessed'], "%Y-%m-%dT%H:%M:%SZ")[0:6])), reverse=True)
credentials['Servers'] = servers
self.credentials.get_credentials(credentials)
return servers
def login_to_connect(self, username, password):
if not username:
raise AttributeError("username cannot be empty")
if not password:
raise AttributeError("password cannot be empty")
try:
result = self._request_url({
'type': "POST",
'url': self.get_connect_url("user/authenticate"),
'data': {
'nameOrEmail': username,
'password': self._get_connect_password_hash(password)
},
'dataType': "json"
})
except Exception as error: # Failed to login
LOG.error(error)
return False
else:
credentials = self.credentials.get_credentials()
credentials['ConnectAccessToken'] = result['AccessToken']
credentials['ConnectUserId'] = result['User']['Id']
credentials['ConnectUser'] = result['User']['DisplayName']
self.credentials.get_credentials(credentials)
# Signed in
self._on_connect_user_signin(result['User'])
return result
def login(self, server, username, password="", options={}):
if not username:
raise AttributeError("username cannot be empty")
if not server:
raise AttributeError("server cannot be empty")
try:
result = self._request_url({
'type': "POST",
'url': self.get_emby_url(server, "Users/AuthenticateByName"),
'json': {
'username': username,
'password': hashlib.sha1(password or "").hexdigest()
}
}, False)
except Exception as error: # Failed to login
LOG.error(error)
return False
else:
self._on_authenticated(result, options)
return result
def connect_to_address(self, address, options={}):
if not address:
return False
address = self._normalize_address(address)
def _on_fail():
LOG.error("connectToAddress %s failed", address)
return self._resolve_failure()
try:
public_info = self._try_connect(address, options=options)
except Exception:
return _on_fail()
else:
LOG.info("connectToAddress %s succeeded", address)
server = {
'ManualAddress': address,
'LastConnectionMode': CONNECTION_MODE['Manual']
}
self._update_server_info(server, public_info)
server = self.connect_to_server(server, options)
if server is False:
return _on_fail()
return server
def connect_to_server(self, server, options={}):
LOG.info("begin connectToServer")
tests = []
if server.get('LastConnectionMode') is not None:
#tests.append(server['LastConnectionMode'])
pass
if CONNECTION_MODE['Manual'] not in tests:
tests.append(CONNECTION_MODE['Manual'])
if CONNECTION_MODE['Local'] not in tests:
tests.append(CONNECTION_MODE['Local'])
if CONNECTION_MODE['Remote'] not in tests:
tests.append(CONNECTION_MODE['Remote'])
# TODO: begin to wake server
LOG.info("beginning connection tests")
return self._test_next_connection_mode(tests, 0, server, options)
def connect(self, options={}):
LOG.info("Begin connect")
return self._connect_to_servers(self.get_available_servers(), options)
def connect_user(self):
return self.user
def connect_user_id(self):
return self.credentials.get_credentials().get('ConnectUserId')
def connect_token(self):
return self.credentials.get_credentials().get('ConnectAccessToken')
def emby_user_id(self):
return self.get_server_info(self.server_id)['UserId']
def emby_token(self):
return self.get_server_info(self.server_id)['AccessToken']
def get_server_info(self, server_id):
if server_id is None:
LOG.info("server_id is empty")
return {}
servers = self.credentials.get_credentials()['Servers']
for server in servers:
if server['Id'] == server_id:
return server
def get_public_users(self):
return self.client.emby.get_public_users()
def get_connect_url(self, handler):
return "https://connect.emby.media/service/%s" % handler
def get_emby_url(self, base, handler):
return "%s/emby/%s" % (base, handler)
def _request_url(self, request, headers=True):
request['timeout'] = request.get('timeout') or self.timeout
if headers:
self._get_headers(request)
try:
return self.http.request(request)
except Exception as error:
LOG.error(error)
raise
def _add_app_info(self):
return "%s/%s" % (self.config['app.name'], self.config['app.version'])
def _get_headers(self, request):
headers = request.setdefault('headers', {})
if request.get('dataType') == "json":
headers['Accept'] = "application/json"
request.pop('dataType')
headers['X-Application'] = self._add_app_info()
headers['Content-type'] = request.get('contentType',
'application/x-www-form-urlencoded; charset=UTF-8')
def _connect_to_servers(self, servers, options):
LOG.info("Begin connectToServers, with %s servers", len(servers))
result = {}
if len(servers) == 1:
result = self.connect_to_server(servers[0], options)
"""
if result['State'] == CONNECTION_STATE['Unavailable']:
result['State'] = CONNECTION_STATE['ConnectSignIn'] if result['ConnectUser'] is None else CONNECTION_STATE['ServerSelection']
"""
LOG.debug("resolving connectToServers with result['State']: %s", result)
return result
first_server = self._get_last_used_server()
# See if we have any saved credentials and can auto sign in
if first_server is not None and first_server['DateLastAccessed'] != "2001-01-01T00:00:00Z":
result = self.connect_to_server(first_server, options)
if result['State'] == CONNECTION_STATE['SignedIn']:
return result
# Return loaded credentials if exists
credentials = self.credentials.get_credentials()
self._ensure_connect_user(credentials)
return {
'Servers': servers,
'State': CONNECTION_STATE['ConnectSignIn'] if (not len(servers) and not self.connect_user()) else (result.get('State') or CONNECTION_STATE['ServerSelection']),
'ConnectUser': self.connect_user()
}
def _try_connect(self, url, timeout=None, options={}):
url = self.get_emby_url(url, "system/info/public")
LOG.info("tryConnect url: %s", url)
return self._request_url({
'type': "GET",
'url': url,
'dataType': "json",
'timeout': timeout,
'verify': options.get('ssl'),
'retry': False
})
def _test_next_connection_mode(self, tests, index, server, options):
if index >= len(tests):
LOG.info("Tested all connection modes. Failing server connection.")
return self._resolve_failure()
mode = tests[index]
address = get_server_address(server, mode)
enable_retry = False
skip_test = False
timeout = self.timeout
LOG.info("testing connection mode %s with server %s", mode, server.get('Name'))
if mode == CONNECTION_MODE['Local']:
enable_retry = True
timeout = 8
if self._string_equals_ignore_case(address, server.get('ManualAddress')):
LOG.info("skipping LocalAddress test because it is the same as ManualAddress")
skip_test = True
elif mode == CONNECTION_MODE['Manual']:
if self._string_equals_ignore_case(address, server.get('LocalAddress')):
enable_retry = True
timeout = 8
if skip_test or not address:
LOG.info("skipping test at index: %s", index)
return self._test_next_connection_mode(tests, index + 1, server, options)
try:
result = self._try_connect(address, timeout, options)
except Exception:
LOG.error("test failed for connection mode %s with server %s", mode, server.get('Name'))
if enable_retry:
# TODO: wake on lan and retry
return self._test_next_connection_mode(tests, index + 1, server, options)
else:
return self._test_next_connection_mode(tests, index + 1, server, options)
else:
if self._compare_versions(self._get_min_server_version(), result['Version']) == 1:
LOG.warn("minServerVersion requirement not met. Server version: %s", result['Version'])
return {
'State': CONNECTION_STATE['ServerUpdateNeeded'],
'Servers': [server]
}
else:
LOG.info("calling onSuccessfulConnection with connection mode %s with server %s", mode, server.get('Name'))
return self._on_successful_connection(server, result, mode, options)
def _on_successful_connection(self, server, system_info, connection_mode, options):
credentials = self.credentials.get_credentials()
if credentials.get('ConnectAccessToken') and options.get('enableAutoLogin') is not False:
if self._ensure_connect_user(credentials) is not False:
if server.get('ExchangeToken'):
self._add_authentication_info_from_connect(server, connection_mode, credentials, options)
return self._after_connect_validated(server, credentials, system_info, connection_mode, True, options)
def _resolve_failure(self):
return {
'State': CONNECTION_STATE['Unavailable'],
'ConnectUser': self.connect_user()
}
def _get_min_server_version(self, val=None):
if val is not None:
LOG.info("hello?")
self.min_server_version = val
return self.min_server_version
def _compare_versions(self, a, b):
''' -1 a is smaller
1 a is larger
0 equal
'''
a = a.split('.')
b = b.split('.')
for i in range(0, max(len(a), len(b)), 1):
try:
aVal = a[i]
except IndexError:
aVal = 0
try:
bVal = b[i]
except IndexError:
bVal = 0
if aVal < bVal:
return -1
if aVal > bVal:
return 1
return 0
def _string_equals_ignore_case(self, str1, str2):
return (str1 or "").lower() == (str2 or "").lower()
def _get_connect_user(self, user_id, access_token):
if not user_id:
raise AttributeError("null userId")
if not access_token:
raise AttributeError("null accessToken")
return self._request_url({
'type': "GET",
'url': self.get_connect_url('user?id=%s' % user_id),
'dataType': "json",
'headers': {
'X-Connect-UserToken': access_token
}
})
def _server_discovery(self):
MULTI_GROUP = ("<broadcast>", 7359)
MESSAGE = "who is EmbyServer?"
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.settimeout(1.0) # This controls the socket.timeout exception
sock.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, 20)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
sock.setsockopt(socket.SOL_IP, socket.IP_MULTICAST_LOOP, 1)
sock.setsockopt(socket.IPPROTO_IP, socket.SO_REUSEADDR, 1)
LOG.debug("MultiGroup : %s", str(MULTI_GROUP))
LOG.debug("Sending UDP Data: %s", MESSAGE)
servers = []
try:
sock.sendto(MESSAGE, MULTI_GROUP)
except Exception as error:
LOG.error(error)
return servers
while True:
try:
data, addr = sock.recvfrom(1024) # buffer size
servers.append(json.loads(data))
except socket.timeout:
LOG.info("Found Servers: %s", servers)
return servers
except Exception as e:
LOG.error("Error trying to find servers: %s", e)
return servers
def _get_connect_servers(self, credentials):
LOG.info("Begin getConnectServers")
servers = list()
if not credentials.get('ConnectAccessToken') or not credentials.get('ConnectUserId'):
return servers
url = self.get_connect_url("servers?userId=%s" % credentials['ConnectUserId'])
request = {
'type': "GET",
'url': url,
'dataType': "json",
'headers': {
'X-Connect-UserToken': credentials['ConnectAccessToken']
}
}
for server in self._request_url(request):
servers.append({
'ExchangeToken': server['AccessKey'],
'ConnectServerId': server['Id'],
'Id': server['SystemId'],
'Name': server['Name'],
'RemoteAddress': server['Url'],
'LocalAddress': server['LocalAddress'],
'UserLinkType': "Guest" if server['UserType'].lower() == "guest" else "LinkedUser",
})
return servers
def _get_last_used_server(self):
servers = self.credentials.get_credentials()['Servers']
if not len(servers):
return
try:
servers.sort(key=lambda x: datetime.strptime(x['DateLastAccessed'], "%Y-%m-%dT%H:%M:%SZ"), reverse=True)
except TypeError:
servers.sort(key=lambda x: datetime(*(time.strptime(x['DateLastAccessed'], "%Y-%m-%dT%H:%M:%SZ")[0:6])), reverse=True)
return servers[0]
def _merge_servers(self, list1, list2):
for i in range(0, len(list2), 1):
try:
self.credentials.add_update_server(list1, list2[i])
except KeyError:
continue
return list1
def _find_servers(self, found_servers):
servers = []
for found_server in found_servers:
server = self._convert_endpoint_address_to_manual_address(found_server)
info = {
'Id': found_server['Id'],
'LocalAddress': server or found_server['Address'],
'Name': found_server['Name']
} #TODO
info['LastConnectionMode'] = CONNECTION_MODE['Manual'] if info.get('ManualAddress') else CONNECTION_MODE['Local']
servers.append(info)
else:
return servers
def _filter_servers(self, servers, connect_servers):
filtered = list()
for server in servers:
if server.get('ExchangeToken') is None:
# It's not a connect server, so assume it's still valid
filtered.append(server)
continue
for connect_server in connect_servers:
if server['Id'] == connect_server['Id']:
filtered.append(server)
break
return filtered
def _convert_endpoint_address_to_manual_address(self, info):
if info.get('Address') and info.get('EndpointAddress'):
address = info['EndpointAddress'].split(':')[0]
# Determine the port, if any
parts = info['Address'].split(':')
if len(parts) > 1:
port_string = parts[len(parts)-1]
try:
address += ":%s" % int(port_string)
return self._normalize_address(address)
except ValueError:
pass
return None
def _normalize_address(self, address):
# Attempt to correct bad input
address = address.strip()
address = address.lower()
if 'http' not in address:
address = "http://%s" % address
return address
def _get_connect_password_hash(self, password):
password = self._clean_connect_password(password)
return hashlib.md5(password).hexdigest()
def _clean_connect_password(self, password):
password = password or ""
password = password.replace("&", '&amp;')
password = password.replace("/", '&#092;')
password = password.replace("!", '&#33;')
password = password.replace("$", '&#036;')
password = password.replace("\"", '&quot;')
password = password.replace("<", '&lt;')
password = password.replace(">", '&gt;')
password = password.replace("'", '&#39;')
return password
def _ensure_connect_user(self, credentials):
if self.user and self.user['Id'] == credentials['ConnectUserId']:
return
elif credentials.get('ConnectUserId') and credentials.get('ConnectAccessToken'):
self.user = None
try:
result = self._get_connect_user(credentials['ConnectUserId'], credentials['ConnectAccessToken'])
self._on_connect_user_signin(result)
except Exception:
return False
def _on_connect_user_signin(self, user):
self.user = user
LOG.info("connectusersignedin %s", user)
def _save_user_info_into_credentials(self, server, user):
info = {
'Id': user['Id'],
'IsSignedInOffline': True
}
self.credentials.add_update_user(server, info)
def _add_authentication_info_from_connect(self, server, connection_mode, credentials, options={}):
if not server.get('ExchangeToken'):
raise KeyError("server['ExchangeToken'] cannot be null")
if not credentials.get('ConnectUserId'):
raise KeyError("credentials['ConnectUserId'] cannot be null")
auth = "MediaBrowser "
auth += "Client='%s', " % self.config['app.name']
auth += "Device='%s', " % self.config['app.device_name']
auth += "DeviceId='%s', " % self.config['app.device_id']
auth += "Version='%s' " % self.config['app.version']
try:
auth = self._request_url({
'url': self.get_emby_url(get_server_address(server, connection_mode), "Connect/Exchange"),
'type': "GET",
'dataType': "json",
'verify': options.get('ssl'),
'params': {
'ConnectUserId': credentials['ConnectUserId']
},
'headers': {
'X-MediaBrowser-Token': server['ExchangeToken'],
'X-Emby-Authorization': auth
}
})
except Exception:
server['UserId'] = None
server['AccessToken'] = None
return False
else:
server['UserId'] = auth['LocalUserId']
server['AccessToken'] = auth['AccessToken']
return auth
def _after_connect_validated(self, server, credentials, system_info, connection_mode, verify_authentication, options):
if options.get('enableAutoLogin') == False:
self.config['auth.user_id'] = server.pop('UserId', None)
self.config['auth.token'] = server.pop('AccessToken', None)
elif verify_authentication and server.get('AccessToken'):
if self._validate_authentication(server, connection_mode, options) is not False:
self.config['auth.user_id'] = server['UserId']
self.config['auth.token'] = server['AccessToken']
return self._after_connect_validated(server, credentials, system_info, connection_mode, False, options)
return self._resolve_failure()
self._update_server_info(server, system_info)
self.server_version = system_info['Version']
server['LastConnectionMode'] = connection_mode
if options.get('updateDateLastAccessed') is not False:
server['DateLastAccessed'] = datetime.now().strftime('%Y-%m-%dT%H:%M:%SZ')
self.credentials.add_update_server(credentials['Servers'], server)
self.credentials.get_credentials(credentials)
self.server_id = server['Id']
# Update configs
self.config['auth.server'] = get_server_address(server, connection_mode)
self.config['auth.server-name'] = server['Name']
self.config['auth.server=id'] = server['Id']
self.config['auth.ssl'] = options.get('ssl', self.config['auth.ssl'])
result = {
'Servers': [server],
'ConnectUser': self.connect_user()
}
result['State'] = CONNECTION_STATE['SignedIn'] if server.get('AccessToken') else CONNECTION_STATE['ServerSignIn']
# Connected
return result
def _validate_authentication(self, server, connection_mode, options={}):
try:
system_info = self._request_url({
'type': "GET",
'url': self.get_emby_url(get_server_address(server, connection_mode), "System/Info"),
'verify': options.get('ssl'),
'dataType': "json",
'headers': {
'X-MediaBrowser-Token': server['AccessToken']
}
})
self._update_server_info(server, system_info)
except Exception:
server['UserId'] = None
server['AccessToken'] = None
return False
def _update_server_info(self, server, system_info):
if server is None or system_info is None:
return
server['Name'] = system_info['ServerName']
server['Id'] = system_info['Id']
if system_info.get('LocalAddress'):
server['LocalAddress'] = system_info['LocalAddress']
if system_info.get('WanAddress'):
server['RemoteAddress'] = system_info['WanAddress']
if 'MacAddress' in system_info:
server['WakeOnLanInfos'] = [{'MacAddress': system_info['MacAddress']}]
def _on_authenticated(self, result, options={}):
credentials = self.credentials.get_credentials()
self.config['auth.user_id'] = result['User']['Id']
self.config['auth.token'] = result['AccessToken']
for server in credentials['Servers']:
if server['Id'] == result['ServerId']:
found_server = server
break
else: return # No server found
if options.get('updateDateLastAccessed') is not False:
found_server['DateLastAccessed'] = datetime.now().strftime('%Y-%m-%dT%H:%M:%SZ')
found_server['UserId'] = result['User']['Id']
found_server['AccessToken'] = result['AccessToken']
self.credentials.add_update_server(credentials['Servers'], found_server)
self._save_user_info_into_credentials(found_server, result['User'])
self.credentials.get_credentials(credentials)

View file

@ -0,0 +1,137 @@
# -*- coding: utf-8 -*-
#################################################################################################
import json
import logging
import os
import time
from datetime import datetime
#################################################################################################
LOG = logging.getLogger('Emby.'+__name__)
#################################################################################################
class Credentials(object):
credentials = None
def __init__(self):
LOG.debug("Credentials initializing...")
def set_credentials(self, credentials):
self.credentials = credentials
def get_credentials(self, data=None):
if data is not None:
self._set(data)
return self._get()
def _ensure(self):
if not self.credentials:
try:
LOG.info(self.credentials)
if not isinstance(self.credentials, dict):
raise ValueError("invalid credentials format")
except Exception as e: # File is either empty or missing
LOG.warn(e)
self.credentials = {}
LOG.debug("credentials initialized with: %s", self.credentials)
self.credentials['Servers'] = self.credentials.setdefault('Servers', [])
def _get(self):
self._ensure()
return self.credentials
def _set(self, data):
if data:
self.credentials.update(data)
else:
self._clear()
LOG.debug("credentialsupdated")
def _clear(self):
self.credentials.clear()
def add_update_user(self, server, user):
for existing in server.setdefault('Users', []):
if existing['Id'] == user['Id']:
# Merge the data
existing['IsSignedInOffline'] = True
break
else:
server['Users'].append(user)
def add_update_server(self, servers, server):
if server.get('Id') is None:
raise KeyError("Server['Id'] cannot be null or empty")
# Add default DateLastAccessed if doesn't exist.
server.setdefault('DateLastAccessed', "2001-01-01T00:00:00Z")
for existing in servers:
if existing['Id'] == server['Id']:
# Merge the data
if server.get('DateLastAccessed'):
if self._date_object(server['DateLastAccessed']) > self._date_object(existing['DateLastAccessed']):
existing['DateLastAccessed'] = server['DateLastAccessed']
if server.get('UserLinkType'):
existing['UserLinkType'] = server['UserLinkType']
if server.get('AccessToken'):
existing['AccessToken'] = server['AccessToken']
existing['UserId'] = server['UserId']
if server.get('ExchangeToken'):
existing['ExchangeToken'] = server['ExchangeToken']
if server.get('RemoteAddress'):
existing['RemoteAddress'] = server['RemoteAddress']
if server.get('ManualAddress'):
existing['ManualAddress'] = server['ManualAddress']
if server.get('LocalAddress'):
existing['LocalAddress'] = server['LocalAddress']
if server.get('Name'):
existing['Name'] = server['Name']
if server.get('WakeOnLanInfos'):
existing['WakeOnLanInfos'] = server['WakeOnLanInfos']
if server.get('LastConnectionMode') is not None:
existing['LastConnectionMode'] = server['LastConnectionMode']
if server.get('ConnectServerId'):
existing['ConnectServerId'] = server['ConnectServerId']
return existing
else:
servers.append(server)
return server
def _date_object(self, date):
# Convert string to date
try:
date_obj = time.strptime(date, "%Y-%m-%dT%H:%M:%SZ")
except (ImportError, TypeError):
# TypeError: attribute of type 'NoneType' is not callable
# Known Kodi/python error
date_obj = datetime(*(time.strptime(date, "%Y-%m-%dT%H:%M:%SZ")[0:6]))
return date_obj

View file

@ -0,0 +1,11 @@
# -*- coding: utf-8 -*-
#################################################################################################
class HTTPException(Exception):
# Emby HTTP exception
def __init__(self, status, message):
self.status = status
self.message = message

View file

@ -0,0 +1,233 @@
# -*- coding: utf-8 -*-
#################################################################################################
import json
import logging
import time
from libraries import requests
from exceptions import HTTPException
#################################################################################################
LOG = logging.getLogger('Emby.'+__name__)
#################################################################################################
class HTTP(object):
session = None
keep_alive = False
def __init__(self, client):
self.client = client
self.config = client['config']
def __shortcuts__(self, key):
if key == "request":
return self.request
return
def start_session(self):
self.session = requests.Session()
max_retries = self.config['http.max_retries']
self.session.mount("http://", requests.adapters.HTTPAdapter(max_retries=max_retries))
self.session.mount("https://", requests.adapters.HTTPAdapter(max_retries=max_retries))
def stop_session(self):
if self.session is None:
return
try:
self.session.close()
except Exception as error:
LOG.warn("The requests session could not be terminated: %s", error)
def _replace_user_info(self, string):
if self.config['auth.server']:
string = string.replace("{server}", self.config['auth.server'])
if self.config['auth.user_id']:
string = string.replace("{UserId}", self.config['auth.user_id'])
return string
def request(self, data, session=None):
''' Give a chance to retry the connection. Emby sometimes can be slow to answer back
data dictionary can contain:
type: GET, POST, etc.
url: (optional)
handler: not considered when url is provided (optional)
params: request parameters (optional)
json: request body (optional)
headers: (optional),
verify: ssl certificate, True (verify using device built-in library) or False
'''
if not data:
raise AttributeError("Request cannot be empty")
data = self._request(data)
LOG.debug("--->[ http ] %s", json.dumps(data, indent=4))
retry = data.pop('retry', 5)
while True:
try:
r = self._requests(session or self.session or requests, data.pop('type', "GET"), **data)
r.content # release the connection
if not self.keep_alive and self.session is not None:
self.stop_session()
r.raise_for_status()
except requests.exceptions.ConnectionError as error:
if retry:
retry -= 1
time.sleep(1)
continue
LOG.error(error)
self.client['callback']("ServerUnreachable", {'ServerId': self.config['auth.server-id']})
raise HTTPException("ServerUnreachable", error)
except requests.exceptions.ReadTimeout as error:
if retry:
retry -= 1
time.sleep(1)
continue
LOG.error(error)
raise HTTPException("ReadTimeout", error)
except requests.exceptions.HTTPError as error:
LOG.error(error)
if r.status_code == 401:
if 'X-Application-Error-Code' in r.headers:
self.client['callback']("AccessRestricted", {'ServerId': self.config['auth.server-id']})
raise HTTPException("AccessRestricted", error)
else:
self.client['callback']("Unauthorized", {'ServerId': self.config['auth.server-id']})
self.client['auth/revoke-token']
raise HTTPException("Unauthorized", error)
elif r.status_code == 500: # log and ignore.
LOG.error("--[ 500 response ] %s", error)
return
elif r.status_code == 502:
if retry:
retry -= 1
time.sleep(1)
continue
raise HTTPException(r.status_code, error)
except requests.exceptions.MissingSchema as error:
raise HTTPException("MissingSchema", {'Id': self.config['auth.server']})
except Exception as error:
raise
else:
try:
elapsed = int(r.elapsed.total_seconds() * 1000)
response = r.json()
LOG.debug("---<[ http ][%s ms]", elapsed)
LOG.debug(json.dumps(response, indent=4))
return response
except ValueError:
return
def _request(self, data):
if 'url' not in data:
data['url'] = "%s/emby/%s" % (self.config['auth.server'], data.pop('handler', ""))
self._get_header(data)
data['timeout'] = data.get('timeout') or self.config['http.timeout']
data['verify'] = data.get('verify') or self.config['auth.ssl'] or False
data['url'] = self._replace_user_info(data['url'])
self._process_params(data.get('params') or {})
self._process_params(data.get('json') or {})
return data
def _process_params(self, params):
for key in params:
value = params[key]
if isinstance(value, dict):
self._process_params(value)
if isinstance(value, str):
params[key] = self._replace_user_info(value)
def _get_header(self, data):
data['headers'] = data.setdefault('headers', {})
if not data['headers']:
data['headers'].update({
'Content-type': "application/json",
'Accept-Charset': "UTF-8,*",
'Accept-encoding': "gzip",
'User-Agent': self.config['http.user_agent'] or "%s/%s" % (self.config['app.name'], self.config['app.version'])
})
if 'Authorization' not in data['headers']:
self._authorization(data)
return data
def _authorization(self, data):
auth = "MediaBrowser "
auth += "Client=%s, " % self.config['app.name']
auth += "Device=%s, " % self.config['app.device_name']
auth += "DeviceId=%s, " % self.config['app.device_id']
auth += "Version=%s" % self.config['app.version']
data['headers'].update({'Authorization': auth})
if self.config['auth.token']:
auth += ', UserId=%s' % self.config['auth.user_id']
data['headers'].update({'Authorization': auth, 'X-MediaBrowser-Token': self.config['auth.token']})
return data
def _requests(self, session, action, **kwargs):
if action == "GET":
return session.get(**kwargs)
elif action == "POST":
return session.post(**kwargs)
elif action == "HEAD":
return session.head(**kwargs)
elif action == "DELETE":
return session.delete(**kwargs)

View file

@ -0,0 +1,92 @@
# -*- coding: utf-8 -*-
#################################################################################################
import json
import logging
import threading
import time
from ..resources import websocket
##################################################################################################
LOG = logging.getLogger('Emby.'+__name__)
##################################################################################################
class WSClient(threading.Thread):
wsc = None
stop = False
def __init__(self, client):
LOG.debug("WSClient initializing...")
self.client = client
threading.Thread.__init__(self)
def __shortcuts__(self, key):
if key == "send":
return self.send
elif key == "stop":
return self.stop_client()
return
def send(self, message, data=""):
if self.wsc is None:
raise ValueError("The websocket client is not started.")
self.wsc.send(json.dumps({'MessageType': message, "Data": data}))
def run(self):
token = self.client['config/auth.token']
device_id = self.client['config/app.device_id']
server = self.client['config/auth.server']
server = server.replace('https', "wss") if server.startswith('https') else server.replace('http', "ws")
wsc_url = "%s/embywebsocket?api_key=%s&device_id=%s" % (server, token, device_id)
LOG.info("Websocket url: %s", wsc_url)
self.wsc = websocket.WebSocketApp(wsc_url,
on_message=self.on_message,
on_error=self.on_error)
self.wsc.on_open = self.on_open
while not self.stop:
self.wsc.run_forever(ping_interval=10)
if not self.stop:
time.sleep(5)
LOG.info("---<[ websocket ]")
def on_error(self, ws, error):
LOG.error(error)
def on_open(self, ws):
LOG.info("--->[ websocket ]")
def on_message(self, ws, message):
message = json.loads(message)
data = message.get('Data', {})
if not self.client['config/app.default']:
data['ServerId'] = self.client['auth/server-id']
self.client['callback_ws'](message['MessageType'], data)
def stop_client(self):
self.stop = True
if self.wsc is not None:
self.wsc.close()

View file

@ -0,0 +1,7 @@
def has_attribute(obj, name):
try:
object.__getattribute__(obj, name)
return True
except AttributeError:
return False

View file

@ -0,0 +1,15 @@
# -*- coding: utf-8 -*-
#################################################################################################
import logging
from uuid import uuid4
#################################################################################################
LOG = logging.getLogger('Emby.'+__name__)
#################################################################################################
def generate_client_id():
return str("%012X" % uuid4())

View file

@ -0,0 +1 @@

View file

@ -0,0 +1,930 @@
"""
websocket - WebSocket client library for Python
Copyright (C) 2010 Hiroki Ohtani(liris)
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
"""
import socket
try:
import ssl
from ssl import SSLError
HAVE_SSL = True
except ImportError:
# dummy class of SSLError for ssl none-support environment.
class SSLError(Exception):
pass
HAVE_SSL = False
from urlparse import urlparse
import os
import array
import struct
import uuid
import hashlib
import base64
import threading
import time
import logging
import traceback
import sys
"""
websocket python client.
=========================
This version support only hybi-13.
Please see http://tools.ietf.org/html/rfc6455 for protocol.
"""
# websocket supported version.
VERSION = 13
# closing frame status codes.
STATUS_NORMAL = 1000
STATUS_GOING_AWAY = 1001
STATUS_PROTOCOL_ERROR = 1002
STATUS_UNSUPPORTED_DATA_TYPE = 1003
STATUS_STATUS_NOT_AVAILABLE = 1005
STATUS_ABNORMAL_CLOSED = 1006
STATUS_INVALID_PAYLOAD = 1007
STATUS_POLICY_VIOLATION = 1008
STATUS_MESSAGE_TOO_BIG = 1009
STATUS_INVALID_EXTENSION = 1010
STATUS_UNEXPECTED_CONDITION = 1011
STATUS_TLS_HANDSHAKE_ERROR = 1015
logger = logging.getLogger()
class WebSocketException(Exception):
"""
websocket exeception class.
"""
pass
class WebSocketConnectionClosedException(WebSocketException):
"""
If remote host closed the connection or some network error happened,
this exception will be raised.
"""
pass
class WebSocketTimeoutException(WebSocketException):
"""
WebSocketTimeoutException will be raised at socket timeout during read/write data.
"""
pass
default_timeout = None
traceEnabled = False
def enableTrace(tracable):
"""
turn on/off the tracability.
tracable: boolean value. if set True, tracability is enabled.
"""
global traceEnabled
traceEnabled = tracable
if tracable:
if not logger.handlers:
logger.addHandler(logging.StreamHandler())
logger.setLevel(logging.DEBUG)
def setdefaulttimeout(timeout):
"""
Set the global timeout setting to connect.
timeout: default socket timeout time. This value is second.
"""
global default_timeout
default_timeout = timeout
def getdefaulttimeout():
"""
Return the global timeout setting(second) to connect.
"""
return default_timeout
def _wrap_sni_socket(sock, sslopt, hostname):
context = ssl.SSLContext(sslopt.get('ssl_version', ssl.PROTOCOL_SSLv23))
if sslopt.get('cert_reqs', ssl.CERT_NONE) != ssl.CERT_NONE:
capath = ssl.get_default_verify_paths().capath
context.load_verify_locations(cafile=sslopt.get('ca_certs', None),
capath=sslopt.get('ca_cert_path', capath))
return context.wrap_socket(
sock,
do_handshake_on_connect=sslopt.get('do_handshake_on_connect', True),
suppress_ragged_eofs=sslopt.get('suppress_ragged_eofs', True),
server_hostname=hostname,
)
def _parse_url(url):
"""
parse url and the result is tuple of
(hostname, port, resource path and the flag of secure mode)
url: url string.
"""
if ":" not in url:
raise ValueError("url is invalid")
scheme, url = url.split(":", 1)
parsed = urlparse(url, scheme="http")
if parsed.hostname:
hostname = parsed.hostname
else:
raise ValueError("hostname is invalid")
port = 0
if parsed.port:
port = parsed.port
is_secure = False
if scheme == "ws":
if not port:
port = 80
elif scheme == "wss":
is_secure = True
if not port:
port = 443
else:
raise ValueError("scheme %s is invalid" % scheme)
if parsed.path:
resource = parsed.path
else:
resource = "/"
if parsed.query:
resource += "?" + parsed.query
return (hostname, port, resource, is_secure)
def create_connection(url, timeout=None, **options):
"""
connect to url and return websocket object.
Connect to url and return the WebSocket object.
Passing optional timeout parameter will set the timeout on the socket.
If no timeout is supplied, the global default timeout setting returned by getdefauttimeout() is used.
You can customize using 'options'.
If you set "header" list object, you can set your own custom header.
>>> conn = create_connection("ws://echo.websocket.org/",
... header=["User-Agent: MyProgram",
... "x-custom: header"])
timeout: socket timeout time. This value is integer.
if you set None for this value, it means "use default_timeout value"
options: current support option is only "header".
if you set header as dict value, the custom HTTP headers are added.
"""
sockopt = options.get("sockopt", [])
sslopt = options.get("sslopt", {})
websock = WebSocket(sockopt=sockopt, sslopt=sslopt)
websock.settimeout(timeout if timeout is not None else default_timeout)
websock.connect(url, **options)
return websock
_MAX_INTEGER = (1 << 32) -1
_AVAILABLE_KEY_CHARS = range(0x21, 0x2f + 1) + range(0x3a, 0x7e + 1)
_MAX_CHAR_BYTE = (1<<8) -1
# ref. Websocket gets an update, and it breaks stuff.
# http://axod.blogspot.com/2010/06/websocket-gets-update-and-it-breaks.html
def _create_sec_websocket_key():
uid = uuid.uuid4()
return base64.encodestring(uid.bytes).strip()
_HEADERS_TO_CHECK = {
"upgrade": "websocket",
"connection": "upgrade",
}
class ABNF(object):
"""
ABNF frame class.
see http://tools.ietf.org/html/rfc5234
and http://tools.ietf.org/html/rfc6455#section-5.2
"""
# operation code values.
OPCODE_CONT = 0x0
OPCODE_TEXT = 0x1
OPCODE_BINARY = 0x2
OPCODE_CLOSE = 0x8
OPCODE_PING = 0x9
OPCODE_PONG = 0xa
# available operation code value tuple
OPCODES = (OPCODE_CONT, OPCODE_TEXT, OPCODE_BINARY, OPCODE_CLOSE,
OPCODE_PING, OPCODE_PONG)
# opcode human readable string
OPCODE_MAP = {
OPCODE_CONT: "cont",
OPCODE_TEXT: "text",
OPCODE_BINARY: "binary",
OPCODE_CLOSE: "close",
OPCODE_PING: "ping",
OPCODE_PONG: "pong"
}
# data length threashold.
LENGTH_7 = 0x7d
LENGTH_16 = 1 << 16
LENGTH_63 = 1 << 63
def __init__(self, fin=0, rsv1=0, rsv2=0, rsv3=0,
opcode=OPCODE_TEXT, mask=1, data=""):
"""
Constructor for ABNF.
please check RFC for arguments.
"""
self.fin = fin
self.rsv1 = rsv1
self.rsv2 = rsv2
self.rsv3 = rsv3
self.opcode = opcode
self.mask = mask
self.data = data
self.get_mask_key = os.urandom
def __str__(self):
return "fin=" + str(self.fin) \
+ " opcode=" + str(self.opcode) \
+ " data=" + str(self.data)
@staticmethod
def create_frame(data, opcode):
"""
create frame to send text, binary and other data.
data: data to send. This is string value(byte array).
if opcode is OPCODE_TEXT and this value is uniocde,
data value is conveted into unicode string, automatically.
opcode: operation code. please see OPCODE_XXX.
"""
if opcode == ABNF.OPCODE_TEXT and isinstance(data, unicode):
data = data.encode("utf-8")
# mask must be set if send data from client
return ABNF(1, 0, 0, 0, opcode, 1, data)
def format(self):
"""
format this object to string(byte array) to send data to server.
"""
if any(x not in (0, 1) for x in [self.fin, self.rsv1, self.rsv2, self.rsv3]):
raise ValueError("not 0 or 1")
if self.opcode not in ABNF.OPCODES:
raise ValueError("Invalid OPCODE")
length = len(self.data)
if length >= ABNF.LENGTH_63:
raise ValueError("data is too long")
frame_header = chr(self.fin << 7
| self.rsv1 << 6 | self.rsv2 << 5 | self.rsv3 << 4
| self.opcode)
if length < ABNF.LENGTH_7:
frame_header += chr(self.mask << 7 | length)
elif length < ABNF.LENGTH_16:
frame_header += chr(self.mask << 7 | 0x7e)
frame_header += struct.pack("!H", length)
else:
frame_header += chr(self.mask << 7 | 0x7f)
frame_header += struct.pack("!Q", length)
if not self.mask:
return frame_header + self.data
else:
mask_key = self.get_mask_key(4)
return frame_header + self._get_masked(mask_key)
def _get_masked(self, mask_key):
s = ABNF.mask(mask_key, self.data)
return mask_key + "".join(s)
@staticmethod
def mask(mask_key, data):
"""
mask or unmask data. Just do xor for each byte
mask_key: 4 byte string(byte).
data: data to mask/unmask.
"""
_m = array.array("B", mask_key)
_d = array.array("B", data)
for i in xrange(len(_d)):
_d[i] ^= _m[i % 4]
return _d.tostring()
class WebSocket(object):
"""
Low level WebSocket interface.
This class is based on
The WebSocket protocol draft-hixie-thewebsocketprotocol-76
http://tools.ietf.org/html/draft-hixie-thewebsocketprotocol-76
We can connect to the websocket server and send/recieve data.
The following example is a echo client.
>>> import websocket
>>> ws = websocket.WebSocket()
>>> ws.connect("ws://echo.websocket.org")
>>> ws.send("Hello, Server")
>>> ws.recv()
'Hello, Server'
>>> ws.close()
get_mask_key: a callable to produce new mask keys, see the set_mask_key
function's docstring for more details
sockopt: values for socket.setsockopt.
sockopt must be tuple and each element is argument of sock.setscokopt.
sslopt: dict object for ssl socket option.
"""
def __init__(self, get_mask_key=None, sockopt=None, sslopt=None):
"""
Initalize WebSocket object.
"""
if sockopt is None:
sockopt = []
if sslopt is None:
sslopt = {}
self.connected = False
self.sock = socket.socket()
for opts in sockopt:
self.sock.setsockopt(*opts)
self.sslopt = sslopt
self.get_mask_key = get_mask_key
# Buffers over the packets from the layer beneath until desired amount
# bytes of bytes are received.
self._recv_buffer = []
# These buffer over the build-up of a single frame.
self._frame_header = None
self._frame_length = None
self._frame_mask = None
self._cont_data = None
def fileno(self):
return self.sock.fileno()
def set_mask_key(self, func):
"""
set function to create musk key. You can custumize mask key generator.
Mainly, this is for testing purpose.
func: callable object. the fuct must 1 argument as integer.
The argument means length of mask key.
This func must be return string(byte array),
which length is argument specified.
"""
self.get_mask_key = func
def gettimeout(self):
"""
Get the websocket timeout(second).
"""
return self.sock.gettimeout()
def settimeout(self, timeout):
"""
Set the timeout to the websocket.
timeout: timeout time(second).
"""
self.sock.settimeout(timeout)
timeout = property(gettimeout, settimeout)
def connect(self, url, **options):
"""
Connect to url. url is websocket url scheme. ie. ws://host:port/resource
You can customize using 'options'.
If you set "header" dict object, you can set your own custom header.
>>> ws = WebSocket()
>>> ws.connect("ws://echo.websocket.org/",
... header={"User-Agent: MyProgram",
... "x-custom: header"})
timeout: socket timeout time. This value is integer.
if you set None for this value,
it means "use default_timeout value"
options: current support option is only "header".
if you set header as dict value,
the custom HTTP headers are added.
"""
hostname, port, resource, is_secure = _parse_url(url)
# TODO: we need to support proxy
self.sock.connect((hostname, port))
if is_secure:
if HAVE_SSL:
if self.sslopt is None:
sslopt = {}
else:
sslopt = self.sslopt
if ssl.HAS_SNI:
self.sock = _wrap_sni_socket(self.sock, sslopt, hostname)
else:
self.sock = ssl.wrap_socket(self.sock, **sslopt)
else:
raise WebSocketException("SSL not available.")
self._handshake(hostname, port, resource, **options)
def _handshake(self, host, port, resource, **options):
headers = []
headers.append("GET %s HTTP/1.1" % resource)
headers.append("Upgrade: websocket")
headers.append("Connection: Upgrade")
if port == 80:
hostport = host
else:
hostport = "%s:%d" % (host, port)
headers.append("Host: %s" % hostport)
if "origin" in options:
headers.append("Origin: %s" % options["origin"])
else:
headers.append("Origin: http://%s" % hostport)
key = _create_sec_websocket_key()
headers.append("Sec-WebSocket-Key: %s" % key)
headers.append("Sec-WebSocket-Version: %s" % VERSION)
if "header" in options:
headers.extend(options["header"])
headers.append("")
headers.append("")
header_str = "\r\n".join(headers)
self._send(header_str)
if traceEnabled:
logger.debug("--- request header ---")
logger.debug(header_str)
logger.debug("-----------------------")
status, resp_headers = self._read_headers()
if status != 101:
self.close()
raise WebSocketException("Handshake Status %d" % status)
success = self._validate_header(resp_headers, key)
if not success:
self.close()
raise WebSocketException("Invalid WebSocket Header")
self.connected = True
def _validate_header(self, headers, key):
for k, v in _HEADERS_TO_CHECK.iteritems():
r = headers.get(k, None)
if not r:
return False
r = r.lower()
if v != r:
return False
result = headers.get("sec-websocket-accept", None)
if not result:
return False
result = result.lower()
value = key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
hashed = base64.encodestring(hashlib.sha1(value).digest()).strip().lower()
return hashed == result
def _read_headers(self):
status = None
headers = {}
if traceEnabled:
logger.debug("--- response header ---")
while True:
line = self._recv_line()
if line == "\r\n":
break
line = line.strip()
if traceEnabled:
logger.debug(line)
if not status:
status_info = line.split(" ", 2)
status = int(status_info[1])
else:
kv = line.split(":", 1)
if len(kv) == 2:
key, value = kv
headers[key.lower()] = value.strip().lower()
else:
raise WebSocketException("Invalid header")
if traceEnabled:
logger.debug("-----------------------")
return status, headers
def send(self, payload, opcode=ABNF.OPCODE_TEXT):
"""
Send the data as string.
payload: Payload must be utf-8 string or unicoce,
if the opcode is OPCODE_TEXT.
Otherwise, it must be string(byte array)
opcode: operation code to send. Please see OPCODE_XXX.
"""
frame = ABNF.create_frame(payload, opcode)
if self.get_mask_key:
frame.get_mask_key = self.get_mask_key
data = frame.format()
length = len(data)
if traceEnabled:
logger.debug("send: " + repr(data))
while data:
l = self._send(data)
data = data[l:]
return length
def send_binary(self, payload):
return self.send(payload, ABNF.OPCODE_BINARY)
def ping(self, payload=""):
"""
send ping data.
payload: data payload to send server.
"""
self.send(payload, ABNF.OPCODE_PING)
def pong(self, payload):
"""
send pong data.
payload: data payload to send server.
"""
self.send(payload, ABNF.OPCODE_PONG)
def recv(self):
"""
Receive string data(byte array) from the server.
return value: string(byte array) value.
"""
opcode, data = self.recv_data()
return data
def recv_data(self):
"""
Recieve data with operation code.
return value: tuple of operation code and string(byte array) value.
"""
while True:
frame = self.recv_frame()
if not frame:
# handle error:
# 'NoneType' object has no attribute 'opcode'
raise WebSocketException("Not a valid frame %s" % frame)
elif frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY, ABNF.OPCODE_CONT):
if frame.opcode == ABNF.OPCODE_CONT and not self._cont_data:
raise WebSocketException("Illegal frame")
if self._cont_data:
self._cont_data[1] += frame.data
else:
self._cont_data = [frame.opcode, frame.data]
if frame.fin:
data = self._cont_data
self._cont_data = None
return data
elif frame.opcode == ABNF.OPCODE_CLOSE:
self.send_close()
return (frame.opcode, None)
elif frame.opcode == ABNF.OPCODE_PING:
self.pong(frame.data)
def recv_frame(self):
"""
recieve data as frame from server.
return value: ABNF frame object.
"""
# Header
if self._frame_header is None:
self._frame_header = self._recv_strict(2)
b1 = ord(self._frame_header[0])
fin = b1 >> 7 & 1
rsv1 = b1 >> 6 & 1
rsv2 = b1 >> 5 & 1
rsv3 = b1 >> 4 & 1
opcode = b1 & 0xf
b2 = ord(self._frame_header[1])
has_mask = b2 >> 7 & 1
# Frame length
if self._frame_length is None:
length_bits = b2 & 0x7f
if length_bits == 0x7e:
length_data = self._recv_strict(2)
self._frame_length = struct.unpack("!H", length_data)[0]
elif length_bits == 0x7f:
length_data = self._recv_strict(8)
self._frame_length = struct.unpack("!Q", length_data)[0]
else:
self._frame_length = length_bits
# Mask
if self._frame_mask is None:
self._frame_mask = self._recv_strict(4) if has_mask else ""
# Payload
payload = self._recv_strict(self._frame_length)
if has_mask:
payload = ABNF.mask(self._frame_mask, payload)
# Reset for next frame
self._frame_header = None
self._frame_length = None
self._frame_mask = None
return ABNF(fin, rsv1, rsv2, rsv3, opcode, has_mask, payload)
def send_close(self, status=STATUS_NORMAL, reason=""):
"""
send close data to the server.
status: status code to send. see STATUS_XXX.
reason: the reason to close. This must be string.
"""
if status < 0 or status >= ABNF.LENGTH_16:
raise ValueError("code is invalid range")
self.send(struct.pack('!H', status) + reason, ABNF.OPCODE_CLOSE)
def close(self, status=STATUS_NORMAL, reason=""):
"""
Close Websocket object
status: status code to send. see STATUS_XXX.
reason: the reason to close. This must be string.
"""
try:
self.sock.shutdown(socket.SHUT_RDWR)
except:
pass
'''
if self.connected:
if status < 0 or status >= ABNF.LENGTH_16:
raise ValueError("code is invalid range")
try:
self.send(struct.pack('!H', status) + reason, ABNF.OPCODE_CLOSE)
timeout = self.sock.gettimeout()
self.sock.settimeout(3)
try:
frame = self.recv_frame()
if logger.isEnabledFor(logging.ERROR):
recv_status = struct.unpack("!H", frame.data)[0]
if recv_status != STATUS_NORMAL:
logger.error("close status: " + repr(recv_status))
except:
pass
self.sock.settimeout(timeout)
self.sock.shutdown(socket.SHUT_RDWR)
except:
pass
'''
self._closeInternal()
def _closeInternal(self):
self.connected = False
self.sock.close()
def _send(self, data):
try:
return self.sock.send(data)
except socket.timeout as e:
raise WebSocketTimeoutException(e.args[0])
except Exception as e:
if "timed out" in e.args[0]:
raise WebSocketTimeoutException(e.args[0])
else:
raise e
def _recv(self, bufsize):
try:
bytes = self.sock.recv(bufsize)
except socket.timeout as e:
raise WebSocketTimeoutException(e.args[0])
except SSLError as e:
if e.args[0] == "The read operation timed out":
raise WebSocketTimeoutException(e.args[0])
else:
raise
if not bytes:
raise WebSocketConnectionClosedException()
return bytes
def _recv_strict(self, bufsize):
shortage = bufsize - sum(len(x) for x in self._recv_buffer)
while shortage > 0:
bytes = self._recv(shortage)
self._recv_buffer.append(bytes)
shortage -= len(bytes)
unified = "".join(self._recv_buffer)
if shortage == 0:
self._recv_buffer = []
return unified
else:
self._recv_buffer = [unified[bufsize:]]
return unified[:bufsize]
def _recv_line(self):
line = []
while True:
c = self._recv(1)
line.append(c)
if c == "\n":
break
return "".join(line)
class WebSocketApp(object):
"""
Higher level of APIs are provided.
The interface is like JavaScript WebSocket object.
"""
def __init__(self, url, header=[],
on_open=None, on_message=None, on_error=None,
on_close=None, keep_running=True, get_mask_key=None):
"""
url: websocket url.
header: custom header for websocket handshake.
on_open: callable object which is called at opening websocket.
this function has one argument. The arugment is this class object.
on_message: callbale object which is called when recieved data.
on_message has 2 arguments.
The 1st arugment is this class object.
The passing 2nd arugment is utf-8 string which we get from the server.
on_error: callable object which is called when we get error.
on_error has 2 arguments.
The 1st arugment is this class object.
The passing 2nd arugment is exception object.
on_close: callable object which is called when closed the connection.
this function has one argument. The arugment is this class object.
keep_running: a boolean flag indicating whether the app's main loop should
keep running, defaults to True
get_mask_key: a callable to produce new mask keys, see the WebSocket.set_mask_key's
docstring for more information
"""
self.url = url
self.header = header
self.on_open = on_open
self.on_message = on_message
self.on_error = on_error
self.on_close = on_close
self.keep_running = keep_running
self.get_mask_key = get_mask_key
self.sock = None
def send(self, data, opcode=ABNF.OPCODE_TEXT):
"""
send message.
data: message to send. If you set opcode to OPCODE_TEXT, data must be utf-8 string or unicode.
opcode: operation code of data. default is OPCODE_TEXT.
"""
if self.sock.send(data, opcode) == 0:
raise WebSocketConnectionClosedException()
def close(self):
"""
close websocket connection.
"""
self.keep_running = False
if(self.sock != None):
self.sock.close()
def _send_ping(self, interval):
while True:
for i in range(interval):
time.sleep(1)
if not self.keep_running:
return
self.sock.ping()
def run_forever(self, sockopt=None, sslopt=None, ping_interval=0):
"""
run event loop for WebSocket framework.
This loop is infinite loop and is alive during websocket is available.
sockopt: values for socket.setsockopt.
sockopt must be tuple and each element is argument of sock.setscokopt.
sslopt: ssl socket optional dict.
ping_interval: automatically send "ping" command every specified period(second)
if set to 0, not send automatically.
"""
if sockopt is None:
sockopt = []
if sslopt is None:
sslopt = {}
if self.sock:
raise WebSocketException("socket is already opened")
thread = None
self.keep_running = True
try:
self.sock = WebSocket(self.get_mask_key, sockopt=sockopt, sslopt=sslopt)
self.sock.settimeout(default_timeout)
self.sock.connect(self.url, header=self.header)
self._callback(self.on_open)
if ping_interval:
thread = threading.Thread(target=self._send_ping, args=(ping_interval,))
thread.setDaemon(True)
thread.start()
while self.keep_running:
try:
data = self.sock.recv()
if data is None or self.keep_running == False:
break
self._callback(self.on_message, data)
except Exception, e:
#print str(e.args[0])
if "timed out" not in e.args[0]:
raise e
except Exception, e:
self._callback(self.on_error, e)
finally:
if thread:
self.keep_running = False
self.sock.close()
self._callback(self.on_close)
self.sock = None
def _callback(self, callback, *args):
if callback:
try:
callback(self, *args)
except Exception, e:
logger.error(e)
if True:#logger.isEnabledFor(logging.DEBUG):
_, _, tb = sys.exc_info()
traceback.print_tb(tb)
if __name__ == "__main__":
enableTrace(True)
ws = create_connection("ws://echo.websocket.org/")
print("Sending 'Hello, World'...")
ws.send("Hello, World")
print("Sent")
print("Receiving...")
result = ws.recv()
print("Received '%s'" % result)
ws.close()