import argparse
import csv
import json
import os
from collections import defaultdict
from typing import Any, Dict, List, Tuple

import torch
import torch.nn.functional as F
from huggingface_hub import HfApi, login

from egu.dataset.text_mcqa import QAMCCollatorDynamicPad, TextDatasetMCQA
from egu.evaluators.utils import load_model_and_tokenizer
from egu.utils.utils import load_yaml


@torch.no_grad()
def score_options_by_logprob(
    model, tokenizer, prompts: List[str], options_list: List[List[str]]
) -> torch.Tensor:
    device = next(model.parameters()).device
    model.eval()

    if tokenizer.pad_token_id is None:
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.eos_token_id

    flat_ids, spans, row_slices = [], [], []
    max_k = max(len(opts) for opts in options_list)

    # pre-encode prompts once
    prompt_ids_list = [
        tokenizer(p, add_special_tokens=True, return_tensors="pt").input_ids[0]
        for p in prompts
    ]

    cur = 0
    for p_ids, opts in zip(prompt_ids_list, options_list):
        start = cur
        for o in opts:
            enc_o = tokenizer(" " + o, add_special_tokens=False, return_tensors="pt")
            ids = torch.cat([p_ids, enc_o.input_ids[0]], dim=0)
            flat_ids.append(ids)
            spans.append((p_ids.shape[-1], ids.shape[-1]))
            cur += 1
        # pad row to max_k with dummy entries
        for _ in range(max_k - len(opts)):
            ids = p_ids.clone()
            flat_ids.append(ids)
            spans.append((ids.shape[-1], ids.shape[-1]))
            cur += 1
        row_slices.append((start, start + max_k))

    # pad batch
    max_len = max(t.shape[0] for t in flat_ids)
    pad_id = tokenizer.pad_token_id
    input_ids = torch.full((len(flat_ids), max_len), pad_id, dtype=torch.long)
    attention_mask = torch.zeros_like(input_ids, dtype=torch.long)
    for i, ids in enumerate(flat_ids):
        L = ids.shape[0]
        input_ids[i, :L] = ids
        attention_mask[i, :L] = 1

    input_ids = input_ids.to(device)
    attention_mask = attention_mask.to(device)

    logits = model(input_ids=input_ids, attention_mask=attention_mask).logits
    logits = logits[:, :-1, :]
    tgt = input_ids[:, 1:]
    logp_tok = F.log_softmax(logits, dim=-1).gather(-1, tgt.unsqueeze(-1)).squeeze(-1)

    scores = []
    for i, (opt_start, opt_end) in enumerate(spans):
        if opt_end - opt_start <= 0:
            scores.append(float("-inf"))
            continue
        s = max(0, opt_start - 1)
        e = max(0, opt_end - 1)
        lp = logp_tok[i, s:e]
        scores.append(lp.mean().item() if lp.numel() > 0 else float("-inf"))

    rows = [scores[s:e] for (s, e) in row_slices]
    return torch.tensor(rows)


# ----------------------- path + parsing utils -----------------------


def parse_adapter_id(adapter_id: str) -> Dict[str, str]:
    """
    Expect: org/msc_unlearn_lora_{rank}_{alpha}_{lr}_{method}_{dataset}_{dataset_subset}
    Returns dict with keys: rank, alpha, lr, method, dataset, dataset_subset.
    Falls back to 'unknown' on mismatch.
    """
    try:
        repo = adapter_id.split("/")[-1]
        parts = repo.split("_")
        # e.g. ['msc','unlearn','lora','16','32','1e-05','npo','tofu','forget10']
        rank, alpha, lr, method, dataset, subset = (
            parts[3],
            parts[4],
            parts[5],
            parts[6],
            parts[7],
            parts[8],
        )
        return {
            "rank": rank,
            "alpha": alpha,
            "lr": lr,
            "method": method,
            "dataset": dataset,
            "dataset_subset": subset,
        }
    except Exception:
        return {
            "rank": "r",
            "alpha": "a",
            "lr": "lr",
            "method": "method",
            "dataset": "dataset",
            "dataset_subset": "subset",
        }


def epoch_dir_from_arg(epoch_arg: str) -> str:
    if not epoch_arg:
        return "epoch-unknown"
    return epoch_arg if epoch_arg.startswith("epoch-") else f"epoch-{epoch_arg}"


def make_out_paths(root: str, meta: Dict[str, str], epoch_arg: str) -> Tuple[str, str]:
    """
    Build CSV path and JSON metrics path.
    root: 'experiments/utility'
    """
    epoch_dir = epoch_dir_from_arg(epoch_arg)
    out_dir = os.path.join(
        root,
        meta["dataset"],
        meta["dataset_subset"],
        meta["lr"],
        f"{meta['rank']}_{meta['alpha']}",
        epoch_dir,
    )
    os.makedirs(out_dir, exist_ok=True)
    csv_path = os.path.join(out_dir, f"{meta['method']}.csv")
    json_path = os.path.join(out_dir, f"{meta['method']}_metrics.json")
    return csv_path, json_path


# ----------------------- evaluation (one epoch) -----------------------


def evaluate_one_epoch(args) -> None:
    # Load model/tokenizer
    model, tokenizer = load_model_and_tokenizer(
        model_id_or_path=args.model_id,
        tokenizer_id=args.tokenizer_id,
        adapter_id_or_path=args.adapter_id,
        epoch=args.epoch,
        dtype_str=args.dtype,
        use_8bit=args.use_8bit,
        use_4bit=args.use_4bit,
        trust_remote_code=args.trust_remote_code,
        local_files_only=args.local_files_only,
        device_map=args.device_map,
        peft=args.peft,
    )
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.eos_token_id

    hf_name = "open-unlearning/tofu_Llama-2-7b-chat-hf_full"
    config_path = "tofu_model_config"

    # MMLU
    question_key: str = "question"
    choices_key: str = "choices"
    answer_key: str = "answer"
    subject_key: str = "subject"

    # ARC
    # question_key: str = "question"
    # choices_key: str = "choices"
    # answer_key: str = "answerKey"
    # subject_key: str = "id"
    my_root = ""
    if "tinyMMLU" in args.data_path:
        my_root = "mmlu"
        question_key: str = "question"
        choices_key: str = "choices"
        answer_key: str = "answer"
        subject_key: str = "subject"
    elif "tinyAI2" in args.data_path:
        my_root = "arc"
        question_key: str = "question"
        choices_key: str = "choices"
        answer_key: str = "answerKey"
        subject_key: str = "id"

    # Dataset + loader
    dataset = TextDatasetMCQA(
        data_path=args.data_path,
        tokenizer=tokenizer,
        model_family=os.path.join(config_path, hf_name),
        split=args.split,
        question_key=question_key,
        choices_key=choices_key,
        answer_key=answer_key,
        subject_key=subject_key,
    )
    collator = QAMCCollatorDynamicPad(tokenizer, max_length=512, left_pad=True)
    loader = torch.utils.data.DataLoader(
        dataset, batch_size=args.batch_size, collate_fn=collator
    )

    # Output paths
    meta = parse_adapter_id(args.adapter_id or "")
    csv_path, json_path = make_out_paths(
        f"experiments/utility/{my_root}", meta, args.epoch
    )

    # Evaluate + collect rows
    idx2letter = "ABCD"
    total, correct = 0, 0
    per_subject_ct = defaultdict(lambda: [0, 0])
    rows = []  # (question, predicted, golden)

    for batch in loader:
        prompts = batch["prompts"]
        choices_list = batch["choices"]
        gold_letters = batch["answer"]
        subjects = batch.get("subject", ["unknown"] * len(prompts))
        questions = batch["questions"]

        scores = score_options_by_logprob(model, tokenizer, prompts, choices_list)
        pred_idx = scores.argmax(dim=1).tolist()
        preds = [idx2letter[i] for i in pred_idx]

        for q, pred, gold, subj in zip(questions, preds, gold_letters, subjects):
            rows.append((q, pred, gold))
            total += 1
            ok = int(pred == gold)
            correct += ok
            per_subject_ct[subj][0] += ok
            per_subject_ct[subj][1] += 1

    acc = correct / max(1, total)
    per_subject = {s: c / t if t else 0.0 for s, (c, t) in per_subject_ct.items()}
    results = {"accuracy": acc, "n": total, "per_subject": per_subject}

    # Write CSV
    with open(csv_path, "w", encoding="utf-8", newline="") as f:
        w = csv.writer(f)
        w.writerow(["question", "predicted", "golden"])
        w.writerows(rows)

    # Write metrics JSON — named {method}_metrics.json
    with open(json_path, "w", encoding="utf-8") as f:
        json.dump(
            {
                "accuracy": results["accuracy"],
                "n_examples": results["n"],
                "per_subject": results["per_subject"],
            },
            f,
            ensure_ascii=False,
            indent=2,
        )

    print(
        f"[tinyMMLU] epoch={args.epoch}  acc={acc:.4f}  n={total}  csv='{csv_path}'  json='{json_path}'"
    )
    with torch.no_grad():
        try:
            model.to("cpu")
        except Exception:
            pass
    del model
    del loader, dataset, collator


# ----------------------- main -----------------------


if __name__ == "__main__":
    if "HUGGINGFACE_HUB_TOKEN" in os.environ:
        login(token=os.environ["HUGGINGFACE_HUB_TOKEN"])
        api = HfApi()
    else:
        api = HfApi()  # unauth still allows public listing

    p = argparse.ArgumentParser("Evaluate tinyMMLU and write CSV/JSON.")
    p.add_argument("--model_id", required=True, type=str)
    p.add_argument("--adapter_id", default=None, type=str)
    p.add_argument("--tokenizer_id", default=None, type=str)
    p.add_argument("--peft", default=True, type=str)
    p.add_argument("--dtype", default="fp16", choices=["auto", "fp16", "bf16", "fp32"])
    p.add_argument("--use_8bit", action="store_true")
    p.add_argument("--use_4bit", action="store_true")
    p.add_argument("--trust_remote_code", action="store_true")
    p.add_argument("--local_files_only", action="store_true")
    p.add_argument("--epoch", default="", type=str)
    p.add_argument("--all_epochs", action="store_true")
    p.add_argument("--batch_size", default=4, type=int)
    p.add_argument("--split", default="test", type=str)
    p.add_argument("--device_map", default="auto", type=str)
    p.add_argument("--cfg", default="egu/config/eval.yaml", type=str)
    p.add_argument(
        "--data_path",
        required=True,
        type=str,
        help="HF repo id or local JSON dir for tinyMMLU.",
    )
    args = p.parse_args()

    # load config (even if not strictly needed here)
    _ = load_yaml(args.cfg)

    # all-epochs branch
    if args.all_epochs and args.adapter_id:
        files = api.list_repo_files(args.adapter_id, repo_type="model")
        available_epochs = sorted(
            {f.split("/")[0] for f in files if f.startswith("epoch-")}
        )
        print("available epochs:", available_epochs)
        for ep in available_epochs:
            print("running evaluation of", ep)
            args.epoch = ep
            evaluate_one_epoch(args)
            import gc

            gc.collect()
            gc.collect()
            if torch.cuda.is_available():  # clear the cache after each batch
                torch.cuda.empty_cache()
    else:
        if not args.epoch:
            print("No --epoch provided; running once with epoch='epoch-unknown'")
            args.epoch = "epoch-unknown"
        evaluate_one_epoch(args)
