from __future__ import annotations

from typing import Any, Dict, Optional, Tuple, Callable

import numpy as np
import torch
from torch import nn



def mse(r: torch.Tensor) -> torch.Tensor:
    return (r * r).mean()


def t_objective_d(r: torch.Tensor, nu: Any, lam: Any, eps: float = 1e-12) -> torch.Tensor:
    nu_t = torch.as_tensor(nu, dtype=r.dtype, device=r.device)
    lam_t = torch.as_tensor(lam, dtype=r.dtype, device=r.device)
    val = 0.5 * (nu_t + 1.0) * torch.log1p((lam_t * (r * r)) / (nu_t + eps))
    return val.mean()


def estep_w_d(r: torch.Tensor, nu: Any, lam: Any, eps: float = 1e-12) -> torch.Tensor:
    nu_t = torch.as_tensor(nu, dtype=r.dtype, device=r.device)
    lam_t = torch.as_tensor(lam, dtype=r.dtype, device=r.device)
    return (nu_t + 1.0) / (nu_t + lam_t * (r * r) + eps)


def snll_term_loss(resid: torch.Tensor, nu: Any, lam: Any) -> torch.Tensor:
    return t_objective_d(resid, nu=nu, lam=lam)


def wls_term_loss(resid: torch.Tensor, nu: Any, lam: Any, eps: float = 1e-12) -> torch.Tensor:
    # Literal translation of your JAX snippet, but likely not what you intended mathematically.
    nu_t = torch.as_tensor(nu, dtype=resid.dtype, device=resid.device)
    lam_t = torch.as_tensor(lam, dtype=resid.dtype, device=resid.device)
    loss = 0.5 * (nu_t + 1.0) * torch.exp(1.0 + (lam_t * (resid * resid)) / (nu_t + eps))
    return loss.mean()



def build_objectives(config, loss_keys):
    obj = getattr(config, "objectives", None) or {}

    terms = dict(getattr(obj, "terms", None) or obj.get("terms", {}) or {})
    st_cfg = getattr(obj, "student_t", None) or obj.get("student_t", {}) or {}
    init = getattr(st_cfg, "init", None) or st_cfg.get("init", {}) or {}

    nu0 = dict(getattr(init, "nu", None) or init.get("nu", {}) or {})
    lam0 = dict(getattr(init, "lam", None) or init.get("lam", {}) or {})

    for k in loss_keys:
        terms.setdefault(k, "mse")
        nu0.setdefault(k, 5.0)
        lam0.setdefault(k, 1.0)

    init_st_params = {"nu": nu0, "lam": lam0}
    return terms, init_st_params


def group_by_prefix(losses: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
    grouped: Dict[str, torch.Tensor] = {}
    for k, v in losses.items():
        prefix = k.split("_", 1)[0]
        grouped[prefix] = grouped.get(prefix, 0.0) + v
    return grouped


def residual_stats_logging(residuals: torch.Tensor, qs: Tuple[float, ...] = (0.01, 0.05)) -> Dict[str, float]:
    from scipy.stats import skew, kurtosis

    res = residuals.detach().cpu().reshape(-1).numpy()
    mean = float(np.mean(res))
    std = float(np.std(res, ddof=1)) if res.size > 1 else float(np.std(res))

    skewness_bias = float(skew(res, bias=False)) if res.size > 2 else float("nan")
    kurtosis_bias = float(kurtosis(res, bias=False)) if res.size > 3 else float("nan")
    skewness = float(skew(res)) if res.size > 2 else float("nan")
    kurt = float(kurtosis(res)) if res.size > 3 else float("nan")

    rms = float(np.sqrt(np.mean(res**2))) if res.size else float("nan")
    max_abs = float(np.max(np.abs(res))) if res.size else float("nan")
    max_to_rms = float(max_abs / rms) if (rms and rms > 0) else float("nan")

    energy = res**2
    total_energy = float(np.sum(energy))

    energy_conc: Dict[str, float] = {}
    n = res.size
    if total_energy > 0 and n > 0:
        abs_res = np.abs(res)
        for q in qs:
            qf = float(q)
            k = max(int(np.ceil(qf * n)), 1)
            idx = np.argpartition(abs_res, -k)[-k:]
            energy_conc[f"top_{int(qf*100)}pct_energy"] = float(np.sum(energy[idx]) / total_energy)
    else:
        for q in qs:
            energy_conc[f"top_{int(float(q)*100)}pct_energy"] = float("nan")

    out = dict(
        mean=mean,
        std=std,
        skewness_bias=skewness_bias,
        kurtosis_bias=kurtosis_bias,
        skewness=skewness,
        kurtosis=kurt,
        rms=rms,
        max_abs=max_abs,
        max_to_rms=max_to_rms,
    )
    out.update(energy_conc)
    return out


class BaseMonitorMixin:
    def init_monitor(self):
        self.log_dict: Dict[str, Any] = {}

    def _log_losses(self, batch: Any, *args):
        losses = self.leaf_losses(batch, *args)
        for k, v in losses.items():
            self.log_dict[k + "_loss"] = v.detach()

        log_sigma = getattr(self, "log_sigma", None)
        if log_sigma is not None:
            for k, p in log_sigma.items():
                self.log_dict["scale_" + k] = torch.exp(-p.detach())

    def _log_weights(self):
        weights = getattr(self, "weights", None)
        if weights is None:
            return
        for k, v in weights.items():
            if torch.is_tensor(v):
                self.log_dict["w_" + k] = v.detach()
            else:
                self.log_dict["w_" + k] = float(v)

    def _leaf_grads(self, batch: Any, *args, network_only: bool = True) -> Dict[str, torch.Tensor]:
        grouped = group_by_prefix(self.leaf_losses(batch, *args))
        params = list(self.model.parameters()) if network_only else [p for p in self.parameters() if p.requires_grad]
        out: Dict[str, torch.Tensor] = {}

        for k, Lk in grouped.items():
            g = torch.autograd.grad(Lk, params, retain_graph=True, create_graph=False, allow_unused=True)
            vecs = []
            for p, gi in zip(params, g):
                vecs.append((torch.zeros_like(p) if gi is None else gi).reshape(-1))
            out[k] = torch.cat(vecs) if vecs else torch.empty(0, device=Lk.device)
        return out

    def _log_leaf_grad(self, batch: Any, *args, eps: float = 1e-12):
        g = self._leaf_grads(batch, *args, network_only=True)
        keys = list(g.keys())
        if not keys:
            return

        weights = getattr(self, "weights", None)
        if weights is None:
            w = {k: torch.ones((), device=next(self.model.parameters()).device) for k in keys}
        else:
            w = {k: (weights[k] if torch.is_tensor(weights[k]) else torch.tensor(float(weights[k]), device=g[keys[0]].device))
                 for k in keys if k in weights}
            keys = [k for k in keys if k in w]
            if not keys:
                return

        def alignment_score(vectors: torch.Tensor, eps_: float = 1e-8) -> torch.Tensor:
            n = vectors.shape[0]
            norms = torch.linalg.norm(vectors, dim=1, keepdim=True)
            v = vectors / (norms + eps_)
            s = v.sum(dim=0)
            return (2.0 / (n * n)) * s.square().sum() - 1.0

        wflat = {k: w[k].to(g[k].dtype) * g[k] for k in keys}
        g_total = torch.stack([wflat[k] for k in keys], dim=0).sum(dim=0)

        denom = (g_total @ g_total) + eps
        self.log_dict["grads/norm_total"] = torch.linalg.norm(g_total).detach()
        self.log_dict["grads/align_score"] = alignment_score(torch.stack([wflat[k] for k in keys], dim=0)).detach()

        for k in keys:
            self.log_dict[f"grads/norm_{k}"] = torch.linalg.norm(g[k]).detach()
            self.log_dict[f"grads/contrib_{k}"] = ((wflat[k] @ g_total) / denom).detach()

        for i in range(len(keys)):
            for j in range(i + 1, len(keys)):
                a, b = keys[i], keys[j]
                dot = g[a] @ g[b]
                denom2 = torch.linalg.norm(g[a]) * torch.linalg.norm(g[b]) + eps
                self.log_dict[f"grads/cos_{a}_{b}"] = (dot / denom2).detach()

    def _log_stats(self, batch: Any, *args):
        residuals = self.residuals(batch, *args)
        pde = residuals["res"]
        for k, v in residual_stats_logging(pde, qs=(0.01, 0.05)).items():
            self.log_dict[f"stats/pde_{k}"] = v

        ic = residuals.get("ics", None)
        if ic is not None:
            self.log_dict["stats/ic_var"] = ic.var().detach()
            self.log_dict["stats/ic_mean"] = ic.mean().detach()
            self.log_dict["stats/ic_max"] = ic.max().detach()
            self.log_dict["stats/ic_min"] = ic.min().detach()

        bc = residuals.get("bcs", None)
        if bc is not None:
            self.log_dict["stats/bc_var"] = bc.var().detach()
            self.log_dict["stats/bc_mean"] = bc.mean().detach()
            self.log_dict["stats/bc_max"] = bc.max().detach()
            self.log_dict["stats/bc_min"] = bc.min().detach()

    def log(self, batch: Any, *args) -> Dict[str, Any]:
        self.log_dict = {}
        self._log_losses(batch, *args)

        cfg = getattr(self.config, "logging", None)
        if cfg is not None and getattr(cfg, "log_weights", False):
            self._log_weights()
        if cfg is not None and getattr(cfg, "log_grads", False):
            self._log_leaf_grad(batch, *args)
        if cfg is not None and getattr(cfg, "log_stats", False):
            self._log_stats(batch, *args)

        return self.log_dict


class PINN(nn.Module):
    loss_keys: Tuple[str, ...] = ("ics", "bcs", "res")

    def __init__(self, config: Any, model: nn.Module, device: Optional[torch.device] = None):
        super().__init__()
        self.config = config
        self.device = device if device is not None else torch.device("cpu")

        self.model = model.to(self.device)
        self.term_mode, st_params = build_objectives(self.config, self.loss_keys)
        self.st_params = st_params  # fixed dicts for now: {"nu": {...}, "lam": {...}}
        #print(self.st_params)

        self.weights: Dict[str, torch.Tensor] = {}
        for k in sorted({kk.split("_", 1)[0] for kk in self.loss_keys}):
            self.weights[k] = torch.tensor(1.0, device=self.device)

        self._term_loss_fn: Dict[str, Callable[[torch.Tensor], torch.Tensor]] = {}
        self._init_term_losses()
        
    def _init_term_losses(self) -> None:
        self._term_loss_fn: Dict[str, Callable[[torch.Tensor, str], torch.Tensor]] = {}

        for k in self.loss_keys:
            mode = str(self.term_mode.get(k, "mse")).lower()

            if mode == "mse":
                self._term_loss_fn[k] = (lambda e, _k=k: mse(e))

            elif mode == "snll":
                def _snll(e: torch.Tensor, _k=k) -> torch.Tensor:
                    nu = self.st_params["nu"][_k]
                    lam = self.st_params["lam"][_k]
                    return snll_term_loss(e, nu=nu, lam=lam)

                self._term_loss_fn[k] = _snll
                print(f"Using dynamic SNLL for term {k} with nu={self.st_params['nu'][k]} lam={self.st_params['lam'][k]}")

            elif mode == "wls":
                def _wls(e: torch.Tensor, _k=k) -> torch.Tensor:
                    nu = self.st_params["nu"][_k]
                    lam = self.st_params["lam"][_k]
                    return wls_term_loss(e, nu=nu, lam=lam)

                self._term_loss_fn[k] = _wls
                print(f"Using dynamic WLS for term {k} with nu={self.st_params['nu'][k]} lam={self.st_params['lam'][k]}")

            elif mode == "em":
                if getattr(self, "st_handler", None) is None:
                    raise ValueError("objective mode 'em' requested but no EM handler was built")

                def _em(e: torch.Tensor, _k=k) -> torch.Tensor:
                    return self.st_handler.term_loss(_k, e, self)

                self._term_loss_fn[k] = _em

            else:
                raise ValueError(f"Unknown objective mode '{mode}' for term '{k}'")


    def residuals(self, batch: Any, *args) -> Dict[str, torch.Tensor]:
        raise NotImplementedError

    def leaf_losses(self, batch: Any, *args) -> Dict[str, torch.Tensor]:
        err = self.residuals(batch, *args)
        return {k: self._term_loss_fn[k](err[k]) for k in self.loss_keys}

    def loss(self, batch: Any, *args) -> torch.Tensor:
        leaf = self.leaf_losses(batch, *args)
        grouped = group_by_prefix(leaf)

        total = None
        for k, Lk in grouped.items():
            wk = self.weights[k].to(Lk.dtype)
            v = wk * Lk
            total = v if total is None else (total + v)
        return total if total is not None else torch.zeros((), device=self.device)

        #return total if total is not None else torch.zeros((), device=self.device)

    def step(self, batch: Any, optimizer: torch.optim.Optimizer, *args) -> torch.Tensor:
        optimizer.zero_grad(set_to_none=True)
        L = self.loss(batch, *args)
        L.backward()
        optimizer.step()
        return L.detach()


class IVP(PINN, BaseMonitorMixin):
    def __init__(self, config: Any, model: nn.Module, device: Optional[torch.device] = None):
        PINN.__init__(self, config, model, device=device)
        BaseMonitorMixin.init_monitor(self)

        self.t_star = None
        self.x_star = None
        self.u_ref = None

    def u_net(self, t: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError

    @torch.no_grad()
    def u_pred_grid(self) -> torch.Tensor:
        if self.t_star is None or self.x_star is None:
            raise AttributeError("IVP requires t_star and x_star to be set by subclass.")
        tt, xx = torch.meshgrid(self.t_star, self.x_star, indexing="ij")
        return self.u_net(tt, xx)

    def compute_l2_error(self) -> torch.Tensor:
        #if self.u_ref is None:
        #    raise AttributeError("IVP requires u_ref to be set by subclass.")
        u_pred = self.u_pred_grid()
        return torch.linalg.norm(u_pred - self.u_ref) / torch.linalg.norm(self.u_ref)

    def compute_rmae(self) -> torch.Tensor:
        #if self.u_ref is None:
        #    raise AttributeError("IVP requires u_ref to be set by subclass.")
        u_pred = self.u_pred_grid()
        return torch.sum(torch.abs(u_pred - self.u_ref)) / torch.sum(torch.abs(self.u_ref))

    def log_errors(self):
        self.log_dict["rmse_error"] = self.compute_l2_error().detach()
        self.log_dict["rmae_error"] = self.compute_rmae().detach()

    def log(self, batch: Any, *args) -> Dict[str, Any]:
        out = BaseMonitorMixin.log(self, batch, *args)
        self.log_errors()
        return out


class BVP(PINN):
    pass
