From dd7c43f7e71c3692e9c148279fe5297a1eb7ad81 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Odd=20Str=C3=A5b=C3=B8?= <oddstr13@openshell.no>
Date: Mon, 10 Apr 2023 03:39:19 +0200
Subject: [PATCH] Add support for the load switch Rework mqtt structure

---
 .pre-commit-config.yaml        |   1 +
 requirements.txt               |   1 +
 srnemqtt/__main__.py           |  62 +++++++-------------
 srnemqtt/config.py             |  12 ++--
 srnemqtt/constants.py          |   2 +-
 srnemqtt/consumers/__init__.py |   6 +-
 srnemqtt/consumers/mqtt.py     |  58 ++++++++++++++++---
 srnemqtt/consumers/stdio.py    |   5 +-
 srnemqtt/protocol.py           |  50 ++++++++++++++--
 srnemqtt/solar_types.py        |   3 +
 srnemqtt/srne.py               | 103 +++++++++++++++++++++++++++++++++
 11 files changed, 239 insertions(+), 64 deletions(-)
 create mode 100644 srnemqtt/srne.py

diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index b615718..c63df7c 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -41,6 +41,7 @@ repos:
         args:
           - "--install-types"
           - "--non-interactive"
+          - "--ignore-missing-imports"
 
   - repo: https://github.com/psf/black
     rev: 23.3.0
diff --git a/requirements.txt b/requirements.txt
index 5a919f0..dafbd39 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -6,3 +6,4 @@ paho-mqtt
 pyserial
 
 types-PyYAML
+types-paho-mqtt
diff --git a/srnemqtt/__main__.py b/srnemqtt/__main__.py
index 38d8058..f13b40a 100755
--- a/srnemqtt/__main__.py
+++ b/srnemqtt/__main__.py
@@ -2,47 +2,42 @@
 # -*- coding: utf-8 -*-
 
 import time
-from decimal import Decimal
-from typing import cast
+from typing import List, Optional, cast
 
 from bluepy.btle import BTLEDisconnectError
 from serial import SerialException
 
+from srnemqtt.consumers import BaseConsumer
+
 from .config import get_config, get_consumers, get_interface
-from .protocol import parse_battery_state, parse_historical_entry, try_read_parse
-from .solar_types import DataName
+from .srne import Srne
 from .util import Periodical, log
 
 
-class CommunicationError(BTLEDisconnectError, SerialException, IOError):
+class CommunicationError(BTLEDisconnectError, SerialException, TimeoutError):
     pass
 
 
-def main():
+def main() -> None:
     conf = get_config()
-    consumers = get_consumers(conf)
+    consumers: Optional[List[BaseConsumer]] = None
 
     per_voltages = Periodical(interval=15)
     per_current_hist = Periodical(interval=60)
-    # import serial
-
-    # ser = serial.Serial()
 
     try:
         while True:
             try:
                 log("Connecting...")
                 with get_interface() as dev:
+                    srne = Srne(dev)
                     log("Connected.")
 
-                    # write(dev, construct_request(0, 32))
+                    if consumers is None:
+                        consumers = get_consumers(srne, conf)
 
-                    # Memory dump
-                    # for address in range(0, 0x10000, 16):
-                    #    log(f"Reading 0x{address:04X}...")
-                    #    write(wd, construct_request(address, 16))
                     days = 7
-                    res = try_read_parse(dev, 0x010B, 21, parse_historical_entry)
+                    res = srne.get_historical_entry()
                     if res:
                         log(res)
                         for consumer in consumers:
@@ -50,9 +45,7 @@ def main():
                         days = cast(int, res.get("run_days", 7))
 
                     for i in range(days):
-                        res = try_read_parse(
-                            dev, 0xF000 + i, 10, parse_historical_entry
-                        )
+                        res = srne.get_historical_entry(i)
                         if res:
                             log({i: res})
                             for consumer in consumers:
@@ -62,40 +55,26 @@ def main():
                         now = time.time()
 
                         if per_voltages(now):
-                            data = try_read_parse(dev, 0x0100, 11, parse_battery_state)
+                            data = srne.get_battery_state()
                             if data:
-                                data[DataName.CALCULATED_BATTERY_POWER] = float(
-                                    Decimal(str(data.get(DataName.BATTERY_VOLTAGE, 0)))
-                                    * Decimal(
-                                        str(data.get(DataName.BATTERY_CURRENT, 0))
-                                    )
-                                )
-                                data[DataName.CALCULATED_PANEL_POWER] = float(
-                                    Decimal(str(data.get(DataName.PANEL_VOLTAGE, 0)))
-                                    * Decimal(str(data.get(DataName.PANEL_CURRENT, 0)))
-                                )
-                                data[DataName.CALCULATED_LOAD_POWER] = float(
-                                    Decimal(str(data.get(DataName.LOAD_VOLTAGE, 0)))
-                                    * Decimal(str(data.get(DataName.LOAD_CURRENT, 0)))
-                                )
                                 log(data)
                                 for consumer in consumers:
                                     consumer.write(data)
 
                         if per_current_hist(now):
-                            data = try_read_parse(
-                                dev, 0x010B, 21, parse_historical_entry
-                            )
-                            if data:
+                            try:
+                                data = srne.get_historical_entry()
                                 log(data)
                                 for consumer in consumers:
                                     consumer.write(data)
+                            except TimeoutError:
+                                pass
 
                         # print(".")
                         for consumer in consumers:
                             consumer.poll()
 
-                        time.sleep(max(0, 1 - time.time() - now))
+                        time.sleep(max(0, 1 - time.time() - now))  # 1s loop
 
                     # if STATUS.get('load_enabled'):
                     #    write(wd, CMD_DISABLE_LOAD)
@@ -107,8 +86,9 @@ def main():
                 time.sleep(1)
 
     except (KeyboardInterrupt, SystemExit, Exception) as e:
-        for consumer in consumers:
-            consumer.exit()
+        if consumers is not None:
+            for consumer in consumers:
+                consumer.exit()
 
         if type(e) is not KeyboardInterrupt:
             raise
diff --git a/srnemqtt/config.py b/srnemqtt/config.py
index 4b3a4c1..301f2f7 100644
--- a/srnemqtt/config.py
+++ b/srnemqtt/config.py
@@ -6,9 +6,9 @@ from typing import Any, Dict, List, Optional, Type
 
 import yaml
 
-from srnemqtt.interfaces import BaseInterface
-
 from .consumers import BaseConsumer
+from .interfaces import BaseInterface
+from .srne import Srne
 
 
 def get_consumer(name: str) -> Optional[Type[BaseConsumer]]:
@@ -38,7 +38,9 @@ def write_config(conf: Dict[str, Any]):
     os.rename(".config.yaml~writing", "config.yaml")
 
 
-def get_consumers(conf: Optional[Dict[str, Any]] = None) -> List[BaseConsumer]:
+def get_consumers(
+    srne: Srne, conf: Optional[Dict[str, Any]] = None
+) -> List[BaseConsumer]:
     if conf is None:
         conf = get_config()
 
@@ -48,7 +50,7 @@ def get_consumers(conf: Optional[Dict[str, Any]] = None) -> List[BaseConsumer]:
         mod = get_consumer(name)
         if mod:
             # print(mod)
-            consumers.append(mod(consumer_config))
+            consumers.append(mod(settings=consumer_config, srne=srne))
 
     write_config(conf)
     return consumers
@@ -81,7 +83,7 @@ def get_interface(conf: Optional[Dict[str, Any]] = None) -> BaseInterface:
 if __name__ == "__main__":
     conf = get_config()
 
-    consumers = get_consumers(conf)
+    consumers = get_consumers(Srne(BaseInterface()), conf)
 
     try:
         while True:
diff --git a/srnemqtt/constants.py b/srnemqtt/constants.py
index 2a13eac..a68bb5c 100644
--- a/srnemqtt/constants.py
+++ b/srnemqtt/constants.py
@@ -5,7 +5,7 @@ MAC = "DC:0D:30:9C:61:BA"
 # read_service  = "0000fff0-0000-1000-8000-00805f9b34fb"
 
 ACTION_READ = 0x03
-ACTION_WRITE = 0x03
+ACTION_WRITE = 0x06
 
 POSSIBLE_MARKER = (0x01, 0xFD, 0xFE, 0xFF)
 
diff --git a/srnemqtt/consumers/__init__.py b/srnemqtt/consumers/__init__.py
index f1b8cf9..bc12ed4 100644
--- a/srnemqtt/consumers/__init__.py
+++ b/srnemqtt/consumers/__init__.py
@@ -2,12 +2,16 @@
 from abc import ABC, abstractmethod
 from typing import Any, Dict
 
+from ..srne import Srne
+
 
 class BaseConsumer(ABC):
     settings: Dict[str, Any]
+    srne: Srne
 
     @abstractmethod
-    def __init__(self, settings: Dict[str, Any]) -> None:
+    def __init__(self, settings: Dict[str, Any], srne: Srne) -> None:
+        self.srne = srne
         self.config(settings)
 
     @abstractmethod
diff --git a/srnemqtt/consumers/mqtt.py b/srnemqtt/consumers/mqtt.py
index 6cd7497..772600a 100644
--- a/srnemqtt/consumers/mqtt.py
+++ b/srnemqtt/consumers/mqtt.py
@@ -7,6 +7,7 @@ from uuid import uuid4
 import paho.mqtt.client as mqtt
 
 from ..solar_types import DataName
+from ..srne import Srne
 from . import BaseConsumer
 
 MAP_VALUES: Dict[DataName, Dict[str, Any]] = {
@@ -81,7 +82,15 @@ MAP_VALUES: Dict[DataName, Dict[str, Any]] = {
         "type": "current",
         "state_class": "measurement",
     },
-    DataName.LOAD_POWER: {"unit": "W", "type": "power", "state_class": "measurement"},
+    DataName.LOAD_POWER: {
+        "unit": "W",
+        "type": "power",
+        "state_class": "measurement",
+    },
+    DataName.LOAD_ENABLED: {
+        "type": "outlet",
+        "platform": "switch",
+    },
     DataName.PANEL_VOLTAGE: {
         "unit": "V",
         "type": "voltage",
@@ -115,11 +124,12 @@ MAP_VALUES: Dict[DataName, Dict[str, Any]] = {
 class MqttConsumer(BaseConsumer):
     client: mqtt.Client
     initialized: List[str]
+    srne: Srne
 
-    def __init__(self, settings: Dict[str, Any]) -> None:
+    def __init__(self, settings: Dict[str, Any], srne: Srne) -> None:
         self.initialized = []
 
-        super().__init__(settings)
+        super().__init__(settings, srne)
         self.client = mqtt.Client(client_id=settings["client"]["id"], userdata=self)
         self.client.on_connect = self.on_connect
         self.client.on_message = self.on_message
@@ -166,7 +176,7 @@ class MqttConsumer(BaseConsumer):
 
     @property
     def topic_prefix(self):
-        return f"{self.settings['prefix']}/{self.settings['device_id']}"
+        return f"{self.settings['prefix']}/{self.srne.serial}"
 
     def get_ha_config(
         self,
@@ -176,23 +186,30 @@ class MqttConsumer(BaseConsumer):
         type: Optional[str] = None,
         expiry: int = 90,
         state_class: Optional[str] = None,
+        platform: str = "sensor",
     ):
         assert state_class in [None, "measurement", "total", "total_increasing"]
 
         res = {
             "~": f"{self.topic_prefix}",
-            "unique_id": f"{self.settings['device_id']}_{id}",
+            "unique_id": f"srne_{self.srne.serial}_{id}",
+            "object_id": f"srne_{self.srne.serial}_{id}",  # Used for entity id
             "availability_topic": "~/available",
             "state_topic": f"~/{id}",
             "name": name,
             "device": {
                 "identifiers": [
-                    self.settings["device_id"],
+                    self.srne.serial,
                 ],
                 # TODO: Get charger serial and use for identifier instead
                 # See: https://www.home-assistant.io/integrations/sensor.mqtt/#device
                 # "via_device": self.settings["device_id"],
                 "suggested_area": "Solar panel",
+                "manufacturer": "SRNE Solar",
+                "model": self.srne.model,
+                "name": self.srne.name,
+                "sw_version": self.srne.version,
+                "via_device": self.settings["device_id"],
             },
             "force_update": True,
             "expire_after": expiry,
@@ -204,6 +221,10 @@ class MqttConsumer(BaseConsumer):
             res["dev_cla"] = type
         if state_class:
             res["state_class"] = state_class
+        if platform == "switch":
+            res["command_topic"] = f"{res['state_topic']}/set"
+            res["payload_on"] = True
+            res["payload_off"] = False
 
         return res
 
@@ -219,6 +240,27 @@ class MqttConsumer(BaseConsumer):
             f"{userdata.topic_prefix}/available", payload="online", retain=True
         )
 
+        load_set_topic = f"{userdata.topic_prefix}/load_enabled/set"
+        client.message_callback_add(load_set_topic, userdata.on_load_switch)
+        client.subscribe(load_set_topic)
+
+    @staticmethod
+    def on_load_switch(
+        client: mqtt.Client, userdata: "MqttConsumer", message: mqtt.MQTTMessage
+    ):
+        print(message)
+        print(message.info)
+        print(message.state)
+        print(message.payload)
+        payload = message.payload.decode().upper() in ("ON", "TRUE", "ENABLE", "YES")
+        if type(payload) is bool:
+            res = userdata.srne.enable_load(payload)
+            client.publish(
+                f"{userdata.topic_prefix}/load_enabled", payload=res, retain=True
+            )
+        else:
+            print(f"!!! Unknown payload for switch callback: {message.payload!r}")
+
     @staticmethod
     def on_connect_fail(client: mqtt.Client, userdata: "MqttConsumer"):
         print(userdata.__class__.__name__, "on_connect_fail")
@@ -256,9 +298,9 @@ class MqttConsumer(BaseConsumer):
                     km = MAP_VALUES[DataName(k)]
                     pretty_name = k.replace("_", " ").capitalize()
                     disc_prefix = self.settings["discovery_prefix"]
-                    device_id = self.settings["device_id"]
+                    platform = km.get("platform", "sensor")
                     self.client.publish(
-                        f"{disc_prefix}/sensor/{device_id}_{k}/config",
+                        f"{disc_prefix}/{platform}/srne_{self.srne.serial}_{k}/config",
                         payload=json.dumps(self.get_ha_config(k, pretty_name, **km)),
                         retain=True,
                     )
diff --git a/srnemqtt/consumers/stdio.py b/srnemqtt/consumers/stdio.py
index df63e70..bf5a8e2 100644
--- a/srnemqtt/consumers/stdio.py
+++ b/srnemqtt/consumers/stdio.py
@@ -2,12 +2,13 @@
 import json
 from typing import Any, Dict
 
+from ..srne import Srne
 from . import BaseConsumer
 
 
 class StdoutConsumer(BaseConsumer):
-    def __init__(self, settings: Dict[str, Any]) -> None:
-        super().__init__(settings)
+    def __init__(self, settings: Dict[str, Any], srne: Srne) -> None:
+        super().__init__(settings, srne)
 
     def poll(self):
         return super().poll()
diff --git a/srnemqtt/protocol.py b/srnemqtt/protocol.py
index 160da0b..11dcd73 100644
--- a/srnemqtt/protocol.py
+++ b/srnemqtt/protocol.py
@@ -7,8 +7,8 @@ from typing import Callable, Collection, Optional
 
 from libscrc import modbus
 
-from .constants import ACTION_READ, POSSIBLE_MARKER
-from .lib.feasycom_ble import BTLEUart
+from .constants import ACTION_READ, ACTION_WRITE, POSSIBLE_MARKER
+from .interfaces import BaseInterface
 from .solar_types import DATA_BATTERY_STATE, HISTORICAL_DATA, DataItem
 from .util import log
 
@@ -25,6 +25,11 @@ def construct_request(address, words=1, action=ACTION_READ, marker=0xFF):
     return struct.pack("!BBHH", marker, action, address, words)
 
 
+def construct_write(address, data: bytes, action=ACTION_WRITE, marker=0xFF):
+    assert marker in POSSIBLE_MARKER, f"marker should be one of {POSSIBLE_MARKER}"
+    return struct.pack("!BBH", marker, action, address) + data
+
+
 def parse(data: bytes, items: Collection[DataItem], offset: int = 0) -> dict:
     pos = offset
     res = {}
@@ -84,7 +89,6 @@ def discardUntil(fh: RawIOBase, byte: int, timeout=10) -> Optional[int]:
     discarded = 0
     read_byte = expand(fh.read(1))
     while read_byte != byte:
-
         if read_byte is not None:
             if not discarded:
                 log("Discarding", end="")
@@ -105,7 +109,7 @@ def discardUntil(fh: RawIOBase, byte: int, timeout=10) -> Optional[int]:
     return read_byte
 
 
-def readMemory(fh: RawIOBase, address: int, words: int = 1) -> Optional[bytes]:
+def readMemory(fh: BaseInterface, address: int, words: int = 1) -> Optional[bytes]:
     # log(f"Reading {words} words from 0x{address:04X}")
     request = construct_request(address, words=words)
     # log("Request:", request)
@@ -135,8 +139,42 @@ def readMemory(fh: RawIOBase, address: int, words: int = 1) -> Optional[bytes]:
     return None
 
 
+def writeMemory(fh: BaseInterface, address: int, output_data: bytes) -> Optional[bytes]:
+    # TODO: Verify behavior on multi-word writes
+    # log(f"Reading {words} words from 0x{address:04X}")
+    request = construct_write(address, data=output_data)
+    # log("Request:", request)
+    write(fh, request)
+
+    tag = discardUntil(fh, 0xFF)
+    if tag is None:
+        return None
+
+    _operation = fh.read(1)
+    result_addr = fh.read(2)
+    # log("Operation:", _operation)
+    if _operation is not None and result_addr is not None:
+        operation = _operation[0]
+        data = fh.read(2)
+        # log("Data:", data)
+        _crc = fh.read(2)
+        if data and _crc:
+            try:
+                crc = struct.unpack_from("<H", _crc)[0]
+            except struct.error:
+                log(f"readMemory: CRC error; read {len(_crc)} bytes (2 expected)")
+                return None
+            calculated_crc = modbus(bytes([tag, operation, *result_addr, *data]))
+            if crc == calculated_crc:
+                return data
+            else:
+                log(f"readMemory: CRC error; {crc:04X} != {calculated_crc:04X}")
+        log("data or crc is falsely", operation, result_addr, data, _crc)
+    return None
+
+
 def try_read_parse(
-    dev: BTLEUart,
+    dev: BaseInterface,
     address: int,
     words: int = 1,
     parser: Optional[Callable] = None,
@@ -152,7 +190,7 @@ def try_read_parse(
             except struct.error as e:
                 log(e)
                 log("0x0100 Unpack error:", len(res), res)
-                log("Flushed from read buffer; ", dev.read(timeout=0.5))
+                log("Flushed from read buffer; ", dev.read())  # TODO: timeout=0.5
         else:
             log(f"No data read, expected {words*2} bytes (attempts left: {attempts})")
     return None
diff --git a/srnemqtt/solar_types.py b/srnemqtt/solar_types.py
index 8fdcb83..a4833ae 100644
--- a/srnemqtt/solar_types.py
+++ b/srnemqtt/solar_types.py
@@ -43,6 +43,9 @@ class DataName(str, Enum):
     def __repr__(self):
         return repr(self.value)
 
+    def __str__(self):
+        return str(self.value)
+
 
 class DataItem:
     name: DataName
diff --git a/srnemqtt/srne.py b/srnemqtt/srne.py
new file mode 100644
index 0000000..58b063e
--- /dev/null
+++ b/srnemqtt/srne.py
@@ -0,0 +1,103 @@
+# -*- coding: utf-8 -*-
+import struct
+from decimal import Decimal
+from functools import cached_property
+from typing import Optional
+
+from .interfaces import BaseInterface
+from .protocol import (
+    parse_battery_state,
+    parse_historical_entry,
+    readMemory,
+    try_read_parse,
+    writeMemory,
+)
+from .solar_types import DataName
+
+
+class Srne:
+    _dev: BaseInterface
+
+    def __init__(self, dev: BaseInterface) -> None:
+        self._dev = dev
+
+    def get_historical_entry(self, day: Optional[int] = None) -> dict:
+        address = 0x010B
+        words = 21
+        if day is not None:
+            address = 0xF000 + day
+        res = try_read_parse(self._dev, address, words, parse_historical_entry)
+
+        if res is None:
+            raise TimeoutError("Timeout reading historical entry")
+        return res
+
+    def run_days(self) -> int:
+        return self.get_historical_entry()["run_days"]
+
+    def get_battery_state(self) -> dict:
+        data = try_read_parse(self._dev, 0x0100, 11, parse_battery_state)
+
+        if data is None:
+            raise TimeoutError("Timeout reading battery state")
+
+        data[DataName.CALCULATED_BATTERY_POWER] = float(
+            Decimal(str(data.get(DataName.BATTERY_VOLTAGE, 0)))
+            * Decimal(str(data.get(DataName.BATTERY_CURRENT, 0)))
+        )
+        data[DataName.CALCULATED_PANEL_POWER] = float(
+            Decimal(str(data.get(DataName.PANEL_VOLTAGE, 0)))
+            * Decimal(str(data.get(DataName.PANEL_CURRENT, 0)))
+        )
+        data[DataName.CALCULATED_LOAD_POWER] = float(
+            Decimal(str(data.get(DataName.LOAD_VOLTAGE, 0)))
+            * Decimal(str(data.get(DataName.LOAD_CURRENT, 0)))
+        )
+        return data
+
+    @cached_property
+    def model(self) -> str:
+        data = readMemory(self._dev, address=0x000C, words=8)
+        if data is None:
+            raise TimeoutError("Timeout reading model")
+
+        return data.decode().strip()
+
+    @cached_property
+    def version(self) -> str:
+        data = readMemory(self._dev, address=0x0014, words=2)
+        if data is None:
+            raise TimeoutError("Timeout reading version")
+
+        return "{}.{}.{}".format(*struct.unpack("!HBB", data))
+
+    @cached_property
+    def serial(self) -> str:
+        data = readMemory(self._dev, address=0x0018, words=2)
+        if data is None:
+            raise TimeoutError("Timeout reading serial")
+
+        return "{:02n}-{:02n}-{:04n}".format(*struct.unpack("!BBH", data))
+
+    @property
+    def load_enabled(self) -> bool:
+        data = readMemory(self._dev, address=0x010A)
+        if data is None:
+            raise TimeoutError("Timeout reading serial")
+
+        return bool(struct.unpack("!xB", data)[0])
+
+    def enable_load(self, enable: bool) -> bool:
+        data = writeMemory(self._dev, 0x010A, bytes((0, enable)))
+        if data is None:
+            raise TimeoutError("Timeout reading serial")
+        print(data)
+        return bool(struct.unpack("!xB", data)[0])
+
+    @cached_property
+    def name(self) -> str:
+        data = readMemory(self._dev, address=0x0049, words=16)
+        if data is None:
+            raise TimeoutError("Timeout reading name")
+
+        return data.decode("utf-16be").strip()