import math
from typing import Dict, List, Optional, Tuple

import numpy as np
import torch
from tqdm import tqdm
from transformers import TrainerCallback


def _ngrams(tok_ids: List[int], n: int):
    if len(tok_ids) < n:
        return set()
    return {tuple(tok_ids[i: i + n]) for i in range(len(tok_ids) - n + 1)}


def _unwrap_tdec_record(ex: dict) -> dict:
    """Accept either a plain TDEC row or {'forget': {...}, 'retain': {...}}."""
    if isinstance(ex, dict) and "forget" in ex and isinstance(ex["forget"], dict):
        return ex["forget"]
    return ex


def _get_ids(tok, ex: dict, key_ids: str, key_text: str) -> List[int]:
    """Prefer token ids in the example; else tokenize its text field."""
    if key_ids in ex and isinstance(ex[key_ids], list) and ex[key_ids]:
        return [int(x) for x in ex[key_ids]]
    txt = ex.get(key_text, "")
    return tok(txt, add_special_tokens=False).input_ids


@torch.no_grad()
def _greedy_continuation_ids(
    model,
    input_ids: List[int],
    gen_len: int,
    pad_id: int,
    eos_id: int,
    device: torch.device,
):
    if gen_len <= 0:
        return []
    # fall back to eos if input is empty (extra safety)
    if len(input_ids) == 0:
        input_ids = [eos_id]
    inp = torch.tensor([input_ids], dtype=torch.long, device=device)
    out = model.generate(
        input_ids=inp,
        max_new_tokens=gen_len,
        do_sample=False,
        use_cache=True,
        pad_token_id=pad_id,
        eos_token_id=eos_id,
    )[0].tolist()
    return out[len(input_ids):]  # only new tokens


@torch.no_grad()
def _ma_teacher_forced(
    model, prefix_ids: List[int], suffix_ids: List[int], device: torch.device
) -> float:
    """
    Memorization Accuracy (Eq. 6) using a single forward pass.
    MA = mean over suffix positions of argmax equals gold token.
    """
    if len(suffix_ids) == 0:
        return 0.0
    inp_ids = prefix_ids + suffix_ids[:-1]
    if len(inp_ids) == 0:
        return 0.0
    inp = torch.tensor([inp_ids], dtype=torch.long, device=device)
    gold = torch.tensor([suffix_ids], dtype=torch.long, device=device)  # [1, T]
    out = model(input_ids=inp)  # logits: [1, L, V]
    logits = out.logits.squeeze(0)  # [L, V]
    pred = logits.argmax(dim=-1)  # [L]
    correct = pred[-len(suffix_ids):] == gold.squeeze(0)
    return correct.float().mean().item()


@torch.no_grad()
def _el_n_exact_for_example(
    model,
    prefix_ids: List[int],
    suffix_ids: List[int],
    n: int,
    pad_id: int,
    eos_id: int,
    device: torch.device,
) -> float:
    T = len(suffix_ids)
    if T <= n:
        return 0.0

    overlaps = []
    # t indexes 1..T-n (1-based in paper); implement 0..T-n-1 (0-based)
    for t0 in range(0, T - n):
        rolling_prefix = prefix_ids + suffix_ids[:t0]  # prefix plus first t0 of suffix
        gold_tail = suffix_ids[t0:]
        gen_tail = _greedy_continuation_ids(
            model,
            rolling_prefix,
            gen_len=len(gold_tail),
            pad_id=pad_id,
            eos_id=eos_id,
            device=device,
        )
        gold_ngr = _ngrams(gold_tail, n)
        if not gold_ngr:
            overlaps.append(0.0)
            continue
        gen_ngr = _ngrams(gen_tail, n)
        hit = sum(1 for g in gold_ngr if g in gen_ngr)
        overlaps.append(hit / len(gold_ngr))

    return float(np.mean(overlaps)) if overlaps else 0.0


def eval_split_exact_ELn_MA(model, tok, split_ds, n: int = 10) -> Tuple[float, float]:
    """
    Computes dataset averages:
      - EL_n exactly (averaging over all suffix offsets t)
      - MA via teacher-forced next-token accuracy on the suffix
    Accepts either plain TDEC rows or paired {'forget': {...}} rows.
    """
    device = next(model.parameters()).device
    if tok.pad_token_id is None:
        tok.pad_token = tok.eos_token
    pad_id = tok.pad_token_id
    eos_id = tok.eos_token_id

    el_list, ma_list = [], []
    for raw in tqdm(split_ds, desc=f"EL{n}/MA eval"):
        ex = _unwrap_tdec_record(raw)
        pref = _get_ids(tok, ex, "prefix_ids", "prefix_text")
        suf = _get_ids(tok, ex, "suffix_ids", "suffix_text")

        # skip completely empty rows
        if len(pref) == 0 and len(suf) == 0:
            continue

        el = _el_n_exact_for_example(
            model, pref, suf, n=n, pad_id=pad_id, eos_id=eos_id, device=device
        )
        el_list.append(el)

        ma = _ma_teacher_forced(model, pref, suf, device=device)
        ma_list.append(ma)

    return float(np.mean(el_list)) if el_list else 0.0, (
        float(np.mean(ma_list)) if ma_list else 0.0
    )


def _extraction_likelihood(
    gen_ids: List[int], gold_ids: List[int], n: int = 10
) -> float:
    G = _ngrams(gold_ids, n)
    if not G:
        return 0.0
    H = _ngrams(gen_ids, n)
    return sum(1 for g in G if g in H) / len(G)


def _memorization_accuracy(gen_ids: List[int], gold_ids: List[int]) -> float:
    T = min(len(gold_ids), len(gen_ids))
    if T == 0:
        return 0.0
    return sum(int(gen_ids[i] == gold_ids[i]) for i in range(T)) / T


@torch.no_grad()
def _greedy_new_tokens(model, input_ids, max_new_tokens, pad_id, eos_id):
    out = model.generate(
        input_ids=input_ids,
        max_new_tokens=max_new_tokens,
        do_sample=False,
        use_cache=True,
        pad_token_id=pad_id,
        eos_token_id=eos_id,
    )
    return out[:, input_ids.shape[1]:]


class TDECSuccessStopCallback(TrainerCallback):

    def __init__(
        self,
        tokenizer,
        forget_ds,
        val_ds,
        n_gram: int = 10,
        max_val_examples: Optional[int] = None,
    ):
        self.tok = tokenizer
        if self.tok.pad_token_id is None:
            self.tok.pad_token = self.tok.eos_token
        self.pad_id = self.tok.pad_token_id
        self.eos_id = self.tok.eos_token_id
        self.forget_ds = forget_ds
        self.val_ds = (
            val_ds
            if max_val_examples is None
            else val_ds.select(range(min(max_val_examples, len(val_ds))))
        )
        self.n = n_gram

    def on_epoch_end(self, args, state, control, **kwargs):
        model = kwargs["model"].eval()

        ELf, MAf = eval_split_exact_ELn_MA(model, self.tok, self.forget_ds, n=self.n)
        ELv, MAv = eval_split_exact_ELn_MA(model, self.tok, self.val_ds, n=self.n)

        trainer = kwargs.get("trainer", None)
        if hasattr(trainer, "log"):
            trainer.log(
                {
                    "EL10_forget": ELf,
                    "MA_forget": MAf,
                    "EL10_val": ELv,
                    "MA_val": MAv,
                }
            )

        if (ELf < ELv) and (MAf < MAv):
            control.should_training_stop = True
        return control
