import os
import random
from typing import List

from tqdm import tqdm

from src.bongard_problems.classes import ClassificationAttempt
from src.evaluation.model import gather_side_images
from src.llm_messenger.classes.content import Content
from src.llm_messenger.classes.image_content import ImageContent
from src.llm_messenger.classes.llm_messenger import LLMMessenger
from src.llm_messenger.classes.text_content import TextContent
from src.postprocessing import fix_json
from src.prompting_techniques.common import parse_json_from_model_response
from src.prompting_techniques.prompt import CommonPrompts, ImageToSidePrompts


def resolve_bongard_with_last_image_to_side_classification_two_sides_at_once(
    problem_ids: List[int],
    splitted_data_path: str,
    output_file: str,
    model: LLMMessenger,
    problem_description: str = CommonPrompts.PROBLEM_DESCRIPTION,
    question: str = ImageToSidePrompts.QUESTION,
    reevaluate: bool = False,
    do_hallucination_intervention: bool = False,
    use_joint_side_image: bool = False,
    do_shuffle_test_panels: bool = False,
):
    resolve_attempt = ClassificationAttempt(model_name=model.get_name())
    if os.path.exists(output_file) and not reevaluate:
        resolve_attempt.load(output_file)

    for problem_id in tqdm(problem_ids):
        if use_joint_side_image:
            left_dir = os.path.join(splitted_data_path, str(problem_id), "left")
            right_dir = os.path.join(splitted_data_path, str(problem_id), "right")
            left_images = [ImageContent.from_basename(left_dir)]
            right_images = [ImageContent.from_basename(right_dir)]
            last_left_image = ImageContent(
                os.path.join(left_dir, sorted(os.listdir(left_dir))[-1])
            )
            last_right_image = ImageContent(
                os.path.join(right_dir, sorted(os.listdir(right_dir))[-1])
            )
        else:
            left_images = gather_side_images(
                problem_id, splitted_data_path, "left", all=True
            )
            right_images = gather_side_images(
                problem_id, splitted_data_path, "right", all=True
            )
            last_left_image = left_images[-1]
            last_right_image = right_images[-1]
            left_images = left_images[:-1]
            right_images = right_images[:-1]

        classify_picked_images_to_sides(
            problem_id=problem_id,
            left_picked_image_path=last_left_image.image_path,
            right_picked_image_path=last_right_image.image_path,
            left_side_images=left_images,
            right_side_images=right_images,
            attempt=resolve_attempt,
            model=model,
            output_file=output_file,
            problem_description=problem_description,
            question=question,
            do_hallucination_intervention=do_hallucination_intervention,
            reevaluate=reevaluate,
            do_shuffle_test_panels=do_shuffle_test_panels,
        )


def classify_picked_images_to_sides(
    problem_id: int,
    left_picked_image_path: str,
    right_picked_image_path: str,
    left_side_images: List[Content],
    right_side_images: List[Content],
    attempt: ClassificationAttempt,
    model: LLMMessenger,
    output_file: str,
    problem_description: str,
    question: str,
    do_hallucination_intervention: bool = False,
    reevaluate: bool = False,
    do_shuffle_test_panels: bool = False,
):
    if attempt.has_solution(left_picked_image_path) and not reevaluate:
        return

    test_image_paths = [left_picked_image_path, right_picked_image_path]
    expected_answers = ["LEFT", "RIGHT"]
    first, second = 0, 1
    if do_shuffle_test_panels and random.random() < 0.5:
        first, second = second, first

    try:
        model_response = model.ask(
            [
                TextContent(problem_description),
                TextContent(question),
                TextContent("Bongard problem:"),
                TextContent("Left images:"),
                *left_side_images,
                TextContent("Right images:"),
                *right_side_images,
                TextContent("First test image:"),
                ImageContent(test_image_paths[first]),
                TextContent("Second test image:"),
                ImageContent(test_image_paths[second]),
            ]
        )

        if do_hallucination_intervention:
            model_response = fix_json(model_response)

        json_response = parse_json_from_model_response(model_response)

        add_and_evaluate_solution(
            problem_id,
            test_image_paths[first],
            attempt,
            json_response["first" if "first" in json_response else "FIRST"],
            expected_answers[first],
            reevaluate=reevaluate,
        )

        add_and_evaluate_solution(
            problem_id,
            test_image_paths[second],
            attempt,
            json_response["second" if "second" in json_response else "SECOND"],
            expected_answers[second],
            reevaluate=reevaluate,
        )

        attempt.save(output_file)

    except Exception as e:
        print(f"problem_id: {problem_id} - Failed to parse json from model response:")
        print(f"Exception: {e}")


def add_and_evaluate_solution(
    problem_id: int,
    picked_image_path: str,
    attempt: ClassificationAttempt,
    json_response: dict,
    expected_answer: str,
    reevaluate: bool = False,
):
    attempt.add_solution(
        problem_id,
        json_response["answer"],
        picked_image_path,
        json_response["explanation"],
        json_response["concept"],
    )
    evaluation = "OK" if expected_answer in json_response["answer"].upper() else "WRONG"
    attempt.evaluate_solution(
        picked_image_path,
        evaluation,
        author="oracle",
        reevaluate=reevaluate,
    )
