from __future__ import annotations

import logging
from typing import Any, NamedTuple

import json
from typing import Dict

from structured_llmuq.model.build import build_model
from structured_llmuq.utils.postprocessing import EntailmentDeberta, get_semantic_ids
from structured_llmuq.utils import prompts
import numpy as np

from .latent_encoder import LatentEncoder


# Type alias for a simplex: a probability distribution over answer strings
# Keys are answer strings, values are probabilities (must sum to 1.0)
Simplex = dict[str, float]


class SimplexMetadata(NamedTuple):
    """Metadata from simplex alignment process.
    
    Attributes:
        mapping: Dictionary mapping original answer strings to their canonical 
                 representatives after semantic alignment. For example:
                 {"London": "London", "london": "London", "LONDON": "London"}
    """
    mapping: dict[str, str]


def parse_simplex(llm_output: str) -> Dict[str, float]:
    """Parse LLM output into a simplex (probability distribution over answers).
    
    Expects JSON format like:
    {"items": [{"answer": "Paris", "p": 0.6}, {"answer": "London", "p": 0.4}]}
    
    Falls back to treating entire string as single answer with probability 1.0.
    """
    text = llm_output.strip()

    try:
        data = json.loads(text)
        return {it["answer"]: float(it["p"]) for it in data["items"]}
    except Exception:
        # Fallback: treat entire string as a single atomic answer
        return {text: 1.0}


class SimplexEncoder(LatentEncoder[Simplex, SimplexMetadata]):
    def __init__(
        self,
        config: dict,
        initialize_selector_model: bool = True,
        add_question_to_entailment: bool = True,
    ) -> None:
        """Create a simplex-based latent encoder.
        
        Maps question-answer pairs into probability distributions (simplexes) over 
        possible answer strings, then aligns these distributions to a common vocabulary 
        using semantic similarity.
        """
        super().__init__()

        add_question_to_entailment = config.pop(
            "add_question_to_entailment", add_question_to_entailment
        )

        if initialize_selector_model:
            #check if model_name == exact 
            if config.get("model_name") is None:
                self.selector_model = None
            elif config.get("model_name") == "exact":
                logging.info("Using exact set encoder, no model initialization.")
                self.selector_model = "exact"
            else:
                logging.info(
                    f"Initializing selector model with config: {config}"
                )
                self.selector_model = build_model(
                    **config,
                    device_map="cuda",
                    generation_config={
                        "max_new_tokens": 100,
                        "temperature": 0,
                        "num_return_sequences": 1,
                    },
                )
        else:
            self.selector_model = None
        self.implication_model = EntailmentDeberta()

        if add_question_to_entailment and not isinstance(
            self.implication_model, EntailmentDeberta
        ):
            logging.warning(
                "add_question_to_entailment is True but implication_model is not EntailmentDeberta. Ignoring."
            )
            add_question_to_entailment = False
        self.add_question_to_entailment = add_question_to_entailment

    def encode(self, question: str, answer: str | dict) -> Simplex:
        """Encode a question-answer pair into a simplex (probability distribution).

        Args:
            question (str): Source question string.
            answer (str | dict): Candidate answer string OR dictionary with answer:probability pairs.

        Returns:
            Simplex: Dictionary mapping answer strings to probabilities (sums to 1.0).
                     Example: {"Paris": 0.6, "London": 0.4}
        """

        # If answer is already a dictionary (ground truth format), return it directly
        if isinstance(answer, dict):
            return answer
    
        
        # Otherwise, answer is a string - parse it to get simplex representation
        if self.selector_model is None:
            # No model: treat entire answer as single element with probability 1.0
            # This creates a degenerate simplex (point mass)
            return {answer: 1.0}
        elif self.selector_model == "exact":
            # Simple heuristic: split by semicolons and assign equal probability
            distinct_set = [ans.strip() for ans in answer.split(";") if ans.strip()]
            if len(distinct_set) == 0:
                return {answer: 1.0}
            # Equal probability for each distinct answer (uniform distribution)
            prob = 1.0 / len(distinct_set)
            return {ans: prob for ans in distinct_set}
        else:
            # Use LLM to encode as simplex (expecting JSON with answer:probability pairs)
            messages = [
                (
                    "system",
                    prompts.SIMPLEX_SYS_PROMPT
                ),
                (
                    "human",
                    prompts.SIMPLEX_HUMAN_PROMPT.format(
                        question=question,
                        answer=answer,
                    ),
                ),
            ]

            # Hard coded for gpt fix
            response = self.selector_model.model.invoke(messages)  # type: ignore
            content = response.content
            simplex_dict = parse_simplex(content)  # type: ignore

        return simplex_dict

    def __call__(
        self, question: str, answers: dict[Any, str | dict | list[str]]
    ) -> tuple[dict[str, list[Simplex]], SimplexMetadata]:
        """Convert orchestrator run results into aligned simplex encodings.
        
        Args:
            question (str): The question string.
            answers (dict): Dictionary mapping keys to answers, where:
                - For model outputs: values are lists of strings (sampled answers)
                - For ground truth: value is a single dict with {answer: probability} pairs
                
        Returns:
            Tuple of (aligned_simplexes, metadata) where:
                - aligned_simplexes: Maps keys to lists of Simplex objects (probability 
                  distributions) all defined over the same aligned vocabulary
                - metadata: SimplexMetadata containing the answer alignment mapping
        """

        encoded = {}
        for key, ans_value in answers.items():
            # Handle ground truth format: single dictionary
            if isinstance(ans_value, dict):
                # Ground truth is already in simplex format, wrap in list
                encoded[key] = [ans_value]
            # Handle model outputs: list of strings
            elif isinstance(ans_value, list):
                # Batch encode all string answers
                encoded[key] = self.batch_encode([question] * len(ans_value), ans_value)
            # Handle single string (edge case)
            else:
                encoded[key] = [self.encode(question, ans_value)]
        
        # Check if selector model is exact - if so, skip alignment
        if self.selector_model == "exact":
            # No need to align - each simplex already has its own vocabulary
            return encoded, SimplexMetadata(mapping={})
        
        # Align all simplexes to a common vocabulary using semantic similarity
        aligned, mapping = self.align(encoded, question=question)
        return aligned, SimplexMetadata(mapping=mapping)

    def _aggregate_sample(
        self, simplex: Simplex, mapping: dict[str, str]
    ) -> Simplex:
        """Aggregate a simplex according to the canonical answer mapping.
        
        When multiple answer strings map to the same canonical representative 
        (e.g., "London" and "london" both map to "London"), their probabilities 
        are summed to maintain a valid probability distribution.
        
        Args:
            simplex: A probability distribution over answer strings
            mapping: Dictionary mapping answer strings to canonical representatives
            
        Returns:
            Aggregated simplex with probabilities summed for equivalent answers
        """
        aggregated = {}
        for ans, prob in simplex.items():
            # Retrieve canonical representative for this answer
            canon = mapping[ans]
            # Sum probability mass for all answers that map to same canonical form
            aggregated[canon] = aggregated.get(canon, 0.0) + prob
        return aggregated

    def align(
        self, answers: dict[Any, list[Simplex]], question: str | None = None
    ) -> tuple[dict[Any, list[Simplex]], dict[str, str]]:
        """
        Align all answer distributions to a common vocabulary using semantic similarity.
        
        Given predictive distributions like:
        p(c_1|x) = {London: 0.5, Paris: 0.5}
        p(c_2|x) = {Paris: 0.3, Berlin: 0.7}

        Returns aligned distributions on common vocab = {Berlin, London, Paris} with
        semantically equivalent answers merged (e.g., "London" and "london" become one).
        
        Args:
            answers: Dictionary mapping keys to lists of simplex dictionaries
            question: Optional question string to add context for entailment checking
            
        Returns:
            Tuple of (aligned_answers, mapping) where:
                - aligned_answers: Same structure as input but with aligned vocabularies
                - mapping: Dictionary mapping original answers to canonical representatives
        """

        # Collect all unique answer strings from all distributions
        union_support = set().union(
            *(set(ans.keys()) for ans_list in answers.values() for ans in ans_list)
        )

        # Convert to list for indexing
        union_support = list(union_support)

        # Prepare input for semantic similarity checking
        # Optionally prepend question for better context
        if self.add_question_to_entailment and question is not None:
            entailment_input = [question + " " + answer for answer in union_support]
        else:
            entailment_input = union_support

        # Find semantically equivalent answers and group them
        semantic_ids = get_semantic_ids(
            strings_list=entailment_input,
            model=self.implication_model,
            strict_entailment=False,
            example=None,
        )

        # For each semantic cluster, pick the first occurrence as canonical representative
        vocab_idx = np.array(
            [np.argwhere(semantic_ids == i)[0][0] for i in np.unique(semantic_ids)]
        )
        vocab = [union_support[i] for i in vocab_idx]

        # Create mapping from semantic ID to canonical representative
        reverse_mapping = {i: vocab[i] for i in np.unique(semantic_ids)}

        # Create mapping from each answer to its canonical form
        mapping = {
            union_support[i]: reverse_mapping[semantic_ids[i]]
            for i in range(len(union_support))
        }
        
        # Aggregate probabilities for answers that map to the same canonical form
        return {
            key: [self._aggregate_sample(ans, mapping) for ans in ans_list]
            for key, ans_list in answers.items()
        }, mapping
