From d38abe28babc905e1c170cedcdd1c3c5c362f0c5 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Odd=20Str=C3=A5b=C3=B8?= <oddstr13@openshell.no>
Date: Sat, 16 Dec 2023 23:36:53 +0100
Subject: [PATCH] Move to python logging

---
 srnemqtt/__main__.py       | 29 ++++++++++---------
 srnemqtt/config.py         | 27 ++++++++++++++++-
 srnemqtt/consumers/mqtt.py | 27 +++++++++--------
 srnemqtt/protocol.py       | 59 +++++++++++++++++++++-----------------
 srnemqtt/util.py           | 15 ++++------
 5 files changed, 95 insertions(+), 62 deletions(-)

diff --git a/srnemqtt/__main__.py b/srnemqtt/__main__.py
index 9cd4316..2a520d7 100755
--- a/srnemqtt/__main__.py
+++ b/srnemqtt/__main__.py
@@ -2,13 +2,17 @@
 # -*- coding: utf-8 -*-
 
 import time
+from logging import getLogger
+from logging.config import dictConfig as loggingDictConfig
 
 from bluepy.btle import BTLEDisconnectError  # type: ignore
 from serial import SerialException  # type: ignore
 
 from .config import get_config, get_consumers, get_interface
 from .protocol import ChargeController
-from .util import Periodical, log
+from .util import Periodical
+
+logger = getLogger(__name__)
 
 
 class CommunicationError(BTLEDisconnectError, SerialException, IOError):
@@ -17,25 +21,24 @@ class CommunicationError(BTLEDisconnectError, SerialException, IOError):
 
 def main():
     conf = get_config()
+
+    loggingDictConfig(conf.get("logging", {}))
     consumers = get_consumers(conf)
 
     per_voltages = Periodical(interval=15)
     per_current_hist = Periodical(interval=60)
-    # import serial
-
-    # ser = serial.Serial()
 
     try:
         while True:
             try:
-                log("Connecting...")
+                logger.info("Connecting...")
                 with get_interface() as dev:
-                    log("Connected.")
+                    logger.info("Connected.")
 
                     cc = ChargeController(dev)
-                    log(f"Controller model: {cc.model}")
-                    log(f"Controller version: {cc.version}")
-                    log(f"Controller serial: {cc.serial}")
+                    logger.info(f"Controller model: {cc.model}")
+                    logger.info(f"Controller version: {cc.version}")
+                    logger.info(f"Controller serial: {cc.serial}")
                     for consumer in consumers:
                         consumer.controller = cc
 
@@ -59,7 +62,7 @@ def main():
                     for i in range(min(days, 4)):
                         hist = cc.get_historical(i)
                         res = hist.as_dict()
-                        log({i: res})
+                        logger.debug({i: res})
                         for consumer in consumers:
                             consumer.write({str(i): res})
 
@@ -68,14 +71,14 @@ def main():
 
                         if per_voltages(now):
                             data = cc.state.as_dict()
-                            log(data)
+                            logger.debug(data)
                             for consumer in consumers:
                                 consumer.write(data)
 
                         if per_current_hist(now):
                             data = cc.today.as_dict()
                             data.update(cc.extra.as_dict())
-                            log(data)
+                            logger.debug(data)
                             for consumer in consumers:
                                 consumer.write(data)
 
@@ -91,7 +94,7 @@ def main():
                     #    write(wd, CMD_ENABLE_LOAD)
 
             except CommunicationError:
-                log("ERROR: Disconnected")
+                logger.error("Disconnected")
                 time.sleep(1)
 
     except (KeyboardInterrupt, SystemExit, Exception) as e:
diff --git a/srnemqtt/config.py b/srnemqtt/config.py
index fd5ec8c..b82357b 100644
--- a/srnemqtt/config.py
+++ b/srnemqtt/config.py
@@ -27,6 +27,29 @@ def get_config() -> Dict[str, Any]:
     with open("config.yaml", "r") as fh:
         conf: dict = yaml.safe_load(fh)
         conf.setdefault("consumers", {})
+        logging = conf.setdefault("logging", {})
+        logging.setdefault("version", 1)
+        logging.setdefault("disable_existing_loggers", False)
+        logging.setdefault(
+            "handlers",
+            {
+                "console": {
+                    "class": "logging.StreamHandler",
+                    "formatter": "default",
+                    "level": "INFO",
+                    "stream": "ext://sys.stdout",
+                }
+            },
+        )
+        logging.setdefault(
+            "formatters",
+            {
+                "format": "%(asctime)s %(levelname)-8s %(name)-15s %(message)s",
+                "datefmt": "%Y-%m-%d %H:%M:%S",
+            },
+        )
+        loggers = logging.setdefault("loggers", {})
+        loggers.setdefault("root", {"handlers": ["console"], "level": "DEBUG"})
 
         return conf
 
@@ -34,7 +57,9 @@ def get_config() -> Dict[str, Any]:
 def write_config(conf: Dict[str, Any]):
     with open(".config.yaml~writing", "w") as fh:
         yaml.safe_dump(conf, fh, indent=2, encoding="utf-8")
-    os.rename(".config.yaml~writing", "config.yaml")
+        fh.flush()
+        os.fsync(fh.fileno())
+    os.replace(".config.yaml~writing", "config.yaml")
 
 
 def get_consumers(conf: Optional[Dict[str, Any]] = None) -> List[BaseConsumer]:
diff --git a/srnemqtt/consumers/mqtt.py b/srnemqtt/consumers/mqtt.py
index db53002..7a29e1c 100644
--- a/srnemqtt/consumers/mqtt.py
+++ b/srnemqtt/consumers/mqtt.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 import json
+from logging import getLogger
 from time import sleep
 from typing import Any, Dict, List, Optional, TypeAlias
 from uuid import uuid4
@@ -9,6 +10,8 @@ import paho.mqtt.client as mqtt
 from ..solar_types import DataName
 from . import BaseConsumer
 
+logger = getLogger(__name__)
+
 MAP_VALUES: Dict[DataName, Dict[str, Any]] = {
     # DataName.BATTERY_VOLTAGE_MIN: {},
     # DataName.BATTERY_VOLTAGE_MAX: {},
@@ -161,8 +164,11 @@ class MqttConsumer(BaseConsumer):
                 elif err.errno == -3:
                     pass
                 else:
+                    logger.exception("Unknown error connecting to mqtt server")
                     raise
-                print(err)
+                logger.warning(
+                    "Temporary failure connecting to mqtt server", exc_info=True
+                )
                 sleep(0.1)
         return self._client
 
@@ -245,7 +251,7 @@ class MqttConsumer(BaseConsumer):
     # The callback for when the client receives a CONNACK response from the server.
     @staticmethod
     def on_connect(client: mqtt.Client, userdata: "MqttConsumer", flags, rc):
-        print("Connected with result code " + str(rc))
+        logger.info("MQTT connected with result code %s", rc)
 
         # Subscribing in on_connect() means that if we lose the connection and
         # reconnect then subscriptions will be renewed.
@@ -263,10 +269,7 @@ class MqttConsumer(BaseConsumer):
         client: mqtt.Client, userdata: "MqttConsumer", message: mqtt.MQTTMessage
     ):
         assert userdata.controller is not None
-        print(message)
-        print(message.info)
-        print(message.state)
-        print(message.payload)
+        logger.debug(message.payload)
         payload = message.payload.decode().upper() in ("ON", "TRUE", "ENABLE", "YES")
 
         res = userdata.controller.load_enabled = payload
@@ -276,29 +279,29 @@ class MqttConsumer(BaseConsumer):
 
     @staticmethod
     def on_connect_fail(client: mqtt.Client, userdata: "MqttConsumer"):
-        print(userdata.__class__.__name__, "on_connect_fail")
+        logger.warning("on_connect_fail")
 
     # The callback for when a PUBLISH message is received from the server.
     @staticmethod
     def on_message(client, userdata, msg):
-        print(msg.topic + " " + str(msg.payload))
+        logger.info(msg.topic + " " + str(msg.payload))
 
     @staticmethod
     def on_disconnect(client: mqtt.Client, userdata: "MqttConsumer", rc, prop=None):
-        print(userdata.__class__.__name__, "on_disconnect", rc)
+        logger.warning("on_disconnect %s", rc)
 
     def poll(self):
         res = self.client.loop(timeout=0.1, max_packets=5)
 
         if res != mqtt.MQTT_ERR_SUCCESS:
-            print(self.__class__.__name__, "loop returned non-success:", res)
+            logger.warning("loop returned non-success: %s", res)
             try:
                 sleep(1)
                 res = self.client.reconnect()
                 if res != mqtt.MQTT_ERR_SUCCESS:
-                    print(self.__class__.__name__, "Reconnect failed:", res)
+                    logger.error("Reconnect failed: %s", res)
             except (OSError, mqtt.WebsocketConnectionError) as err:
-                print(self.__class__.__name__, "Reconnect failed:", err)
+                logger.error("Reconnect failed: %s", err)
 
         return super().poll()
 
diff --git a/srnemqtt/protocol.py b/srnemqtt/protocol.py
index bdb1c97..5b66ea6 100644
--- a/srnemqtt/protocol.py
+++ b/srnemqtt/protocol.py
@@ -1,8 +1,8 @@
 # -*- coding: utf-8 -*-
 import struct
-import sys
 import time
-from typing import Callable, Collection, Optional
+from logging import getLogger
+from typing import Callable, Collection, List, Optional
 
 from libscrc import modbus  # type: ignore
 
@@ -16,7 +16,8 @@ from .solar_types import (
     HistoricalData,
     HistoricalExtraInfo,
 )
-from .util import log
+
+logger = getLogger(__name__)
 
 
 def write(fh, data):
@@ -92,15 +93,13 @@ def discardUntil(fh: BaseInterface, byte: int, timeout=10) -> Optional[int]:
         return b[0]
 
     start = time.time()
-    discarded = 0
+    discarded: List[str] = []
     read_byte = expand(fh.read(1))
     while read_byte != byte:
         if read_byte is not None:
             if not discarded:
-                log("Discarding", end="")
-            discarded += 1
-            print(f" {read_byte:02X}", end="")
-            sys.stdout.flush()
+                discarded.append("Discarding")
+            discarded.append(f"{read_byte:02X}")
 
         if time.time() - start > timeout:
             read_byte = None
@@ -109,8 +108,7 @@ def discardUntil(fh: BaseInterface, byte: int, timeout=10) -> Optional[int]:
         read_byte = expand(fh.read(1))
 
     if discarded:
-        print()
-        sys.stdout.flush()
+        logger.debug(" ".join(discarded))
 
     return read_byte
 
@@ -134,14 +132,18 @@ def readMemory(fh: BaseInterface, address: int, words: int = 1) -> Optional[byte
             try:
                 crc = struct.unpack_from("<H", _crc)[0]
             except struct.error:
-                log(f"readMemory: CRC error; read {len(_crc)} bytes (2 expected)")
+                logger.error(
+                    "readMemory: CRC error; read %s bytes (2 expected)", len(_crc)
+                )
                 return None
             calculated_crc = modbus(bytes([tag, operation, size, *data]))
             if crc == calculated_crc:
                 return data
             else:
-                log(f"readMemory: CRC error; {crc:04X} != {calculated_crc:04X}")
-        log("data or crc is falsely", header, data, _crc)
+                logger.error(
+                    f"readMemory: CRC error; {crc:04X} != {calculated_crc:04X}"
+                )
+        logger.error("data or crc is falsely %s %s %s", header, data, _crc)
     return None
 
 
@@ -166,7 +168,7 @@ def writeMemory(fh: BaseInterface, address: int, data: bytes):
     header = fh.read(3)
     if header and len(header) == 3:
         operation, size, address = header
-        log(header)
+        logger.log(5, header)
         # size field is zero when writing device name for whatever reason
         # write command seems to only accept a single word, so this is fine;
         # we just hardcode the number of bytes read to two here.
@@ -176,14 +178,18 @@ def writeMemory(fh: BaseInterface, address: int, data: bytes):
             try:
                 crc = struct.unpack_from("<H", _crc)[0]
             except struct.error:
-                log(f"writeMemory: CRC error; read {len(_crc)} bytes (2 expected)")
+                logger.error(
+                    f"writeMemory: CRC error; read {len(_crc)} bytes (2 expected)"
+                )
                 return None
             calculated_crc = modbus(bytes([tag, operation, size, address, *rdata]))
             if crc == calculated_crc:
                 return rdata
             else:
-                log(f"writeMemory: CRC error; {crc:04X} != {calculated_crc:04X}")
-        log("data or crc is falsely", header, rdata, _crc)
+                logger.error(
+                    f"writeMemory: CRC error; {crc:04X} != {calculated_crc:04X}"
+                )
+        logger.error("data or crc is falsely %s %s %s", header, rdata, _crc)
     return None
 
 
@@ -193,7 +199,6 @@ def writeMemoryMultiple(fh: BaseInterface, address: int, data: bytes):
     res = bytearray()
     for i in range(len(data) // 2):
         d = data[i * 2 : (i + 1) * 2]
-        log(address + i, d)
         r = writeMemory(fh, address + i, d)
         if r:
             res.extend(r)
@@ -214,15 +219,16 @@ def try_read_parse(
             try:
                 if parser:
                     return parser(res)
-            except struct.error as e:
-                log(e)
-                log("0x0100 Unpack error:", len(res), res)
+            except struct.error:
+                logger.exception("0x0100 Unpack error: %s %s", len(res), res)
                 _timeout = dev.timeout
                 dev.timeout = 0.5
-                log("Flushed from read buffer; ", dev.read())
+                logger.warning("Flushed from read buffer; %s", dev.read())
                 dev.timeout = _timeout
         else:
-            log(f"No data read, expected {words*2} bytes (attempts left: {attempts})")
+            logger.warning(
+                f"No data read, expected {words*2} bytes (attempts left: {attempts})"
+            )
     return None
 
 
@@ -308,7 +314,6 @@ class ChargeController:
         # Pad name to 32 bytes to ensure ensure nothing is left of old name
         while len(bin_value) < 32:
             bin_value.extend(b"\x00\x20")
-        print(len(bin_value), bin_value)
 
         data = writeMemoryMultiple(self.device, 0x0049, bin_value)
         if data is None:
@@ -316,7 +321,7 @@ class ChargeController:
 
         res = data.decode("UTF-16BE").strip()
         if res != value:
-            log(f"setting device name failed; {res!r} != {value!r}")
+            logger.error("setting device name failed; %r != %r", res, value)
         self._cached_name = value
 
     @property
@@ -333,9 +338,9 @@ class ChargeController:
         if data is not None:
             res = struct.unpack("x?", data)[0]
             if res != value:
-                log(f"setting load_enabled failed; {res!r} != {value!r}")
+                logger.error("setting load_enabled failed; %r != %r", res, value)
         else:
-            log("setting load_enabled failed; communications error")
+            logger.error("setting load_enabled failed; communications error")
 
     @property
     def state(self) -> ChargerState:
diff --git a/srnemqtt/util.py b/srnemqtt/util.py
index b641e70..254e38b 100644
--- a/srnemqtt/util.py
+++ b/srnemqtt/util.py
@@ -1,14 +1,14 @@
 # -*- coding: utf-8 -*-
-
-import datetime
-import sys
 import time
+from logging import getLogger
 from typing import Optional
 
 # Only factor of 1000
 SI_PREFIXES_LARGE = "kMGTPEZY"
 SI_PREFIXES_SMALL = "mµnpfazy"
 
+logger = getLogger(__name__)
+
 
 def humanize_number(data, unit: str = ""):
     counter = 0
@@ -35,11 +35,6 @@ def humanize_number(data, unit: str = ""):
     return f"{data:.3g} {prefix}{unit}"
 
 
-def log(*message: object, **kwargs):
-    print(datetime.datetime.utcnow().isoformat(" "), *message, **kwargs)
-    sys.stdout.flush()
-
-
 class Periodical:
     prev: float
     interval: float
@@ -56,7 +51,9 @@ class Periodical:
             skipped, overshoot = divmod(now - self.prev, self.interval)
             skipped -= 1
             if skipped:
-                log("Skipped:", skipped, overshoot, now - self.prev, self.interval)
+                logger.debug(
+                    "Skipped:", skipped, overshoot, now - self.prev, self.interval
+                )
             self.prev = now - overshoot
             return True