from typing import Any

import numpy as np
import torch


class RunningStats:
    """Mergeable running stats (Welford / Chan) in float64."""
    __slots__ = ("n", "mean", "M2", "min", "max")
    def __init__(self):
        self.n = 0
        self.mean = 0.0
        self.M2 = 0.0
        self.min = float("inf")
        self.max = float("-inf")

    def update_batch(self, x: torch.Tensor):
        """x: 1D tensor; will be converted to float64. Ignores empty batches."""
        if x.numel() == 0:
            return

        # batch count, mean, M2 (population)
        nb = x.numel()
        mb = x.mean().item()
        # var with correction=0 gives population variance; handle nb==1 -> var=0
        if nb > 1:
            vb = x.var(correction=0).item()
            M2b = vb * nb
        else:
            M2b = 0.0

        # Merge (Chan et al. parallel variance)
        na = self.n
        if na == 0:
            self.n = nb
            self.mean = mb
            self.M2 = M2b
            self.min = x.min().item()
            self.max = x.max().item()
            return

        delta = mb - self.mean
        nt = na + nb
        self.mean = self.mean + delta * (nb / nt)
        self.M2 = self.M2 + M2b + (delta * delta) * (na * nb / nt)
        self.n = nt

        # min/max
        xb_min = x.min().item()
        xb_max = x.max().item()
        if xb_min < self.min: self.min = xb_min
        if xb_max > self.max: self.max = xb_max

    def finalize_row(
        self, 
        **kwargs
    ) -> dict[str, Any]:
        if self.n == 0:
            return {
                **kwargs, 
                "count": 0,
                "mean": float("nan"), 
                "std": float("nan"),
                "min": float("nan"), 
                "max": float("nan")
            }
        
        if self.n > 1:
            var_sample = self.M2 / (self.n - 1)
            std = var_sample ** 0.5
        else:
            std = float("nan")
        
        return {
            **kwargs, 
            "count": int(self.n),
            "mean": self.mean, 
            "std": std, 
            "min": self.min, 
            "max": self.max
        }
