import json
import os
from typing import List
from tqdm import tqdm

from src.bongard_problems.classes import BongardResolveAttempt
from src.image import is_image_supported
from src.llm_messenger.classes.content import Content
from src.llm_messenger.classes.image_content import ImageContent
from src.llm_messenger.classes.llm_messenger import LLMMessenger
from src.llm_messenger.classes.text_content import TextContent
from src.prompting_techniques.prompt import CommonPrompts, DescriptivePrompts


def resolve_bongard_with_descriptive_prompting(
    problem_ids: List[int],
    splitted_data_path: str,
    output_file: str,
    model: LLMMessenger,
    descriptions_file: str,
    problem_description: str = CommonPrompts.PROBLEM_DESCRIPTION,
    question: str = DescriptivePrompts.QUESTION,
    with_image: bool = False,
    reevaluate: bool = False,
    num_panels: int = 6,
):
    resolve_attempt = BongardResolveAttempt(model_name=model.get_name())
    if os.path.exists(output_file) and not reevaluate:
        resolve_attempt.load(output_file)

    with open(descriptions_file, "r") as file:
        description_dictionary = json.load(file)

    for problem_id in tqdm(problem_ids):
        if resolve_attempt.has_solution(problem_id) and not reevaluate:
            continue

        descriptions = description_dictionary[str(problem_id)]
        left_descriptions = list(dict(sorted(descriptions["left"].items())).values())[
            :num_panels
        ]
        right_descriptions = list(dict(sorted(descriptions["right"].items())).values())[
            :num_panels
        ]

        msg_contents: List[Content] = [
            TextContent(problem_description),
            (
                ImageContent.from_basename(f"{splitted_data_path}/{problem_id}/whole")
                if with_image
                else None
            ),
            TextContent(
                question.format(
                    left_descriptions="\n".join(left_descriptions),
                    right_descriptions="\n".join(right_descriptions),
                )
            ),
        ]

        try:
            model_answer = model.ask(msg_contents)
            resolve_attempt.add_solution(problem_id, model_answer)
            resolve_attempt.save(output_file)
        except Exception as e:
            print(f"Failed to solve problem with id: {problem_id}.")
            print(e)


def get_descriptions(
    problem_ids: List[int],
    splitted_data_path: str,
    output_file: str,
    model: LLMMessenger,
    question: str = DescriptivePrompts.DESCRIBE_IMAGE_PROMPT,
    num_panels: int = 6,
    reevaluate: bool = False,
):
    for problem_id in tqdm(problem_ids):
        get_side_descriptions(
            problem_id,
            f"{splitted_data_path}/{problem_id}/left",
            model,
            question,
            "left",
            output_file,
            num_panels,
            reevaluate,
        )

        get_side_descriptions(
            problem_id,
            f"{splitted_data_path}/{problem_id}/right",
            model,
            question,
            "right",
            output_file,
            num_panels,
            reevaluate,
        )


def get_side_descriptions(
    problem_id: int,
    splitted_data_path: str,
    model: LLMMessenger,
    question: str,
    side: str,
    output_file: str,
    num_panels: int = 6,
    reevaluate: bool = False,
):
    for filename in sorted(os.listdir(splitted_data_path))[:num_panels]:
        if is_image_supported(filename):
            image_path = os.path.join(splitted_data_path, filename)

            if not os.path.exists(output_file) or os.path.getsize(output_file) == 0:
                description_dictionary = {}
            else:
                with open(output_file, "r") as file:
                    description_dictionary = json.load(file)

            if str(problem_id) not in description_dictionary:
                description_dictionary[str(problem_id)] = {"left": {}, "right": {}}

            if (
                filename not in description_dictionary[str(problem_id)][side]
                or reevaluate
            ):
                description = get_description(image_path, model, question)
                description_dictionary[str(problem_id)][side][filename] = description

                with open(output_file, "w+") as file:
                    json.dump(description_dictionary, file, indent=4)
        else:
            print(f"Unsupported extension: {filename}")


def get_description(
    image_path: str,
    model: LLMMessenger,
    question: str = DescriptivePrompts.DESCRIBE_IMAGE_PROMPT,
) -> str:
    contents = [
        ImageContent(image_path),
        TextContent(question),
    ]

    try:
        description = model.ask(contents)
        return description
    except Exception:
        return ""
