# text_mcqa.py
import os
from typing import Any, Dict, List, Optional

import datasets
import torch
from datasets import Dataset as HFDataset
from datasets import load_dataset
from torch.utils.data import Dataset

from egu.utils.utils import get_model_identifiers_from_yaml

LETTER2IDX = {"A": 0, "B": 1, "C": 2, "D": 3}
IDX2LETTER = {v: k for k, v in LETTER2IDX.items()}


def _safe_answer_letter(answer, choices: Optional[List[str]] = None) -> str:
    """
    Normalize various answer encodings to a single letter 'A'..'E'.
    Handles: int index, str index, 'A'..'E', '3 D', full option text, etc.
    """
    # int index
    if isinstance(answer, int):
        if 0 <= answer < 5:
            return IDX2LETTER.get(answer, "A")
        # some sets use 1-based
        if 1 <= answer <= 5:
            return IDX2LETTER.get(answer - 1, "A")

    # str forms
    if isinstance(answer, str):
        s = answer.strip()
        # exact letter
        if len(s) == 1 and s.upper() in LETTER2IDX:
            return s.upper()
        # things like "3 D" or "Answer: D"
        tail = s.split()[-1].upper()
        if tail in LETTER2IDX:
            return tail
        # numeric index
        if s.isdigit():
            idx = int(s)
            if 0 <= idx < 5:
                return IDX2LETTER.get(idx, "A")
            if 1 <= idx <= 5:
                return IDX2LETTER.get(idx - 1, "A")
        # match by option text
        if choices:
            try:
                # exact match
                j = choices.index(s)
                return IDX2LETTER.get(j, "A")
            except ValueError:
                pass

    # fallback
    return "A"


def _build_prompt_no_answer(
    formatting_tokens: Dict[str, str], question: str, choices: List[str]
) -> str:
    """
    Simple, label-free MC prompt that ends with 'Answer:'.
    You can adapt to your chat template here if needed.
    """
    sys_p = formatting_tokens.get("system_prompt", None)
    lines = []
    if sys_p:
        lines += [sys_p, ""]
    lines += [question, "", "Choices:"]
    for L, opt in zip("ABCD", choices):
        lines.append(f"{L}. {opt}")
    lines.append("Answer:")
    return "\n".join(lines)


class TextDatasetMCQA(Dataset):
    """
    Multiple-choice QA dataset for tinyMMLU-style evaluation.

    Returns items with:
      - question: str
      - choices: List[str]
      - answer_letter: str  ('A'..'D')
      - subject: str (if present, else 'unknown')
      - prompt: str (label-free prompt ending with 'Answer:')
    """

    def __init__(
        self,
        data_path: str,  # HF repo like 'cais/mmlu' variant or your tinyMMLU repo, or local json dir
        tokenizer,  # kept for parity with your API (not used here directly)
        model_family: str,  # path to your YAML; used only to fetch formatting tokens
        split: str = "test",
        max_length: int = 512,
        question_key: str = "question",
        choices_key: str = "choices",
        answer_key: str = "answer",
        subject_key: str = "subject",
    ):
        super().__init__()
        self.tokenizer = tokenizer
        self.max_length = max_length

        # Load dataset (HF or local like your QA class)
        if "egu/dataset/raw" not in data_path:
            self.data: HFDataset = load_dataset(
                data_path, split=split
            )  # e.g., 'tiny_mmlu', split name
        else:
            self.data: HFDataset = load_dataset(
                "json", data_files=os.path.join(data_path, split + ".json")
            )["train"]

        # Model formatting tokens (optional)

        self.model_configs = get_model_identifiers_from_yaml(model_family)
        self.formatting_tokens = self.model_configs.get("formatting_tokens", {})

        self.qk = question_key
        self.ck = choices_key
        self.ak = answer_key
        self.sk = subject_key

        # Precompute normalized fields via map to keep __getitem__ lightweight
        def _map_row(ex):
            q = ex[self.qk]
            choices = ex[self.ck]
            # if isinstance(choices, dict):  # arc column has "text"
            #     choices = ex[self.ck]["text"]
            # ensure list of length up to 5
            if not isinstance(choices, list):
                # some datasets store as dict; take values in order
                # choices = list(choices.values())
                choices = ex[self.ck][
                    "text"
                ]  # arc column stored as dict and has "text"
            choices = choices[:5]
            gold_letter = _safe_answer_letter(ex[self.ak], choices)
            subj = ex.get(self.sk, "unknown")
            prompt = _build_prompt_no_answer(self.formatting_tokens, q, choices)
            ex["__prompt"] = prompt
            ex["__answer_letter"] = gold_letter
            ex["__choices_norm"] = choices
            ex["__subject_norm"] = subj
            return ex

        self.data = self.data.map(_map_row, remove_columns=[])
        # Optional: filter out broken rows (no choices)
        self.data = self.data.filter(
            lambda x: isinstance(x["__choices_norm"], list)
            and len(x["__choices_norm"]) >= 2
        )

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx) -> Dict[str, Any]:
        ex = self.data[idx]
        return {
            "question": ex[self.qk],
            "choices": ex["__choices_norm"],
            "answer_letter": ex["__answer_letter"],
            "subject": ex["__subject_norm"],
            "prompt": ex["__prompt"],  # label-free formatted input
        }


class QAMCCollatorDynamicPad:
    """
    Collator for MCQA evaluation.
    Tokenizes the label-free prompt, pads (left padding for generation convenience),
    and passes through raw fields for log-prob scoring or constrained generation.
    """

    def __init__(self, tokenizer, max_length: int = 512, left_pad: bool = True):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.left_pad = left_pad

    def __call__(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
        prompts = [ex["prompt"] for ex in batch]
        old_side = self.tokenizer.padding_side
        if self.left_pad:
            self.tokenizer.padding_side = "left"

        enc = self.tokenizer(
            prompts,
            add_special_tokens=True,
            padding=True,
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
        )

        # Restore padding side
        self.tokenizer.padding_side = old_side

        # Pass through raw fields for scoring
        out = {
            "input_ids": enc.input_ids,
            "attention_mask": enc.attention_mask.to(torch.long),
            "questions": [ex["question"] for ex in batch],
            "choices": [ex["choices"] for ex in batch],
            "answer": [ex["answer_letter"] for ex in batch],  # gold letters
            "subject": [ex["subject"] for ex in batch],
            "prompts": prompts,
        }
        return out
