import argparse
import csv
import json
import math
import os
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

import evaluate
import torch
import torch.nn.functional as F
from huggingface_hub import HfApi, login
from rouge_score import rouge_scorer
from tqdm import tqdm

from egu.dataset.qa_dataset import QACollatorDynamicPad, TextDatasetQA
from egu.evaluators.utils import load_model_and_tokenizer
from egu.utils.metrics import get_forget_quality, get_model_utility, get_matching_model_utility
from egu.utils.utils import load_yaml


def parse_adapter_id(adapter_id: str) -> Dict[str, str]:
    """
    Expect: org/msc_unlearn_lora_{rank}_{alpha}_{lr}_{method}_{dataset}_{dataset_subset}
    Returns dict with keys: rank, alpha, lr, method, dataset, dataset_subset.
    Falls back to 'unknown' on mismatch.
    """
    try:
        repo = adapter_id.split("/")[-1]
        parts = repo.split("_")
        # e.g. ['msc','unlearn','lora','16','32','1e-05','npo','tofu','forget10']
        rank, alpha, lr, method, dataset, subset = (
            parts[3],
            parts[4],
            parts[5],
            parts[6],
            parts[7],
            parts[8],
        )
        return {
            "rank": rank,
            "alpha": alpha,
            "lr": lr,
            "method": method,
            "dataset": dataset,
            "dataset_subset": subset,
        }
    except Exception:
        return {
            "rank": "r",
            "alpha": "a",
            "lr": "lr",
            "method": "method",
            "dataset": "dataset",
            "dataset_subset": "subset",
        }


def eval_accuracy_any(logits, labels, ignore_index: int = -100):
    correct, ntok = _accuracy_counts(logits, labels, ignore_index)
    acc = (correct / ntok) if ntok else float("nan")
    return {"eval accuracy": acc}


def _accuracy_counts(logits, labels, ignore_index: int):
    if isinstance(logits, (list, tuple)):
        total_c, total_n = 0, 0
        for l, y in zip(logits, labels):
            c, n = _accuracy_counts(l, y, ignore_index)
            total_c += c
            total_n += n
        return total_c, total_n
    if not torch.is_tensor(logits):
        logits = torch.as_tensor(logits)
    if not torch.is_tensor(labels):
        labels = torch.as_tensor(labels)
    if labels.dtype != torch.long:
        labels = labels.long()
    if logits.dim() == 3:
        preds = logits.argmax(dim=-1)
        pred_shift = preds[..., :-1]
        label_shift = labels[..., 1:]
    elif logits.dim() == 2:
        preds = logits.argmax(dim=-1)
        pred_shift = preds[:-1]
        label_shift = labels[1:]
    else:
        raise ValueError(f"Unexpected logits.dim()={logits.dim()}")
    mask = label_shift != ignore_index
    if not torch.any(mask):
        return 0, 0
    correct = ((pred_shift == label_shift) & mask).sum().item()
    ntok = mask.sum().item()
    return correct, ntok


def eval_perplexity_any(logits, labels, ignore_index: int = -100):
    nll_sum, ntok = _nll_sum_and_count(logits, labels, ignore_index)
    if ntok == 0:
        return {"perplexity": float("inf"), "avg_nll": float("nan"), "num_tokens": 0}
    avg_nll = nll_sum / ntok
    return {"perplexity": math.exp(avg_nll), "avg_nll": avg_nll, "num_tokens": ntok}


def _nll_sum_and_count(logits, labels, ignore_index: int):
    if isinstance(logits, (list, tuple)):
        total_nll, total_tok = 0.0, 0
        for l, y in zip(logits, labels):
            nll, nt = _nll_sum_and_count(l, y, ignore_index)
            total_nll += nll
            total_tok += nt
        return total_nll, total_tok
    if not torch.is_tensor(logits):
        logits = torch.as_tensor(logits)
    if not torch.is_tensor(labels):
        labels = torch.as_tensor(labels)
    if labels.dtype != torch.long:
        labels = labels.long()
    if logits.dim() == 3:
        B, T, V = logits.shape
        logits = logits.reshape(-1, V)
        labels = labels.reshape(-1)
    elif logits.dim() == 2:
        V = logits.shape[-1]
        labels = labels.reshape(-1)
    else:
        raise ValueError(f"Unexpected logits.dim()={logits.dim()}")
    shift_logits = logits[:-1]
    shift_labels = labels[1:]
    mask = shift_labels != ignore_index
    ntok = int(mask.sum().item())
    if ntok == 0:
        return 0.0, 0
    nll = F.cross_entropy(shift_logits[mask], shift_labels[mask], reduction="sum")
    return float(nll.item()), ntok


def run_generation(cfg, batch, model, tokenizer, formatting_tokens):
    def build_prompt(ft, q):
        return f"{q} "

    prompts = [build_prompt(formatting_tokens, q) for q in batch["questions"]]
    gts = batch["answers"]
    old_side = tokenizer.padding_side
    tokenizer.padding_side = "left"
    enc = tokenizer(
        prompts,
        add_special_tokens=True,
        padding=True,
        truncation=True,
        return_tensors="pt",
    ).to(model.device)
    out = model.generate(
        enc.input_ids,
        attention_mask=enc.attention_mask.to(torch.long),
        max_new_tokens=cfg["generation"]["max_new_tokens"],
        do_sample=False,
        use_cache=True,
        pad_token_id=tokenizer.pad_token_id,
    )
    tokenizer.padding_side = old_side
    start = enc.input_ids.shape[-1]
    new_tokens = out[:, start:]
    preds = (
        [""] * new_tokens.size(0)
        if new_tokens.size(1) == 0
        else [
            tokenizer.decode(seq, skip_special_tokens=True)
            for seq in new_tokens.detach().cpu().tolist()
        ]
    )
    return preds, gts


def eval_bleu(gen_outputs, ground_truths):
    rouge = evaluate.load("rouge")
    bleu = evaluate.load("bleu")
    rouge_res = rouge.compute(predictions=gen_outputs, references=ground_truths)
    bleu_res = bleu.compute(predictions=gen_outputs, references=ground_truths)
    return {"rouge": rouge_res, "bleu": bleu_res}


def eval_rouge_recall(gen_outputs, ground_truths):
    scorer = rouge_scorer.RougeScorer(["rouge1", "rougeL"], use_stemmer=True)
    rouge1_recall, rougeL_recall = [], []
    for gen, gt in zip(gen_outputs, ground_truths):
        rs = scorer.score(gt, gen)
        rouge1_recall.append(rs["rouge1"].recall)
        rougeL_recall.append(rs["rougeL"].recall)
    return {"rouge1_recall": rouge1_recall, "rougeL_recall": rougeL_recall}


TOFU_FILENAMES = {
    "retain": "eval_log.json",
    "forget": "eval_log_forget.json",
    "real_authors": "eval_real_author_wo_options.json",
    "world_facts": "eval_real_world_wo_options.json",
}

RETAIN_EVAL_SPLIT = "retain_perturbed"
REAL_AUTHORS_SPLIT = "real_authors"
WORLD_FACTS_SPLIT = "world_facts"

PAIRING = {
    "forget10": ("forget10", "retain90"),
    "forget05": ("forget05", "retain95"),
    "forget01": ("forget01", "retain99"),
}

PERT_JSON = {
    "retain_perturbed": "retain_perturbed.json",
    "real_authors": "real_authors_perturbed.json",
    "world_facts": "world_facts_perturbed.json",
    "forget10": "forget10_perturbed.json",
    "forget05": "forget05_perturbed.json",
    "forget01": "forget01_perturbed.json",
}


def _sum_nll_and_count_from_logits(
    logits: torch.Tensor, labels: torch.Tensor, ignore_index: int = -100
):
    assert logits.dim() == 3 and labels.dim() == 2 and logits.shape[0] == 1
    with torch.no_grad():
        logp = torch.log_softmax(logits[0], dim=-1)
        y = labels[0]
        mask = y.ne(ignore_index)
        picked = logp[torch.arange(y.shape[0]), y.clamp_min(0)]
        picked = torch.where(mask, picked, torch.zeros_like(picked))
        return float(-(picked.sum().item())), int(mask.sum().item())


def _nll_for_answer(model, tokenizer, question: str, answer: str) -> tuple[float, int]:
    device = next(model.parameters()).device
    prompt = f"{question} "
    enc_q = tokenizer(prompt, add_special_tokens=True, return_tensors="pt").to(device)
    enc_full = tokenizer(
        prompt + answer, add_special_tokens=True, return_tensors="pt"
    ).to(device)
    labels = enc_full.input_ids.clone()
    labels[:, : enc_q.input_ids.shape[1]] = -100
    with torch.no_grad():
        out = model(
            input_ids=enc_full.input_ids,
            attention_mask=enc_full.attention_mask,
            labels=labels,
        )
        return _sum_nll_and_count_from_logits(out.logits, labels)


def _batch_generate(model, tokenizer, qs: List[str], max_new: int) -> List[str]:
    device = next(model.parameters()).device
    prompts = [q + " " for q in qs]
    old = tokenizer.padding_side
    tokenizer.padding_side = "left"
    enc = tokenizer(
        prompts,
        add_special_tokens=True,
        padding=True,
        truncation=True,
        return_tensors="pt",
    ).to(device)
    with torch.no_grad():
        out = model.generate(
            enc.input_ids,
            attention_mask=enc.attention_mask.to(torch.long),
            max_new_tokens=max_new,
            do_sample=False,
            use_cache=True,
            pad_token_id=tokenizer.pad_token_id,
        )
    tokenizer.padding_side = old
    start = enc.input_ids.shape[-1]
    new = out[:, start:]
    if new.size(1) == 0:
        return [""] * new.size(0)
    return [
        tokenizer.decode(seq, skip_special_tokens=True)
        for seq in new.detach().cpu().tolist()
    ]


def _load_perturbed_json(path: Path) -> tuple[List[str], List[List[str]]]:
    with path.open("r", encoding="utf-8") as f:
        txt = f.read().strip()
    paras: List[str] = []
    perts: List[List[str]] = []
    if not txt:
        return paras, perts
    if txt[0] == "[":  # JSON array
        data = json.loads(txt)
        for row in data:
            paras.append(row.get("paraphrased_answer", None))
            perts.append(row.get("perturbed_answer", []) or [])
    else:  # JSONL
        for line in txt.splitlines():
            line = line.strip()
            if not line:
                continue
            row = json.loads(line)
            paras.append(row.get("paraphrased_answer", None))
            perts.append(row.get("perturbed_answer", []) or [])
    return paras, perts


def _avg(sum_list, cnt_list):
    out = []
    for s, c in zip(sum_list, cnt_list):
        out.append(float(s) / float(c) if c > 0 else float("nan"))
    return out


def _kl_to_uniform(probs: List[float]) -> float:
    import math as _m

    K = len(probs)
    if K == 0:
        return 0.0
    u = 1.0 / K
    s = 0.0
    for p in probs:
        if p > 0:
            s += p * (_m.log(p) - _m.log(u))
    return float(s)


@torch.no_grad()
def _batched_avg_nll(
    model,
    tokenizer,
    prompts: List[str],
    conts: List[str],
    batch_size: int = 16,
    max_length: int = 512,
) -> Tuple[List[float], List[int]]:
    """
    Returns (avg_nll_per_item, tokcount_per_item) for P(cont | prompt).
    """
    device = next(model.parameters()).device
    model.eval()
    avg_losses, tok_counts = [], []
    old_pad = tokenizer.padding_side
    tokenizer.padding_side = "right"

    for i in range(0, len(prompts), batch_size):
        ps = prompts[i: i + batch_size]
        cs = conts[i: i + batch_size]
        enc_p = tokenizer(
            ps,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=max_length,
        ).to(device)
        enc_full = tokenizer(
            [p + c for p, c in zip(ps, cs)],
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=max_length,
        ).to(device)
        labels = enc_full.input_ids.clone()
        p_lens = (enc_p.input_ids != tokenizer.pad_token_id).sum(dim=1)
        for b, pl in enumerate(p_lens.tolist()):
            labels[b, :pl] = -100

        with torch.autocast(
            device_type=device.type,
            enabled=(model.dtype in (torch.float16, torch.bfloat16)),
        ):
            logits = model(
                input_ids=enc_full.input_ids, attention_mask=enc_full.attention_mask
            ).logits

        shift_logits = logits[:, :-1, :]
        shift_labels = labels[:, 1:]
        mask = shift_labels != -100
        log_probs = torch.log_softmax(shift_logits, dim=-1)
        picked = log_probs.gather(-1, shift_labels.clamp_min(0).unsqueeze(-1)).squeeze(
            -1
        )
        picked = torch.where(mask, picked, torch.zeros_like(picked))

        nll_sum = (-picked).sum(dim=1)
        tok_cnt = mask.sum(dim=1).clamp_min(1)

        avg_losses.extend((nll_sum / tok_cnt).float().cpu().tolist())
        tok_counts.extend(tok_cnt.int().cpu().tolist())

    tokenizer.padding_side = old_pad
    return avg_losses, tok_counts


def _make_one_tofu_log_from_json(
    model,
    tokenizer,
    cfg,
    base_dataset_name: str,
    model_family_path: str,
    task_split: str,
    pert_json_file: Path,
    out_path: Path,
    batch_size: int,
):
    """Build one TOFU log JSON using **batched NLL scoring**."""
    out_path.parent.mkdir(parents=True, exist_ok=True)

    ds = TextDatasetQA(
        base_dataset_name,
        tokenizer=tokenizer,
        model_family=model_family_path,
        max_length=512,
        split=task_split,
        question_key="question",
        answer_key="answer",
    )
    qs = list(ds.data["question"])
    gts = list(ds.data["answer"])
    paras, perts = _load_perturbed_json(pert_json_file)
    N = min(len(qs), len(paras), len(perts))
    if N < len(qs):
        print(f"[{task_split}] truncating to {N} due to JSON length mismatch.")
        qs, gts, paras, perts = qs[:N], gts[:N], paras[:N], perts[:N]

    # ROUGE
    preds = []
    for i in range(0, N, batch_size):
        preds.extend(
            _batch_generate(
                model,
                tokenizer,
                qs[i: i + batch_size],
                cfg["generation"]["max_new_tokens"],
            )
        )
    scorer = rouge_scorer.RougeScorer(["rouge1", "rougeL"], use_stemmer=True)
    rouge1, rougeL = [], []
    for pred, gt in zip(preds, gts):
        rs = scorer.score(gt, pred)
        rouge1.append(float(rs["rouge1"].recall))
        rougeL.append(float(rs["rougeL"].recall))

    # Batched NLLs: ground truth + paraphrase
    gt_losses, gt_toks = _batched_avg_nll(model, tokenizer, qs, gts, batch_size)
    use_paras = [p if p and p.strip() else gt for p, gt in zip(paras, gts)]
    para_losses, para_toks = _batched_avg_nll(
        model, tokenizer, qs, use_paras, batch_size
    )

    # Batched NLLs: flattened perturbations
    flat_qs, flat_wrongs, owner = [], [], []
    for i, (q, plist) in enumerate(zip(qs, perts)):
        for w in plist or []:
            flat_qs.append(q)
            flat_wrongs.append(w)
            owner.append(i)

    pert_lists, pert_tok_lists = [[] for _ in range(N)], [[] for _ in range(N)]
    if flat_qs:
        wrong_losses, wrong_toks = _batched_avg_nll(
            model, tokenizer, flat_qs, flat_wrongs, batch_size
        )
        for loss, tok, idx in zip(wrong_losses, wrong_toks, owner):
            pert_lists[idx].append(loss)
            pert_tok_lists[idx].append(tok)

    # KL diagnostic
    kl_list = []
    for ploss, plist in zip(para_losses, pert_lists):
        p_para = math.exp(-ploss)
        norm_probs = [math.exp(-x) for x in plist if x == x]
        if norm_probs:
            dist = [p_para] + norm_probs
            Z = sum(dist)
            dist = [p / Z for p in dist]
            kl_list.append(_kl_to_uniform(dist))
        else:
            kl_list.append(0.0)

    # Truth ratio
    truth_ratio = []
    for plist, ploss in zip(pert_lists, para_losses):
        if plist and ploss == ploss:
            finite = [x for x in plist if x == x]
            truth_ratio.append(
                float(math.exp(sum(finite) / len(finite) - ploss))
                if finite
                else float("nan")
            )
        else:
            truth_ratio.append(float("nan"))

    payload = {
        "avg_gt_loss": gt_losses,
        "gt_loss": gt_losses,
        "num_token_gt": gt_toks,
        "generated_text": list(zip(preds, gts)),
        "rouge1_recall": rouge1,
        "rougeL_recall": rougeL,
        "paraphrased_loss": para_losses,
        "avg_paraphrased_loss": para_losses,
        "num_token_paraphrased": para_toks,
        "average_perturb_loss": pert_lists,
        "num_token_perturb": pert_tok_lists,
        "truth_ratio": truth_ratio,
        "kl_divergence": kl_list,
    }
    with open(out_path, "w", encoding="utf-8") as f:
        json.dump(payload, f, ensure_ascii=False, indent=2)
    print(f"[TOFU] wrote {out_path}")


def _load_four_logs(dir_path: Path) -> Dict[str, Any]:
    needed = {
        "eval_log.json",
        "eval_log_forget.json",
        "eval_real_author_wo_options.json",
        "eval_real_world_wo_options.json",
    }
    out = {}
    for n in needed:
        p = dir_path / n
        if not p.is_file():
            return {}
        out[n] = json.load(open(p, "r", encoding="utf-8"))
    return out


def _load_four_logs_with_method(dir_path, method):
    needed = {
        f"{method}-eval_log.json",
        f"{method}-eval_log_forget.json",
        f"{method}-eval_real_author_wo_options.json",
        f"{method}-eval_real_world_wo_options.json",
    }
    my_map = {
        f"{method}-eval_log.json": "eval_log.json",
        f"{method}-eval_log_forget.json": "eval_log_forget.json",
        f"{method}-eval_real_author_wo_options.json": "eval_real_author_wo_options.json",
        f"{method}-eval_real_world_wo_options.json": "eval_real_world_wo_options.json",
    }
    out = {}
    for n in needed:
        p = dir_path / n
        if not p.is_file():
            return {}
        out[my_map[n]] = json.load(open(p, "r", encoding="utf-8"))
    return out


def parse_adapter_id(adapter_id: str):
    try:
        tail = adapter_id.split("/")[-1]
        parts = tail.split("_")
        if len(parts) < 9 or parts[0:3] != ["msc", "unlearn", "lora"]:
            raise ValueError
        return {
            "rank": parts[3],
            "alpha": parts[4],
            "lr": parts[5],
            "method": parts[6],
            "dataset": parts[7],
            "dataset_subset": parts[8],
        }
    except Exception:
        return {
            "rank": "unknown",
            "alpha": "unknown",
            "lr": "unknown",
            "method": "baseline",
            "dataset": "tofu",
            "dataset_subset": "forget10",
        }


def run_eval(args):
    info = {}

    if args.adapter_id == "retain90":
        method = "retain90"
        info["dataset"] = "tofu"
        info["dataset_subset"] = "forget10"
        args.split = "forget10"

    else:

        info = (
            parse_adapter_id(args.adapter_id)
            if args.adapter_id
            else {
                "rank": "unknown",
                "alpha": "unknown",
                "lr": "unknown",
                "method": "baseline",
                "dataset": "tofu",
                "dataset_subset": args.split or "forget10",
            }
        )
        method, selected_dataset, dataset_subset = (
            info["method"],
            info["dataset"],
            info["dataset_subset"],
        )
        lr, rank, alpha = info["lr"], info["rank"], info["alpha"]

    forget_split, retain_folder = PAIRING.get(dataset_subset, ("forget10", "retain90"))

    epoch = args.epoch if args.epoch else "epoch-10"
    if method == "retain90":
        out_dir = os.path.join("experiments", selected_dataset, dataset_subset)
    else:
        out_dir = os.path.join(
            "experiments",
            selected_dataset,
            dataset_subset,
            lr,
            f"{rank}_{alpha}",
            epoch,
        )
    os.makedirs(out_dir, exist_ok=True)
    csv_path = os.path.join(out_dir, f"{method}.csv")
    metrics_path = os.path.join(out_dir, f"{method}_metrics.json")

    if method == "retain90":
        model, tokenizer = load_model_and_tokenizer(
            model_id_or_path=args.model_id,
            tokenizer_id=args.tokenizer_id,
            adapter_id_or_path=None,
            epoch=epoch,
            dtype_str=args.dtype,
            use_8bit=args.use_8bit,
            use_4bit=args.use_4bit,
            trust_remote_code=args.trust_remote_code,
            local_files_only=args.local_files_only,
            device_map=args.device_map,
            peft=args.peft,
        )

    else:

        model, tokenizer = load_model_and_tokenizer(
            model_id_or_path=args.model_id,
            tokenizer_id=args.tokenizer_id,
            adapter_id_or_path=args.adapter_id,
            epoch=epoch,
            dtype_str=args.dtype,
            use_8bit=args.use_8bit,
            use_4bit=args.use_4bit,
            trust_remote_code=args.trust_remote_code,
            local_files_only=args.local_files_only,
            device_map=args.device_map,
            peft=args.peft,
        )

    # original metric
    hf_name = "open-unlearning/tofu_Llama-2-7b-chat-hf_full"
    config_path = "tofu_model_config"
    max_length = 512
    if args.adapter_id == "retain90":
        print(f"loading the baseline dataset for {info["dataset_subset"]}")
        ds = TextDatasetQA(
            "locuslab/TOFU",
            tokenizer=tokenizer,
            model_family=os.path.join(config_path, hf_name),
            max_length=max_length,
            split=info["dataset_subset"],
            question_key="question",
            answer_key="answer",
        )
    else:
        print(f"loading dataset with {args.split}")

        ds = TextDatasetQA(
            "locuslab/TOFU",
            tokenizer=tokenizer,
            model_family=os.path.join(config_path, hf_name),
            max_length=max_length,
            split=args.split,
            question_key="question",
            answer_key="answer",
        )
    collator = QACollatorDynamicPad(
        tokenizer, formatting_tokens=ds.formatting_tokens, max_length=512
    )
    loader = torch.utils.data.DataLoader(
        ds, batch_size=args.batch_size, collate_fn=collator
    )

    with open(csv_path, "w", encoding="utf-8", newline="") as f_csv:
        writer = csv.writer(f_csv)
        writer.writerow(["question", "prediction", "reference"])
        all_preds, all_refs, all_logits, all_labels = [], [], [], []
        for batch in tqdm(loader):
            batch = {
                k: (v.to(model.device) if torch.is_tensor(v) else v)
                for k, v in batch.items()
            }
            with torch.no_grad():
                gen_outputs, ground_truths = run_generation(
                    cfg, batch, model, tokenizer, ds.formatting_tokens
                )
                qs = batch.get("questions", [""] * len(gen_outputs))
                for q, pred, ref in zip(qs, gen_outputs, ground_truths):
                    writer.writerow([q, pred, ref])
                all_preds.extend(gen_outputs)
                all_refs.extend(ground_truths)
                outputs = model(
                    input_ids=batch["input_ids"],
                    attention_mask=batch["attention_mask"],
                    labels=batch["labels"],
                )
                logits = outputs.logits.detach().to("cpu")
                labels_cpu = (
                    batch["labels"].detach().to("cpu")
                    if torch.is_tensor(batch["labels"])
                    else batch["labels"]
                )
                all_logits.append(logits)
                all_labels.append(labels_cpu)
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

    base_metrics = {
        "accuracy": eval_accuracy_any(all_logits, all_labels),
        "bleu_rouge": eval_bleu(all_preds, all_refs),
        "rouge_recall": eval_rouge_recall(all_preds, all_refs),
        "perplexity": eval_perplexity_any(all_logits, all_labels),
    }
    with open(metrics_path, "w", encoding="utf-8") as f_json:
        json.dump(base_metrics, f_json, ensure_ascii=False, indent=2)

    #  TOFU logs + score
    if args.perturbed_dir:
        tofu_dir = Path(out_dir)
        pert_dir = Path(args.perturbed_dir)
        model_family_path = os.path.join(config_path, hf_name)
        if args.adapter_id == "retain90":
            method == "retain90"
        else:
            method = parse_adapter_id(args.adapter_id)["method"]

        # Retain eval (HF retain_perturbed)
        _make_one_tofu_log_from_json(
            model,
            tokenizer,
            cfg,
            "locuslab/TOFU",
            model_family_path,
            task_split="retain_perturbed",
            pert_json_file=pert_dir / "retain_perturbed.json",
            out_path=tofu_dir / f"{method}-{TOFU_FILENAMES["retain"]}",
            batch_size=args.batch_size,
        )
        # Forget eval: forget10/05/01 according to adapter
        forget_json = PERT_JSON[forget_split]
        _make_one_tofu_log_from_json(
            model,
            tokenizer,
            cfg,
            "locuslab/TOFU",
            model_family_path,
            task_split=forget_split,
            pert_json_file=pert_dir / forget_json,
            out_path=tofu_dir / f"{method}-{TOFU_FILENAMES["forget"]}",
            batch_size=args.batch_size,
        )
        # Real Authors
        _make_one_tofu_log_from_json(
            model,
            tokenizer,
            cfg,
            "locuslab/TOFU",
            model_family_path,
            task_split=REAL_AUTHORS_SPLIT,
            pert_json_file=pert_dir / PERT_JSON["real_authors"],
            out_path=tofu_dir / f"{method}-{TOFU_FILENAMES["real_authors"]}",
            batch_size=args.batch_size,
        )
        # World Facts
        _make_one_tofu_log_from_json(
            model,
            tokenizer,
            cfg,
            "locuslab/TOFU",
            model_family_path,
            task_split=WORLD_FACTS_SPLIT,
            pert_json_file=pert_dir / PERT_JSON["world_facts"],
            out_path=tofu_dir / f"{method}-{TOFU_FILENAMES["world_facts"]}",
            batch_size=args.batch_size,
        )

        four = _load_four_logs_with_method(tofu_dir, method)

        tofu_scores: Dict[str, Any] = {}
        if four:
            # Optional sanity checks (won't modify data)
            for k in (
                "eval_real_author_wo_options.json",
                "eval_real_world_wo_options.json",
            ):
                if k in four:
                    if "avg_gt_loss" not in four[k]:
                        print(
                            f"[WARN] {k} missing 'avg_gt_loss' (expected for RA/WF probability)."
                        )
                    if "average_perturb_loss" not in four[k]:
                        print(
                            f"[WARN] {k} missing 'average_perturb_loss' (needed for RA/WF probability & Truth Ratio)."
                        )
                    if "avg_paraphrased_loss" not in four[k]:
                        print(
                            f"[WARN] {k} missing 'avg_paraphrased_loss' (needed for Truth Ratio)."
                        )

            # mu = get_model_utility(four)
            mu = get_matching_model_utility(four)

            # Store scalar and full breakdown
            tofu_scores["model_utility"] = mu.get("Model Utility", None)
            tofu_scores["model_utility_breakdown"] = mu

        # Forget Quality vs paired retain baseline folder (retain_root/retain90|95|99)
        if args.retain_root and four:
            baseline_dir = Path(args.retain_root) / retain_folder
            retain_logs = _load_four_logs(baseline_dir)
            if retain_logs:
                fq_stats, _ = get_forget_quality(four, retain_logs)
                tofu_scores.update(fq_stats)
            else:
                print(
                    f"[TOFU] Missing baseline logs in {baseline_dir}. Skipping Forget Quality."
                )

        if tofu_scores:
            with open(metrics_path, "r", encoding="utf-8") as f_json:
                payload = json.load(f_json)
            payload["tofu_scores"] = tofu_scores
            with open(metrics_path, "w", encoding="utf-8") as f_json:
                json.dump(payload, f_json, ensure_ascii=False, indent=2)
            print(f"[TOFU] scores appended for {epoch} → {metrics_path}")


if __name__ == "__main__":
    # if "HUGGINGFACE_HUB_TOKEN" in os.environ:
    #     login(token=os.environ["HUGGINGFACE_HUB_TOKEN"])
    #     api = HfApi()

    parser = argparse.ArgumentParser(description="Evaluate Causal LM + TOFU per-epoch")
    parser.add_argument("--model_id", type=str, required=True)
    parser.add_argument("--adapter_id", type=str, default=None)
    parser.add_argument("--tokenizer_id", type=str, default=None)
    parser.add_argument("--all_epochs", action="store_true")
    parser.add_argument("--peft", type=str, default=True)
    parser.add_argument(
        "--dtype", type=str, default="bf16", choices=["auto", "fp16", "bf16", "fp32"]
    )
    parser.add_argument("--use_8bit", action="store_true")
    parser.add_argument("--use_4bit", action="store_true")
    parser.add_argument("--trust_remote_code", action="store_true")
    parser.add_argument("--local_files_only", action="store_true")
    parser.add_argument("--epoch", default="", type=str)
    parser.add_argument("--batch_size", type=int, default=16)
    parser.add_argument("--split", type=str, default="forget10")
    parser.add_argument("--device_map", type=str, default="auto")
    parser.add_argument("--cfg", type=str, default="egu/config/eval.yaml")
    parser.add_argument(
        "--perturbed_dir",
        type=str,
        default=None,
        help="Folder with: retain_perturbed.json, forget{01,05,10}_perturbed.json, real_authors_perturbed.json, world_facts_perturbed.json",
    )
    parser.add_argument(
        "--retain_root",
        type=str,
        default=None,
        help="Folder with subdirs retain90/, retain95/, retain99/, each containing the 4 baseline logs.",
    )

    args = parser.parse_args()
    cfg = load_yaml(args.cfg)

    available_epochs = []
    parse_adapter_id(args.adapter_id)["method"]
    if args.adapter_id and parse_adapter_id(args.adapter_id)["method"] != "retain90":

        api = HfApi()
        files = api.list_repo_files(args.adapter_id, repo_type="model")
        available_epochs = sorted(
            {f.split("/")[0] for f in files if f.startswith("epoch-")}
        )
        available_epochs = sorted(available_epochs, key=lambda x: int(x.split("-")[1]))
        print("available epochs:", available_epochs)

    if args.all_epochs and available_epochs:
        for ep in available_epochs:
            print(f"\n=== running evaluation of {ep} ===")
            args.epoch = ep
            # try:
            run_eval(args)
            # except Exception:
            #     pass
    else:
        run_eval(args)
