from typing import Dict, List, Optional

import torch


class DataCollatorTDEC:

    def __init__(
        self,
        tokenizer,
        max_length: int = 256,
        idk_phrases: Optional[List[str]] = None,
        idk_file: Optional[str] = None,
        use_ids_first: bool = True,
    ):
        self.tok = tokenizer
        self.max_length = max_length
        self.use_ids_first = use_ids_first
        self.DEFAULT_IDK = open("egu/dataset/idontknow.jsonl", "r").readlines()

        # GPT-2 family has no pad by default -> use EOS as pad (standard for causal LMs)
        self.pad_id = getattr(self.tok, "pad_token_id", None)
        if self.pad_id is None:
            self.tok.pad_token = self.tok.eos_token
            self.pad_id = self.tok.eos_token_id
        self.eos_id = self.tok.eos_token_id

        phrases = idk_phrases if idk_phrases else self.DEFAULT_IDK
        if idk_file:
            try:
                with open(idk_file, "r", encoding="utf-8") as f:
                    lines = [ln.strip() for ln in f if ln.strip()]
                if lines:
                    phrases = lines
            except Exception:
                pass
        self.idk_token_lists: List[List[int]] = [
            self.tok(p, add_special_tokens=False).input_ids for p in phrases
        ]
        self.idk_token_lists = [t for t in self.idk_token_lists if t] or [[self.eos_id]]


    def _pad(self, ids: List[int]) -> Dict[str, torch.Tensor]:
        L = len(ids)
        if L >= self.max_length:
            ids = ids[: self.max_length]
            attn = [1] * self.max_length
        else:
            pad = [self.pad_id] * (self.max_length - L)
            ids = ids + pad
            attn = [1] * L + [0] * (self.max_length - L)
        return {
            "input_ids": torch.tensor(ids, dtype=torch.long),
            "attention_mask": torch.tensor(attn, dtype=torch.long),
        }

    def _labels_mask_prefix(
        self, padded_ids: torch.Tensor, attn: torch.Tensor, prefix_len: int
    ) -> torch.Tensor:
        labels = padded_ids.clone()
        labels[attn == 0] = -100  # mask pads
        if prefix_len > 0:
            labels[:prefix_len] = -100  # mask prefix (no loss on prompt)
        return labels

    def _get_ids(self, obj: Dict, text_key: str, ids_key: str) -> List[int]:
        if (
            self.use_ids_first
            and ids_key in obj
            and isinstance(obj[ids_key], list)
            and len(obj[ids_key]) > 0
        ):
            return [int(x) for x in obj[ids_key]]
        txt = obj.get(text_key, "")
        return self.tok(txt, add_special_tokens=False).input_ids

    # ---------- main ----------

    def __call__(self, batch: List[Dict]) -> Dict[str, torch.Tensor]:
        f_inp, f_attn, f_lab = [], [], []
        r_inp, r_attn, r_lab = [], [], []
        i_inp, i_attn, i_lab = [], [], []

        for pair in batch:
            # ===== FORGET (prefix + gold suffix) =====
            f = pair["forget"]
            prefix_ids = self._get_ids(f, "prefix_text", "prefix_ids")
            suffix_ids = self._get_ids(f, "suffix_text", "suffix_ids")

            max_suffix = max(self.max_length - len(prefix_ids), 0)
            if max_suffix == 0:
                full_forget = prefix_ids[-self.max_length:]
                prefix_len = self.max_length
            else:
                full_forget = prefix_ids + suffix_ids[:max_suffix]
                prefix_len = len(prefix_ids)

            fpad = self._pad(full_forget)
            flab = self._labels_mask_prefix(
                fpad["input_ids"], fpad["attention_mask"], prefix_len
            )
            f_inp.append(fpad["input_ids"])
            f_attn.append(fpad["attention_mask"])
            f_lab.append(flab)

            # retain
            r = pair["retain"]
            if "input_ids" in r and isinstance(r["input_ids"], list) and r["input_ids"]:
                ids_r = [int(x) for x in r["input_ids"]]
            else:
                ids_r = self._get_ids(r, "text", "input_ids")

            rpad = self._pad(ids_r)
            rlabels = rpad["input_ids"].clone()
            rlabels[rpad["attention_mask"] == 0] = -100
            r_inp.append(rpad["input_ids"])
            r_attn.append(rpad["attention_mask"])
            r_lab.append(rlabels)

            # idk
            idk_suffix_ids = self.idk_token_lists[
                torch.randint(0, len(self.idk_token_lists), ()).item()
            ]
            max_idk = max(self.max_length - len(prefix_ids), 0)
            if max_idk == 0:
                full_idk = prefix_ids[-self.max_length:]
                prefix_len_idk = self.max_length
            else:
                full_idk = prefix_ids + idk_suffix_ids[:max_idk]
                prefix_len_idk = len(prefix_ids)

            ipad = self._pad(full_idk)
            ilab = self._labels_mask_prefix(
                ipad["input_ids"], ipad["attention_mask"], prefix_len_idk
            )
            i_inp.append(ipad["input_ids"])
            i_attn.append(ipad["attention_mask"])
            i_lab.append(ilab)

        return {
            "forget_input_ids": torch.stack(f_inp),
            "forget_attention_mask": torch.stack(f_attn),
            "forget_labels": torch.stack(f_lab),
            "retain_input_ids": torch.stack(r_inp),
            "retain_attention_mask": torch.stack(r_attn),
            "retain_labels": torch.stack(r_lab),
            "idk_input_ids": torch.stack(i_inp),
            "idk_attention_mask": torch.stack(i_attn),
            "idk_labels": torch.stack(i_lab),
        }
