from __future__ import annotations

import logging
from typing import Any, NamedTuple

import numpy as np

from structured_llmuq.model.build import build_model
from structured_llmuq.utils.postprocessing import EntailmentDeberta, get_semantic_ids
from structured_llmuq.utils.prompts import SET_HUMAN_PROMPT, SET_SYS_PROMPT

from .latent_encoder import LatentEncoder

LatentSet = dict[Any, float]


class LatentSetMetadata(NamedTuple):
    mapping: dict[str, str]


class SetEncoder(LatentEncoder[LatentSet, LatentSetMetadata]):
    def __init__(
        self,
        config: dict,
        initialize_selector_model: bool = True,
        add_question_to_entailment: bool = True,
    ) -> None:
        """Create a set-based latent encoder using a selector language model."""
        super().__init__()

        add_question_to_entailment = config.pop(
            "add_question_to_entailment", add_question_to_entailment
        )

        if initialize_selector_model:
            # check if model_name == exact
            if config.get("model_name") is None:
                self.selector_model = None
            elif config.get("model_name") == "exact":
                logging.info("Using exact set encoder, no model initialization.")
                self.selector_model = "exact"
            else:
                logging.info(f"Initializing selector model with config: {config}")
                self.selector_model = build_model(
                    **config,
                    device_map="cuda",
                    generation_config={
                        "max_new_tokens": 100,
                        "temperature": 0,
                        "num_return_sequences": 1,
                    },
                )
        else:
            self.selector_model = None
        self.implication_model = EntailmentDeberta()

        if add_question_to_entailment and not isinstance(
            self.implication_model, EntailmentDeberta
        ):
            logging.warning(
                "add_question_to_entailment is True but implication_model is not EntailmentDeberta. Ignoring."
            )
            add_question_to_entailment = False
        self.add_question_to_entailment = add_question_to_entailment

    def encode(self, question: str, answer: str) -> LatentSet:
        """Use the selector model to produce a distinct set representation.

        Args:
            question (str): Source question string.
            answer (str): Candidate answer string.

        Returns:
            list[str]: Distinct set elements derived from the selector model output.
        """

        # Can be list if it comes from references #hacky solution here
        if isinstance(answer, list):
            # ansume list of lists so [[x1], [x2], ...]
            distinct_set = [x[0] for x in answer]
            return {s: 1.0 for s in distinct_set if s.strip() != ""}

        if self.selector_model is None:
            distinct_set = [answer]
        elif self.selector_model == "exact":
            # Simple heuristic: split by commas and strip whitespace
            distinct_set = [ans.strip() for ans in answer.split(";") if ans.strip()]
        else:
            # Use LLM to encode set
            messages = [
                (
                    "system",
                    SET_SYS_PROMPT,
                ),
                (
                    "human",
                    SET_HUMAN_PROMPT.format(
                        question=question,
                        answer=answer,
                    ),
                ),
            ]

            # Hard coded for gpt fix
            response = self.selector_model.model.invoke(messages)  # type: ignore
            content = response.content
            distinct_set = content.split("\n")  # type: ignore

        return {s: 1.0 for s in distinct_set if s.strip() != ""}

    def __call__(
        self, question: str, answers: dict[Any, list[str]]
    ) -> tuple[dict[str, list[LatentSet]], LatentSetMetadata]:
        """Convert orchestrator run results into aligned set encodings."""

        encoded = {
            key: self.batch_encode([question] * len(ans_list), ans_list)
            for key, ans_list in answers.items()
        }
        # with open("debug_encoded_sets.txt", "w") as f:
        #     for key, ans_list in encoded.items():
        #         f.write(f"Key: {key}\n")
        #         for ans in ans_list:
        #             f.write(f"{ans}\n")
        #         f.write("\n")

        # check if selector model is exact
        if self.selector_model == "exact":
            # no need to align
            return encoded, LatentSetMetadata(mapping={})

        aligned, mapping = self.align(encoded, question=question)
        # with open("debug_aligned_sets.txt", "w") as f:
        #     for key, ans_list in aligned.items():
        #         f.write(f"Key: {key}\n")
        #         for ans in ans_list:
        #             f.write(f"{ans}\n")
        #         f.write("\n")
        return aligned, LatentSetMetadata(mapping=mapping)

    def _aggregate_sample(
        self, sample_dict: dict[str, float], mapping: LatentSet
    ) -> LatentSet:
        """
        For one sample dictionary, aggregate probabilities according to the canonical mapping.
        If two keys map to the same canonical answer, sum their probabilities.
        """
        aggregated = {}
        for ans, prob in sample_dict.items():
            # retrieve canonical word
            canon = mapping[ans]
            # update probability mass in aggregated dictionary
            aggregated[canon] = aggregated.get(canon, 0) + prob
        return aggregated

    def align(
        self, answers: dict[Any, list[LatentSet]], question: str | None = None
    ) -> tuple[dict[Any, list[LatentSet]], dict[str, str]]:
        """
        Given the predictive distributions, find the union support of the distributions.
        I.e. If we have
        p(c_1|x) = {London: 0.5, Paris: 0.5}
        p(c_2|x) = {Paris: 0.3, Berlin: 0.7}

        -> vocab = {Berlin, London, Paris}
        -> vectors = [[0, 0.5, 0.5], [0.7, 0, 0.3]]
        """

        union_support = set().union(
            *(set(ans.keys()) for ans_list in answers.values() for ans in ans_list)
        )

        # Now check which of them are semantically equivalent
        union_support = list(union_support)

        # check if implication model is EntailmentDeberta
        if self.add_question_to_entailment and question is not None:
            # add question to union_support
            entailment_input = [question + " " + answer for answer in union_support]
        else:
            entailment_input = union_support

        semantic_ids = get_semantic_ids(
            strings_list=entailment_input,
            model=self.implication_model,
            strict_entailment=False,
            example=None,
        )

        # now we need to find the cluster representative for each cluster by taking the first occurence
        vocab_idx = np.array(
            [np.argwhere(semantic_ids == i)[0][0] for i in np.unique(semantic_ids)]
        )
        vocab = [union_support[i] for i in vocab_idx]

        # create a mapping from unique id to cluster representative
        reverse_mapping = {i: vocab[i] for i in np.unique(semantic_ids)}

        # Maps from the union support to the canonical representative
        mapping = {
            union_support[i]: reverse_mapping[semantic_ids[i]]
            for i in range(len(union_support))
        }
        # with open("debug_set_encoder_mapping.txt", "w") as f:
        #     for key, val in mapping.items():
        #         f.write(f"{key} -> {val}\n")
        return {
            key: [self._aggregate_sample(ans, mapping) for ans in ans_list]
            for key, ans_list in answers.items()
        }, mapping
