import os
from typing import Dict, List, Literal, Set, Optional, Callable, Tuple

import pandas as pd
from tqdm import tqdm

from src.bongard_problems.classes import (
    BongardResolveAttempt,
    BongardSolution,
    ClassificationAttempt,
)
from src.llm_messenger.classes.image_content import ImageContent
from src.llm_messenger.classes.llm_messenger import LLMMessenger
from src.llm_messenger.classes.text_content import TextContent
from src.prompting_techniques.common import DEFAULT_PROBLEM_DESCRIPTION


OLD_COMPARISON_QUESTION = """
You are a logic module which is designed to provide accurate answers. 
A Bongard problem is a kind of puzzle. 
The objective is to spot the difference between the contents of squares located on the two opposite sides of the image. 

You are given correctly assigned labels of these sides showing the correct answer.
You must decide whether the answer provided by the user is correct and matches with those labels. Answer with 'OK' or 'WRONG'.

Notice: user may refer to the sides of the puzzle as separate images or sets. 
Also, we assume that the order of the sides (images) does not matter.

LEFT SIDE LABEL:
{first_label}

RIGHT SIDE LABEL:
{second_label}

USER ANSWER:
{model_answer}
""".strip()


COMPARISON_QUESTION = """
You are a logic module designed to provide accurate answers. In a Bongard Problem the objective is to spot the difference between the contents of images located on the two opposite sides of the problem. You are given correct labels of these sides and must decide whether the answer provided by the user is correct and matches with those labels. Answer with 'OK' or 'WRONG'.

LEFT SIDE LABEL:
{first_label}

RIGHT SIDE LABEL:
{second_label}

USER ANSWER:
{model_answer}
""".strip()


def evalute_answers_using_model(
    answers_file: str,
    labels_file: str,
    model: LLMMessenger,
    splitted_data_path: Optional[str] = None,
    author: Optional[str] = None,
    with_image: bool = False,
    reevaluate: bool = False,
    prompt: str = COMPARISON_QUESTION,
):
    resolve_attempt = BongardResolveAttempt()
    resolve_attempt.load(answers_file)

    df = pd.read_csv(labels_file)
    evaluation_author = author if author else model.get_name()

    for solution in tqdm(resolve_attempt.get_solutions().values()):
        authors_that_checked_solution = [
            evaluation.author for evaluation in solution.evaluations
        ]

        if (
            evaluation_author in authors_that_checked_solution
            and not reevaluate
            or solution.answer == ""
        ):
            continue

        model_answer = solution.answer
        first_label = df.iloc[solution.problem_id - 1]["Left-side Rule"]
        second_label = df.iloc[solution.problem_id - 1]["Right-side Rule"]

        comparison_question = prompt.format(
            first_label=first_label,
            second_label=second_label,
            model_answer=model_answer,
        )

        msg_contents = [TextContent(comparison_question)]

        if with_image:
            if splitted_data_path is None:
                raise Exception(f"splitted_data_path is required for with_image=True.")

            msg_contents.append(
                ImageContent(f"{splitted_data_path}/{solution.problem_id}/whole.png")
            )

        try:
            model_response = model.ask(msg_contents)
            resolve_attempt.evaluate_solution(
                solution.problem_id,
                model_response,
                author=evaluation_author,
                reevaluate=reevaluate,
            )
            resolve_attempt.save(answers_file)
        except ValueError as e:
            # Response might have been blocked by a safety filter, let's ignore it and proceed to the next problem
            print(f"Failed to evaluate answer to problem with id: {solution.problem_id}.")
            print(e)


def count_correct_answers(
    answers_file: str,
    evaluation_author: Optional[str] = None,
):
    solutions = get_resolved_problems(answers_file, evaluation_author)
    return len(solutions)


def count_solutions(
    answers_file: str, solution_predicate: Callable[[BongardSolution], bool]
):
    solutions = get_solutions_from_file(answers_file)
    correct_answers = 0

    for solution in solutions.values():
        if solution_predicate(solution):
            correct_answers += 1

    return correct_answers


def count_num_answers(answers_file: str):
    solutions = get_solutions_from_file(answers_file)
    return len(
        [solution for solution in solutions.values() if solution.evaluations != []]
    )


def get_solutions_from_file(answers_file: str) -> Dict[int, BongardSolution]:
    resolve_attempt = BongardResolveAttempt()
    resolve_attempt.load(answers_file)
    return resolve_attempt.get_solutions()


def get_unresolved_problems(
    answers_file: str,
    evaluation_author: Optional[str] = None,
) -> Set[int]:
    solutions = get_solutions_from_file(answers_file)
    unsolved_problems: Set[int] = set()

    for problem_id, solution in solutions.items():
        solved = False
        for evaluation in solution.evaluations:
            if evaluation_author and evaluation_author != evaluation.author:
                continue
            if "OK" in evaluation.value:
                solved = True
                break

        if not solved:
            unsolved_problems.add(problem_id)

    return unsolved_problems


def get_resolved_problems(
    answers_file: str,
    evaluation_authors: Optional[List[str]] = None,
    evaluation_type: Literal["voting", "all models", "any model"] = "voting",
) -> Set[int]:
    if evaluation_type == "voting":
        return get_resolved_problems_with_voting(answers_file, evaluation_authors)
    elif evaluation_type == "all models":
        return get_resolved_problems_by_all_models(answers_file, evaluation_authors)
    elif evaluation_type == "any model":
        return get_resolved_problems_by_any_model(answers_file, evaluation_authors)


def get_resolved_problems_by_all_models(
    answers_file: str,
    evaluation_authors: Optional[str] = None,
) -> Set[int]:
    solutions = get_solutions_from_file(answers_file)
    solved_problems: Set[int] = set()

    for problem_id, solution in solutions.items():
        ok = True
        for evaluation in solution.evaluations:
            if evaluation_authors and evaluation.author not in evaluation_authors:
                continue
            if "OK" not in evaluation.value:
                ok = False
                break

        if ok:
            solved_problems.add(problem_id)

    return solved_problems


def get_resolved_problems_by_any_model(
    answers_file: str,
    evaluation_authors: Optional[str] = None,
) -> Set[int]:
    solutions = get_solutions_from_file(answers_file)
    solved_problems: Set[int] = set()

    for problem_id, solution in solutions.items():
        for evaluation in solution.evaluations:
            if evaluation_authors and evaluation.author not in evaluation_authors:
                continue
            if "OK" in evaluation.value:
                solved_problems.add(problem_id)
                break

    return solved_problems


def get_resolved_problems_with_voting(
    answers_file: str,
    evaluation_authors: Optional[List[str]] = None,
) -> Set[int]:
    solutions = get_solutions_from_file(answers_file)
    solved_problems: Set[int] = set()

    for problem_id, solution in solutions.items():
        ok_votes = sum(
            1
            for evaluation in solution.evaluations
            if (evaluation_authors is None or evaluation.author in evaluation_authors)
            and "OK" in evaluation.value
        )

        if evaluation_authors is not None and ok_votes >= len(evaluation_authors) / 2:
            solved_problems.add(problem_id)
        elif evaluation_authors is None and ok_votes >= len(solution.evaluations) / 2:
            solved_problems.add(problem_id)

    return solved_problems


def count_voting_solutions(
    answers_file: str,
    authors: Set[str],
    predicate: Callable[[str], bool],
    threshold: int,
    ignore_missing: bool = False,
) -> Tuple[int, int]:
    solutions = get_solutions_from_file(answers_file)
    num_solved_problems = 0
    num_all_problems = 0
    for problem_id, solution in solutions.items():
        num_accepted = 0
        for author in authors:
            evaluations = [e for e in solution.evaluations if e.author == author]
            if len(evaluations) == 0 and ignore_missing:
                print(f"Problem {problem_id} was not evaluated by {author}")
                continue
            elif len(evaluations) == 0:
                raise ValueError(f"Problem {problem_id} was not evaluated by {author}")
            elif len(evaluations) > 1:
                raise ValueError(
                    f"Problem {problem_id} was evaluated by {author} {len(evaluations)} times"
                )
            if predicate(evaluations[0].value):
                num_accepted += 1
        num_all_problems += 1
        if num_accepted >= threshold:
            num_solved_problems += 1
    return num_all_problems, num_solved_problems


def get_voting_mean_agreement(
    answers_file: str,
    evaluation_authors: Optional[List[str]] = None,
):
    solutions = get_solutions_from_file(answers_file)
    solved_problems = get_resolved_problems_with_voting(
        answers_file, evaluation_authors
    )
    agreement_level = 0

    for problem_id, solution in solutions.items():
        included_authors_count = (
            len(evaluation_authors)
            if evaluation_authors is not None
            else len(solution.evaluations)
        )

        ok_votes = sum(
            1
            for evaluation in solution.evaluations
            if (evaluation_authors is None or evaluation.author in evaluation_authors)
            and "OK" in evaluation.value
        )

        wrong_votes = included_authors_count - ok_votes

        if problem_id in solved_problems:
            agreement_level += ok_votes / included_authors_count
        else:
            agreement_level += wrong_votes / included_authors_count

    return agreement_level / len(solutions)


EVALUATE_EXPLANATION_QUESTION = """
You are a logic module which is designed to provide accurate answers. 
A Bongard problem is a kind of puzzle. 
The objective is to classify a query image to proper side of the image and provide short explanation. 

You are given correctly assigned labels of these sides showing the proper answer.
You must decide wheter explanation provided by user is correct and matches with those labels. Answer with 'OK' or 'WRONG'. 

Notice: user may refer to the sides of the puzzle as seperate images or sets. 
Also, we assume that the order of the sides (images) does not matter.

FIRST IMAGE LABEL:
{first_label}

SECOND IMAGE LABEL:
{second_label}

USER ANSWER:
{model_answer}
"""


def check_if_explanation_matches_labels(
    answers_file: str,
    labels_file: str,
    model: LLMMessenger,
    reevaluate: bool = False,
):
    resolve_attempt = ClassificationAttempt()
    resolve_attempt.load(answers_file)

    df = pd.read_csv(labels_file)

    for query_file, solution in resolve_attempt.get_solutions().items():
        if (
            solution.evaluation != "OK"
            or solution.explanation_evaluation != ""
            and not reevaluate
        ):
            continue

        model_answer = solution.answer
        first_label = df.iloc[solution.problem_id - 1]["Left-side Rule"]
        second_label = df.iloc[solution.problem_id - 1]["Right-side Rule"]

        comparison_question = EVALUATE_EXPLANATION_QUESTION.format(
            first_label=first_label,
            second_label=second_label,
            model_answer=model_answer,
        )

        msg_contents = [TextContent(comparison_question)]
        model_response = model.ask(msg_contents)
        resolve_attempt.evaluate_explanation(query_file, model_response)
        resolve_attempt.save(answers_file)


EVALUATE_EXPLANATION_QUESTION = """
You are a logic module which is designed to provide accurate answers. 
A Bongard problem is a kind of puzzle. 
All images on the left side match a certain rule, and all images on the right do not.

The user was given a task: 
`{task_to_evaluate_description}`

Based on the images, rule and correct answer, you must decide whether the explanation provided by user is correct. 
Also, you are given a correct answer. DO NOT TRY TO RECLASSIFY THE IMAGE.
If the correct answer is 'LEFT', the left side of the image is the correct answer.
Answer with:
'OK' when user answer mentions  or 'WRONG' and provide a short comment.  

Notice: user may refer to the sides of the puzzle as seperate images or sets. 
"""


def check_if_explanations_matches_labels_and_images(
    answers_file: str,
    labels_file: str,
    splitted_data_path: str,
    model: LLMMessenger,
    task_to_evaluate_description: str,
    reevaluate: bool = False,
):
    resolve_attempt = ClassificationAttempt()
    resolve_attempt.load(answers_file)

    df = pd.read_csv(labels_file)

    for query_file, solution in resolve_attempt.get_solutions().items():
        if (
            solution.evaluation != "OK"
            or solution.explanation_evaluation != ""
            and not reevaluate
        ):
            continue

        left_images = gather_side_images(
            solution.problem_id, splitted_data_path, "left"
        )
        right_images = gather_side_images(
            solution.problem_id, splitted_data_path, "right"
        )

        model_answer = solution.answer
        correct_answer = "RIGHT" if "right" in query_file else "LEFT"
        left_label = df.iloc[solution.problem_id - 1]["Left-side Rule"]

        msg_contents = [
            TextContent(
                EVALUATE_EXPLANATION_QUESTION.format(
                    task_to_evaluate_description=task_to_evaluate_description
                )
            ),
            TextContent("\nLeft images:\n"),
            *left_images,
            TextContent("\nRight images:\n"),
            *right_images,
            TextContent(f"\nLeft images common feature: {left_label}\n"),
            TextContent(f"Correct answer: {correct_answer}\n"),
            TextContent("Query image:\n"),
            ImageContent(query_file),
            TextContent(f"\n\nUser answer: `{model_answer}`\n"),
        ]

        model_response = model.ask(msg_contents)
        resolve_attempt.evaluate_explanation(query_file, model_response)
        resolve_attempt.save(answers_file)


def gather_side_images(
    problem_id: int, splitted_data_path: str, side: str, all: bool = False
):
    side_directory = os.path.join(splitted_data_path, str(problem_id), side)
    image_contents = [
        ImageContent(os.path.join(side_directory, image_file))
        for image_file in sorted(os.listdir(side_directory))
    ]

    return image_contents if all else image_contents[:6]


CHECK_CORRECTNESS_QUESTION = """
Your goal is to evaluate the correctness of the user's answer based on the given Bongard problem.
All images are provided correctly. Respond with 'OK' if the user's answer is correct, otherwise respond with 'WRONG'. 
Do not explain the answer, just evaluate it.
"""


def check_solution_correctness(
    answers_file: str,
    splitted_data_path: str,
    model: LLMMessenger,
    description: str = DEFAULT_PROBLEM_DESCRIPTION,
    question: str = CHECK_CORRECTNESS_QUESTION,
    reevaluate: bool = False,
    use_joint_image: bool = False,
):
    resolve_attempt = BongardResolveAttempt()
    resolve_attempt.load(answers_file)

    for solution in tqdm(resolve_attempt.get_solutions().values()):
        models_that_checked_solution = [
            evaluation.author for evaluation in solution.evaluations
        ]
        if model.get_name() in models_that_checked_solution and not reevaluate:
            continue

        if use_joint_image:
            msg_contents = [
                TextContent(description),
                TextContent(question),
                ImageContent.from_basename(
                    os.path.join(splitted_data_path, str(solution.problem_id), "whole")
                ),
                TextContent(f"\n\nUser answer: `{solution.answer}`\n"),
            ]
        else:
            left_images = gather_side_images(
                solution.problem_id, splitted_data_path, "left"
            )
            right_images = gather_side_images(
                solution.problem_id, splitted_data_path, "right"
            )
            msg_contents = [
                TextContent(description),
                TextContent(question),
                TextContent("\nLeft images:\n"),
                *left_images,
                TextContent("\nRight images:\n"),
                *right_images,
                TextContent(f"\n\nUser answer: `{solution.answer}`\n"),
            ]

        try:
            model_response = model.ask(msg_contents)
            resolve_attempt.evaluate_solution(
                solution.problem_id,
                model_response,
                author=model.get_name(),
                reevaluate=reevaluate,
            )
            resolve_attempt.save(answers_file)
        except Exception as e:
            print(f"Failed to solve problem with id: {solution.problem_id}.")
            print(e)


def get_perfect_resolve_attempt(
    problem_ids: List[int],
    labels_file: str,
    answers_file: Optional[str] = None,
    model_name: str = "",
) -> BongardResolveAttempt:
    return get_moved_resolved_attempt(
        problem_ids, 0, labels_file, answers_file, None, model_name
    )


def get_wrong_resolve_attempt(
    problem_ids: List[int],
    labels_file: str,
    answers_file: Optional[str] = None,
    offset: int = 20,
    circular_buffer_size: Optional[int] = None,
    model_name: str = "",
) -> BongardResolveAttempt:
    return get_moved_resolved_attempt(
        problem_ids, offset, labels_file, answers_file, circular_buffer_size, model_name
    )


def get_moved_resolved_attempt(
    problem_ids: List[int],
    offset: int,
    labels_file: str,
    answers_file: Optional[str] = None,
    circular_buffer_size: Optional[int] = None,
    model_name: str = "",
):
    df = pd.read_csv(labels_file)
    resolve_attempt = BongardResolveAttempt(model_name)
    if os.path.exists(answers_file):
        resolve_attempt.load(answers_file)

    for problem_id in problem_ids:
        if resolve_attempt.has_solution(problem_id):
            continue
        label_idx = problem_id - 1 + offset
        if circular_buffer_size is not None:
            label_idx %= circular_buffer_size
        left_label = df.iloc[label_idx]["Left-side Rule"]
        right_label = df.iloc[label_idx]["Right-side Rule"]
        resolve_attempt.add_solution(
            problem_id,
            f"LEFT: {left_label}, RIGHT: {right_label}",
            "",
        )

    return resolve_attempt
