
import math
import os
from dataclasses import dataclass, field
from typing import Dict, Optional, List, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from transformers import Trainer
from transformers.trainer_utils import EvalLoopOutput
from transformers.utils import logging

logger = logging.get_logger(__name__)

# ===== Calibration Head =====
class CalibrationHead(nn.Module):
    """
    Simple calibration head that maps hidden states to a single logit.
    Accepts either pooled hidden or per-example features provided by the collator/model.
    """
    def __init__(self, in_dim: int, hidden_dim: int = 0):
        super().__init__()
        if hidden_dim and hidden_dim > 0:
            self.net = nn.Sequential(
                nn.Linear(in_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, 1),
            )
        else:
            self.net = nn.Linear(in_dim, 1)

    def forward(self, feats: torch.Tensor) -> torch.Tensor:
        # feats: [B, D] -> logits [B]
        out = self.net(feats).squeeze(-1)
        return out

# ===== Helper: ranking pairs =====
def _build_rank_pairs(scores: torch.Tensor, top_k: int = 16) -> List[Tuple[int,int]]:
    """
    Build a small set of index pairs (i,j) where score[i] > score[j].
    We sample by sorting and taking heads/tails to keep compute small.
    """
    B = scores.shape[0]
    if B < 4:
        return []
    k = min(top_k, B // 2)
    order = torch.argsort(scores, descending=True)
    heads = order[:k]
    tails = order[-k:]
    pairs = [(int(i), int(j)) for i in heads for j in tails if int(i) != int(j)]
    return pairs

# ===== Args for unsupervised calibration =====
@dataclass
class CalibArgs:
    lambda_mean: float = 0.75
    lambda_rank: float = 0.10
    lambda_tv: float = 0.01
    lambda_align: float = 0.10

    # If provided, each example already carries "targets" in the dataset (soft labels in [0,1]).
    # If not provided, we will skip BCE and rely on mean/rank/tv/align (still works but weaker).
    use_bce_when_targets_available: bool = True

    # Binning config for align loss
    num_bins: int = 10

    # Optional per-bin target accuracies for UTDC-lite style alignment (length == num_bins)
    target_bin_acc: Optional[List[float]] = None

    # Clamp calibrated probabilities to avoid log/grad blow-ups
    prob_eps: float = 1e-5

class MyTrainerEnhanced(Trainer):
    """
    Drop-in replacement that expects the model to return a dict with:
      - "loss" (optional, if doing SFT jointly)
      - "features": [B, D]   (pooled features for calibration head)
      - optionally "conf_logits": [B] if the model already produces a confidence logit;
        otherwise we compute it via CalibrationHead.

    The dataset should provide (when available):
      - "targets": [B] float tensor in [0,1], the soft pseudo labels from E-step.
      - "bin_idx": [B] int in [0, num_bins-1] (optional), used for align loss.
      - "rank_score": [B] float (optional), higher means more likely correct, for rank loss.
    """
    def __init__(self, *args, calib_args: Optional[CalibArgs] = None, calib_head: Optional[nn.Module] = None, **kwargs):
        super().__init__(*args, **kwargs)
        self.calib_args = calib_args or CalibArgs()
        self.calib_head = calib_head  # if None, expect model to emit conf_logits
        self.bce = nn.BCEWithLogitsLoss(reduction="mean")

    def _get_conf_logits(self, model_outputs: Dict) -> torch.Tensor:
        if "conf_logits" in model_outputs and model_outputs["conf_logits"] is not None:
            return model_outputs["conf_logits"]
        if self.calib_head is None:
            raise ValueError("No conf_logits from model and no calib_head provided.")
        feats = model_outputs.get("features", None)
        if feats is None:
            raise ValueError("Model outputs must include 'features' for external calibration head.")
        return self.calib_head(feats)

    def compute_loss(self, model, inputs, return_outputs=False):
        # Forward
        outputs = model(**{k: v for k, v in inputs.items() if k not in ("targets", "bin_idx", "rank_score")})
        conf_logits = self._get_conf_logits(outputs)
        probs = torch.sigmoid(conf_logits).clamp(self.calib_args.prob_eps, 1 - self.calib_args.prob_eps)

        # Base BCE with soft targets (if provided)
        loss_main = torch.tensor(0.0, device=conf_logits.device)
        targets = inputs.get("targets", None)
        if targets is not None and self.calib_args.use_bce_when_targets_available:
            loss_main = self.bce(conf_logits, targets)

        # Mean constraint (match global average to estimated target-domain accuracy if available)
        loss_mean = torch.tensor(0.0, device=conf_logits.device)
        if hasattr(self.calib_args, "target_global_acc") and self.calib_args.target_global_acc is not None:
            loss_mean = (probs.mean() - float(self.calib_args.target_global_acc)) ** 2

        # Ranking constraint (if rank_score provided)
        loss_rank = torch.tensor(0.0, device=conf_logits.device)
        rank_score = inputs.get("rank_score", None)
        if rank_score is not None:
            pairs = _build_rank_pairs(rank_score, top_k=16)
            if pairs:
                margin = 0.05
                diffs = []
                for i, j in pairs:
                    diffs.append(F.relu(margin - (probs[i] - probs[j])))
                if diffs:
                    loss_rank = torch.stack(diffs).mean()

        # Total variation (TV) smoothing across confidence histogram bins
        loss_tv = torch.tensor(0.0, device=conf_logits.device)
        if self.calib_args.lambda_tv > 0:
            # put current probs into bins by their *raw* confidence (probs)
            bin_idx = torch.clamp((probs * self.calib_args.num_bins).long(), 0, self.calib_args.num_bins - 1)
            bin_means = []
            for b in range(self.calib_args.num_bins):
                mask = bin_idx == b
                if mask.any():
                    bin_means.append(probs[mask].mean())
                else:
                    bin_means.append(probs.mean().detach())  # backprop-safe fallback
            diffs = [torch.abs(bin_means[b] - bin_means[b-1]) for b in range(1, self.calib_args.num_bins)]
            loss_tv = torch.stack(diffs).mean()

        # Align loss to per-bin target accuracies (UTDC-lite)
        loss_align = torch.tensor(0.0, device=conf_logits.device)
        if self.calib_args.target_bin_acc is not None:
            tb = torch.tensor(self.calib_args.target_bin_acc, device=conf_logits.device, dtype=probs.dtype)
            with torch.no_grad():
                # fixed binning by *model's raw confidence* (before training) is also fine; here we use current probs
                pass
            bin_idx = torch.clamp((probs * self.calib_args.num_bins).long(), 0, self.calib_args.num_bins - 1)
            # Compute per-bin predicted means
            pred_means = []
            for b in range(self.calib_args.num_bins):
                mask = bin_idx == b
                if mask.any():
                    pred_means.append(probs[mask].mean())
                else:
                    pred_means.append(probs.mean().detach())  # stable fallback
            pred_means = torch.stack(pred_means)
            # KL-like or L2 alignment; we use L2 for stability
            loss_align = F.mse_loss(pred_means, tb)

        loss = (
            loss_main
            + self.calib_args.lambda_mean * loss_mean
            + self.calib_args.lambda_rank * loss_rank
            + self.calib_args.lambda_tv * loss_tv
            + self.calib_args.lambda_align * loss_align
        )

        if return_outputs:
            outputs["conf_logits"] = conf_logits  # ensure exposed
            return loss, outputs
        return loss

    # Optional: lightweight evaluate that logs reliability with pseudo labels if present
    def evaluation_loop(
        self,
        dataloader: DataLoader,
        description: str,
        prediction_loss_only: Optional[bool] = None,
        ignore_keys: Optional[List[str]] = None,
        metric_key_prefix: str = "eval",
    ) -> EvalLoopOutput:
        model = self._wrap_model(self.model, training=False, dataloader=dataloader)
        model.eval()
        preds = []
        targs = []
        with torch.no_grad():
            for batch in dataloader:
                batch = self._prepare_inputs(batch)
                outputs = model(**{k: v for k, v in batch.items() if k not in ("targets","bin_idx","rank_score")})
                conf_logits = self._get_conf_logits(outputs)
                probs = torch.sigmoid(conf_logits)
                preds.append(probs.detach().cpu())
                if "targets" in batch:
                    targs.append(batch["targets"].detach().cpu())
        logs = {}
        if preds:
            import numpy as np
            p = torch.cat(preds).numpy()
            logs[f"{metric_key_prefix}_prob_mean"] = float(p.mean())
            logs[f"{metric_key_prefix}_prob_std"] = float(p.std())
            if targs:
                y = torch.cat(targs).numpy()
                # proxy ECE with soft labels
                def soft_ece(p, y, M=10):
                    bins = np.linspace(0,1,M+1)
                    ece = 0.0
                    for m in range(M):
                        mask = (p >= bins[m]) & (p < bins[m+1] if m<M-1 else p<=bins[m+1])
                        if mask.any():
                            conf = p[mask].mean()
                            acc = y[mask].mean()
                            ece += (mask.mean()) * abs(conf - acc)
                    return float(ece)
                logs[f"{metric_key_prefix}_soft_ece"] = soft_ece(p, y, M=self.calib_args.num_bins)
        self.log(logs)
        return EvalLoopOutput(predictions=None, label_ids=None, metrics=logs, num_samples=0)
