import os
import math
from typing import (
    Tuple,
    List,
)
import itertools
from collections import Counter

import numpy as np
from tqdm import tqdm
from openai import OpenAI
from google import genai
from google.genai.types import GenerateContentConfig
from loguru import logger
from dotenv import load_dotenv

from src.schema import (
    QA,
    Answer,
    Evidence,
    ConflictGraph,
)


DISTRIBUTION_MIMICKING_PROMPT_TEMPLATE = """
Choose one of the given options, and respond with exactly that option.

Options are numbered, and you should only respond with the number of the option you choose.

# Question
{question}

# Options
{options}

# Answer
""".strip()


class DistributionMimickingStage:
    def __init__(self) -> None:
        logger.info("Initializing DistributionMimickingStage.")

        load_dotenv()

        self.gpt_client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])
        self.gemini_client = genai.Client(vertexai=True, project=os.environ["GEMINI_PROJECT_ID"], location="global")
        self.fireworks_client = OpenAI(base_url="https://api.fireworks.ai/inference/v1", api_key=os.getenv("FIREWORKS_API_KEY"))

    def get_raw_dist_gpt(
        self,
        question: str,
        answers: List[str],
    ) -> List[float]:
        option_map = {i: answer for i, answer in enumerate(answers, start=1)}
        prompt = DISTRIBUTION_MIMICKING_PROMPT_TEMPLATE.format(
            question=question,
            options="\n".join([f"{i}. {answer}" for i, answer in option_map.items()]).strip(),
        )
        response = self.gpt_client.chat.completions.create(
            model="gpt-4o-2024-08-06",
            messages=[{"role": "user", "content": prompt}],
            max_tokens=1,
            temperature=1.0,  # TODO: check here
            logprobs=True,
            top_logprobs=len(answers),
        )
        token_logprob_map = {}
        for item in response.choices[0].logprobs.content[0].top_logprobs:
            try:
                int_token = int(item.token)
                token_logprob_map[int_token] = item.logprob
            except ValueError:
                continue
        answer_logprobs = [token_logprob_map.get(i, 0) for i in option_map.keys()]
        answer_probs = [math.exp(lp) for lp in answer_logprobs]
        return answer_probs
    
    def get_raw_dist_gemini(
        self,
        question: str,
        answers: List[str],
    ) -> List[float]:
        option_map = {i: answer for i, answer in enumerate(answers, start=1)}
        prompt = DISTRIBUTION_MIMICKING_PROMPT_TEMPLATE.format(
            question=question,
            options="\n".join([f"{i}. {answer}" for i, answer in option_map.items()]).strip(),
        )
        response_schema = {"type": "STRING", "enum": [str(i) for i in option_map.keys()]}
        response = self.gemini_client.models.generate_content(
            model="gemini-2.0-flash-001",
            contents=prompt,
            config=GenerateContentConfig(
                response_mime_type="application/json",              
                response_schema=response_schema,
                response_logprobs=True,
                logprobs=len(answers),
            ),
        )
        token_logprob_map = {}
        for chosen_candidate in response.candidates[0].logprobs_result.chosen_candidates:
            try:
                int_token = int(chosen_candidate.token)
                token_logprob_map[int_token] = chosen_candidate.log_probability
            except ValueError:
                continue
        answer_logprobs = [token_logprob_map.get(i, 0) for i in option_map.keys()]
        answer_probs = [math.exp(lp) for lp in answer_logprobs]
        return answer_probs
    
    def get_raw_dist_llama(
        self,
        question: str,
        answers: List[str],
    ) -> List[float]:
        option_map = {i: answer for i, answer in enumerate(answers, start=1)}
        prompt = DISTRIBUTION_MIMICKING_PROMPT_TEMPLATE.format(
            question=question,
            options="\n".join([f"{i}. {answer}" for i, answer in option_map.items()]).strip(),
        )
        response = self.fireworks_client.chat.completions.create(
            model="accounts/fireworks/models/llama-v3p1-70b-instruct#accounts/shovelingpig-c02c73/deployments/leiptr5m",
            messages=[{"role": "user", "content": prompt}],
            max_tokens=1,
            temperature=1.0,  # TODO: check here
            logprobs=True,
            top_logprobs=len(answers),
        )
        token_logprob_map = {}
        for item in response.choices[0].logprobs.content[0].top_logprobs:
            try:
                int_token = int(item.token)
                token_logprob_map[int_token] = item.logprob
            except ValueError:
                continue
        answer_logprobs = [token_logprob_map.get(i, 0) for i in option_map.keys()]
        answer_probs = [math.exp(lp) for lp in answer_logprobs]
        return answer_probs
    
    def get_raw_dist(
        self,
        question: str,
        answers: List[str],
    ) -> List[float]:
        answer_probs_gpt = self.get_raw_dist_gpt(question=question, answers=answers)
        answer_probs_gemini = self.get_raw_dist_gemini(question=question, answers=answers)
        answer_probs_llama = self.get_raw_dist_llama(question=question, answers=answers)
        answer_probs_list = [
            answer_probs_gpt,
            answer_probs_gemini,
            answer_probs_llama,
        ]
        avg_answer_probs = [sum(probs) / len(probs) for probs in zip(*answer_probs_list)]
        return avg_answer_probs

    def get_quantized_dist(
        self,
        raw_dist: List[float],
        min_counts: List[int],
        max_total_count: int,
    ) -> List[int]:
        max_total_count = max(max_total_count, sum(min_counts))  # TODO: check here

        assert len(raw_dist) == len(min_counts)

        n = len(raw_dist)
        raw_dist = np.array(raw_dist)
        min_counts = np.array(min_counts)
        norm_raw_dist = raw_dist / np.sum(raw_dist)

        best_assignment = None
        min_mse = float("inf")

        ranges = [range(min_c, max_total_count + 1) for min_c in min_counts]
        for curr_dist in itertools.product(*ranges):
            curr_dist = list(curr_dist)
            total = sum(curr_dist)

            if total > max_total_count:
                continue

            satisfies_order = True
            for i in range(n):
                for j in range(n):
                    if raw_dist[i] > raw_dist[j] and curr_dist[i] <= curr_dist[j]:
                        satisfies_order = False
                        break
                    elif raw_dist[i] == raw_dist[j] and curr_dist[i] != curr_dist[j]:
                        satisfies_order = False
                        break
                if not satisfies_order:
                    break
            if not satisfies_order:
                continue

            norm_curr_dist = np.array(curr_dist) / total
            mse = np.mean((norm_curr_dist - norm_raw_dist) ** 2)

            if mse < min_mse:
                min_mse = mse
                best_assignment = curr_dist

        if best_assignment is None:
            raise ValueError("No valid assignment found")

        return best_assignment

    def duplicate_answer(
        self,
        answer: Answer,
        duplicate_index: int,
    ) -> Answer:
        new_answer = Answer(
            answer_id=f"{answer.answer_id}_{duplicate_index}",
            short_answer=answer.short_answer,
            rubric_question=answer.rubric_question,
            evidences=[
                Evidence(
                    evidence_id=f"{evidence.evidence_id}_{duplicate_index}",
                    question=evidence.question,
                    answer=evidence.answer,
                )
                for evidence in answer.evidences
            ],
        )
        return new_answer

    def run(
        self,
        qa_dataset: List[QA],
        conflict_graph: ConflictGraph,
    ) -> Tuple[List[QA], ConflictGraph]:
        for qa_sample in tqdm(qa_dataset):
            question = qa_sample.question
            answers = qa_sample.answers
            short_answers = [answer.short_answer for answer in answers]
            short_answer_counter = dict(Counter(short_answers))
            unique_short_answers = list(short_answer_counter.keys())
            min_counts = list(short_answer_counter.values())

            raw_dist = self.get_raw_dist(question=question, answers=unique_short_answers)
            n = len(answers)
            max_total_count = int(n * (n+1) // 2)
            quantized_dist = self.get_quantized_dist(raw_dist=raw_dist, min_counts=min_counts, max_total_count=max_total_count)

            for i in range(len(unique_short_answers)):
                unique_short_answer = unique_short_answers[i]
                diff = quantized_dist[i] - min_counts[i]
                answer_indices = [
                    index for index, answer in enumerate(answers)
                    if answer.short_answer == unique_short_answer
                ]
                for j in range(diff):
                    selected_index = answer_indices[j % len(answer_indices)]
                    selected_answer = answers[selected_index]
                    new_answer = self.duplicate_answer(answer=selected_answer, duplicate_index=j)
                    answers.append(new_answer) 

                    orig_answer_id = selected_answer.answer_id
                    orig_node_id = conflict_graph.answer_id_to_node_id[orig_answer_id]
                    new_answer_id = new_answer.answer_id
                    new_node_id = conflict_graph.n_nodes
                    conflict_graph.answer_id_to_node_id[new_answer_id] = new_node_id
                    conflict_graph.node_id_to_answer_id[new_node_id] = new_answer_id
                    conflict_graph.adjacency_dict[new_node_id] = conflict_graph.adjacency_dict[orig_node_id] + [orig_node_id]
                    for neighbor in conflict_graph.adjacency_dict[new_node_id]:
                        conflict_graph.adjacency_dict[neighbor].append(new_node_id)
                    conflict_graph.n_nodes += 1

        for node_id in conflict_graph.adjacency_dict:
            conflict_graph.adjacency_dict[node_id] = list(sorted(conflict_graph.adjacency_dict[node_id]))

        return qa_dataset, conflict_graph
