from __future__ import annotations
import contextlib
import torch
from torch import Tensor
import torch.nn as nn
from typing import Any, Dict, List, Optional, Sequence, Tuple
import numpy as np
from .mgda import MGDASolver
from .utils import rank_normalize, zscore_normalize, EMATracker, RunningStats

class ControlGSensor:
    def __init__(self, K: int, device: str = "cpu", alpha: float = 0.5, beta: float = 0.5,
                 rho: float = 0.2, D_min: float = 0.0, D_max: float = 1.0, norm_type: str = "rank",
                 mgda_iters: int = 100, mgda_lr: float = 0.1, mgda_eps: float = 1e-8,
                 probe_num_nodes: int = 512, probe_num_edges: int = 2048, loss_ema_rho: Optional[float] = None):
        self.K, self.device = K, device
        self.alpha, self.beta, self.rho = alpha, beta, rho
        self.D_min, self.D_max, self.norm_type = D_min, D_max, norm_type
        self.probe_num_nodes, self.probe_num_edges = probe_num_nodes, probe_num_edges
        self.loss_ema_rho = float(rho if loss_ema_rho is None else loss_ema_rho)
        self.mgda = MGDASolver(K=K, max_iters=mgda_iters, lr=mgda_lr, eps=mgda_eps, device=device)
        self.loss_stats = RunningStats(K, device=device)
        self.loss_ema = EMATracker(K=K, alpha=self.loss_ema_rho, v_min=None, v_max=None, init_value=0.0, device=device)
        self.difficulty_ema = EMATracker(K=K, alpha=rho, v_min=D_min, v_max=D_max, init_value=0.5, device=device)
        self.last_RQ, self.last_Conf, self.last_D = None, None, None
        self.last_losses, self.last_normalized_losses, self.last_smoothed_losses = None, None, None
        self._full_graph, self._full_deg = None, None
    
    def _normalize_across_tasks(self, values: Tensor) -> Tensor:
        return rank_normalize(values) if self.norm_type == "rank" else zscore_normalize(values)
    
    def set_full_graph(self, graph: Any) -> None:
        self._full_graph = graph
        deg = None
        with contextlib.suppress(Exception): deg = graph.out_degrees()
        if deg is None:
            with contextlib.suppress(Exception): deg = graph.in_degrees()
        self._full_deg = deg.to(self.device).clamp(min=1).to(torch.float32) if deg is not None else None
    
    def compute_spectral_demand(self, rep_grads: Tensor, edge_index: Tensor, degrees: Optional[Tensor] = None, eps: float = 1e-8) -> float:
        if edge_index.numel() == 0: return 0.0
        num_nodes = rep_grads.shape[0]
        if degrees is None:
            degrees = torch.zeros(num_nodes, device=rep_grads.device, dtype=torch.float32)
            degrees.scatter_add_(0, edge_index[0].long(), torch.ones(edge_index.shape[1], device=rep_grads.device))
            degrees = degrees.clamp(min=1.0)
        else:
            degrees = degrees.to(device=rep_grads.device, dtype=torch.float32).clamp(min=1.0)
        h_normalized = rep_grads * (1.0 / torch.sqrt(degrees + eps)).unsqueeze(1)
        src, dst = edge_index[0].long(), edge_index[1].long()
        diff = h_normalized[src] - h_normalized[dst]
        dirichlet = torch.sum(diff ** 2)
        h_norm_sq = torch.sum(rep_grads ** 2)
        return float((dirichlet / (h_norm_sq + eps)).item())
    
    @staticmethod
    def _gram_from_param_grads(param_grads: Sequence[Sequence[Optional[Tensor]]]) -> Tensor:
        K = len(param_grads)
        if K == 0: return torch.zeros((0, 0))
        device = None
        for g in param_grads[0]:
            if g is not None: device = g.device; break
        device = device or torch.device("cpu")
        G = torch.zeros((K, K), device=device)
        with torch.no_grad():
            for i in range(K):
                for j in range(i, K):
                    dot = torch.tensor(0.0, device=device)
                    for gi, gj in zip(param_grads[i], param_grads[j]):
                        if gi is not None and gj is not None: dot = dot + torch.sum(gi * gj)
                    G[i, j] = G[j, i] = dot
        return G
    
    def sense(self, model: nn.Module, sample_dict: Dict[str, Any], tasks: List[str], opt: Any,
              rng: np.random.RandomState, encoder_params: Optional[List[nn.Parameter]] = None) -> Dict[str, Any]:
        K, device = len(tasks), torch.device(self.device)
        if encoder_params is None:
            encoder_params = list(model.big_model.node_module.parameters()) if hasattr(model, "big_model") else list(model.parameters())
        losses = torch.full((K,), float("nan"), device=device)
        RQ = torch.zeros(K, device=device)
        param_grads_per_task = []
        was_training = bool(getattr(model, "training", False))
        model.eval()
        with torch.enable_grad():
            for k, task in enumerate(tasks):
                if task not in sample_dict:
                    param_grads_per_task.append(tuple([None for _ in encoder_params])); continue
                captured_z = []
                def _hook(_m, _i, out):
                    if torch.is_tensor(out): captured_z.append(out)
                handle = model.big_model.register_forward_hook(_hook) if hasattr(model, "big_model") else None
                try:
                    out = model({task: sample_dict[task]}, opt)
                    loss = out.get(task) if isinstance(out, dict) else None
                finally:
                    if handle: handle.remove()
                if loss is None:
                    param_grads_per_task.append(tuple([None for _ in encoder_params])); continue
                losses[k] = loss.detach()
                self.loss_stats.update(k, float(loss.detach().item()))
                grads = torch.autograd.grad(loss, list(captured_z) + list(encoder_params), retain_graph=False, allow_unused=True)
                param_grads_per_task.append(tuple(grads[len(captured_z):]))
        model.train(was_training)
        losses = torch.nan_to_num(losses, nan=0.0)
        G = self._gram_from_param_grads(param_grads_per_task).to(device)
        lam, Conf = self.mgda.solve_from_gram(G)
        RQ_norm, Conf_norm = self._normalize_across_tasks(RQ), self._normalize_across_tasks(Conf)
        raw_D = self.alpha * RQ_norm + self.beta * Conf_norm
        D = self.difficulty_ema.update_all(raw_D)
        norm_losses = torch.tensor([self.loss_stats.normalize(k, float(losses[k].item())) for k in range(K)], device=device)
        smoothed_losses = self.loss_ema.update_all(norm_losses)
        self.last_RQ, self.last_Conf, self.last_D = RQ.clone(), Conf.clone(), D.clone()
        self.last_losses, self.last_normalized_losses, self.last_smoothed_losses = losses.clone(), norm_losses.clone(), smoothed_losses.clone()
        return {"losses": losses, "normalized_losses": norm_losses, "smoothed_losses": smoothed_losses,
                "RQ": RQ, "Conf": Conf, "D": D, "mgda_lambda": lam, "gram": G}
    
    def get_difficulty(self) -> Tensor:
        return self.last_D.clone() if self.last_D is not None else torch.full((self.K,), 0.5, device=self.device)
    
    def get_normalized_losses(self) -> Tensor:
        return self.last_smoothed_losses.clone() if self.last_smoothed_losses is not None else torch.zeros(self.K, device=self.device)
