import torch
import torch.distributed as dist
from torch.utils.tensorboard import SummaryWriter
from typing import Dict, Literal, Union
import math
import logging

NormType = Literal["batch_avg", "frame_avg", "sample_avg", "sum"]
Number = Union[int, float]

class MetricsTracker:
    def __init__(self):
        self._values: Dict[str, float] = {}
        self._norms: Dict[str, NormType] = {}

    def set_value(self, key: str, value: Number, normalization: NormType = "batch_avg"):
        """Set a pre-normalized metric value."""
        self._values[key] = float(value)
        self._norms[key] = normalization

    def update(self, other: "MetricsTracker", reset_interval: int = -1):
        if reset_interval > 0:
            alpha = 1 - 1.0 / reset_interval
        else:
            alpha = 1

        # Step 1: update actual metrics
        for k, v in other._values.items():
            if k in ("batches", "frames", "samples"):
                continue

            if math.isnan(v) or math.isinf(v):
                logging.warning(f"[MetricsTracker] Invalid value in update('{k}'): {v}. Skipping.")
                continue
            
            norm = other._norms[k]
            self._norms[k] = norm

            if norm == "sum":
                self._values[k] = self._values.get(k, 0.0) + v
                
            elif norm == "batch_avg":
                prev_batches = self._values.get("batches", 0.0)
                new_batches = other._values.get("batches", 1.0)

                if k in self._values:
                    prev = self._values[k]
                    self._values[k] = (
                        alpha * prev * prev_batches + v * new_batches
                    ) / (alpha * prev_batches + new_batches)
                else:
                    self._values[k] = v


            elif norm == "frame_avg":
                prev_frames = self._values.get("frames", 0.0)
                new_frames = other._values.get("frames", 0.0)

                if new_frames > 0:
                    if k in self._values:
                        prev = self._values[k]
                        self._values[k] = (
                            alpha * prev * prev_frames + v * new_frames
                        ) / (alpha * prev_frames + new_frames)
                    else:
                        self._values[k] = v

            elif norm == "sample_avg":
                prev_samples = self._values.get("samples", 0.0)
                new_samples = other._values.get("samples", 0.0)

                if new_samples > 0:
                    if k in self._values:
                        prev = self._values[k]
                        self._values[k] = (
                            alpha * prev * prev_samples + v * new_samples
                        ) / (alpha * prev_samples + new_samples)
                    else:
                        self._values[k] = v


            else:
                raise ValueError(f"Unsupported normalization: {norm}")
            
        # Step 2: update global counts in the end
        # Always increment batches (default = 1)
        self._values["batches"] = alpha * self._values.get("batches", 0.0) + other._values.get("batches", 1.0)
        self._norms["batches"] = "sum"   
        for meta_key in ("frames", "samples"):
            if meta_key in other._values:
                self._values[meta_key] = alpha * self._values.get(meta_key, 0.0) + other._values[meta_key]
                self._norms[meta_key] = "sum"


    def write_summary(self, tb_writer: SummaryWriter, prefix: str, step: int):
        for k, v in self._values.items():
            tb_writer.add_scalar(f"{prefix}{k}", v, step)

    def reduce(self, device):
        """All-reduce values across DDP workers, using correct weighting for avg metrics."""

        if not dist.is_available() or not dist.is_initialized():
            return

        keys = sorted(self._values.keys())
        values_to_reduce = []

        for k in keys:
            v = self._values[k]
            norm = self._norms.get(k, "sum")
            if norm == "sample_avg":
                v *= self._values.get("samples", 1.0)
            elif norm == "frame_avg":
                v *= self._values.get("frames", 1.0)
            elif norm == "batch_avg":
                v *= self._values.get("batches", 1.0)
            # "sum" → use as-is
            values_to_reduce.append(v)

        tensor = torch.tensor(values_to_reduce, dtype=torch.float32, device=device)
        dist.all_reduce(tensor, op=dist.ReduceOp.SUM)

        # Store back and re-normalize
        for i, k in enumerate(keys):
            norm = self._norms.get(k, "sum")
            self._values[k] = tensor[i].item()

        # Now normalize
        for k, norm in self._norms.items():
            if k in ("samples", "frames", "batches"):
                continue
            if norm == "sample_avg":
                self._values[k] /= self._values.get("samples", 1.0)
            elif norm == "frame_avg":
                self._values[k] /= self._values.get("frames", 1.0)
            elif norm == "batch_avg":
                self._values[k] /= self._values.get("batches", 1.0)

    def __str__(self):
        return ", ".join(f"{self._norms[k]}_{k}={v:.4g}" for k, v in self._values.items())
