import os
from typing import List

from tqdm import tqdm

from src.bongard_problems.classes import BongardResolveAttempt
from src.llm_messenger.classes.image_content import ImageContent
from src.llm_messenger.classes.llm_messenger import LLMMessenger
from src.llm_messenger.classes.text_content import TextContent
from src.prompting_techniques.prompt import CommonPrompts, ContrastiveIterativePrompts


def resolve_bongard_with_iterative_contrastive_prompting(
    problem_ids: List[int],
    splitted_data_path: str,
    output_file: str,
    model: LLMMessenger,
    problem_description: str = CommonPrompts.PROBLEM_DESCRIPTION,
    question: str = ContrastiveIterativePrompts.QUESTION,
    final_question: str = ContrastiveIterativePrompts.FINAL_QUESTION,
    reevaluate: bool = False,
    num_comparisons: int = 6,
):
    resolve_attempt = BongardResolveAttempt(model_name=model.get_name())
    if os.path.exists(output_file) and not reevaluate:
        resolve_attempt.load(output_file)

    for problem_id in tqdm(problem_ids):
        if resolve_attempt.has_solution(problem_id) and not reevaluate:
            continue

        try:
            model.open_context(f"problem_{problem_id}", keep_image_history=False)
            model.ask([TextContent(problem_description), TextContent(question)])

            for i in range(0, num_comparisons):
                get_comparison(
                    problem_id=problem_id,
                    splitted_data_path=splitted_data_path,
                    model=model,
                    square_number=i,
                )

            model_answer = model.ask([TextContent(final_question)])
            resolve_attempt.add_solution(problem_id, model_answer)
            resolve_attempt.save(output_file)
            model.close_context()
        except Exception as e:
            print(f"Failed to solve problem with id: {problem_id}.")
            print(e)


def get_comparison(
    problem_id: int,
    splitted_data_path: str,
    model: LLMMessenger,
    square_number: int,
) -> str:
    try:
        return model.ask(
            [
                TextContent("LEFT IMAGE:"),
                ImageContent.from_basename(
                    f"{splitted_data_path}/{problem_id}/left/{square_number}"
                ),
                TextContent("RIGHT IMAGE:"),
                ImageContent.from_basename(
                    f"{splitted_data_path}/{problem_id}/right/{square_number}"
                ),
            ]
        )
    except Exception as e:
        print(e)
        return ""
