import logging
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple, TypedDict, Union

from structured_llmuq.data.qa_datasets import AbstractQADataset
from structured_llmuq.model.api_lm import ApiLM
from structured_llmuq.model.causal_lm import CausalLM
from structured_llmuq.utils.latent_encoder import LatentEncoder
from structured_llmuq.utils.postprocessing import STOP_SEQUENCES


class ModelGeneration(TypedDict, total=False):
    tokens_decoded_generated_truncated: list[str]


@dataclass
class ModelOutput:
    question: str
    prompt: str
    beam_search_answer: ModelGeneration
    reference_answer: str  # The ground truth answer
    model_input: str
    generations: ModelGeneration

    @property
    def answers(self) -> dict[str, list[str]]:
        return {
            "beam_search": [
                self.beam_search_answer["tokens_decoded_generated_truncated"][0]
            ],
            "generations": self.generations["tokens_decoded_generated_truncated"],
            "reference": [self.reference_answer],
        }

    def to_dict(self) -> dict:
        return {
            "question": self.question,
            "prompt": self.prompt,
            "beam_search_answer": self.beam_search_answer,
            "reference_answer": self.reference_answer,
            "model_input": self.model_input,
            "generations": self.generations,
        }


class Orchestrator:
    def __init__(
        self,
        dataset: AbstractQADataset,
        model: CausalLM | ApiLM,
        latent_encoder: LatentEncoder,
        config: dict[str, Any] | None = None,
    ):
        self.dataset = dataset
        self.model = model
        self.latent_encoder = latent_encoder
        config = config or {}
        self.stop_sequences = config.get("stop_sequences", STOP_SEQUENCES)
        self.max_num_parallel_generations = config.get(
            "max_num_parallel_generations", None
        )

    def generate_answers(
        self,
        idx: int,
        stop_sequences: list[str] | None = None,
        max_num_parallel_generations: int | None = None,
    ) -> ModelOutput:
        if stop_sequences is None:
            stop_sequences = self.stop_sequences
        if max_num_parallel_generations is None:
            max_num_parallel_generations = self.max_num_parallel_generations
        logging.info(f"Running sample {idx}")
        sample = self.dataset[idx]
        question = sample["question"]
        true_answer = sample["answer"]
        prompt = self.dataset.construct_prompt(question, self.model.model_type)
        logging.info(f"Constructed Prompt: {prompt}")

        # Generate beam search
        best_answer_dict = self.get_best_answer(prompt)
        best_answer = best_answer_dict["tokens_decoded_generated_truncated"][0]
        logging.info(f"Generated Beam Search Answer: {best_answer}")

        model_output = self.model.generate(
            prompt,
            stop_sequences=stop_sequences,
            max_num_parallel_generations=max_num_parallel_generations,
        )
        logging.info("Generated Responses")

        return ModelOutput(
            question=question,
            prompt=prompt,
            beam_search_answer=best_answer_dict,
            reference_answer=true_answer,
            model_input=prompt,
            generations=model_output,
        )

    def map_to_latent_space(self, run_result: ModelOutput) -> dict[str, list[Any]]:
        return self.latent_encoder(run_result.question, run_result.answers)[0]

    def get_best_answer(self, prompt: str) -> dict[str, Any]:
        """Gets the best answer of the model using beam search"""

        old_num_return_sequences = self.model.generation_config.num_return_sequences
        #old_temperature = self.model.generation_config.temperature
        #old_top_p = self.model.generation_config.top_p
        old_num_beams = self.model.generation_config.num_beams
        # New settings for beam
        #self.model.generation_config.temperature = 0.6
        #self.model.generation_config.top_p = 0.9
        self.model.generation_config.num_return_sequences = 1
        # change to beam search
        self.model.generation_config.num_beams = 5
        model_output = self.model.generate(prompt)  # generate with new settings
        # Change back
        #self.model.generation_config.top_p = old_top_p
        #self.model.generation_config.temperature = old_temperature
        self.model.generation_config.num_return_sequences = old_num_return_sequences
        self.model.generation_config.num_beams = old_num_beams

        return model_output
