import os
from typing import List
from tqdm import tqdm

from src.bongard_problems.classes import BongardResolveAttempt
from src.evaluation.model import gather_side_images
from src.image import get_image_filename_with_extension
from src.llm_messenger.classes.content import Content
from src.llm_messenger.classes.image_content import ImageContent
from src.llm_messenger.classes.llm_messenger import LLMMessenger
from src.llm_messenger.classes.text_content import TextContent
from src.prompting_techniques.common import (
    get_concepts,
    parse_json_from_model_response,
)
from src.prompting_techniques.prompt import CommonPrompts


def resolve_bongard_with_direct_prompting(
    problem_ids: List[int],
    splitted_data_path: str,
    output_file: str,
    model: LLMMessenger,
    problem_description: str = CommonPrompts.PROBLEM_DESCRIPTION,
    question: str = CommonPrompts.QUESTION,
    reevaluate: bool = False,
):
    resolve_attempt = BongardResolveAttempt(model_name=model.get_name())
    if os.path.exists(output_file) and not reevaluate:
        resolve_attempt.load(output_file)

    for problem_id in tqdm(problem_ids):
        if resolve_attempt.has_solution(problem_id) and not reevaluate:
            continue
        msg_contents: List[Content] = [
            ImageContent(
                get_image_filename_with_extension(
                    f"{splitted_data_path}/{problem_id}/whole"
                )
            ),
            TextContent(problem_description),
            TextContent(question),
        ]

        try:
            model_answer = model.ask(msg_contents)
            resolve_attempt.add_solution(problem_id, model_answer)
            resolve_attempt.save(output_file)
        except Exception as e:
            print(f"Failed to solve problem with id: {problem_id}.")
            print(e)


CLASSIFICATION_DEFAULT_QUESTION = """
You are a visual undestanding module which is designed to provide accurate answers. 
Your goal is to solve a provided Bongard problem using one of the following classes:

{classes}

What is the difference between the two sides of the image? Choose the correct class.
Remember to provide only a chosen class. Use the JSON format: 

{{
    "class": "pentagons vs circles",
    "explanation": "It's clear that all shapes on the left are pentagons, while all shapes on the right are circles. The difference is clear. Other classes are not relevant."
}}
"""


def resolve_bongard_with_direct_classification_prompting(
    problem_ids: List[int],
    splitted_data_path: str,
    output_file: str,
    labels_file: str,
    model: LLMMessenger,
    problem_description: str = CommonPrompts.PROBLEM_DESCRIPTION,
    question: str = CLASSIFICATION_DEFAULT_QUESTION,
):
    classes = get_concepts(problem_ids, labels_file)
    classes_text = "\n".join(classes)
    resolve_attempt = BongardResolveAttempt(model_name=model.get_name())

    for problem_id in problem_ids:
        model.open_context(f"problem_{problem_id}")

        msg_contents: List[Content] = [
            ImageContent(
                get_image_filename_with_extension(
                    f"{splitted_data_path}/{problem_id}/whole"
                )
            )
        ]

        if problem_description is not None:
            msg_contents.append(TextContent(problem_description))

        msg_contents.append(TextContent(question.format(classes=classes_text)))

        model_answer = model.ask(msg_contents)
        resolve_attempt.add_solution(problem_id, model_answer)
        resolve_attempt.save(output_file)
        model.close_context()


def resolve_bongard_with_direct_classification_prompting_splitted(
    problem_ids: List[int],
    n_classes: int,
    splitted_data_path: str,
    output_file: str,
    labels_file: str,
    model: LLMMessenger,
    problem_description: str = CommonPrompts.PROBLEM_DESCRIPTION,
    question: str = CLASSIFICATION_DEFAULT_QUESTION,
):
    classes = get_concepts(n_classes, labels_file)
    classes_text = "\n".join(classes)
    resolve_attempt = BongardResolveAttempt(model_name=model.get_name())

    if os.path.exists(output_file):
        resolve_attempt.load(output_file)

    for problem_id in problem_ids:
        if not os.path.exists(f"{splitted_data_path}/{problem_id}"):
            continue

        expected_class = classes[problem_id - 1]
        model.open_context(f"{problem_id}_{expected_class}")

        msg_contents: List[Content] = [
            TextContent(problem_description),
            TextContent(question.format(classes=classes_text)),
            TextContent("Left side images:"),
            *gather_side_images(problem_id, splitted_data_path, "left"),
            TextContent("Right side images:"),
            *gather_side_images(problem_id, splitted_data_path, "right"),
        ]

        model_answer = model.ask(msg_contents)
        model_answer_json = parse_json_from_model_response(model_answer)
        resolve_attempt.add_solution(
            problem_id,
            model_answer_json["class"],
            model_answer_json["explanation"],
        )
        resolve_attempt.evaluate_solution(
            problem_id,
            "OK" if expected_class == model_answer_json["class"] else "WRONG",
        )
        resolve_attempt.save(output_file)
        model.close_context()
