from typing import Dict, List, Optional, Tuple

import numpy as np
import torch
from tqdm import tqdm
from transformers import TrainerCallback


def _ngrams(seq, n):
    return {tuple(seq[i: i + n]) for i in range(max(0, len(seq) - n + 1))}


def _extraction_likelihood(
    gen_ids: List[int], gold_ids: List[int], n: int = 10
) -> float:
    G = _ngrams(gold_ids, n)
    if not G:
        return 0.0
    H = _ngrams(gen_ids, n)
    return sum(1 for g in G if g in H) / len(G)


def _memorization_accuracy(gen_ids: List[int], gold_ids: List[int]) -> float:
    T = min(len(gold_ids), len(gen_ids))
    if T == 0:
        return 0.0
    return sum(int(gen_ids[i] == gold_ids[i]) for i in range(T)) / T


def _unwrap_tdec_record(ex: dict) -> dict:
    """Accept either a plain TDEC row or {'forget': {...}, 'retain': {...}}."""
    if isinstance(ex, dict) and "forget" in ex and isinstance(ex["forget"], dict):
        return ex["forget"]
    return ex


def _get_ids(tok, ex: dict, key_ids: str, key_text: str) -> List[int]:
    if key_ids in ex and isinstance(ex[key_ids], list) and ex[key_ids]:
        return [int(x) for x in ex[key_ids]]
    txt = ex.get(key_text, "")
    return tok(txt, add_special_tokens=False).input_ids


@torch.no_grad()
def _greedy_new_tokens(model, input_ids, max_new_tokens, pad_id, eos_id):
    if max_new_tokens <= 0:
        return input_ids.new_zeros((input_ids.shape[0], 0))
    if input_ids.numel() == 0:
        input_ids = torch.tensor(
            [[eos_id]], dtype=torch.long, device=next(model.parameters()).device
        )
    out = model.generate(
        input_ids=input_ids,
        max_new_tokens=max_new_tokens,
        do_sample=False,
        use_cache=True,
        pad_token_id=pad_id,
        eos_token_id=eos_id,
    )
    return out[:, input_ids.shape[1]:]


# ---------- callback (FAST EL10/MA at the suffix boundary) ----------


class TDECSuccessStopCallback(TrainerCallback):
    """
    FAST version: computes EL10/MA using only the main prefix→suffix boundary.
    Early-stops when EL10(D_f) < EL10(D_val) AND MA(D_f) < MA(D_val).
    Accepts forget/val datasets that are either plain TDEC rows or paired rows.
    """

    def __init__(
        self,
        tokenizer,
        forget_ds,
        val_ds,
        n_gram: int = 10,
        max_val_examples: Optional[int] = None,
    ):
        self.tok = tokenizer
        if self.tok.pad_token_id is None:
            self.tok.pad_token = self.tok.eos_token
        self.pad_id = self.tok.pad_token_id
        self.eos_id = self.tok.eos_token_id
        self.forget_ds = forget_ds
        self.val_ds = (
            val_ds
            if max_val_examples is None
            else val_ds.select(range(min(max_val_examples, len(val_ds))))
        )
        self.n = n_gram

    def _eval_split_fast(self, model, ds, desc="eval") -> Tuple[float, float]:
        ELs, MAs = [], []
        device = next(model.parameters()).device
        for raw in tqdm(ds, desc=desc):
            ex = _unwrap_tdec_record(raw)
            pref = _get_ids(self.tok, ex, "prefix_ids", "prefix_text")
            gold = _get_ids(self.tok, ex, "suffix_ids", "suffix_text")

            if len(pref) == 0 and len(gold) == 0:
                continue

            inp = (
                torch.tensor([pref], dtype=torch.long, device=device)
                if len(pref) > 0
                else torch.tensor([[self.eos_id]], dtype=torch.long, device=device)
            )
            gen = _greedy_new_tokens(
                model,
                inp,
                max_new_tokens=len(gold),
                pad_id=self.pad_id,
                eos_id=self.eos_id,
            )[0].tolist()

            ELs.append(_extraction_likelihood(gen, gold, n=self.n))
            MAs.append(_memorization_accuracy(gen, gold))

        return float(np.mean(ELs)) if ELs else 0.0, float(np.mean(MAs)) if MAs else 0.0

    def on_epoch_end(self, args, state, control, **kwargs):
        model = kwargs["model"].eval()

        ELf, MAf = self._eval_split_fast(model, self.forget_ds, desc="forget split")
        ELv, MAv = self._eval_split_fast(model, self.val_ds, desc="validation split")

        trainer = kwargs.get("trainer", None)
        if hasattr(trainer, "log"):
            trainer.log(
                {
                    "EL10_forget": ELf,
                    "MA_forget": MAf,
                    "EL10_val": ELv,
                    "MA_val": MAv,
                }
            )

        if (ELf < ELv) and (MAf < MAv):
            control.should_training_stop = True
        return control
