from typing import Dict, Tuple
import torch
import torch.nn.functional as F
import torch.linalg as LA
from torch_geometric.data import Batch
import networkarch as NA

def _active_submatrix(M: torch.Tensor, eps: float = 1e-12):
    row_energy = M.abs().sum(dim=-1)
    idx = (row_energy > eps).nonzero(as_tuple=False).squeeze(1)
    if idx.numel() >= 1 and idx.numel() < M.size(0):
        Mb = M.index_select(0, idx).index_select(1, idx)
        return Mb
    return M

def _active_indices_pair(L: torch.Tensor, C: torch.Tensor, eps: float = 1e-12):
    rowL = L.abs().sum(dim=-1) > eps
    rowC = C.abs().sum(dim=-1) > eps
    mask = rowL | rowC
    idx = mask.nonzero(as_tuple=False).squeeze(1)
    if idx.numel() == 0:
        n = min(L.size(0), C.size(0))
        if n >= 1:
            idx = torch.arange(1, device=L.device)
    return idx


def _eigvals_eigvecs_hermitian(M: torch.Tensor):
    try:
        w, U = torch.linalg.eigh(0.5*(M + M.transpose(-1, -2)))
        return w.to(M.dtype), U
    except Exception:
        e, U = torch.linalg.eig(M)
        return e, U

def _get_koopman_eigs(K, device):
    try:
        lam = torch.linalg.eigvals(K)  # [d]
    except Exception:
        lam = torch.linalg.eigvals(K.to(torch.complex64))
    return lam

def _safe(t: torch.Tensor, clamp: float = 1e6):
    if t is None:
        return None
    t = torch.nan_to_num(t, nan=0.0, posinf=clamp, neginf=-clamp)
    return torch.clamp(t, min=-clamp, max=clamp)

def _spec_over_limit_penalty(sig_hist, sig_seq, limit: float) -> torch.Tensor:
    def _to_list(x):
        if x is None: return []
        if isinstance(x, (list, tuple)): return [t for t in x if isinstance(t, torch.Tensor)]
        if isinstance(x, torch.Tensor): return [x]
        return []

    buckets = _to_list(sig_hist) + _to_list(sig_seq)
    if len(buckets) == 0:
        return torch.tensor(0.0)

    vals = []
    for t in buckets:
        t = _safe(t)
        if t.dim() == 0: t = t.view(1)

        over = torch.relu(t - float(limit))
        vals.append((over ** 2).mean())
    return torch.stack(vals).mean()

def _rollout_loss(pred_seq: torch.Tensor, tgt_seq: torch.Tensor, horizon: int, gamma: float = 0.97):
    T = min(pred_seq.size(1), tgt_seq.size(1), horizon)
    w = pred_seq.new_tensor([gamma**k for k in range(T)])
    w = w / w.sum()
    err = (pred_seq[:, :T] - tgt_seq[:, :T]).pow(2).sum(dim=-1).mean(dim=0)  # [T]
    return (err * w).sum()


def compute_graph_matrices_batched(batch, gk_aux, device):
    gk_aux = gk_aux or {}

    def _to_batched(mat, B):
        if mat is None:
            return None
        mat = _safe(mat).to(device)
        return mat.unsqueeze(0).expand(B, *mat.shape) if mat.dim() == 2 else mat

    if isinstance(batch, Batch):
        node_batch = batch.batch
        B = int(node_batch.max().item()) + 1 if node_batch.numel() > 0 else 1
        node_idx_lists = [(node_batch == b).nonzero(as_tuple=False).squeeze(1) for b in range(B)]
    else:
        B = 1
        node_idx_lists = [torch.arange(batch.x.size(0), device=device)]

    if isinstance(gk_aux, dict) and ('L' in gk_aux and 'C' in gk_aux):
        L = _to_batched(gk_aux['L'], B)
        C = _to_batched(gk_aux['C'], B)
        if L is not None and C is not None:
            return L, C

    if isinstance(gk_aux, dict) and ('L_list' in gk_aux and 'C_list' in gk_aux):
        L_list = [ _safe(M).to(device) for M in gk_aux['L_list'] ]
        C_list = [ _safe(M).to(device) for M in gk_aux['C_list'] ]
        if len(L_list) > 0 and len(L_list) == len(C_list):
            return torch.stack(L_list, 0), torch.stack(C_list, 0)

    def _build_L(ei, w, nodes):
        n = nodes.numel()
        Lb = torch.zeros((n, n), device=device)
        if ei is None or ei.numel() == 0 or w is None or w.numel() == 0:
            return Lb
        idmap = {int(nodes[i]): int(i) for i in range(n)}
        src, dst = ei[0], ei[1]
        mask = torch.isin(src, nodes)
        e_src, e_dst = src[mask].tolist(), dst[mask].tolist()
        w_b = w[mask].tolist()
        for u, v, ww in zip(e_src, e_dst, w_b):
            if u in idmap and v in idmap:
                iu, iv = idmap[u], idmap[v]
                ww = float(ww)
                Lb[iu, iu] += ww
                Lb[iv, iv] += ww
                Lb[iu, iv] -= ww
                Lb[iv, iu] -= ww
        Lb = 0.5 * (Lb + Lb.T)
        Lb = Lb + 1e-6 * torch.eye(n, device=device)
        return Lb

    def _build_C(ei, s, nodes):
        n = nodes.numel()
        Cb = torch.zeros((n, n), device=device)
        if ei is None or ei.numel() == 0 or s is None or s.numel() == 0:
            return Cb
        idmap = {int(nodes[i]): int(i) for i in range(n)}
        src, dst = ei[0], ei[1]
        mask = torch.isin(src, nodes)
        e_src, e_dst = src[mask].tolist(), dst[mask].tolist()
        s_b = s[mask].tolist()
        for u, v, ss in zip(e_src, e_dst, s_b):
            if u in idmap and v in idmap:
                iu, iv = idmap[u], idmap[v]
                ss = float(ss)
                Cb[iu, iv] += ss
                Cb[iv, iu] -= ss
        Cb = 0.5 * (Cb - Cb.T)
        return Cb

    use_edges = ('w_diff' in gk_aux) or ('w_adv' in gk_aux)
    if use_edges:
        ei_diff = getattr(batch, 'edge_index_diff', None)
        ei_adv  = getattr(batch, 'edge_index_adv', None)
        w_diff = gk_aux.get('w_diff', None)
        w_adv  = gk_aux.get('w_adv', None)
        max_n = max(idx.numel() for idx in node_idx_lists) if len(node_idx_lists) else 0
        if max_n > 0:
            Ls, Cs = [], []
            for b in range(B):
                nodes_b = node_idx_lists[b]
                Lb = _build_L(ei_diff, w_diff, nodes_b)
                Cb = _build_C(ei_adv,  w_adv,  nodes_b)
                padL = torch.zeros((max_n, max_n), device=device); padL[:Lb.size(0), :Lb.size(1)] = Lb
                padC = torch.zeros((max_n, max_n), device=device); padC[:Cb.size(0), :Cb.size(1)] = Cb
                Ls.append(padL); Cs.append(padC)
            return torch.stack(Ls, 0), torch.stack(Cs, 0)


    d_model = 64
    eye = torch.eye(d_model, device=device)
    L = eye.unsqueeze(0).expand(B, d_model, d_model) * 0.1
    C = eye.unsqueeze(0).expand(B, d_model, d_model) * 0.05
    return L, C


def compute_graph_matrices_single(data, gk_aux, device):
    d_model = 64
    L = torch.eye(d_model, device=device) * 0.1
    C = torch.eye(d_model, device=device) * 0.05
    return L, C


def mse_loss_safe_batched(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    if target is None:
        target = torch.zeros_like(pred)

    if pred.shape != target.shape:
        if pred.dim() == 3 and target.dim() == 2:
            target = target.unsqueeze(0).expand_as(pred)
        elif pred.dim() == 2 and target.dim() == 3:
            pred = pred.unsqueeze(0).expand_as(target)

    pred = _safe(pred)
    target = _safe(target)

    if pred.dim() == 3:  # [B, T, D]
        return ((pred - target) ** 2).mean(dim=[1, 2])  # [B]
    return F.mse_loss(pred, target)


class LossComputer:
    def __init__(self, cfg: Dict):
        self.cfg = cfg
        self.hist_steps=cfg.get('history_steps',75)+1
        self.w_micro = cfg.get("w_micro", 1.0)
        self.w_macro = cfg.get("w_macro", 1.0)
        self.w_recon = cfg.get("w_recon", 0.1)
        self.w_jad = cfg.get("w_jad", 0.1)
        self.w_spec = cfg.get("w_spec", 0.1)
        self.w_recon=cfg.get('w_recon', 0.1)
        self.w_recon_micro=cfg.get('w_recon_micro', self.w_recon)
        self.w_recon_macro=cfg.get('w_recon_macro', self.w_recon)

        self.lam_diff_edge_l1 = cfg.get("lam_diff_edge_l1", 1e-4)
        self.lam_adv_edge_l1 = cfg.get("lam_adv_edge_l1", 1e-4)
        self.lam_alpha_l1 = cfg.get("lam_alpha_l1", 1e-5)
        self.lam_beta_l1 = cfg.get("lam_beta_l1", 1e-5)
        self.lam_iss = cfg.get("lam_iss", 1.0)
        self.delta_iss = cfg.get("delta_iss", 0.05)

        self.k_max = float(cfg.get("k_max", 0.90))
        self.b_max = float(cfg.get("b_max", 0.80))
        self.w_specK = float(cfg.get("w_specK", 0.01))
        self.w_specB = float(cfg.get("w_specB", 0.01))

        self.w_intent_entropy = float(cfg.get("w_intent_entropy", 1e-3))
        self.w_intent_div = float(cfg.get("w_intent_div", 5e-3))
        self.w_intent_smooth = float(cfg.get("w_intent_smooth", 1e-2))
        self.w_intent_switch = float(cfg.get("w_intent_switch", 1e-2))

    def _intent_regularizers(self, aux, device):
        it = aux.get("intent", {})
        p_hist = it.get("hist", None)
        p_fut = it.get("future", None)

        zero = torch.tensor(0.0, device=device)

        def _entropy(p):
            eps = 1e-8
            return -(p.clamp_min(eps) * (p.clamp_min(eps)).log()).sum(dim=-1)

        H_hist = _entropy(p_hist).mean() if isinstance(p_hist, torch.Tensor) else zero
        H_fut = _entropy(p_fut).mean() if isinstance(p_fut, torch.Tensor) else zero
        L_ent = -(H_hist + H_fut) * 0.5

        def _div_kl(p):
            if not isinstance(p, torch.Tensor): return zero
            q = p.mean(dim=(0, 1))  # [M]
            M = q.numel()
            u = torch.full_like(q, 1.0 / M)
            eps = 1e-8
            return (q.clamp_min(eps) * (q.clamp_min(eps) / u).log()).sum()

        L_div = _div_kl(p_hist) + _div_kl(p_fut)

        if isinstance(p_fut, torch.Tensor) and p_fut.size(1) > 1:
            diff = p_fut[:, 1:, :] - p_fut[:, :-1, :]
            L_smooth = (diff.pow(2).mean())
        else:
            L_smooth = zero

        if isinstance(p_fut, torch.Tensor) and p_fut.size(1) > 1:
            m = p_fut.argmax(dim=-1)
            sw = (m[:, 1:] != m[:, :-1]).float().mean()
            L_switch = sw
        else:
            L_switch = zero

        return L_ent, L_div, L_smooth, L_switch

    def _intent_target_entropy(self, aux, device, H_target=0.8):
        it = aux.get("intent")
        p_hist = it.get("hist")
        p_fut = it.get("future")

        def H(p):
            eps = 1e-8;
            return -(p.clamp_min(eps) * p.clamp_min(eps).log()).sum(dim=-1).mean()

        Hh = H(p_hist)
        Hf = H(p_fut)
        L = 0.5 * ((F.relu(Hh - H_target)) ** 2 + (F.relu(Hf - H_target)) ** 2)
        return L

    def _compute_intent_loss(self,aux,batch,device):
        it=aux.get("intent")
        top1=it.get('top1').unsqueeze(-1)
        B,T_h,_=top1.shape
        intent_label=batch.hist_intent.view(B,-1,1)[:,:-1,:]
        loss=torch.norm(intent_label-top1,dim=-1).mean(dim=1)
        return loss


    def _compute_trajectory_loss_batched(self, y_vk_135, batch, device):
        B, T, _ = y_vk_135.shape
        future_traj=batch.future_traj.view(B,-1,2)
        #idx=torch.tensor([24, 74, 124], device=device, dtype=torch.long)
        fps = int(round(1.0 / float(self.cfg.get('dt', 0.04))))
        #idx = torch.tensor([1 * fps/stride- 1, 3 * fps /stride- 1, 5 * fps /stride- 1],
        #                   device=device, dtype=torch.long)
        # idx = torch.tensor([0.8 * fps  - 1, 1.6 * fps - 1, 2.4 * fps - 1, 3.2 * fps  - 1, 4 * fps  - 1, 4.8 * fps- 1],
        #                    device=device, dtype=torch.long)
        #print('idx:',idx)
        idx=torch.tensor([23,71,119])
        future_traj=future_traj.index_select(1,idx)
        anchor_xy = batch.vk_feat.view(B,-1,10)[:,-1,:2]
        pred_pos=anchor_xy.unsqueeze(1)+y_vk_135
        se=(pred_pos-future_traj).pow(2).sum(dim=-1).mean(dim=0).sqrt()
        #print(se)
        ade=torch.norm(pred_pos-future_traj,dim=2).mean(dim=1)
        return ade,se

    def _compute_macro_loss_batched(self, y_gk_seq, batch, device):
        B,T,_ = y_gk_seq.shape
        density=batch.density_seq.view(B,-1,1)
        anchor_density=density[:,0]
        future_density=density[:,1:]
        pred_density=anchor_density.unsqueeze(1)+y_gk_seq.cumsum(dim=1)
        per=(pred_density-future_density).abs().div(future_density).mean(dim=0).squeeze(-1)
        ade=torch.norm(pred_density-future_density,dim=2).mean(dim=1)
        return ade,per

    def _compute_recon_losses(self, aux, device):
        rec = aux.get('recon', {})
        latent_1s = rec.get('latent_1s', {})

        z_pred = latent_1s.get('z_pred', None)
        z_tgt = latent_1s.get('z_tgt', None)
        Z_pred = latent_1s.get('Z_pred', None)
        Z_tgt = latent_1s.get('Z_tgt', None)

        def _mse_time(pred, tgt):
            if pred is None or tgt is None:
                return torch.tensor(0.0, device=device)
            if pred.numel() == 0 or tgt.numel() == 0:
                return torch.tensor(0.0, device=device)

            pred = torch.nan_to_num(pred, nan=0.0, posinf=0.0, neginf=0.0)
            tgt = torch.nan_to_num(tgt, nan=0.0, posinf=0.0, neginf=0.0)
            if pred.dim() == 2:  # [N_hist, D] -> [1, N_hist, D]
                pred = pred.unsqueeze(0)
                tgt = tgt.unsqueeze(0)
            return ((pred - tgt) ** 2).mean(dim=(1, 2))  # [B]

        L_micro_recon = _mse_time(z_pred, z_tgt)
        L_macro_recon = _mse_time(Z_pred, Z_tgt)
        return L_micro_recon, L_macro_recon

    def _compute_decode_losses(self,aux,batch,device):
        rec=aux.get('recon',{})
        lag1s_result=rec.get('lag1s',{})
        pred=lag1s_result.get('y_delta_seq',{})
        B,T,_=pred.shape
        src_idx = lag1s_result['src_idx'].to(pred.device)  # [N_hist]
        tgt_idx = lag1s_result['tgt_idx'].to(pred.device)  # [N_hist]

        pos = batch.vk_feat.view(B, -1, 10)[..., :2].to(pred.device)  # [B, T_hist, 2]
        pos_src = pos.index_select(1, src_idx)  # [B, N_hist, 2]
        pos_tgt = pos.index_select(1, tgt_idx)  # [B, N_hist, 2]
        tgt_delta = pos_tgt - pos_src  # [B, N_hist, 2]

        loss = F.smooth_l1_loss(pred, tgt_delta, beta=1.0)
        return loss

    def _sparsity_regs(self, gk_aux: dict, model) -> torch.Tensor:
        try:
            device = next(model.parameters()).device
        except Exception:
            device = torch.device("cpu")

        lam_diff_edge_l1 = float(self.cfg.get("lam_diff_edge_l1", 1e-4))
        lam_adv_edge_l1 = float(self.cfg.get("lam_adv_edge_l1", 1e-4))

        reg = torch.tensor(0.0, device=device)

        aux = gk_aux if isinstance(gk_aux, dict) else {}
        w_diff = aux.get("w_diff", None)  # [E_diff]
        w_adv = aux.get("w_adv", None)  # [E_adv]
        reg += lam_diff_edge_l1 * w_diff.abs().mean()
        reg += lam_adv_edge_l1 * w_adv.abs().mean()

        return reg

    def __call__(self, outputs, batch, aux, model):
        y_gk_135, y_vk_135 = outputs
        device = y_vk_135.device

        # Compute per-sample losses
        L_micro,se = self._compute_trajectory_loss_batched(y_vk_135, batch, device)  # [B]
        [rmse1,rmse3,rmse5]=se
        L_macro,per = self._compute_macro_loss_batched(y_gk_135, batch, device)  # [B]
        [per1,per3,per5]=per
        L_micro_rec,L_macro_rec=self._compute_recon_losses(aux,device)
        L_decode=self._compute_decode_losses(aux,batch,device)

        L_intent=self._compute_intent_loss(aux,batch,device)

        # JAD, Spectrum, ISS
        Ls, Cs = compute_graph_matrices_batched(batch, aux.get('gk_aux', {}), device)
        # Ensure batched [B,n,n]
        if Ls.dim() == 2:
            Ls = Ls.unsqueeze(0)
        if Cs.dim() == 2:
            Cs = Cs.unsqueeze(0)

        comm = torch.matmul(Ls, Cs) - torch.matmul(Cs, Ls)
        L_jad_per = _safe(comm).pow(2).mean(dim=(1, 2))  # [B]
        L_jad = L_jad_per.mean()

        K_hist=aux['stability']["K_hist"]#[B,T,D,D]
        dt = 1.0
        lam = _get_koopman_eigs(K_hist, device)
        K_eff=K_hist.mean(dim=1)

        per_graph_losses = []
        diag_sizes = []
        for b in range(Ls.size(0)):
            idx = _active_indices_pair(Ls[b], Cs[b])
            Lb = Ls[b].index_select(0, idx).index_select(1, idx)
            Cb = Cs[b].index_select(0, idx).index_select(1, idx)
            if Lb.numel() == 0 or Cb.numel() == 0:
                continue

            wL, U = _eigvals_eigvecs_hermitian(Lb)

            U_H = torch.conj(U.transpose(-1, -2))
            Ct = U_H @ (Cb.to(U.dtype)) @ U

            eta = wL.real

            if torch.is_complex(Ct):
                diagCt=torch.diagonal(Ct, dim1=-2, dim2=-1)
                xi=(-diagCt.imag).real
            else:
                xi=torch.linalg.vector_norm(Ct,ord=2,dim=-1)

            lam_b=_get_koopman_eigs(K_eff[b],device)
            mag_b=lam_b.abs().clamp_min(1e-8)
            ang_b=torch.atan2(lam_b.imag, lam_b.real)
            loglam_re_b=torch.log(mag_b)
            loglam_im_b=ang_b

            k_modes=min(eta.numel(),loglam_re_b.numel())
            if k_modes==0:
                continue
            idx_eta = torch.argsort(eta)[:k_modes]
            eta_k = eta.index_select(0, idx_eta)
            xi_k  = xi.index_select(0, idx_eta)

            idx_lam = torch.argsort(mag_b, descending=True)[:k_modes]
            logre_k = loglam_re_b.index_select(0, idx_lam)
            logim_k = loglam_im_b.index_select(0, idx_lam)



            c_val = model.phys_params.get('c', None)
            nu_val = model.phys_params.get('nu', None)

            res_re = (logre_k + (nu_val * eta_k * dt))
            res_im = (logim_k + (c_val  * xi_k  * dt))

            per_graph_losses.append((res_re.pow(2).mean() + res_im.pow(2).mean()))
            diag_sizes.append(float(k_modes))

        L_spec = torch.stack(per_graph_losses).mean()

        eigC = torch.linalg.eigvals(Cs)
        rhoC = _safe(eigC.abs())
        max_rho = rhoC.max(dim=-1).values
        margin = 1.0 - float(self.delta_iss)
        iss_violation = F.relu(max_rho - margin)
        pen_iss = (self.lam_iss * iss_violation.mean()).to(device)

        reg = self._sparsity_regs(aux.get('gk_aux', {}), model)

        stability_aux  = aux.get('stability', {})

        rho_hist=stability_aux.get('rho_hist',None)
        rho_fut=stability_aux.get('rho_fut',None)
        kdiag_hist=stability_aux.get('kdiag_hist',None)
        kdiag_fut=stability_aux.get('kdiag_fut',None)

        pen_specK = _spec_over_limit_penalty(rho_hist, rho_fut, self.k_max).to(device)
        pen_specB = _spec_over_limit_penalty(kdiag_hist, kdiag_fut, self.b_max).to(device)
        L_specKB = self.w_specK * pen_specK + self.w_specB * pen_specB

        L_Ht=self._intent_target_entropy(aux,device,H_target=0.8)
        L_ent, L_div, L_smooth, L_switch = self._intent_regularizers(aux, device)

        total = (self.w_micro * L_micro.mean() +self.w_recon_micro * L_micro_rec.mean() +
                 self.w_macro * L_macro.mean() +self.w_recon_macro * L_macro_rec.mean() +L_decode+
                 self.w_jad * L_jad +L_intent.mean()+
                 self.w_spec * L_spec +
                 reg + pen_iss+L_specKB+L_Ht*(5e-3)
                 # self.w_intent_entropy * L_ent +
                 # self.w_intent_div * L_div +
                 # self.w_intent_smooth * L_smooth +
                 # self.w_intent_switch * L_switch+
                 )

        # Logging
        logs = {
            "loss/total": total.detach().item(),
            "pred/ADE_rel": L_micro.mean().detach().item(),
            "pred/macro": L_macro.mean().detach().item(),
            'pred/RMSE1':rmse1.detach().item(),
            'pred/RMSE3':rmse3.detach().item(),
            'pred/RMSE5':rmse5.detach().item(),
            'pred/PER1':per1.detach().item(),
            'pred/PER3':per3.detach().item(),
            'pred/PER5':per5.detach().item(),
            "recon/micro": L_micro_rec.mean().detach().item(),
            "recon/macro": L_macro_rec.mean().detach().item(),
            "recon/decode": L_decode.mean().detach().item(),
            "recon/intent":L_intent.mean().detach().item(),

            "specK/pen": (self.w_specK * pen_specK).detach().item(),
            "specB/pen": (self.w_specB * pen_specB).detach().item(),

            "specK/mean_hist": (rho_hist.mean().detach().item()),
            "specB/mean_hist": (kdiag_hist.mean().detach().item() ),
            "specK/mean_fut": (rho_fut.mean().detach().item()),
            "specB/mean_fut": (kdiag_fut.mean().detach().item()),

            "intent/neg_entropy": L_ent.detach().item(),
            "intent/div_kl": L_div.detach().item(),
            "intent/Ht":L_Ht.mean().detach().item(),
        }

        return total, logs