import json
import os
from typing import List
from tqdm import tqdm

from src.bongard_problems.classes import BongardResolveAttempt
from src.image import get_image_filename_with_extension
from src.llm_messenger.classes.image_content import ImageContent
from src.llm_messenger.classes.llm_messenger import LLMMessenger
from src.llm_messenger.classes.text_content import TextContent
from src.prompting_techniques.prompt import CommonPrompts, ContrastivePrompts


def resolve_bongard_with_contrastive_prompting(
    problem_ids: List[int],
    splitted_data_path: str,
    output_file: str,
    model: LLMMessenger,
    comparisons_file: str,
    problem_description: str = CommonPrompts.PROBLEM_DESCRIPTION,
    question: str = ContrastivePrompts.QUESTION,
    with_image: bool = False,
    reevaluate: bool = False,
):
    with open(comparisons_file, "r") as f:
        comparisons_dictionary: dict = json.load(f)

    resolve_attempt = BongardResolveAttempt(model_name=model.get_name())
    if os.path.exists(output_file) and not reevaluate:
        resolve_attempt.load(output_file)

    for problem_id in tqdm(problem_ids):
        if resolve_attempt.has_solution(problem_id) and not reevaluate:
            continue
        comparisons = comparisons_dictionary[str(problem_id)]
        comparisons_list = "\n".join(comparisons.values())

        msg_contents = [
            TextContent(problem_description),
            (
                ImageContent(
                    get_image_filename_with_extension(
                        f"{splitted_data_path}/{problem_id}/whole"
                    )
                )
                if with_image
                else None
            ),
            TextContent(question.format(comparisons=comparisons_list)),
        ]

        try:
            model_answer = model.ask(msg_contents)
            resolve_attempt.add_solution(problem_id, model_answer)
            resolve_attempt.save(output_file)
        except Exception as e:
            print(f"Failed to solve problem with id: {problem_id}.")
            print(e)


def get_comparisons(
    problem_ids: List[int],
    splitted_data_path: str,
    output_file: str,
    model: LLMMessenger,
    problem_description: str = CommonPrompts.PROBLEM_DESCRIPTION,
    question: str = ContrastivePrompts.COMPARE_IMAGES_PROMPT,
    num_comparisons: int = 6,
    reevaluate: bool = False,
):
    if os.path.exists(output_file) and os.path.getsize(output_file) > 0:
        with open(output_file, "r") as file:
            comparisons_dictionary = json.load(file)
    else:
        comparisons_dictionary = {}

    for problem_id in tqdm(problem_ids):
        key = str(problem_id)
        if key in comparisons_dictionary and not reevaluate:
            continue

        comparisons_dictionary[str(problem_id)] = {}
        for i in range(0, num_comparisons):
            comparison = get_comparison(
                problem_id=problem_id,
                splitted_data_path=splitted_data_path,
                model=model,
                description=problem_description,
                question=question,
                square_number=i,
            )

            comparisons_dictionary[str(problem_id)][str(i)] = comparison

            with open(output_file, "w") as f:
                json.dump(comparisons_dictionary, f, indent=4)


def get_comparison(
    problem_id: int,
    splitted_data_path: str,
    model: LLMMessenger,
    square_number: int,
    description: str,
    question: str,
):
    try:
        return model.ask(
            [
                TextContent(description),
                TextContent(question),
                TextContent("LEFT IMAGE:"),
                ImageContent.from_basename(
                    f"{splitted_data_path}/{problem_id}/left/{square_number}"
                ),
                TextContent("RIGHT IMAGE:"),
                ImageContent.from_basename(
                    f"{splitted_data_path}/{problem_id}/right/{square_number}"
                ),
            ]
        )
    except Exception as e:
        print(f"Failed to solve problem with id: {problem_id}.")
        print(e)
        return ""
