import torch
import torch.nn.functional as F
import numpy as np
from typing import List, Callable
from typing import Optional, List


__all__ = [
    "ConsistencyLoss",
]

def _flatten_logits(x: torch.Tensor) -> torch.Tensor:
    """If logits are (B,N,C) flatten to (B*N,C) else keep as-is."""
    if x.dim() == 3:
        B, N, C = x.shape
        return x.reshape(B * N, C)
    return x


def _logits_process(logits: List[torch.Tensor]):
    """Ensure all logits are (B*N,C) for loss calc."""
    return [_flatten_logits(lg) for lg in logits]


class ConsistencyLoss:
    """Hierarchical consistency regularisation (inter-stage or root-referenced).

    Parameters
    ----------
    h_matrices : List[np.ndarray]
        List of hierarchy matrices ``H_k`` where each row sums to 1 and
        maps *leaf* distribution to *upper* classes.
        For *inter* mode we also internally build H_inter(k) = H_k · H_{k+1}ᵀ.
    ignore_label : int, default -100
        Unused here but kept to align with PointSegBase signature.
    layer_weights : List[float] or None
        Weighting per hierarchy level.  If None defaults to 1 for each.
    device : str, default 'cuda'
    reg_func : Callable, default ``torch.abs``
        Element-wise distance used between prob distributions.
    mode : {'inter','root'}, default 'inter'
        * inter : penalise adjacent levels (same as PointSegBase default).
        * root  : penalise every upper level against the finest.
    """

    def __init__(
        self,
        h_matrices: List[np.ndarray],
        ignore_label: int = -100,
        layer_weights: Optional[List[float]] = None,
        device: str = "cuda",
        reg_func: Callable[[torch.Tensor], torch.Tensor] = torch.abs,
        mode: str = "inter",
    ) -> None:
        self.ignore_label = ignore_label  # kept for API compatibility
        self.h_matrices = [torch.from_numpy(m).float().to(device) for m in h_matrices]
        # build intermediate mapping matrices if needed
        self.inter_h_matrices = []
        for i in range(len(h_matrices) - 1):
            nm = np.clip(h_matrices[i] @ h_matrices[i + 1].T, 0.0, 1.0)
            self.inter_h_matrices.append(torch.from_numpy(nm).float().to(device))

        if layer_weights is None:
            self.layer_weights = [1.0] * len(self.h_matrices)
        else:
            assert len(layer_weights) == len(self.h_matrices)
            self.layer_weights = layer_weights

        self.device = device
        self.reg_func = reg_func
        self.mode = mode.lower()
        assert self.mode in {"inter", "root"}, "mode must be 'inter' or 'root'"

    # -----------------------------------------------------
    def __call__(self, pred: List[torch.Tensor]) -> torch.Tensor:
        """Compute consistency loss over a list of logits (coarse→fine).

        """
        pred = _logits_process(pred)
        assert len(pred) == len(self.h_matrices), "#levels mismatch"

        # softmax probabilities per level
        probs = [F.softmax(p, dim=-1) for p in pred]
        # align dtype with logits (handles AMP fp16)
        _dtype = probs[0].dtype
        total_loss = torch.zeros((), device=self.device, dtype=_dtype)

        if self.mode == "inter":
            for i in range(len(pred) - 1):
                su, sd = probs[i], probs[i + 1]  # upper, lower
                H = self.inter_h_matrices[i].to(device=self.device, dtype=_dtype)
                mapped_sd = torch.matmul(sd, H.T)
                diff = self.reg_func(su - mapped_sd)
                total_loss = total_loss + self.layer_weights[i] * diff.mean()
        else:  # root mode：全部与最细层对齐
            sd = probs[-1]
            for i in range(len(pred) - 1):
                su = probs[i]
                H = self.h_matrices[i].to(device=self.device, dtype=_dtype)
                mapped_sd = torch.matmul(sd, H.T)
                diff = self.reg_func(su - mapped_sd)
                total_loss = total_loss + self.layer_weights[i] * diff.mean()

        return total_loss / (len(self.h_matrices) - 1) 