from typing import List, Dict, Tuple, Optional, Set
import numpy as np
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForMaskedLM
from tqdm import tqdm

from .config import ExistenceConfig
from .data import Slot

OTHER_TOKEN_ID = -1

class BeliefExtractor:

    def __init__(self, config: ExistenceConfig):

        self.config = config
        self.device = torch.device(config.device)

        print(f"[Belief] Loading model: {config.model_name}")
        self.tokenizer = AutoTokenizer.from_pretrained(config.model_name)
        self.model = AutoModelForMaskedLM.from_pretrained(config.model_name)
        self.model.to(self.device)
        self.model.eval()

        self.cls_id = self.tokenizer.cls_token_id
        self.sep_id = self.tokenizer.sep_token_id
        self.mask_id = self.tokenizer.mask_token_id
        self.pad_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id

        print(f"[Belief] Model loaded on {self.device}")
        print(f"[Belief] Vocab size: {self.tokenizer.vocab_size}")

    def _build_input(
        self,
        left_tokens: List[int],
        right_tokens: List[int]
    ) -> Tuple[List[int], int]:

        input_ids = [self.cls_id] + left_tokens + [self.mask_id] + right_tokens + [self.sep_id]
        mask_position = 1 + len(left_tokens)
        return input_ids, mask_position

    def _get_log_probs_single(
        self,
        input_ids: List[int],
        mask_position: int
    ) -> np.ndarray:

        input_tensor = torch.tensor([input_ids], dtype=torch.long, device=self.device)

        with torch.no_grad():
            outputs = self.model(input_tensor)
            logits = outputs.logits[0, mask_position]
            log_probs = F.log_softmax(logits, dim=-1)

        return log_probs.cpu().numpy()

    def _get_log_probs_batch(
        self,
        inputs: List[Tuple[List[int], int]]
    ) -> List[np.ndarray]:

        if not inputs:
            return []

        max_len = max(len(inp[0]) for inp in inputs)
        padded = []
        attention_masks = []
        mask_positions = []

        for input_ids, mask_pos in inputs:
            padding = [self.pad_id] * (max_len - len(input_ids))
            padded.append(input_ids + padding)
            attention_masks.append([1] * len(input_ids) + [0] * len(padding))
            mask_positions.append(mask_pos)

        input_tensor = torch.tensor(padded, dtype=torch.long, device=self.device)
        attention_tensor = torch.tensor(attention_masks, dtype=torch.long, device=self.device)

        with torch.no_grad():
            outputs = self.model(input_tensor, attention_mask=attention_tensor)
            logits = outputs.logits

        results = []
        for i, mask_pos in enumerate(mask_positions):
            log_probs = F.log_softmax(logits[i, mask_pos], dim=-1)
            results.append(log_probs.cpu().numpy())

        return results

    def compute_candidate_set(
        self,
        slot: Slot,
        k_cand: int
    ) -> Set[int]:

        l_max = self.config.l_max
        r_max = self.config.r_max

        candidates = set()

        left_tokens = slot.get_left_context(l_max)
        right_tokens = []
        input_ids, mask_pos = self._build_input(left_tokens, right_tokens)
        log_probs = self._get_log_probs_single(input_ids, mask_pos)
        top_k_1 = np.argsort(log_probs)[-k_cand:]
        candidates.update(top_k_1.tolist())

        left_tokens = []
        right_tokens = slot.get_right_context(r_max, "real")
        input_ids, mask_pos = self._build_input(left_tokens, right_tokens)
        log_probs = self._get_log_probs_single(input_ids, mask_pos)
        top_k_2 = np.argsort(log_probs)[-k_cand:]
        candidates.update(top_k_2.tolist())

        left_tokens = slot.get_left_context(l_max)
        right_tokens = slot.get_right_context(r_max, "real")
        input_ids, mask_pos = self._build_input(left_tokens, right_tokens)
        log_probs = self._get_log_probs_single(input_ids, mask_pos)
        top_k_3 = np.argsort(log_probs)[-k_cand:]
        candidates.update(top_k_3.tolist())

        return candidates

    def pushforward_to_candidates(
        self,
        log_probs_full: np.ndarray,
        candidate_set: Set[int],
        eps: float = 1e-12
    ) -> Dict[int, float]:

        probs_full = np.exp(log_probs_full)

        mu = {}
        prob_sum = 0.0

        for token_id in candidate_set:
            p = float(probs_full[token_id])
            mu[token_id] = p
            prob_sum += p

        other_prob = max(1.0 - prob_sum, 0.0)
        mu[OTHER_TOKEN_ID] = other_prob

        total = sum(mu.values())
        if total > 0:
            for k in mu:
                mu[k] /= total
        else:
            n = len(mu)
            for k in mu:
                mu[k] = 1.0 / n

        for k in mu:
            mu[k] = max(mu[k], eps)

        total = sum(mu.values())
        for k in mu:
            mu[k] /= total

        return mu

    def compute_belief_grid(
        self,
        slot: Slot,
        candidate_set: Set[int],
        condition: str = "real"
    ) -> Dict[Tuple[int, int], Dict[int, float]]:

        l_list = self.config.l_list
        r_list = self.config.r_list
        eps = self.config.eps

        inputs = []
        grid_coords = []

        for L in l_list:
            for R in r_list:
                left_tokens = slot.get_left_context(L)
                right_tokens = slot.get_right_context(R, condition)
                input_ids, mask_pos = self._build_input(left_tokens, right_tokens)
                inputs.append((input_ids, mask_pos))
                grid_coords.append((L, R))

        batch_size = self.config.batch_size
        all_log_probs = []

        for i in range(0, len(inputs), batch_size):
            batch_inputs = inputs[i:i + batch_size]
            batch_log_probs = self._get_log_probs_batch(batch_inputs)
            all_log_probs.extend(batch_log_probs)

        belief_grid = {}
        for (L, R), log_probs in zip(grid_coords, all_log_probs):
            mu = self.pushforward_to_candidates(log_probs, candidate_set, eps)
            belief_grid[(L, R)] = mu

        return belief_grid

def extract_all_beliefs(
    slots: List[Slot],
    config: ExistenceConfig
) -> Tuple[BeliefExtractor, List[Dict]]:

    extractor = BeliefExtractor(config)
    results = []

    for slot in tqdm(slots, desc="Extracting beliefs"):
        candidate_set = extractor.compute_candidate_set(slot, config.k_cand)

        slot_result = {
            "slot": slot,
            "candidate_set": candidate_set,
            "beliefs": {}
        }

        for condition in ["real", "swap", "shuffle"]:
            belief_grid = extractor.compute_belief_grid(slot, candidate_set, condition)
            slot_result["beliefs"][condition] = belief_grid

        results.append(slot_result)

        if config.device == "cuda" and len(results) % 100 == 0:
            torch.cuda.empty_cache()

    return extractor, results
