import os
from typing import List

from src.bongard_problems.classes import ClassificationAttempt
from src.evaluation.model import gather_side_images
from src.llm_messenger.classes.content import Content
from src.llm_messenger.classes.image_content import ImageContent
from src.llm_messenger.classes.llm_messenger import LLMMessenger
from src.llm_messenger.classes.text_content import TextContent
from src.prompting_techniques.common import (
    get_concepts,
    parse_json_from_model_response,
)
from src.prompting_techniques.prompt import CommonPrompts, ImageToSidePrompts

MULTICLASSIFICATION_QUESTION = """
Provided problem presents one of the following concepts: 

{classes}

Peek one of the concepts above that best matches provided problem and enhance your previous response. 
Warning: If none of the concepts match to your previous reasoning, you might reconsider your previous response.
Use the following format:

{{
    "first": {{
        "answer": "RIGHT",
        "concept": "compact vs elongated",
        "explanation": "The query image shows an elongated object, similar to all the images on the right side, which feature bottles. The images on the left side, on the other hand, feature compact and round apples."
   }}, 
    "second": {{
        "answer": "LEFT",
        "concept": "compact vs elongated",
        "explanation": "The query image shows a compact and round object, similar to all the images on the left side, which feature apples. The images on the right side, on the other hand, feature elongated bottles."
    }}
}}
"""


def resolve_bongard_with_last_image_to_side_multiclassification_two_sides_at_once(
    problem_ids: List[int],
    splitted_data_path: str,
    n_classes: int,
    labels_file: str,
    output_file: str,
    model: LLMMessenger,
    problem_description: str = CommonPrompts.PROBLEM_DESCRIPTION,
    question: str = ImageToSidePrompts.QUESTION,
    multiclassification_question: str = MULTICLASSIFICATION_QUESTION,
    reevaluate: bool = False,
):
    resolve_attempt = ClassificationAttempt(model_name=model.get_name())
    if os.path.exists(output_file) and not reevaluate:
        resolve_attempt.load(output_file)

    concepts = get_concepts(n_classes, labels_file)
    available_classes_text = "\n".join(concepts)

    for problem_id in problem_ids:
        left_images = gather_side_images(
            problem_id, splitted_data_path, "left", all=True
        )
        right_images = gather_side_images(
            problem_id, splitted_data_path, "right", all=True
        )
        last_left_image = left_images[-1]
        last_right_image = right_images[-1]
        left_images = left_images[:-1]
        right_images = right_images[:-1]

        if resolve_attempt.has_solution(
            last_left_image.image_path
        ) and resolve_attempt.has_solution(last_right_image.image_path):
            continue

        classify_picked_images_to_sides(
            problem_id=problem_id,
            concept=concepts[problem_id - 1],
            left_picked_image_path=last_left_image.image_path,
            right_picked_image_path=last_right_image.image_path,
            left_side_images=left_images,
            right_side_images=right_images,
            attempt=resolve_attempt,
            model=model,
            output_file=output_file,
            problem_description=problem_description,
            question=question,
            multiclassification_question=multiclassification_question,
            available_classes_text=available_classes_text,
        )


def classify_picked_images_to_sides(
    problem_id: int,
    concept: str,
    left_picked_image_path: str,
    right_picked_image_path: str,
    left_side_images: List[Content],
    right_side_images: List[Content],
    attempt: ClassificationAttempt,
    model: LLMMessenger,
    output_file: str,
    problem_description: str,
    question: str,
    multiclassification_question: str,
    available_classes_text: str,
):
    if attempt.has_solution(left_picked_image_path):
        return

    model.open_context(f"{problem_id}_{concept}")

    model_response = model.ask(
        [
            TextContent(problem_description),
            TextContent(question),
            TextContent("Bongard problem:"),
            TextContent("Left side images:"),
            *left_side_images,
            TextContent("Right side images:"),
            *right_side_images,
            TextContent("First image query:"),
            ImageContent(left_picked_image_path),
            TextContent("Second image query:"),
            ImageContent(right_picked_image_path),
        ]
    )

    model_response = model.ask(
        [
            TextContent(
                multiclassification_question.format(classes=available_classes_text)
            ),
        ]
    )

    model.close_context()

    json_response = parse_json_from_model_response(model_response)

    add_and_evaluate_solution(
        problem_id,
        concept,
        left_picked_image_path,
        attempt,
        json_response["first"],
        "LEFT",
    )

    add_and_evaluate_solution(
        problem_id,
        concept,
        right_picked_image_path,
        attempt,
        json_response["second"],
        "RIGHT",
    )

    attempt.save(output_file)


def add_and_evaluate_solution(
    problem_id: int,
    concept: str,
    image_path: str,
    attempt: ClassificationAttempt,
    json_response: dict,
    expected_answer: str,
):
    attempt.add_solution(
        problem_id,
        json_response["answer"],
        image_path,
        json_response["explanation"],
        json_response["concept"],
    )
    evaluation = "OK" if json_response["answer"].upper() == expected_answer else "WRONG"
    attempt.evaluate_solution(image_path, evaluation)

    concept_evaluation = (
        "OK" if json_response["concept"].lower() == concept.lower() else "WRONG"
    )
    attempt.evaluate_concept(image_path, concept_evaluation)
