from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import torch
from transformers import AutoProcessor, PreTrainedTokenizerBase
import re
from utils.tasks import CausalQATask

CHOICE_RE = re.compile(r"(?:\b([A-Z])\s*\.|\(\s*([A-Z])\s*\))", re.UNICODE)

def parse_choice_letters(prompt: str) -> list[str]:
    """Parse choice letters from a prompt string.

    Args:
        prompt (str): The input prompt containing choices.
    Returns:
        list[str]: A list of choice letters extracted from the prompt.
    """
    letters = CHOICE_RE.findall(prompt)
    seen = set()
    unique_letters = []
    for letter_tuple in letters:
        letter = letter_tuple[0] if letter_tuple[0] else letter_tuple[1]
        if letter not in seen:
            seen.add(letter)
            unique_letters.append(letter)
    return unique_letters if unique_letters else ["A", "B", "C", "D"]

def get_letter_token_id(tokenizer: PreTrainedTokenizerBase, letter: str) -> int:
    """Get the token ID for a given choice letter.

    Args:
        tokenizer (PreTrainedTokenizerBase): The tokenizer to use.
        letter (str): The choice letter (e.g., 'A', 'B', etc.).
    Returns:
        int: The token ID corresponding to the letter.
    """
    token = f"{letter}"
    tok = tokenizer(token, add_special_tokens=False).input_ids
    if not tok:
        tok = tokenizer(letter, add_special_tokens=False).input_ids
        if not tok:
            raise ValueError(f"Cannot find token ID for letter '{letter}'")
    return tok[-1]

ASSISTANT_TOKEN = "<|assistant|>"
END_TOKEN = "<|end|>"

class DataCollator:
    """
    Collate callable that prepares causal QA samples for multimodal models.

    Args:
        processor (AutoProcessor): Hugging Face processor handling vision-text batching.
        tokenizer (PreTrainedTokenizerBase): Tokenizer used to map choice letters to token IDs.
        max_length (int): Maximum tokenized prompt length passed to the processor.
    """
    def __init__(self, processor: AutoProcessor, tokenizer: PreTrainedTokenizerBase, max_length: int, tasks: Optional[List[CausalQATask]] = None):
        self.processor = processor
        self.tokenizer = tokenizer
        self.max_length = max_length

        self.assistant_token = ASSISTANT_TOKEN
        self.end_token = END_TOKEN
        self.assistant_token_id = self.tokenizer.convert_tokens_to_ids(self.assistant_token)
        self.end_token_id = self.tokenizer.convert_tokens_to_ids(self.end_token)
        if self.assistant_token_id is None or self.end_token_id is None:
            raise ValueError("Assistant or end token not found in tokenizer vocabulary.")
        
        if tasks:
            self.exp_tasks = tasks
        else:
            self.exp_tasks = []

    def __call__(self, features: List[Dict[str, Any]], num_items_in_batch: Optional[int] = None) -> Dict[str, Any]:
        images_batch = [f["images"] for f in features]
        flat_images = []
        for pair in images_batch:
            # assert len(pair) == 2, "Each sample must contain exactly two images."
                
            flat_images.extend(pair)

        prompts = [f["prompt"] for f in features]
        completions = [f["label"] for f in features]
        for c in completions:
            if not c.endswith(self.end_token):
                raise ValueError(f"Completion '{c}' does not end with end token '{self.end_token}'")
            
        for p in prompts:
            if self.assistant_token not in p:
                raise ValueError(f"Prompt '{p}' does not contain assistant token '{self.assistant_token}'")
            
        full_texts = [p + c for p, c in zip(prompts, completions)]

        proc = self.processor(
            text=full_texts,
            images=flat_images,
            padding=True,
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
        )
        prompt_proc = self.processor(
            text=prompts,
            images=flat_images,
            padding=True,
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
        )
        
        for k in prompt_proc:
            proc[k + "_prompt"] = prompt_proc[k]

        # ===== Extra eval prompt variants (optional) =====
        extra_map = {
            "prompt_nodeexp": "_nodeexp_prompt",
            "prompt_astar": "_astar_prompt",
            "prompt_both": "_both_prompt",
        }

        for field, suffix in extra_map.items():
            if any(field in f for f in features):
                extra_prompts = [f[field] for f in features]
                extra_proc = self.processor(
                    text=extra_prompts,
                    images=flat_images,
                    padding=True,
                    truncation=True,
                    max_length=self.max_length,
                    return_tensors="pt",
                )
                for k in extra_proc:
                    proc[k + suffix] = extra_proc[k]

        input_ids = proc["input_ids"]
        attention_mask = proc["attention_mask"]

        labels = input_ids.clone()
        # Mask out tokens after assistant token and before end token
        B, L = input_ids.size()
        for b in range(B):
            seq = input_ids[b]
            idx = (seq == self.assistant_token_id).nonzero(as_tuple=False).squeeze(-1)
            if idx.numel() == 0:
                labels[b, :] = -100
                continue
            start_idx = idx[-1].item() + 1  # After last assistant token
            labels[b, :start_idx] = -100

        if attention_mask is not None:
            labels[attention_mask == 0] = -100
        proc["labels"] = labels
            
        A_star = torch.stack([f["A_star"] for f in features], dim=0) 
        A_mask = torch.stack([f["A_mask"] for f in features], dim=0)
        proc["A_star"] = A_star
        proc["A_mask"] = A_mask

        choices_list = [parse_choice_letters(f["prompt"]) for f in features]
        max_ops = max(len(choices) for choices in choices_list)
        choice_ids = torch.full((len(features), max_ops), fill_value=-100, dtype=torch.long)
        choice_mask = torch.zeros((len(features), max_ops), dtype=torch.bool)
        target_index = torch.full((len(features),), fill_value=-1, dtype=torch.long)

        for i, (letters, feat) in enumerate(zip(choices_list, features)):
            ids = [get_letter_token_id(self.tokenizer, l) for l in letters]
            choice_ids[i, :len(ids)] = torch.tensor(ids, dtype=torch.long)
            choice_mask[i, :len(ids)] = True
            gt = feat["label"].strip().upper()[0]
            if gt in letters:
                target_index[i] = letters.index(gt)

        proc["choice_token_ids"] = choice_ids
        proc["choice_mask"] = choice_mask
        proc["target_index"] = target_index
        proc["answer"] = [f["answer"] for f in features]
        proc["ids"] = [f["id"] for f in features]

        if CausalQATask.ALIGNMENT in self.exp_tasks:
            node_prompt_emb = torch.stack([f["node_prompt_emb"] for f in features], dim=0)
            explanation_prompt_emb = torch.stack([f["explanation_prompt_emb"] for f in features], dim=0)
            proc["node_prompt_emb"] = node_prompt_emb
            proc["explanation_prompt_emb"] = explanation_prompt_emb
            node_mask = (node_prompt_emb.abs().sum(-1) > 0)  # (B,d_max) bool
            proc["node_mask"] = node_mask
            proc["node_mask_prompt"] = node_mask

        return proc


class DataCollatorCF:
    """
    Counterfactual: free-form structured generation (variable: value lines)
    """
    def __init__(
        self,
        processor: AutoProcessor,
        tokenizer: PreTrainedTokenizerBase,
        max_length: int,
        tasks: Optional[List[CausalQATask]] = None
    ):
        self.processor = processor
        self.tokenizer = tokenizer
        self.max_length = max_length

        self.assistant_token = ASSISTANT_TOKEN
        self.end_token = END_TOKEN
        self.assistant_token_id = self.tokenizer.convert_tokens_to_ids(self.assistant_token)
        self.end_token_id = self.tokenizer.convert_tokens_to_ids(self.end_token)
        if self.assistant_token_id is None or self.end_token_id is None:
            raise ValueError("Assistant or end token not found in tokenizer vocabulary.")

        self.exp_tasks = tasks or []

    def __call__(self, features: List[Dict[str, Any]], num_items_in_batch: Optional[int] = None) -> Dict[str, Any]:
        # ---- images ----
        images_batch = [f["images"] for f in features]  # each is [img] now
        flat_images = []
        for imgs in images_batch:
            flat_images.extend(imgs)

        # ---- text ----
        prompts = [f["prompt"] for f in features]
        completions = [f["label"] for f in features]
        for c in completions:
            if not c.endswith(self.end_token):
                raise ValueError(f"Completion does not end with {self.end_token}")
        for p in prompts:
            if self.assistant_token not in p:
                raise ValueError(f"Prompt does not contain {self.assistant_token}")

        full_texts = [p + c for p, c in zip(prompts, completions)]

        proc = self.processor(
            text=full_texts,
            images=flat_images,
            padding=True,
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
        )
        prompt_proc = self.processor(
            text=prompts,
            images=flat_images,
            padding=True,
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
        )
        for k in prompt_proc:
            proc[k + "_prompt"] = prompt_proc[k]

        # ---- labels: mask prompt part ----
        input_ids = proc["input_ids"]
        attention_mask = proc.get("attention_mask", None)

        labels = input_ids.clone()
        B, L = input_ids.size()
        for b in range(B):
            seq = input_ids[b]
            idx = (seq == self.assistant_token_id).nonzero(as_tuple=False).squeeze(-1)
            if idx.numel() == 0:
                labels[b, :] = -100
                continue
            start_idx = idx[-1].item() + 1
            labels[b, :start_idx] = -100

        if attention_mask is not None:
            labels[attention_mask == 0] = -100
        proc["labels"] = labels

        # ---- graph supervision ----
        proc["A_star"] = torch.stack([f["A_star"] for f in features], dim=0)
        proc["A_mask"] = torch.stack([f["A_mask"] for f in features], dim=0)

        # ---- bookkeeping ----
        proc["answer"] = [f.get("answer", "") for f in features]
        proc["ids"] = [f.get("id", None) for f in features]

        # ---- alignment tasks (unchanged) ----
        if CausalQATask.ALIGNMENT in self.exp_tasks:
            node_prompt_emb = torch.stack([f["node_prompt_emb"] for f in features], dim=0)
            explanation_prompt_emb = torch.stack([f["explanation_prompt_emb"] for f in features], dim=0)
            proc["node_prompt_emb"] = node_prompt_emb
            proc["explanation_prompt_emb"] = explanation_prompt_emb
            node_mask = (node_prompt_emb.abs().sum(-1) > 0)
            proc["node_mask"] = node_mask
            proc["node_mask_prompt"] = node_mask

        return proc