import time
from typing import Dict
import docker
import prometheus_client
from prometheus_client.core import REGISTRY, CounterMetricFamily, InfoMetricFamily, StateSetMetricFamily

class ContainerCollector(object):
    docker_client: docker.Client

    def __init__(self, docker_client: docker.Client):
        self.docker_client = docker_client

    def collect(self):
        nw_counters: Dict[str, CounterMetricFamily] = {}
        for stat in ["bytes", "packets", "errors", "dropped"]:
            for rt_short in ("rx", "tx"):
                rt_long = "receive" if rt_short == "rx" else "transmit"
                key = f"{rt_short}_{stat}"
                nw_counters[key] = CounterMetricFamily(
                    f"node_network_{rt_long}_{stat}_total",
                    f"Docker container stats network.{key}",
                    labels=["container", "device"],
                )

        m_container_status = InfoMetricFamily("docker_container_info", "Container info.", labels=["container"])
        
        for container in self.docker_client.containers(all=True):
            container_name = container.get("Names", [""])[0].strip("/")
            container_id = container.get("Id")
            container_status = container.get('State','').lower()
            
            m_container_status.add_metric([container_name], dict(status=container_status))
            
            #print(container_name, container_id)
            stats = self.docker_client.stats(container_id, stream=False)
            # print(stats)
            for interface, ifstats in stats.get("networks", {}).items():
                #print(interface, ifstats)

                for stat, value in ifstats.items():
                    if stat in nw_counters:
                        nw_counters[stat].add_metric(
                            [container_name, interface],
                            value,
                        )
        yield m_container_status
        for metric in nw_counters.values():
            yield metric


REGISTRY.register(ContainerCollector(docker.Client()))
prometheus_client.start_http_server(9101)


print("Started")
try:
    while True:
        time.sleep(10)
except KeyboardInterrupt:
    pass