diff --git a/consumers/__init__.py b/consumers/__init__.py new file mode 100644 index 0000000..f1b8cf9 --- /dev/null +++ b/consumers/__init__.py @@ -0,0 +1,43 @@ +# -*- coding: utf-8 -*- +from abc import ABC, abstractmethod +from typing import Any, Dict + + +class BaseConsumer(ABC): + settings: Dict[str, Any] + + @abstractmethod + def __init__(self, settings: Dict[str, Any]) -> None: + self.config(settings) + + @abstractmethod + def write(self, data: Dict[str, Any]): + """ + Process and send data to wherever it is going. + Avoid blocking. + """ + pass + + @abstractmethod + def poll(self): + """ + This function will be ran whenever there is down time. + If your consumer needs to do something periodically, do so here. + This function should not block. + """ + pass + + def exit(self): + """ + Called on exit, clean up your handles here + """ + pass + + def config(self, settings: Dict[str, Any]): + self.settings = settings + + def __enter__(self): + return self + + def __exit__(self, etype, value, traceback): + self.exit() diff --git a/consumers/mqtt.py b/consumers/mqtt.py new file mode 100644 index 0000000..a8e4d2b --- /dev/null +++ b/consumers/mqtt.py @@ -0,0 +1,141 @@ +# -*- coding: utf-8 -*- +from typing import Any, Dict, List, Optional +from uuid import uuid4 + +import paho.mqtt.client as mqtt + +from . import BaseConsumer + +# MAP_VALUES = {} + + +class MqttConsumer(BaseConsumer): + client: mqtt.Client + initialized: List[str] + + def __init__(self, settings: Dict[str, Any]) -> None: + self.initialized = [] + + super().__init__(settings) + self.client = mqtt.Client(client_id=settings["client"]["id"], userdata=self) + self.client.on_connect = self.on_connect + self.client.on_message = self.on_message + # Will must be set before connecting!! + self.client.will_set( + f"{self.topic_prefix}/available", payload="offline", retain=True + ) + self.client.connect( + settings["client"]["host"], + settings["client"]["port"], + settings["client"]["keepalive"], + ) + + def config(self, settings: Dict[str, Any]): + super().config(settings) + settings.setdefault("client", {}) + settings["client"].setdefault("id", None) + settings["client"].setdefault("host", "") + settings["client"].setdefault("port", 1883) + settings["client"].setdefault("keepalive", 60) + + if not settings.get("device_id"): + settings["device_id"] = str(uuid4()) + + settings.setdefault("prefix", "solarmppt") + + settings.setdefault("discovery_prefix", "homeassistant") + + @property + def topic_prefix(self): + return f"{self.settings['prefix']}/{self.settings['device_id']}" + + def get_ha_config( + self, + id, + name, + unit: Optional[str] = None, + type: Optional[str] = None, + expiry: int = 90, + state_class: Optional[str] = None, + ): + assert state_class in [None, "measurement", "total", "total_increasing"] + + res = { + "~": f"{self.topic_prefix}", + "unique_id": f"{self.settings['device_id']}_{id}", + "availability_topic": "~/available", + "state_topic": f"~/{id}", + "device": { + "identifiers": [ + self.settings["device_id"], + ], + # TODO: Get charger serial and use for identifier instead + # See: https://www.home-assistant.io/integrations/sensor.mqtt/#device + # "via_device": self.settings["device_id"], + }, + "force_update": True, + "expire_after": expiry, + } + + if unit: + res["unit_of_meas"] = unit + if type: + res["dev_cla"] = type + if state_class: + res["state_class"] = state_class + + # The callback for when the client receives a CONNACK response from the server. + @staticmethod + def on_connect(client: mqtt.Client, userdata: "MqttConsumer", flags, rc): + print("Connected with result code " + str(rc)) + + # Subscribing in on_connect() means that if we lose the connection and + # reconnect then subscriptions will be renewed. + # client.subscribe("$SYS/#") + client.publish( + f"{userdata.topic_prefix}/available", payload="online", retain=True + ) + + # The callback for when a PUBLISH message is received from the server. + @staticmethod + def on_message(client, userdata, msg): + print(msg.topic + " " + str(msg.payload)) + + def poll(self): + self.client.loop(timeout=0.1, max_packets=5) + return super().poll() + + def write(self, data: Dict[str, Any]): + self.client.publish(f"{self.topic_prefix}/raw", payload=data) + return super().write(data) + + def exit(self): + self.client.publish( + f"{self.topic_prefix}/available", payload="offline", retain=True + ) + + while self.client.want_write(): + self.client.loop_write(10) + + self.client.disconnect() + return super().exit() + + +# Client(client_id="", clean_session=True, userdata=None, +# protocol=MQTTv311, transport="tcp") + +# connect_srv(domain, keepalive=60, bind_address="") +# Connect to a broker using an SRV DNS lookup to obtain the broker address. +# Takes the following arguments: +# domain +# the DNS domain to search for SRV records. +# If None, try to determine the local domain name. + +# client.will_set(topic, payload=None, qos=0, retain=False) + +# Blocking call that processes network traffic, dispatches callbacks and +# handles reconnecting. +# Other loop*() functions are available that give a threaded interface and a +# manual interface. + +# client.loop_forever() diff --git a/test_config.py b/test_config.py new file mode 100644 index 0000000..a0f9e3f --- /dev/null +++ b/test_config.py @@ -0,0 +1,64 @@ +# -*- coding: utf-8 -*- +import importlib +import os +from time import sleep +from typing import Any, Dict, Optional, Type + +import yaml + +from consumers import BaseConsumer + + +def get_consumer(name: str) -> Optional[Type[BaseConsumer]]: + mod_name, cls_name = name.rsplit(".", 1) + + mod = importlib.import_module(f"consumers.{mod_name}") + + print(mod) + print(dir(mod)) + res = getattr(mod, cls_name) + assert issubclass(res, BaseConsumer) + + return res + + +def get_config() -> Dict[str, Any]: + with open("config.yaml", "r") as fh: + conf: dict = yaml.safe_load(fh) + conf.setdefault("consumers", {}) + + return conf + + +def write_config(conf: Dict[str, Any]): + with open(".config.yaml~writing", "w") as fh: + yaml.safe_dump(conf, fh, indent=2, encoding="utf-8") + os.rename(".config.yaml~writing", "config.yaml") + + +conf = get_config() + +consumers = [] +for name, consumer_config in conf["consumers"].items(): + print(name, consumer_config) + mod = get_consumer(name) + if mod: + print(mod) + consumers.append(mod(consumer_config)) + +write_config(conf) + + +try: + while True: + for consumer in consumers: + consumer.poll() + sleep(1) +except (KeyboardInterrupt, SystemExit, Exception) as e: + for consumer in consumers: + consumer.exit() + + if type(e) is not KeyboardInterrupt: + raise + + write_config(conf)