from typing import List, Dict, Optional
import numpy as np
import json

import sys

sys.path.append("..")
from skate_utils.utils import (
    get_task,
    get_ground_truth_answer,
    get_embedding,
    get_mc_options,
    is_sufficiently_different,
    get_answer_to_verifiable_MC_problem_with_ABCD,
)

from skate_utils.scoring_tools import option_generator, is_correct
from question import Question
from player import Player


class Round:
    """
    Represents a single round of the game.
    """

    def __init__(self, round_number: int, player_tags: List[str]):
        self.round_number = round_number
        self.questions: Dict[str, Question] = {tag: [] for tag in player_tags}
        self.rejected_questions: Dict[str, List[Dict]] = {
            tag: [] for tag in player_tags
        }
        self.scores: Dict = {
            setter_tag: {answerer_tag: 0 for answerer_tag in player_tags}
            for setter_tag in player_tags
        }

    def add_question(self, question: Question, player_tag: str):
        """Adds a question created by a specific player to the round."""
        self.questions[player_tag].append(question)

    def get_score_of_player_on_task(
        self, answerer_tag: str, task_setter_tag: str
    ) -> float:
        """
        Retrieves the score of a player answering a question set by another player.
        """
        return self.scores[task_setter_tag][answerer_tag]

    def record_score(
        self, task_setter_tag: str, question_answerer_tag: str, score: float
    ):
        """Records the score of an answerer for a question set by a task setter."""
        self.scores[task_setter_tag][question_answerer_tag] = score

    def did_player_create_valid_question(self, player_tag: str) -> bool:
        """
        Check if player successfully created at least one valid question in this round.
        """
        return len(self.questions[player_tag]) > 0


class GameArchive:
    """
    Stores all questions asked across rounds.
    """

    def __init__(self, players: Optional[List[Player]] = None):
        self.players = players if players is not None else []
        self.questions: Dict[str, List[Question]] = {
            player.player_tag: [] for player in self.players
        }

    def add_question(self, player_tag, question: Question):
        """Adds a question to the archive under the specified player's tag."""
        self.questions[player_tag].append(question)


class Game:
    """
    Main controller for the game, managing rounds, players and scoring.
    """

    def __init__(
        self,
        players: List[Player],
        num_rounds: int = 50,
        num_attempts: int = 3,
        n_samples_warm_start: int = 10,
        n_samples_upper_limit: int = 150,
        score_std_stopping_criterion: float = 0.05,
        sim_distance_threshold: float = 0.336,
    ):
        if not players:
            raise ValueError("The game must be initialized with at least one player.")
        self.players = players
        self.num_attempts = num_attempts
        self.archive = GameArchive(players=players)
        self.rounds: List[Round] = []
        self.current_round: Optional[Round] = None
        self.num_rounds = num_rounds
        self.player_tags = [player.player_tag for player in players]
        self.n_samples_warm_start = n_samples_warm_start
        self.n_samples_upper_limit = n_samples_upper_limit
        self.score_std_stopping_criterion = score_std_stopping_criterion
        self.sim_distance_threshold = sim_distance_threshold
        self.game_file_path = (
            "skate/eval_duel/workshop-experiments/game-database/"
            + f"{players[0].player_tag}_v_{players[1].player_tag}_for_{self.num_rounds}_rounds_game_data.json"
        )

    def _get_matchups(self):
        """Returns a list of player matchups for scoring purposes."""
        if len(self.players) != 2:
            raise ValueError("Game must have exactly two players for matchups.")
        return [
            (self.players[0], self.players[1]),
            (self.players[1], self.players[0]),
        ]

    def get_competitor_tag(self, player_tag: str) -> str:
        """Returns the competitor's player tag for a given player tag."""
        if len(self.players) != 2:
            raise ValueError(
                "Game must have exactly two players to determine competitor."
            )
        return (
            self.players[1].player_tag
            if player_tag == self.players[0].player_tag
            else self.players[0].player_tag
        )

    def start_new_round(self):
        """Initializes and appends a new round to the game."""
        round_number = len(self.rounds) + 1
        r = Round(round_number, self.player_tags)
        self.current_round = r
        self.rounds.append(r)
        print(f"\n--- Starting Round {round_number} ---")

    def add_question_to_current_round(self, question: Question, task_setter_tag: str):
        """Adds question to the current round and the game archive."""
        self.current_round.add_question(question=question, player_tag=task_setter_tag)
        self.archive.add_question(player_tag=task_setter_tag, question=question)

    def end_round(self):
        """Finalizes the current round by printing scores and saving game data."""
        if self.current_round is None:
            raise RuntimeError("Cannot end round: no round is currently active.")

        for player in self.players:
            if self.current_round.did_player_create_valid_question(
                player_tag=player.player_tag
            ):
                print(
                    f"Scores for {player.player_tag}'s Q: {self.current_round.scores[player.player_tag]}"
                )
            else:
                print(
                    f"{player.player_tag} did not create a valid question this round."
                )
            self._save_game_data()

    def _save_game_data(self):
        """Saves the current state of the game to a JSON file."""

        game_data_summary = {
            "num_attempts": self.num_attempts,
            "num_rounds": self.num_rounds,
            "n_samples_warm_start": self.n_samples_warm_start,
            "n_samples_upper_limit": self.n_samples_upper_limit,
            "score_std_stopping_criterion": self.score_std_stopping_criterion,
            "sim_distance_threshold": self.sim_distance_threshold,
        }
        player_data_summary = {
            "player_tags": self.player_tags,
            "model_names": [player.model_name for player in self.players],
        }

        round_rejection_summary = {}
        for idx, rnd in enumerate(self.rounds):
            round_rejection_summary[idx] = {
                player_tag: rnd.rejected_questions[player_tag]
                for player_tag in self.player_tags
            }

        serializable_rounds = {}
        for idx, rnd in enumerate(self.rounds):
            serializable_questions = []
            for player_tag, question_list in rnd.questions.items():
                for q in question_list:
                    serializable_questions.append(
                        {
                            "text": q.text,
                            "task_setter_player_tag": q.task_setter_player_tag,
                            "distractor_options": q.distractor_options,
                            "ground_truth_answer": q.ground_truth_answer,
                            "embedding": q.embedding
                            if q.embedding is not None
                            else None,
                        }
                    )
            serializable_rounds[idx] = {
                "round_number": rnd.round_number,
                "scores": rnd.scores,
                "questions": serializable_questions,
            }

        full_data = {
            "game_data": game_data_summary,
            "player_data": player_data_summary,
            "round_data": serializable_rounds,
            "rejected_questions": round_rejection_summary,
        }
        with open(self.game_file_path, "w") as f:
            json.dump(full_data, f, indent=4)

    def get_game_score_summary(self):
        """
        Calculates and returns the aggregated scores for each round,
        considering valid question creation and answer accuracy."""
        scores: Dict[int, Dict[str, float]] = {r.round_number: [] for r in self.rounds}

        for r in self.rounds:
            score = {player_tag: 0.0 for player_tag in self.player_tags}
            for player, competitor in self._get_matchups():
                if r.did_player_create_valid_question(player_tag=player.player_tag):
                    player_score = r.get_score_of_player_on_task(
                        answerer_tag=player.player_tag,
                        task_setter_tag=player.player_tag,
                    )
                    competitor_score = r.get_score_of_player_on_task(
                        answerer_tag=competitor.player_tag,
                        task_setter_tag=player.player_tag,
                    )

                    score[player.player_tag] += 1  # for creating valid q
                    if player_score >= 0.5:
                        score[player.player_tag] += 1.0
                    if competitor_score >= 0.5:
                        score[competitor.player_tag] += 1.0

            scores[r.round_number] = score
        return scores

    def _is_question_unique(self, q_embedding, player_tag: str) -> bool:
        """
        Checks if a question's embedding is sufficiently different from all previous
        asked questions by the same player.
        """
        archive_questions = self.archive.questions[player_tag]
        if not archive_questions:
            return True  # archive is empty.

        other_embeddings = [
            q.embedding for q in archive_questions if q.embedding is not None
        ]
        if not other_embeddings:
            return True

        return is_sufficiently_different(
            q_embedding=q_embedding,
            other_embeddings=other_embeddings,
            threshold=self.sim_distance_threshold,
        )

    def _is_question_a_copy(self, question_text: str, player_tag: str) -> bool:
        """
        Checks if the question is text is an exact copy of any previous question
        created by the same player
        """
        archive_questions = self.archive.questions[player_tag]
        return any(question_text == prev_q.text for prev_q in archive_questions)

    def get_valid_task_with_options(self, task_setter: Player, prompt: str):
        """
        Generates a valid task with multiple-choice options for the given task setter.
        A valid task is one that has a verifiable ground truth answer, is original (not a copy),
        and has enough distractor-rich options.

        Args:
            task_setter (Player): The language model instance setting the question
            prompt (str): The prompt to generate the question.

        Returns:
            Question: Question object or None if the question is not valid.
        """
        q_text = get_task(model=task_setter.model_name, prompt=prompt)

        if not q_text:
            return None  # Question generation failed

        gt_answer = get_ground_truth_answer(q_text)
        if not gt_answer:
            self.current_round.rejected_questions[task_setter.player_tag].append(
                {
                    "text": q_text,
                    "reason": "This question does not have a verifiable ground truth answer/ the code did not return a valid answer.",
                }
            )
            return None  # No verifiable ground truth answer

        # if self._is_question_a_copy(q_text, task_setter.player_tag):
        #    self.current_round.rejected_questions[task_setter.player_tag].append(
        #        {
        #            "text": q_text,
        #            "reason": "This question is an exact copy of a previous question that you created.",
        #        }
        #    )
        #    return None

        question_number = len(self.archive.questions[task_setter.player_tag]) + 1
        q_text = f"Q{question_number}: {q_text}"

        q_embedding = get_embedding(question=q_text)
        if q_embedding is None or not self._is_question_unique(
            q_embedding=q_embedding, player_tag=task_setter.player_tag
        ):
            self.current_round.rejected_questions[task_setter.player_tag].append(
                {
                    "text": q_text,
                    "reason": "This question is too similar to one, or more, of the questions you previous created.",
                }
            )
            return None

        options = get_mc_options(
            model=task_setter.model_name,
            question=q_text,
            gt_answer=gt_answer,
            N_options=9,
        )

        if not options:
            self.current_round.rejected_questions[task_setter.player_tag].append(
                {
                    "text": q_text,
                    "reason": "This question is not complex enough to have many distractor options.",
                }
            )
            return None

        return Question(
            text=q_text,
            task_setter_player_tag=task_setter.player_tag,
            distractor_options=options,
            ground_truth_answer=gt_answer,
            round_number=self.current_round.round_number,
            embedding=q_embedding,
            scores={p.player_tag: None for p in self.players},
        )

    def score_model_on_question(
        self,
        model,
        problem: Question,
    ):
        """
        Scores a model's performance on a given multiple-choice question problem.

        The function iteratively samples model answers until the standard deviation
        of the proportion of correct answers falls below a specified criterion
        or a maximum total number of samples is reached.

        Args:
            model: The language model instance to be scored.
            problem (Question): The multiple-choice question to be answered.

        Returns:
            Tuple[float, float]: A tuple containing:
                - The proportion of correct answers (accuracy) as a float.
                - The standard deviation of the proportion correct as a float.
        """
        all_gt_abcd_answers: List[str] = []
        all_model_answers: List[str] = []
        total_samples_taken: int = 0

        p, std = 0.0, float("inf")

        while True:
            batch_options = [
                option_generator(
                    problem.distractor_options, problem.ground_truth_answer
                )
                for _ in range(self.n_samples_warm_start)
            ]

            batch_gt_abcd_answers = [
                ["A", "B", "C", "D"][options.index(problem.ground_truth_answer)]
                for options in batch_options
            ]

            batch_model_answers = [
                get_answer_to_verifiable_MC_problem_with_ABCD(
                    model, problem.text, options
                )
                for options in batch_options
            ]

            all_gt_abcd_answers.extend(batch_gt_abcd_answers)
            all_model_answers.extend(batch_model_answers)
            total_samples_taken += self.n_samples_warm_start

            p, std = self._calculate_p_std(
                answers_set=all_model_answers,
                gt_abcd_answers_set=all_gt_abcd_answers,
                n_total_samples=total_samples_taken,
            )

            if std < self.score_std_stopping_criterion:
                break
            if total_samples_taken >= self.n_samples_upper_limit:
                break

        return p, std

    def _calculate_p_std(
        self,
        answers_set: List[str],
        gt_abcd_answers_set: List[str],
        n_total_samples: int,
    ):
        """
        Helper method to calculate the proportion of correct answers (p) and its standard deviation (std).
        """
        if n_total_samples == 0:
            return 0.0, 0.0

        correct_count = sum(
            1
            for i in range(n_total_samples)
            if is_correct(answers_set[i], gt_abcd_answers_set[i])
        )

        p = correct_count / n_total_samples
        std = np.sqrt((1.0 / n_total_samples) * p * (1 - p))
        return p, std

    def play(self):
        """
        Executes the main game loop for the specified number of rounds."""
        for _ in range(self.num_rounds):
            self.start_new_round()

            for task_setter in self.players:
                for attempt in range(self.num_attempts):
                    prompt = task_setter.prompt_generator(self, task_setter.player_tag)
                    question = self.get_valid_task_with_options(
                        task_setter=task_setter, prompt=prompt
                    )

                    if question:
                        for answerer in self.players:
                            score_val, std = self.score_model_on_question(
                                model=answerer.model_name, problem=question
                            )
                            self.current_round.record_score(
                                task_setter_tag=task_setter.player_tag,
                                question_answerer_tag=answerer.player_tag,
                                score=score_val,
                            )
                            question.scores[answerer.player_tag] = score_val

                        self.add_question_to_current_round(
                            task_setter_tag=task_setter.player_tag, question=question
                        )
                        break

            self.end_round()  # saves data and prints summary of the round

        print(f"Data saved to: {self.game_file_path}")
        print("\n--- Game Over ---")

    def display_results(self):
        """
        Returns the final scores of the game.
        """
        scores = self.get_game_score_summary()
        cum_score = {tag: 0.0 for tag in self.player_tags}
        for round_number, score in scores.items():
            print(f"\n--- Round {round_number} Scores ---")
            print(f"Scores for Round {round_number}: {score}")
            for player_tag, player_score in score.items():
                cum_score[player_tag] += player_score

        print("\n--- Final Scores ---")
        for player_tag, score in cum_score.items():
            print(f"{player_tag}: {score}")
        if cum_score[self.player_tags[0]] > cum_score[self.player_tags[1]]:
            print(f"{self.player_tags[0]} wins!")
        elif cum_score[self.player_tags[0]] < cum_score[self.player_tags[1]]:
            print(f"{self.player_tags[1]} wins!")
        else:
            print("It's a tie!")
        return cum_score
