from __future__ import annotations

from dataclasses import dataclass
from typing import Dict, List, Optional


def contraction_margin(gamma: float, k: float, L_kappa: float, R_inf: float, tau: float, C_tau: float = 1.0) -> float:
    rho = gamma + (k / 4.0) * L_kappa * R_inf + C_tau * tau
    return 1.0 - rho


def summarize_training(gamma: float, k: float, L_kappa_hat: float, R_inf: float, tau_hat: float) -> Dict[str, float]:
    margin = contraction_margin(gamma, k, L_kappa_hat, R_inf, tau_hat)
    return {
        "theory/margin": margin,
        "theory/gamma": gamma,
        "theory/k": k,
        "theory/L_kappa_hat": L_kappa_hat,
        "theory/R_inf": R_inf,
        "theory/tau_hat": tau_hat,
    }


@dataclass
class MarginTracker:
    window: int = 100
    margins: List[float] = None                

    def __post_init__(self):
        if self.margins is None:
            self.margins = []

    def update(self, value: float):
        self.margins.append(float(value))
        if len(self.margins) > self.window:
            self.margins.pop(0)

    def stats(self) -> Dict[str, float]:
        if not self.margins:
            return {"margin/mean": 0.0, "margin/min": 0.0, "margin/max": 0.0}
        m = self.margins
        return {"margin/mean": sum(m) / len(m), "margin/min": min(m), "margin/max": max(m)}


@dataclass
class RiskMonitor:
    kappa_mean: List[float] = None                
    sigma_mean: List[float] = None                
    cost_ema: List[float] = None                
    window: int = 100

    def __post_init__(self):
        self.kappa_mean = [] if self.kappa_mean is None else self.kappa_mean
        self.sigma_mean = [] if self.sigma_mean is None else self.sigma_mean
        self.cost_ema = [] if self.cost_ema is None else self.cost_ema

    def update(self, kappa: float, sigma: float, cost: float):
        self.kappa_mean.append(float(kappa))
        self.sigma_mean.append(float(sigma))
        self.cost_ema.append(float(cost))
        for arr in (self.kappa_mean, self.sigma_mean, self.cost_ema):
            if len(arr) > self.window:
                arr.pop(0)

    def stats(self) -> Dict[str, float]:
        def _agg(x: List[float]) -> Dict[str, float]:
            if not x:
                return {"mean": 0.0, "min": 0.0, "max": 0.0}
            return {"mean": sum(x) / len(x), "min": min(x), "max": max(x)}

        out = {}
        km = _agg(self.kappa_mean)
        sm = _agg(self.sigma_mean)
        ce = _agg(self.cost_ema)
        for k, v in km.items():
            out[f"kappa/{k}"] = v
        for k, v in sm.items():
            out[f"sigma/{k}"] = v
        for k, v in ce.items():
            out[f"cost_ema/{k}"] = v
        return out


def aggregate_dicts(*ds: Dict[str, float]) -> Dict[str, float]:
    out: Dict[str, float] = {}
    for d in ds:
        out.update(d)
    return out
                
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
