import numpy as np
from typing import List
import random

import sys

sys.path.append("..")
from skate_utils.utils import (
    get_answer_to_verifiable_MC_problem_with_ABCD,
)

from skate_utils.scoring_tools import option_generator, is_correct

from skate_utils.utils import (
    get_task,
    get_ground_truth_answer,
    get_embedding,
    get_mc_options,
    is_sufficiently_different,
    get_answer_to_verifiable_MC_problem_with_ABCD,
)


def calculate_p_std(
    answers_set: List[str],
    gt_abcd_answers_set: List[str],
    n_total_samples: int,
):
    """
    Helper method to calculate the proportion of correct answers (p) and its standard deviation (std).
    """
    if n_total_samples == 0:
        return 0.0, 0.0

    correct_count = sum(
        1
        for i in range(n_total_samples)
        if is_correct(answers_set[i], gt_abcd_answers_set[i])
    )

    p = correct_count / n_total_samples
    std = np.sqrt((1.0 / n_total_samples) * p * (1 - p))
    return p, std


def score_model_on_question(
    model,
    problem,
    ground_truth_answer,
    distractor_options,
    n_samples_warm_start: int = 10,
    n_samples_upper_limit: int = 150,
    score_std_stopping_criterion: float = 0.05,
):
    """
    Scores a model's performance on a given multiple-choice question problem.

    The function iteratively samples model answers until the standard deviation
    of the proportion of correct answers falls below a specified criterion
    or a maximum total number of samples is reached.


    Returns:
        Tuple[float, float]: A tuple containing:
            - The proportion of correct answers (accuracy) as a float.
            - The standard deviation of the proportion correct as a float.
    """

    if type(distractor_options) is not list:
        try:
            distractor_options = eval(distractor_options)
        except Exception as e:
            raise ValueError(
                f"Invalid distractor options format: {distractor_options}. Error: {e}"
            )

    all_gt_abcd_answers: List[str] = []
    all_model_answers: List[str] = []
    total_samples_taken: int = 0

    p, std = 0.0, float("inf")

    while True:
        batch_options = [
            option_generator(distractor_options, ground_truth_answer)
            for _ in range(n_samples_warm_start)
        ]

        batch_gt_abcd_answers = [
            ["A", "B", "C", "D"][options.index(ground_truth_answer)]
            for options in batch_options
        ]

        batch_model_answers = [
            get_answer_to_verifiable_MC_problem_with_ABCD(model, problem, options)
            for options in batch_options
        ]

        all_gt_abcd_answers.extend(batch_gt_abcd_answers)
        all_model_answers.extend(batch_model_answers)
        total_samples_taken += n_samples_warm_start

        p, std = calculate_p_std(
            answers_set=all_model_answers,
            gt_abcd_answers_set=all_gt_abcd_answers,
            n_total_samples=total_samples_taken,
        )

        if std < score_std_stopping_criterion:
            break
        if total_samples_taken >= n_samples_upper_limit:
            break

    return p, std


def historical_performance_aug_strat_prompt_generator(
    round_num, archive, failed_archive
):
    """
    Generate prompt for historical performance augmentation strategy.
    This strategy uses the task setter's archive and task setter's scores and the current round's rejected questions.
    """
    with open(
        "skate/eval_duel/prompts/historical-performance-prompt.txt",
        "r",
    ) as f:
        prompt = f.read()

    questions = archive
    if len(questions) > 0:
        random.shuffle(questions)

    previous_questions = "\n".join(
        f"On the following question, you scored {round(100 * q[1], 2)}%: "
        + "\n"
        + str(q[0])
        for q in archive
    )
    round_number = round_num

    failed_attempts = failed_archive
    num_failed_attempts = len(failed_attempts)
    num_attempts_left = 3 - num_failed_attempts

    failed_attempt_string = ""
    if len(failed_attempts) > 0:
        for attempt in failed_attempts:
            text = attempt["text"]
            reason = attempt["reason"]
            failed_attempt_string += f"Question: {text}\nReason: {reason}\n"

    return prompt.format(
        previous_questions=previous_questions,
        round_number=round_number,
        num_attempts_left=num_attempts_left,
        failed_attempts=failed_attempt_string,
    )


def is_question_unique(q_embedding, archive) -> bool:
    """
    Checks if a question's embedding is sufficiently different from all previous
    asked questions by the same player.
    """
    archive_questions = archive
    if not archive_questions:
        return True  # archive is empty.

    other_embeddings = [q[-1] for q in archive_questions if q[-1] is not None]
    if not other_embeddings:
        return True

    return is_sufficiently_different(
        q_embedding=q_embedding,
        other_embeddings=other_embeddings,
        threshold=0.336,
    )


def get_valid_task_with_options(task_setter, archive, round_number):
    """
    Generates a valid task with multiple-choice options for the given task setter.
    A valid task is one that has a verifiable ground truth answer, is original (not a copy),
    and has enough distractor-rich options.

    Args:
        task_setter (Player): The language model instance setting the question
        prompt (str): The prompt to generate the question.

    Returns:
        Question: Question object or None if the question is not valid.
    """
    rejected_questions = []

    for attempt in range(3):
        prompt = historical_performance_aug_strat_prompt_generator(
            round_num=round_number,
            archive=archive,
            failed_archive=rejected_questions,
        )

        q_text = get_task(model=task_setter, prompt=prompt)

        if not q_text:
            print(
                f"question generated failed on attempt {attempt + 1} for round {round_number}: {q_text}"
            )
            return None, None, None, None  # Question generation failed

        gt_answer = get_ground_truth_answer(q_text)
        if not gt_answer:
            rejected_questions.append(
                {
                    "text": q_text,
                    "reason": "This question does not have a verifiable ground truth answer/ the code did not return a valid answer.",
                }
            )
            continue

        question_number = len(archive) + 1
        q_text = f"Q{question_number}: {q_text}"

        q_embedding = get_embedding(question=q_text)
        if q_embedding is None or not is_question_unique(
            q_embedding=q_embedding, archive=archive
        ):
            rejected_questions.append(
                {
                    "text": q_text,
                    "reason": "This question is too similar to one, or more, of the questions you previous created.",
                }
            )
            continue

        options = get_mc_options(
            model=task_setter,
            question=q_text,
            gt_answer=gt_answer,
            N_options=9,
        )

        if not options:
            rejected_questions.append(
                {
                    "text": q_text,
                    "reason": "This question is not complex enough to have many distractor options.",
                }
            )
            continue

        return q_text, gt_answer, options, q_embedding

    return None, None, None, None
