import json
import re
from typing import Optional, Sequence, Union, Any, Dict, List, Tuple

import numpy as np
import torch
import torch.nn.functional as F

from models.utils import CausalLMOutputWithPastExtended
from utils.tasks import CausalQATask


class CausalCriterion:
    """
    Graph reconstruction loss + sequence LM loss (+ optional slot/node alignment + NOTEARS DAG regularizer)
    """

    def __init__(
        self,
        lambda_text_weight: float,
        lambda_consistency_weight: float,
        lambda_graph_weight: float,
        lambda_node_q_weight: float,
        lambda_node_token_weight: float,
        lambda_graph_alighnment_weight: float,
        temp_node: float,
        temp_graph: float,
        exp_tasks: Optional[list[CausalQATask]],
        # ----- Hungarian matching -----
        enable_hungarian: bool,
        hungarian_match_on: str,  # "q"  or "node_tokens"
        hungarian_fallback_to_node_tokens: bool,
        # ----- NOTEARS DAG regularizer -----
        notears_enable: bool,
        notears_weight: float,
        notears_warmup_steps: int,
        notears_update_every: int,
        notears_alpha_init: float,
        notears_rho_init: float,
        notears_rho_max: float,
        notears_rho_mult: float,
        notears_zero_diag: bool,
        notears_mask_invalid_nodes: bool,
        notears_grad_clip: Optional[float],
    ):
        total = lambda_text_weight + lambda_consistency_weight

        if exp_tasks is None:
            self.exp_tasks = []
        else:
            self.exp_tasks = exp_tasks

        if CausalQATask.SUPERVISION in self.exp_tasks:
            total += lambda_graph_weight

        if CausalQATask.ALIGNMENT in self.exp_tasks:
            total += lambda_graph_alighnment_weight
            total += lambda_node_q_weight
            total += lambda_node_token_weight

        # normalize main loss weights
        self.lambda_text = lambda_text_weight / total if total > 0 else 0.0
        self.lambda_consistency = lambda_consistency_weight / total if total > 0 else 0.0
        self.lambda_graph = lambda_graph_weight / total if total > 0 else 0.0
        self.lambda_node_q = lambda_node_q_weight / total if total > 0 else 0.0
        self.lambda_node_token = lambda_node_token_weight / total if total > 0 else 0.0
        self.lambda_graph_alignment = lambda_graph_alighnment_weight / total if total > 0 else 0.0

        self.temp_node = temp_node
        self.temp_graph = temp_graph

        # ----- Hungarian -----
        self.enable_hungarian = bool(enable_hungarian)
        self.hungarian_match_on = str(hungarian_match_on)
        self.hungarian_fallback_to_node_tokens = bool(hungarian_fallback_to_node_tokens)

        # ----- NOTEARS -----
        self.notears_enable = bool(notears_enable)
        self.notears_weight = float(notears_weight)
        self.notears_warmup_steps = int(notears_warmup_steps)
        self.notears_update_every = int(notears_update_every)
        self.notears_alpha = float(notears_alpha_init)
        self.notears_rho = float(notears_rho_init)
        self.notears_rho_max = float(notears_rho_max)
        self.notears_rho_mult = float(notears_rho_mult)
        self.notears_zero_diag = bool(notears_zero_diag)
        self.notears_mask_invalid_nodes = bool(notears_mask_invalid_nodes)
        self.notears_grad_clip = None if notears_grad_clip is None else float(notears_grad_clip)

        # print("CausalCriterion initialized with weights:")
        # print(f"  lambda_text: {self.lambda_text}")
        # print(f"  lambda_consistency: {self.lambda_consistency}")
        # if CausalQATask.SUPERVISION in self.exp_tasks:
        #     print(f"  lambda_graph: {self.lambda_graph}")
        # if CausalQATask.ALIGNMENT in self.exp_tasks:
        #     print(f"  lambda_node_q: {self.lambda_node_q}")
        #     print(f"  lambda_node_token: {self.lambda_node_token}")
        #     print(f"  lambda_graph_alignment: {self.lambda_graph_alignment}")
        # print(f"  temp_node: {self.temp_node}")
        # print(f"  temp_graph: {self.temp_graph}")

        # print("Extra options:")
        # print(f"  enable_hungarian: {self.enable_hungarian}")
        # print(f"  hungarian_match_on: {self.hungarian_match_on}")
        # print(f"  hungarian_fallback_to_node_tokens: {self.hungarian_fallback_to_node_tokens}")
        # print(f"  notears_enable: {self.notears_enable}")
        # if self.notears_enable:
        #     print(f"    notears_weight: {self.notears_weight}")
        #     print(f"    notears_warmup_steps: {self.notears_warmup_steps}")
        #     print(f"    notears_update_every: {self.notears_update_every}")
        #     print(f"    notears_alpha_init: {notears_alpha_init}")
        #     print(f"    notears_rho_init: {notears_rho_init}")
        #     print(f"    notears_rho_mult: {self.notears_rho_mult}")
        #     print(f"    notears_rho_max: {self.notears_rho_max}")
        #     print(f"    notears_zero_diag: {self.notears_zero_diag}")
        #     print(f"    notears_mask_invalid_nodes: {self.notears_mask_invalid_nodes}")
        #     print(f"    notears_grad_clip: {self.notears_grad_clip}")

    # ------------------------- helpers -------------------------

    @staticmethod
    def _normalize(x: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
        return x / (x.norm(dim=-1, keepdim=True) + eps)

    @staticmethod
    def _apply_perm_vec(x: torch.Tensor, perm: torch.Tensor) -> torch.Tensor:
        """
        x:    (B, D, H)
        perm: (B, D)  -- for each GT index j, perm[b,j] gives the predicted slot index i.
        return: (B, D, H) where out[b,j] = x[b, perm[b,j]]
        """
        if x is None:
            return None
        B, D, H = x.shape
        idx = perm.unsqueeze(-1).expand(B, D, H)
        return x.gather(dim=1, index=idx)

    @staticmethod
    def _apply_perm_mat(A: torch.Tensor, perm: torch.Tensor) -> torch.Tensor:
        """
        A:    (B, D, D)
        perm: (B, D)
        return: (B, D, D) where out[b,j,k] = A[b, perm[b,j], perm[b,k]]
        """
        if A is None:
            return None
        B, D, _ = A.shape
        row_idx = perm.unsqueeze(-1).expand(B, D, D)
        A2 = A.gather(dim=1, index=row_idx)
        col_idx = perm.unsqueeze(1).expand(B, D, D)
        A3 = A2.gather(dim=2, index=col_idx)
        return A3

    @staticmethod
    def _derive_node_mask_from_A_mask(A_mask: torch.Tensor) -> torch.Tensor:
        """
        Fallback when node_mask isn't provided:
        node is valid if it participates in any valid edge slot (row/col sum > 0) OR diag is valid.
        A_mask: (B, D, D) 0/1 or bool
        return: (B, D) bool
        """
        if A_mask is None:
            return None
        m = A_mask.float()
        row = m.sum(dim=-1)
        col = m.sum(dim=-2)
        return (row + col) > 0

    def _hungarian_perm(
        self,
        pred_for_match: torch.Tensor,  # (B, Dp, H)
        node_text: torch.Tensor,       # (B, Dg, H) in GT order
        node_mask: torch.Tensor,       # (B, Dg) bool
    ) -> torch.Tensor:
        """
        Compute a per-sample permutation perm[b] of length Dg, mapping GT index -> predicted slot index,
        using Hungarian assignment on cosine similarity.

        We return a FULL permutation of length D (Dg==Dp==d_max in your setup) by filling unmatched
        GT positions (invalid nodes) with remaining predicted indices (stable order).

        If SciPy is unavailable, we fall back to a simple greedy matching (still one-to-one, but not optimal).
        """
        try:
            from scipy.optimize import linear_sum_assignment  # type: ignore
        except Exception:
            linear_sum_assignment = None

        device = pred_for_match.device
        B, Dp, H = pred_for_match.shape
        Dg = node_text.size(1)

        # cosine sim
        g = self._normalize(pred_for_match.float())          # (B, Dp, H)
        t = self._normalize(node_text.float())               # (B, Dg, H)
        sim = torch.matmul(g, t.transpose(1, 2))             # (B, Dp, Dg)

        perm = torch.zeros(B, Dg, device=device, dtype=torch.long)

        for b in range(B):
            valid_js = torch.nonzero(node_mask[b], as_tuple=False).squeeze(-1).tolist()
            n_valid = len(valid_js)

            if n_valid == 0:
                perm[b] = torch.arange(Dg, device=device, dtype=torch.long)
                continue

            used_pred = set()
            assigned_gt = torch.zeros(Dg, device=device, dtype=torch.bool)

            if linear_sum_assignment is not None:
                # cost: rows = valid GT nodes, cols = predicted slots
                cost = (-sim[b, :, valid_js].transpose(0, 1)).detach().cpu().numpy()  # (n_valid, Dp)
                row_ind, col_ind = linear_sum_assignment(cost)
                for r, c in zip(row_ind.tolist(), col_ind.tolist()):
                    j = valid_js[int(r)]
                    perm[b, j] = int(c)
                    used_pred.add(int(c))
                    assigned_gt[j] = True
            else:
                # Greedy fallback: for each GT node, pick best unused predicted slot
                # (process in descending max-sim order to reduce collisions)
                # scores: (n_valid, Dp)
                scores = sim[b, :, valid_js].transpose(0, 1).detach()  # (n_valid, Dp)
                # order GT nodes by their best possible match score
                best_scores, _ = scores.max(dim=1)
                order = torch.argsort(best_scores, descending=True).tolist()
                for rr in order:
                    j = valid_js[int(rr)]
                    sc = scores[rr].clone()
                    if used_pred:
                        sc[list(used_pred)] = -1e9
                    c = int(torch.argmax(sc).item())
                    perm[b, j] = c
                    used_pred.add(c)
                    assigned_gt[j] = True

            remaining = [i for i in range(Dp) if i not in used_pred]
            rem_ptr = 0
            for j in range(Dg):
                if not bool(assigned_gt[j]):
                    if rem_ptr < len(remaining):
                        perm[b, j] = int(remaining[rem_ptr])
                        rem_ptr += 1
                    else:
                        perm[b, j] = int(j % Dp)

        return perm

    def _node_infonce_in_sample(
        self,
        node_vec: torch.Tensor,   # (B, d_max, d_model)
        node_text: torch.Tensor,  # (B, d_max, d_model)
        node_mask: torch.Tensor   # (B, d_max)
    ) -> torch.Tensor:
        """
        For each sample in the batch, only take the valid nodes (according to node_mask) to compute InfoNCE loss.
        n < 2: become 1-cosine loss.
        """
        if node_vec is None:
            return torch.tensor(0.0, device=node_text.device, dtype=torch.float32)

        B, D, H = node_vec.shape
        node_vec = node_vec.float()
        node_text = node_text.float()

        g = self._normalize(node_vec)   # (B, d_max, d_model)
        t = self._normalize(node_text)  # (B, d_max, d_model)

        losses = []
        for b in range(B):
            idx = torch.nonzero(node_mask[b], as_tuple=False).squeeze(-1)
            n = int(idx.numel())
            if n == 0:
                continue
            if n == 1:
                i = idx[0]
                cos_loss = 1.0 - (g[b, i] * t[b, i]).sum()
                losses.append(cos_loss)
                continue

            gb = g[b, idx]  # (n, d_model)
            tb = t[b, idx]  # (n, d_model)
            sim = torch.matmul(gb, tb.t()) / self.temp_node  # (n, n)
            labels = torch.arange(n, device=sim.device)
            loss_i2t = torch.nn.functional.cross_entropy(sim, labels)
            loss_t2i = torch.nn.functional.cross_entropy(sim.t(), labels)
            losses.append(0.5 * (loss_i2t + loss_t2i))

        if len(losses) == 0:
            return torch.tensor(0.0, device=node_vec.device, dtype=torch.float32)
        return torch.stack(losses).mean()

    def _graph_clip_loss(
        self,
        global_tokens: torch.Tensor,  # (B, num_graph_tokens - d_max, d_model)
        graph_text: torch.Tensor,     # (B, d_model)
    ) -> torch.Tensor:
        """
        CLIP-style loss between graph global tokens (mean) and graph text representation.
        B < 2: become 1-cosine loss.
        """
        if global_tokens is None or global_tokens.size(1) == 0:
            return torch.tensor(0.0, device=graph_text.device, dtype=torch.float32)

        global_tokens = global_tokens.float()
        graph_text = graph_text.float()
        g = global_tokens.mean(dim=1)   # (B, d_model)
        g = self._normalize(g)          # (B, d_model)
        t = self._normalize(graph_text) # (B, d_model)

        B = g.size(0)
        if B < 2:
            return (1.0 - (g * t).sum(dim=-1)).mean()

        logits = torch.matmul(g, t.t()) / self.temp_graph  # (B, B)
        labels = torch.arange(B, device=logits.device)
        loss_i2t = torch.nn.functional.cross_entropy(logits, labels)
        loss_t2i = torch.nn.functional.cross_entropy(logits.t(), labels)
        return 0.5 * (loss_i2t + loss_t2i)

    def _notears_h(self, P: torch.Tensor) -> torch.Tensor:
        """
        NOTEARS acyclicity measure for each sample
        P: (B, D, D) float tensor
        return: (B,) tensor
        """
        B, D, _ = P.shape
        M = P * P
        expm = torch.linalg.matrix_exp(M)
        tr = torch.diagonal(expm, dim1=-2, dim2=-1).sum(dim=-1)
        return tr - float(D)

    def maybe_update_notears(self, step_after: int, h_value: Optional[float]):
        """
        Augmented-Lagrangian updates (called ONCE per optimizer step, ideally).
        step_after: completed optimizer step count (1-based makes it convenient to schedule)
        h_value: scalar (ideally mean across ranks)
        """
        if not self.notears_enable:
            return
        if h_value is None:
            return

        # update starts only after warmup and then every notears_update_every steps
        if step_after <= self.notears_warmup_steps:
            return
        if self.notears_update_every <= 0:
            return

        # e.g., warmup=1000, update_every=100 -> first update at step 1100
        if (step_after - self.notears_warmup_steps) % self.notears_update_every != 0:
            return

        h = float(h_value)
        # alpha update uses current rho
        self.notears_alpha = float(self.notears_alpha + self.notears_rho * h)
        # rho update
        self.notears_rho = float(min(self.notears_rho_max, self.notears_rho_mult * self.notears_rho))

    # ------------------------- main call -------------------------

    def __call__(
        self,
        outputs: CausalLMOutputWithPastExtended,
        batch: dict,
        global_step: Optional[int] = None,
    ):
        """
        Returns: (loss, ret_dict)
        """
        ret_dict: Dict[str, torch.Tensor] = {}

        A_logits = getattr(outputs, "A_logits", None)  # (B, D, D) or None

        # ===== Text (sequence) loss =====
        if getattr(outputs, "loss", None) is None:
            L_text = torch.tensor(
                0.0,
                device=A_logits.device if A_logits is not None else batch["labels"].device,
                dtype=torch.float32,
            )
        else:
            L_text = outputs.loss.float()
        ret_dict["L_text"] = L_text.detach()

        # ===== Consistency (optional placeholder) =====
        L_cons = torch.tensor(0.0, device=L_text.device, dtype=torch.float32)

        # Decide node_mask (for Hungarian + NOTEARS masking)
        node_mask = batch.get("node_mask", None)
        if node_mask is None and "node_prompt_emb" in batch:
            node_text_tmp = batch["node_prompt_emb"]
            node_mask = (node_text_tmp.abs().sum(-1) > 0)
        if node_mask is None and "A_mask" in batch:
            node_mask = self._derive_node_mask_from_A_mask(batch["A_mask"])
        if node_mask is not None:
            node_mask = node_mask.to(device=L_text.device, dtype=torch.bool)

        # ===== Hungarian alignment permutation (permute predictions to GT node order) =====
        perm = None
        if self.enable_hungarian and ("node_prompt_emb" in batch) and (node_mask is not None):
            node_text_for_match = batch["node_prompt_emb"].to(device=L_text.device)
            # choose pred features for matching
            pred_for_match = None
            if self.hungarian_match_on == "q":
                pred_for_match = getattr(outputs, "q", None)
                if pred_for_match is None and self.hungarian_fallback_to_node_tokens:
                    pred_for_match = getattr(outputs, "node_tokens", None)
            elif self.hungarian_match_on == "node_tokens":
                pred_for_match = getattr(outputs, "node_tokens", None)
                if pred_for_match is None and self.hungarian_fallback_to_node_tokens:
                    pred_for_match = getattr(outputs, "q", None)
            else:
                pred_for_match = getattr(outputs, "q", None)
                if pred_for_match is None:
                    pred_for_match = getattr(outputs, "node_tokens", None)

            if pred_for_match is not None:
                try:
                    perm = self._hungarian_perm(
                        pred_for_match=pred_for_match,
                        node_text=node_text_for_match,
                        node_mask=node_mask,
                    )
                except Exception:
                    # fall back to identity perm (do not crash training)
                    B = int(pred_for_match.size(0))
                    D = int(node_text_for_match.size(1))
                    perm = torch.arange(D, device=L_text.device).unsqueeze(0).expand(B, -1).contiguous()
                    ret_dict["hungarian_error"] = torch.tensor(1.0, device=L_text.device)
            else:
                perm = None

        # Apply permutation (if any) to A_logits / q / node_tokens for supervised losses
        A_logits_used = A_logits
        q_used = getattr(outputs, "q", None)
        node_tokens_used = getattr(outputs, "node_tokens", None)

        if perm is not None:
            if A_logits_used is not None:
                A_logits_used = self._apply_perm_mat(A_logits_used, perm)
            if q_used is not None:
                q_used = self._apply_perm_vec(q_used, perm)
            if node_tokens_used is not None:
                node_tokens_used = self._apply_perm_vec(node_tokens_used, perm)

        # ===== Graph loss (supervision) =====
        L_graph = None
        if (CausalQATask.SUPERVISION in self.exp_tasks) and (A_logits_used is not None):
            A_star = batch["A_star"].to(A_logits_used.device)  # (B, D, D)
            A_mask = batch["A_mask"].to(A_logits_used.device)  # (B, D, D)

            # cast to fp32 for stability
            A_logits_f = A_logits_used.float()
            A_star_f = A_star.float()
            A_mask_f = A_mask.float()

            pos_weight = torch.tensor(2.0, device=A_logits_used.device, dtype=torch.float32)
            bce_fn = torch.nn.BCEWithLogitsLoss(reduction="none", pos_weight=pos_weight)
            loss_el = bce_fn(A_logits_f, A_star_f)  # (B, D, D)
            denom = A_mask_f.sum().clamp_min(1.0)
            L_graph = (loss_el * A_mask_f).sum() / denom
            ret_dict["L_graph"] = L_graph.detach()

        # ===== Alignment loss =====
        L_node_q = None
        L_node_token = None
        L_graph_align = None

        if CausalQATask.ALIGNMENT in self.exp_tasks:
            node_text = batch["node_prompt_emb"].to(L_text.device)  # (B, D, H)
            if node_mask is None:
                node_mask = (node_text.abs().sum(-1) > 0).to(L_text.device)

            L_node_q = self._node_infonce_in_sample(
                node_vec=q_used,
                node_text=node_text,
                node_mask=node_mask,
            )
            L_node_token = self._node_infonce_in_sample(
                node_vec=node_tokens_used,
                node_text=node_text,
                node_mask=node_mask,
            )
            graph_text = batch["explanation_prompt_emb"].to(L_text.device)  # (B, H)
            L_graph_align = self._graph_clip_loss(
                global_tokens=getattr(outputs, "global_tokens", None),
                graph_text=graph_text,
            )

            ret_dict["L_node_q"] = L_node_q.detach()
            ret_dict["L_node_token"] = L_node_token.detach()
            ret_dict["L_graph_align"] = L_graph_align.detach()

        # ===== NOTEARS DAG regularizer =====
        L_dag = None
        h_dag = None
        apply_dag = self.notears_enable and (A_logits_used is not None)
        if apply_dag:
            step = int(global_step) if global_step is not None else 0
            if step >= self.notears_warmup_steps:
                if self.notears_grad_clip is not None and A_logits_used.requires_grad:
                    clip = float(self.notears_grad_clip)
                    A_logits_used.register_hook(lambda g: g.clamp(min=-clip, max=clip))

                P = torch.sigmoid(A_logits_used.float())  # (B, D, D) in [0,1]

                if self.notears_mask_invalid_nodes and (node_mask is not None):
                    m = node_mask.float()
                    P = P * (m.unsqueeze(-1) * m.unsqueeze(-2))

                if self.notears_zero_diag:
                    D = P.size(-1)
                    eye = torch.eye(D, device=P.device, dtype=P.dtype).unsqueeze(0)
                    P = P * (1.0 - eye)

                h_per = self._notears_h(P)  # (B,)
                h_dag = h_per.mean()

                alpha = torch.tensor(self.notears_alpha, device=P.device, dtype=P.dtype)
                rho = torch.tensor(self.notears_rho, device=P.device, dtype=P.dtype)

                L_dag = self.notears_weight * (alpha * h_dag + 0.5 * rho * (h_dag ** 2))

                ret_dict["h_dag"] = h_dag.detach()
                ret_dict["L_dag"] = L_dag.detach()
                ret_dict["dag_alpha"] = alpha.detach()
                ret_dict["dag_rho"] = rho.detach()

        # ===== Total loss =====
        loss = self.lambda_text * L_text + self.lambda_consistency * L_cons

        if CausalQATask.SUPERVISION in self.exp_tasks and (L_graph is not None):
            loss = loss + self.lambda_graph * L_graph

        if CausalQATask.ALIGNMENT in self.exp_tasks:
            if L_node_q is not None:
                loss = loss + self.lambda_node_q * L_node_q
            if L_node_token is not None:
                loss = loss + self.lambda_node_token * L_node_token
            if L_graph_align is not None:
                loss = loss + self.lambda_graph_alignment * L_graph_align

        if L_dag is not None:
            loss = loss + L_dag

        ret_dict["loss"] = loss.detach()

        if torch.isnan(loss):
            raise ValueError("NaN detected in loss")

        return loss, ret_dict


def make_compute_metrics(
    processor,
    pos_weight: float = 2.0,
    graph_threshold: float = 0.5,
    exp_tasks: Optional[list[CausalQATask]] = None
):
    'if you need to run evaluation, make sure to implement this function appropriately for your task'
    def compute_metrics(eval_preds):
        return 'metrics_placeholder'
    
    return compute_metrics