import os
from typing import Any, Dict, List, Optional, Tuple, Union

import datasets
import torch
from torch.utils.data import Dataset

from egu.utils.utils import get_model_identifiers_from_yaml


def build_prompt(formatting: Dict, question: str) -> str:
    inst_s = formatting.get("inst_start", "[INST]")
    inst_e = formatting.get("inst_end", "[/INST]")
    return f"{inst_s} {question} {inst_e}"
    # sys_s = formatting.get("sys_start", "<<SYS>>")
    # sys_e = formatting.get("sys_end", "<</SYS>>")
    # sys_p = formatting.get("system_prompt", None)
    # if sys_p is not None:
    #     return f"{inst_s} {sys_s}\n{sys_p}\n{sys_e}\n\n{question} {inst_e}"
    return f"{question} "


def build_full_text(formatting: Dict, question: str, answer: str) -> str:
    prompt = build_prompt(formatting, question)
    return prompt + ("" if prompt.endswith(" ") else " ") + answer


class QACollatorDynamicPad:
    def __init__(
        self, tokenizer, formatting_tokens: Dict, max_length: int | None = None
    ):
        self.tok = tokenizer
        self.ftoks = formatting_tokens or {}
        if self.tok.pad_token_id is None:
            self.tok.pad_token = self.tok.eos_token
            self.tok.pad_token_id = self.tok.eos_token_id
        self.pad_id: int = int(self.tok.pad_token_id)  # <-- force plain int
        self.max_length = max_length

    def _ensure_1d_ids(self, ids):
        # HF can return [[...]] if return_tensors=None sometimes (tokenizers differ)
        return (
            ids[0]
            if isinstance(ids, (list, tuple)) and ids and isinstance(ids[0], list)
            else ids
        )

    def _pad_1d(self, t: torch.Tensor, L: int, fill: int) -> torch.Tensor:
        t = t[:L]
        out = t.new_full((L,), int(fill))  # ensure fill is in-range int for dtype
        out[: t.shape[0]] = t
        return out

    def __call__(self, batch: List[Dict]) -> Dict[str, torch.Tensor]:
        pad_id = self.pad_id
        seqs: List[Tuple[torch.Tensor, int]] = []  # (input_ids, prompt_len)

        for ex in batch:
            q, a = ex["question"], ex["answer"]
            prompt_text = build_prompt(self.ftoks, q)
            full_text = build_full_text(self.ftoks, q, a)

            tok_prompt = self.tok(
                prompt_text, add_special_tokens=True, return_tensors=None
            )
            tok_full = self.tok(
                full_text,
                add_special_tokens=True,
                return_tensors=None,
                truncation=self.max_length is not None,
                max_length=self.max_length if self.max_length is not None else None,
            )

            ids_prompt = self._ensure_1d_ids(tok_prompt["input_ids"])
            ids_full = self._ensure_1d_ids(tok_full["input_ids"])

            input_ids = torch.tensor(ids_full, dtype=torch.int64)  # explicit int64
            prompt_len = int(len(ids_prompt))
            seqs.append((input_ids, prompt_len))

        lengths = [s[0].shape[0] for s in seqs]
        L = (
            min(max(lengths), self.max_length)
            if self.max_length is not None
            else max(lengths)
        )

        input_ids = torch.stack([self._pad_1d(s[0], L, pad_id) for s in seqs], dim=0)
        attention_mask = (input_ids != pad_id).to(torch.int64)  # safe 0/1 int64

        # labels: -100 on prompt & pad, tokens on answer
        labels = input_ids.clone().to(torch.int64)
        labels.fill_(-100)
        nonpad_len = (input_ids != pad_id).sum(dim=1).tolist()
        for i, (_, prompt_len) in enumerate(seqs):
            ans_end = int(nonpad_len[i])
            p_len = min(int(prompt_len), ans_end)  # guard truncation
            if p_len < ans_end:
                labels[i, p_len:ans_end] = input_ids[i, p_len:ans_end]
            if ans_end < L:
                labels[i, ans_end:] = -100

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels,
            "questions": [ex["question"] for ex in batch],
            "answers": [ex["answer"] for ex in batch],
        }


class TextDatasetQA(Dataset):
    def __init__(
        self,
        data_path,  # locuslab/tofu
        tokenizer,
        model_family,
        max_length=512,
        split=None,  # forget10
        question_key="question",
        answer_key="answer",
    ):

        super(TextDatasetQA, self).__init__()
        self.tokenizer = tokenizer
        self.max_length = max_length
        if "egu/dataset/raw" not in data_path:  # load dataset from hugingface hub.
            self.data = datasets.load_dataset(data_path, split)["train"]
        else:
            self.data = datasets.load_dataset(
                "json", data_files=os.path.join(data_path, split + ".json")
            )["train"]
        self.model_configs = get_model_identifiers_from_yaml(model_family)
        self.formatting_tokens = self.model_configs.get("formatting_tokens", {})
        self.qk = question_key
        self.ak = answer_key

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        ex = self.data[idx]
        q = ex[self.qk]
        a = ex[self.ak]
        if isinstance(a, (list, tuple)):  # safety, though you said single answer
            a = a[0]
        return {"question": q, "answer": a}
