import json
import os
from typing import List

from tqdm import tqdm

from src.bongard_problems.classes import BongardResolveAttempt
from src.image import is_image_supported
from src.llm_messenger.classes.content import Content
from src.llm_messenger.classes.image_content import ImageContent
from src.llm_messenger.classes.llm_messenger import LLMMessenger
from src.llm_messenger.classes.text_content import TextContent
from src.prompting_techniques.prompt import CommonPrompts, DescriptiveIterativePrompts


def resolve_bongard_with_iterative_prompting(
    problem_ids: List[int],
    output_file: str,
    model: LLMMessenger,
    descriptions_file: str,
    problem_description: str = CommonPrompts.PROBLEM_DESCRIPTION,
    question: str = DescriptiveIterativePrompts.QUESTION,
    reevaluate: bool = False,
):
    resolve_attempt = BongardResolveAttempt(model_name=model.get_name())
    if os.path.exists(output_file) and not reevaluate:
        resolve_attempt.load(output_file)

    with open(descriptions_file, "r") as file:
        description_dictionary = json.load(file)

    for problem_id in tqdm(problem_ids):
        if resolve_attempt.has_solution(problem_id) and not reevaluate:
            continue

        try:
            left_description = list(
                description_dictionary[str(problem_id)]["left"].values()
            )[-1]
            right_description = list(
                description_dictionary[str(problem_id)]["right"].values()
            )[-1]

            msg_contents: List[Content] = [
                TextContent(problem_description),
                TextContent(
                    question.format(
                        left_description=left_description,
                        right_description=right_description,
                    )
                ),
            ]

            model_answer = model.ask(msg_contents)
            resolve_attempt.add_solution(problem_id, model_answer)
            resolve_attempt.save(output_file)
        except Exception as e:
            print(f"Failed to solve problem with id: {problem_id}.")
            print(e)


def get_iterative_descriptions(
    problem_ids: List[int],
    splitted_data_path: str,
    output_file: str,
    model: LLMMessenger,
    task_description: str = DescriptiveIterativePrompts.DESCRIBE_ITERATIVELY_SIDE_IMAGES_PROMPT,
    final_answer_prompt: str = DescriptiveIterativePrompts.FINAL_ANSWER_PROMPT,
    keep_image_history: bool = True,
    num_panels: int = 6,
    reevaluate: bool = False,
):
    if os.path.exists(output_file):
        with open(output_file, "r") as file:
            description_dictionary = json.load(file)
    else:
        description_dictionary = dict()
    for problem_id in tqdm(problem_ids):
        if str(problem_id) in description_dictionary and not reevaluate:
            continue
        try:
            get_side_descriptions(
                problem_id,
                f"{splitted_data_path}/{problem_id}/left",
                model,
                task_description,
                "left",
                output_file,
                final_answer_prompt,
                keep_image_history,
                num_panels,
            )

            get_side_descriptions(
                problem_id,
                f"{splitted_data_path}/{problem_id}/right",
                model,
                task_description,
                "right",
                output_file,
                final_answer_prompt,
                keep_image_history,
                num_panels,
            )

        except Exception as e:
            print(f"Failed to get side descriptions for problem with id: {problem_id}.")
            print(e)


def get_side_descriptions(
    problem_id: int,
    splitted_data_path: str,
    model: LLMMessenger,
    task_description: str,
    side: str,
    output_file: str,
    final_answer_prompt: str = DescriptiveIterativePrompts.FINAL_ANSWER_PROMPT,
    keep_image_history: bool = True,
    num_panels: int = 6,
):
    model.open_context(f"problem_{problem_id}_{side}", keep_image_history)
    model.ask([TextContent(task_description)])

    for filename in sorted(os.listdir(splitted_data_path))[:num_panels]:
        if is_image_supported(filename):
            image_path = os.path.join(splitted_data_path, filename)
            description = get_description(image_path, model)

            dump_description(side, description, filename, problem_id, output_file)

    final_deduction = model.ask([TextContent(final_answer_prompt)])
    dump_description(side, final_deduction, "final", problem_id, output_file)
    model.close_context()


def dump_description(
    side: str,
    description: str,
    description_key: str,
    problem_id: int,
    output_file: str,
):
    problem_id_str = str(problem_id)

    if not os.path.exists(output_file) or os.path.getsize(output_file) == 0:
        description_dictionary = {}
    else:
        with open(output_file, "r") as file:
            description_dictionary = json.load(file)

    if problem_id_str not in description_dictionary:
        description_dictionary[problem_id_str] = {"left": {}, "right": {}}
    description_dictionary[problem_id_str][side][description_key] = description

    with open(output_file, "w+") as file:
        json.dump(description_dictionary, file, indent=4)


def get_description(
    image_path: str,
    model: LLMMessenger,
) -> str:
    contents = [
        ImageContent(image_path),
    ]

    try:
        description = model.ask(contents)
        return description
    except Exception as e:
        print(e)
        return ""
