import json
import os
import re
from typing import Dict, List, Union
import pandas as pd

from src.bongard_problems.classes import ClassificationAttempt
from src.image import get_image_filename_with_extension
from src.llm_messenger.classes.content import Content
from src.llm_messenger.classes.image_content import ImageContent
from src.llm_messenger.classes.llm_messenger import LLMMessenger
from src.llm_messenger.classes.text_content import TextContent
from src.prompting_techniques.prompt import CommonPrompts

DEFAULT_QUESTION = """
Your goal is to classify the image into one of the two classes: 'LEFT' or 'RIGHT'. 
Each class is represented by respective side of the provided Bondgard problem. 
Choose the class that best fits the provided image.
The image is always provided correctly. Respond only to the specific request.
Respond with 'LEFT' or 'RIGHT'. Shortly explain your choice.
"""

CONCEPT_QUESTION = """
Your goal is to classify the image into one of the two classes: 'LEFT' or 'RIGHT'. 
Each class is represented by respective side of the provided Bondgard problem. 
Left images share common visual concepts, while right images do not. 
So if you think the query image matches the left side, respond with 'LEFT', otherwise respond with 'RIGHT'.
The image is always provided correctly. Respond only to the specific request.
Respond with 'LEFT' or 'RIGHT'. Additionally, try to provide a feature that all images on the left side have in common.
"""

ONE_SHOT_QUESTION = """
Your goal is to classify the image into one of the two classes: 'LEFT' or 'RIGHT'. 
Each class is represented by respective side of the provided Bondgard problem. 
Left images share common visual concept, while right images do not. 
So if you think the query image matches the left side, respond with 'LEFT', otherwise respond with 'RIGHT'.

The image is always provided correctly. Respond only to the specific request.
Respond with 'LEFT' or 'RIGHT'. Additionally, try to provide a feature that all images on the left side have in common.

EXAMPLES: 
LEFT. All images on the left side have a brown dog, same as the query image.
RIGHT. All images on the left side have a brown dog, while query image has a black dog.
"""

CLASSIC_ONE_SHOT_QUESTION = """
Your goal is to classify the image into one of the two classes: 'LEFT' or 'RIGHT'. 
Each class is represented by respective side of the provided Bondgard problem. 
Left images share common visual concept, while right images do not. 
So if you think the query image matches the left side, respond with 'LEFT', otherwise respond with 'RIGHT'.

The image is always provided correctly. Respond only to the specific request.
Respond with 'LEFT' or 'RIGHT'. Additionally, try to provide a concept that all images on the left side have in common.

EXAMPLES: 
LEFT. COMMON CONCEPT: two triangles. All images on the left side have two triangles, same as the query image.
RIGHT. COMMON CONCEPT: single shape. All images on the left side have single shape, while query image has two shapes.
"""


def resolve_bongard_with_more_examples_image_to_side_classification(
    problem_ids: List[int],
    splitted_data_path: str,
    output_file: str,
    model: LLMMessenger,
    use_spliited_data: bool = False,
    problem_description: str = CommonPrompts.PROBLEM_DESCRIPTION,
    question: str = DEFAULT_QUESTION,
    reevaluate: bool = False,
):
    resolve_attempt = ClassificationAttempt(model_name=model.get_name())
    if os.path.exists(output_file) and not reevaluate:
        resolve_attempt.load(output_file)

    for problem_id in problem_ids:
        classify_more_examples_to_side(
            problem_id=problem_id,
            side="left",
            splitted_data_path=splitted_data_path,
            attempt=resolve_attempt,
            model=model,
            output_file=output_file,
            problem_description=problem_description,
            question=question,
            use_spliited_data=use_spliited_data,
        )

        classify_more_examples_to_side(
            problem_id=problem_id,
            side="right",
            splitted_data_path=splitted_data_path,
            attempt=resolve_attempt,
            model=model,
            output_file=output_file,
            problem_description=problem_description,
            question=question,
            use_spliited_data=use_spliited_data,
        )


def gather_side_images(problem_id: int, splitted_data_path: str, side: str):
    side_directory = os.path.join(splitted_data_path, str(problem_id), side)
    return [
        ImageContent(os.path.join(side_directory, image_file))
        for image_file in os.listdir(side_directory)
    ]


def get_image_data(
    splitted_data_path: str,
    problem_id: int,
    use_spliited_data: bool = False,
) -> List[Content]:
    if not use_spliited_data:
        return [
            ImageContent(
                get_image_filename_with_extension(
                    f"{splitted_data_path}/{problem_id}/whole"
                )
            )
        ]
    else:
        left_images = gather_side_images(problem_id, splitted_data_path, "left")
        right_images = gather_side_images(problem_id, splitted_data_path, "right")

        return [
            TextContent("Left side images:"),
            *left_images,
            TextContent("Right side images:"),
            *right_images,
        ]


def classify_more_examples_to_side(
    problem_id: int,
    side: str,
    splitted_data_path: str,
    attempt: ClassificationAttempt,
    model: LLMMessenger,
    output_file: str,
    problem_description: str,
    question: str,
    use_spliited_data: bool = False,
):
    examples_path = os.path.join(
        splitted_data_path, str(problem_id), "more_examples", side
    )

    for image_file in os.listdir(examples_path):
        image_path = os.path.join(examples_path, image_file)

        if attempt.has_solution(image_path):
            continue

        problem_contents = get_image_data(
            splitted_data_path, problem_id, use_spliited_data
        )

        model_response = model.ask(
            [
                TextContent(problem_description),
                TextContent(question),
                TextContent("Bongard problem:"),
                *problem_contents,
                TextContent("Image to classify:"),
                ImageContent(image_path),
            ]
        )

        attempt.add_solution(problem_id, model_response, image_path)
        evaluation = "OK" if side.upper() in model_response else "WRONG"
        attempt.evaluate_solution(image_path, evaluation)
        attempt.save(output_file)


def resolve_bongard_with_last_image_to_side_classification(
    problem_ids: List[int],
    splitted_data_path: str,
    output_file: str,
    model: LLMMessenger,
    problem_description: str = CommonPrompts.PROBLEM_DESCRIPTION,
    question: str = ONE_SHOT_QUESTION,
) -> None:
    resolve_attempt = ClassificationAttempt(model_name=model.get_name())

    if os.path.exists(output_file):
        resolve_attempt.load(output_file)

    for problem_id in problem_ids:
        left_images = gather_side_images(problem_id, splitted_data_path, "left")
        right_images = gather_side_images(problem_id, splitted_data_path, "right")
        last_left_image = left_images[-1]
        last_right_image = right_images[-1]
        left_images = left_images[:-1]
        right_images = right_images[:-1]

        classify_picked_image_to_side(
            problem_id=problem_id,
            side="left",
            picked_image_path=last_left_image.image_path,
            left_side_images=left_images,
            right_side_images=right_images,
            attempt=resolve_attempt,
            model=model,
            output_file=output_file,
            problem_description=problem_description,
            question=question,
        )

        classify_picked_image_to_side(
            problem_id=problem_id,
            side="right",
            picked_image_path=last_right_image.image_path,
            left_side_images=left_images,
            right_side_images=right_images,
            attempt=resolve_attempt,
            model=model,
            output_file=output_file,
            problem_description=problem_description,
            question=question,
        )


def classify_picked_image_to_side(
    problem_id: int,
    side: str,
    picked_image_path: str,
    left_side_images: List[Content],
    right_side_images: List[Content],
    attempt: ClassificationAttempt,
    model: LLMMessenger,
    output_file: str,
    problem_description: str,
    question: str,
):
    if attempt.has_solution(picked_image_path):
        return

    model_response = model.ask(
        [
            TextContent(problem_description),
            TextContent(question),
            TextContent("Bongard problem:"),
            TextContent("Left side images:"),
            *left_side_images,
            TextContent("Right side images:"),
            *right_side_images,
            TextContent("Image to classify:"),
            ImageContent(picked_image_path),
        ]
    )

    attempt.add_solution(problem_id, model_response, picked_image_path)
    evaluation = "OK" if side.upper() in model_response else "WRONG"
    attempt.evaluate_solution(picked_image_path, evaluation)
    attempt.save(output_file)


ONE_SHOT_GIVE_ONLY_CONCEPT_QUESTION = """
Let's consider two classes: 'LEFT' and 'RIGHT'. 
Each class is represented by respective side of the provided Bondgard problem. 
Left images share common visual concept, which is not present in any of the right images. 
Your task is to provide this common concept.

The image is always provided correctly. Respond only to the specific request.
Provide a feature that all images on the left side have in common and it is not present in any of the right images.

EXAMPLES: 
CONCEPT: A brown dog. 
EXPLANATION: All images on the left side have a brown dog. On the right side there are some dogs, but they are black.

CONCEPT: People in a car.
EXPLANATION: All images on the left side present people in a car. Some images on the right side also have people, but they are outside the car. Also, there are some images with cars, but without people.
"""

VERIFY_QUESTION = """
{question}
Your asnwers must be strict and correct. Respond only with 'YES' or 'NO'.
"""

FORMULATE_VERIFY_QUESTIONS = """
Your goal is to formulate a set of questions about the image from a given concept.
The question should be simple and clear.

EXAMPLES:
CONCEPT: A brown dog. 
QUESTIONS: 
Is there a dog in the image?
Is the dog brown?

CONCEPT: People in a car. 
QUESTIONS: 
Is there a car in the image?
Are there people in the image?
Are people inside the car?

Respond only with the questions.
"""


class ConceptEntry:
    def __init__(
        self,
        problem_id: int,
        concept: str,
        questions: List[str] = [],
        evaluation: str = "",
    ):
        self.problem_id = problem_id
        self.concept = concept
        self.questions = questions
        self.evaluation = evaluation

    def to_dict(self):
        return {
            "problem_id": self.problem_id,
            "concept": self.concept,
            "questions": self.questions,
            "evaluation": self.evaluation,
        }

    @staticmethod
    def from_dict(entry_dict):
        return ConceptEntry(
            problem_id=entry_dict["problem_id"],
            concept=entry_dict["concept"],
            questions=entry_dict["questions"],
            evaluation=entry_dict.get("evaluation", ""),
        )

    def to_json(self):
        return json.dumps(self.to_dict())

    @staticmethod
    def from_json(json_string):
        return ConceptEntry.from_dict(json.loads(json_string))


class ConceptDictionary:
    def __init__(
        self,
        concept_dictionary_path: str,
        model_name: str = "",
        dataset_name: str = "",
    ):
        self.__concept_dictionary_path = concept_dictionary_path
        self.__model_name = model_name
        self.__dataset_name = dataset_name
        self.__concepts: Dict[int, ConceptEntry] = {}

        if os.path.exists(concept_dictionary_path):
            with open(concept_dictionary_path, "r") as json_file:
                concept_dictionary = json.load(json_file)
                concept_list = concept_dictionary.get("concepts", {})

                self.__dataset_name = concept_dictionary.get("dataset_name", "")
                self.__model_name = concept_dictionary.get("model_name", "")
                self.__concepts = {
                    concept["problem_id"]: ConceptEntry.from_dict(concept)
                    for concept in concept_list
                }

    def get_concept(self, problem_id: int) -> Union[ConceptEntry, None]:
        return self.__concepts.get(problem_id, None)

    def save_concept(self, problem_id: int, concept: str):
        self.__concepts[problem_id] = ConceptEntry(problem_id, concept)
        self.dump()

    def save_concept_questions(self, problem_id: int, questions: List[str]):
        if problem_id not in self.__concepts:
            return

        self.__concepts[problem_id].questions = questions
        self.dump()

    def evaluate_concept(self, problem_id: int, evaluation: str):
        if problem_id not in self.__concepts:
            return

        self.__concepts[problem_id].evaluation = evaluation
        self.dump()

    def dump(self):
        with open(self.__concept_dictionary_path, "w") as json_file:
            json.dump(
                {
                    "model_name": self.__model_name,
                    "dataset_name": self.__dataset_name,
                    "concepts": [
                        concept.to_dict() for concept in self.__concepts.values()
                    ],
                },
                json_file,
                indent=4,
            )


def resolve_bongard_with_last_image_to_side_classification_and_two_agents(
    problem_ids: List[int],
    splitted_data_path: str,
    concept_dictionary_path: str,
    output_file: str,
    model: LLMMessenger,
    verify_model: LLMMessenger,
    dataset_name: str = "",
    problem_description: str = CommonPrompts.PROBLEM_DESCRIPTION,
    question: str = ONE_SHOT_GIVE_ONLY_CONCEPT_QUESTION,
):
    resolve_attempt = ClassificationAttempt(model_name=model.get_name())
    concept_dictionary = ConceptDictionary(
        concept_dictionary_path,
        model.get_name(),
        dataset_name=dataset_name,
    )

    if os.path.exists(output_file):
        resolve_attempt.load(output_file)

    solved_images = resolve_attempt.get_solutions()

    for problem_id in problem_ids:
        left_images = gather_side_images(problem_id, splitted_data_path, "left")
        right_images = gather_side_images(problem_id, splitted_data_path, "right")
        last_left_image = left_images[-1]
        last_right_image = right_images[-1]
        left_images = left_images[:-1]
        right_images = right_images[:-1]

        if concept_dictionary.get_concept(problem_id) is None:
            concept = get_concept(
                left_side_images=left_images,
                right_side_images=right_images,
                model=model,
                problem_description=problem_description,
                question=question,
            )

            concept_dictionary.save_concept(problem_id, concept)

        concept_entry = concept_dictionary.get_concept(problem_id)
        if concept_entry.questions == []:
            concept_questions = get_concept_questions(
                verify_model, concept_entry.concept
            )
            concept_dictionary.save_concept_questions(problem_id, concept_questions)

        concept_questions = concept_dictionary.get_concept(problem_id).questions

        if last_left_image.image_path not in solved_images:
            check_concept_in_image(
                problem_id=problem_id,
                side="left",
                picked_image_path=last_left_image.image_path,
                attempt=resolve_attempt,
                verify_model=verify_model,
                output_file=output_file,
                concept_questions=concept_questions,
            )

        if last_right_image.image_path not in solved_images:
            check_concept_in_image(
                problem_id=problem_id,
                side="right",
                picked_image_path=last_right_image.image_path,
                attempt=resolve_attempt,
                verify_model=verify_model,
                output_file=output_file,
                concept_questions=concept_questions,
            )


def get_concept(
    left_side_images: List[Content],
    right_side_images: List[Content],
    model: LLMMessenger,
    problem_description: str,
    question: str = ONE_SHOT_GIVE_ONLY_CONCEPT_QUESTION,
) -> str:
    get_concept_response = model.ask(
        [
            TextContent(problem_description),
            TextContent(question),
            TextContent("Bongard problem:"),
            TextContent("Left side images:"),
            *left_side_images,
            TextContent("Right side images:"),
            *right_side_images,
        ]
    )

    concept = parse_concept(get_concept_response)
    return concept


def get_concept_questions(
    verify_model: LLMMessenger,
    concept: str,
) -> List[str]:
    model_response = verify_model.ask(
        [
            TextContent(FORMULATE_VERIFY_QUESTIONS),
            TextContent(f"CONCEPT: {concept}"),
        ]
    )

    return model_response.split(sep="\n")


def check_concept_in_image(
    problem_id: int,
    side: str,
    picked_image_path: str,
    attempt: ClassificationAttempt,
    verify_model: LLMMessenger,
    output_file: str,
    concept_questions: List[str],
):
    all_questions_passed = True

    for question in concept_questions:
        model_response = verify_model.ask(
            [
                TextContent(VERIFY_QUESTION.format(question=question)),
                ImageContent(picked_image_path),
            ]
        )

        all_questions_passed = True if "YES" in model_response else False

        if not all_questions_passed:
            break

    answer = "LEFT" if all_questions_passed else "RIGHT"
    attempt.add_solution(problem_id, answer, picked_image_path)
    evaluation = "OK" if answer.lower() == side.lower() else "WRONG"
    attempt.evaluate_solution(picked_image_path, evaluation)
    attempt.save(output_file)


EVALUATE_CONCEPT_QUESTION = """
You are a logic module which is designed to provide accurate answers. 
Your task is to tell how the provided concept matches with the provided label. 
Use 'OK" when the CONCEPT fully matches provided label, 'WRONG' when the CONCEPT does not match label, 
'LACK OF DETAIL' when the CONCEPT is too general comparing to label, 
and 'TOO MUCH DETAIL' when the CONCEPT is too specific comparing to label.
Respond only with 'OK', 'WRONG', 'LACK OF DETAIL' or 'TOO MUCH DETAIL'.

EXAMPLES: 
CONCEPT: A dog.
LABEL: A brown dog.
LACK OF DETAIL

CONCEPT: A car with people.
LABEL: A car.
TOO MUCH DETAIL 

CONCEPT: A bedroom.
LABEL: A bed.
WRONG

CONCEPT: Some apples.
LABEL: A few apples.
OK

CONCEPT: {concept}
LABEL: {label}
"""


def evaluate_concept(
    model: LLMMessenger,
    problem_id: int,
    concept: str,
    labels: pd.DataFrame,
    concept_dictionary: ConceptDictionary,
) -> bool:
    label = labels.iloc[problem_id - 1]["Left-side Rule"]

    model_response = model.ask(
        [
            TextContent(EVALUATE_CONCEPT_QUESTION.format(concept=concept, label=label)),
        ]
    )

    concept_dictionary.evaluate_concept(problem_id, model_response)
    return model_response


def save_solution_with_failed_concept(
    problem_id: int,
    picked_image_path: str,
    attempt: ClassificationAttempt,
):
    attempt.add_solution(problem_id, "Wrong concept", picked_image_path)
    attempt.evaluate_solution(picked_image_path, "WRONG")


def parse_concept(model_response) -> str:
    pattern = r"CONCEPT:(.*)$"

    match = re.search(pattern, model_response, re.MULTILINE)
    concept = match.group(1) if match else ""
    return concept
