import json
from operator import itemgetter
import re
import shutil
from typing import Dict, List
import pandas as pd
from tqdm import tqdm
from src.domain_translator.image_generator import ImageGenerator
from src.llm_messenger.classes.exceptions import UnsupportedImageException
from src.llm_messenger.classes.image_content import ImageContent
from src.llm_messenger.classes.llm_messenger import LLMMessenger
from src.llm_messenger.classes.text_content import TextContent
from .image_provider import ImageProvider
import os
import requests


class ImageCheck:
    def __init__(self, answer: str, decision: str):
        self.answer = answer
        self.decision = decision

    def to_dict(self):
        return {"answer": self.answer, "decision": self.decision}

    @staticmethod
    def from_dict(data: dict):
        return ImageCheck(data["answer"], data["decision"])


class Concept:
    def __init__(
        self,
        concept: str,
        prompt: str,
        question: str,
        img_checks: Dict[str, ImageCheck] = {},
    ):
        self.concept = concept
        self.question = question
        self.prompt = prompt
        self.img_checks = img_checks

    def to_dict(self):
        return {
            "concept": self.concept,
            "prompt": self.prompt,
            "question": self.question,
            "img_checks": {
                key: value.to_dict() for key, value in self.img_checks.items()
            },
        }

    @staticmethod
    def from_dict(data: dict):
        return Concept(
            data["concept"],
            data["prompt"],
            data["question"],
            {
                key: ImageCheck.from_dict(value)
                for key, value in data.get("img_checks", {}).items()
            },
        )


class Translation:
    def __init__(
        self,
        left: Concept,
        right: Concept,
        original_concept: str,
        evaluation: str = "",
    ):
        self.left = left
        self.right = right
        self.original_concept = original_concept
        self.evaluation = evaluation

    def to_dict(self):
        return {
            "left": self.left.to_dict(),
            "right": self.right.to_dict(),
            "original_concept": self.original_concept,
            "evaluation": self.evaluation,
        }

    @staticmethod
    def from_dict(data: dict):
        return Translation(
            Concept.from_dict(data["left"]),
            Concept.from_dict(data["right"]),
            data.get("original_concept", ""),
            data.get("evaluation", ""),
        )

    def to_comparison(self) -> str:
        return f"{self.left.concept} vs {self.right.concept}"

    def __repr__(self) -> str:
        return self.to_comparison()


class TranslationDictionary(Dict[int, List[Translation]]):
    def dump(self, file_path: str):
        with open(file_path, "w") as f:
            json.dump(
                {
                    key: [translation.to_dict() for translation in value]
                    for key, value in self.items()
                },
                f,
                indent=4,
            )

    @staticmethod
    def load(file_path: str):
        translations = TranslationDictionary()

        with open(file_path, "r") as f:
            data = json.load(f)
            translations.update(
                {
                    int(key): [
                        Translation.from_dict(translation) for translation in value
                    ]
                    for key, value in data.items()
                }
            )

        return translations


def translate_from_geometric_concepts(
    rows_to_translate: List[int],
    geometric_labels_csv_path: str,
    output_path: str,
    model: LLMMessenger,
    trials: int = 5,
):
    df = pd.read_csv(geometric_labels_csv_path)
    df_translations = pd.DataFrame(columns=df.columns)
    translations_dict_path = os.path.join(output_path, "translations.json")
    translation_dict = (
        TranslationDictionary.load(translations_dict_path)
        if os.path.exists(translations_dict_path)
        else TranslationDictionary()
    )

    for index in tqdm(rows_to_translate):
        problem_id = index + 1
        if problem_id in translation_dict and len(translation_dict[problem_id]) > 0:
            continue

        row = df.iloc[index]
        left_label = row["Left-side Rule"]
        right_label = row["Right-side Rule"]

        translations = translate_concept(
            left_label=left_label,
            right_label=right_label,
            number_of_translations=trials,
            model=model,
        )

        translation_dict[problem_id] = translations

        for translation in translations:
            df_translations.loc[len(df_translations)] = [
                problem_id,
                translation.left.concept,
                translation.right.concept,
            ]

        translation_dict.dump(translations_dict_path)
        df_translations.to_csv(f"{output_path}/translations.csv", index=False)


def query_images_for_translations(
    translations_directory: str,
    img_provider: ImageProvider,
    imgs_per_concept: int,
):
    translations = TranslationDictionary.load(
        os.path.join(translations_directory, "translations.json")
    )

    for problem_id, problem_translations in tqdm(translations.items()):
        for translation in tqdm(problem_translations):
            left = translation.left.concept
            right = translation.right.concept
            sanitized_left = left.replace(":", "")
            sanitized_right = right.replace(":", "")
            comparison = f"{sanitized_left} vs {sanitized_right}"
            image_output_left = os.path.join(
                translations_directory, str(problem_id), comparison, sanitized_left
            )
            image_output_right = os.path.join(
                translations_directory, str(problem_id), comparison, sanitized_right
            )

            if not os.path.exists(image_output_left):
                find_matching_images(
                    label=left,
                    img_provider=img_provider,
                    output_path=image_output_left,
                    count=imgs_per_concept,
                )

            if not os.path.exists(image_output_right):
                find_matching_images(
                    label=right,
                    img_provider=img_provider,
                    output_path=image_output_right,
                    count=imgs_per_concept,
                )


def generate_images_for_translations(
    translations_directory: str,
    generator: ImageGenerator,
):
    translations = TranslationDictionary.load(
        os.path.join(translations_directory, "translations.json")
    )

    for problem_id, translation in translations.items():
        for translation in translation:
            left = translation.left.concept
            right = translation.right.concept
            comparison = f"{left} vs {right}"
            image_output_left = os.path.join(
                translations_directory, str(problem_id), comparison, left
            )
            image_output_right = os.path.join(
                translations_directory, str(problem_id), comparison, right
            )

            generate_matching_images(
                prompt=translation.left.prompt,
                generator=generator,
                output_path=image_output_left,
            )

            generate_matching_images(
                prompt=translation.right.prompt,
                generator=generator,
                output_path=image_output_right,
            )


def try_generate_lacking_images_for_translations(
    problem_ids: List[int],
    translations_directory: str,
    generator: ImageGenerator,
    expected_image_pairs: int,
    expected_correct_translations: int,
):
    translations = TranslationDictionary.load(
        os.path.join(translations_directory, "translations.json")
    )

    for problem_id, translation in translations.items():
        if problem_id not in problem_ids:
            continue

        correct_translations = 0

        for translation in translation:
            if translation.evaluation == "OK":
                correct_translations += 1
            if correct_translations >= expected_correct_translations:
                break
            if translation.evaluation != "NEEDS_MORE_IMAGES":
                continue

            right_correct_imgs_count = count_correct_images_in_concept(
                translation.right
            )
            left_correct_imgs_count = count_correct_images_in_concept(translation.left)

            sanitized_left = translation.left.concept.replace(":", "")
            sanitized_right = translation.right.concept.replace(":", "")
            comparison = f"{sanitized_left} vs {sanitized_right}"
            image_output_left = os.path.join(
                translations_directory, str(problem_id), comparison, sanitized_left
            )
            image_output_right = os.path.join(
                translations_directory, str(problem_id), comparison, sanitized_right
            )

            print("Generating images for: ", comparison)

            generate_matching_images(
                prompt=translation.left.prompt,
                generator=generator,
                output_path=image_output_left,
                count=expected_image_pairs - left_correct_imgs_count,
            )

            generate_matching_images(
                prompt=translation.right.prompt,
                generator=generator,
                output_path=image_output_right,
                count=expected_image_pairs - right_correct_imgs_count,
            )

            correct_translations += 1


def count_correct_images_in_concept(
    concept: Concept,
):
    return sum(
        1
        for check in concept.img_checks.values()
        if check.decision and "OK" in check.decision
    )


TRANSLATE_DOMAIN_QUESTION = """
Your goal is to translate comparison concept from geometric domain to the real world domain. 
Your translations should be expressible as images.

Example: 
Geometric domain: triangles vs squares
{{
    "left": {{
        "concept": "pyramids",
    }},
    "right": {{
        "concept": "rectangular buildings",
    }}
}}

Give {number_of_translations} unique translations for the following concept as a raw JSON array of objects (same as in the example above). 
{concept}
"""

GENERATE_CONCEPT_QUESTIONS_PROMPT = """
Now, you need to provide a simple and meaningful question for each given concept, 
so that a simple model can classify if the image matches the concept.

Use the question template "Does the picture present <...>? Respond only with 'yes' or 'no'"

Example:
Geometric domain: triangles vs squares
{{
    "left": {{
        "concept": "pyramids",
        "question": "Does the picture present a pyramid? Respond only with 'yes' or 'no'"
    }},
    "right": {{
        "concept": "rectangular buildings",
        "question": "Does the picture present a rectangular building? Respond only with 'yes' or 'no'"
    }}
}}
"""

GENERATE_PROMPTS_PROMPT = """
Now, you need to provide a detailed prompts for stable diffusion model to generate images that match each given concept. 
Your prompts should be simple, short and descriptive and lead to generating realistic image.

Example:
Geometric domain: triangles vs squares
{{
    "left": {{
        "concept": "pyramids",
        "question": "Does the picture present a pyramid? Respond only with 'yes' or 'no'",
        "prompt": "A realistic, high-resolution image of the ancient pyramids of Giza at sunset, with the Great Pyramid prominently in the foreground and the other pyramids visible in the background. The sky is clear with vibrant colors, and the scene includes some tourists in the distance and a few camels near the base of the pyramids. The surrounding desert is bathed in the warm glow of the setting sun, highlighting the textures and details of the stone structures."
    }},
    "right": {{
        "concept": "rectangular buildings",
        "question": "Does the picture present a rectangular building? Respond only with 'yes' or 'no'",
        "prompt": "A realistic cityscape featuring modern rectangular buildings with large glass windows, set against a clear blue sky. The buildings should vary in height and be surrounded by green trees and bustling streets with people and cars."
    }}
}}

"""


def translate_concept(
    left_label: str,
    right_label: str,
    number_of_translations: int,
    model: LLMMessenger,
) -> List[Translation]:
    comparison = f"{left_label} vs {right_label}"
    model.open_context(
        f"translate_concept_{sanitize_text(left_label)}_vs_{sanitize_text(right_label)}"
    )

    model.ask(
        [
            TextContent(
                TRANSLATE_DOMAIN_QUESTION.format(
                    number_of_translations=number_of_translations,
                    concept=comparison,
                )
            )
        ]
    )

    model_response = model.ask([TextContent(GENERATE_CONCEPT_QUESTIONS_PROMPT)])
    model_response = model.ask([TextContent(GENERATE_PROMPTS_PROMPT)])
    model.close_context()

    translations = parse_json_array(model_response)

    for translation in translations:
        translation.original_concept = comparison

    return translations


def sanitize_text(text: str) -> str:
    return text.replace(":", "")[:64]


def find_matching_images(
    label: str,
    img_provider: ImageProvider,
    output_path: str,
    count: int,
) -> List[str]:
    img_urls = img_provider.find_images(label, count)
    downloaded_images = download_images(img_urls, output_path)

    return downloaded_images


def download_images(
    img_urls: List[str],
    output_path: str,
    prefix: str = "",
) -> List[str]:
    downloaded_images: List[str] = []
    os.makedirs(output_path, exist_ok=True)
    files_in_output_directory = os.listdir(output_path)
    files_count_in_output_directory = len(files_in_output_directory)

    for i, img_url in enumerate(img_urls):
        image_extension = img_url.split(".")[-1].split("?")[0].lower()
        img_path = f"{output_path}/{prefix}{files_count_in_output_directory + i}.{image_extension}"
        response = requests.get(img_url)

        if response.status_code == 200:
            with open(img_path, "wb") as f:
                f.write(response.content)
            downloaded_images.append(img_path)

    return downloaded_images


def check_if_images_match_concepts(
    translations_directory: str,
    model: LLMMessenger,
):
    translation_dict_path = os.path.join(translations_directory, "translations.json")
    translation_dict = TranslationDictionary.load(translation_dict_path)

    for problem_id, translations in translation_dict.items():
        for translation in translations:
            left = translation.left.concept
            right = translation.right.concept
            comparison = f"{left} vs {right}"
            left_images_path = os.path.join(
                translations_directory, str(problem_id), comparison, left
            )
            right_images_path = os.path.join(
                translations_directory, str(problem_id), comparison, right
            )

            left_images = [
                f"{left_images_path}/{image}" for image in os.listdir(left_images_path)
            ]
            right_images = [
                f"{right_images_path}/{image}"
                for image in os.listdir(right_images_path)
            ]

            if translation.left.img_checks == {}:
                translation.left.img_checks = check_if_images_match_quesion(
                    question=translation.left.question,
                    images=left_images,
                    model=model,
                )

                translation_dict.dump(translation_dict_path)

            if translation.right.img_checks == {}:
                translation.right.img_checks = check_if_images_match_quesion(
                    question=translation.right.question,
                    images=right_images,
                    model=model,
                )

                translation_dict.dump(translation_dict_path)


def check_if_images_match_quesion(
    question: str,
    images: List[str],
    model: LLMMessenger,
) -> Dict[str, ImageCheck]:
    checks: Dict[str, ImageCheck] = {}

    for image_path in images:
        model_response = model.ask([TextContent(question), ImageContent(image_path)])

        file_name_with_extension = os.path.basename(image_path)
        file_name, _ = os.path.splitext(file_name_with_extension)
        checks[file_name] = ImageCheck(model_response, "")

    return checks


EVALUATE_CHEKS_PROMPT = """
A visual question answering model (VQA) was asked a 'yes' / 'no' question, but sometimes it described the image instead giving a clear answer. 
Your goal is to answer the same question using its answer. Give your answer in format: 

EVALUATION: OK
or
EVALUATION: REJECTED

Question: {question}
VQA answer:  {answer}
"""


def evaluate_gathered_images(
    translations_directory: str,
    model: LLMMessenger,
):
    translation_dict_path = os.path.join(translations_directory, "translations.json")
    translation_dict = TranslationDictionary.load(translation_dict_path)

    for translations in translation_dict.values():
        for translation in translations:
            evaluate_concept_images(model, translation.left)
            evaluate_concept_images(model, translation.right)

            translation_dict.dump(translation_dict_path)


def evaluate_concept_images(model: LLMMessenger, concept: Concept):
    for check in concept.img_checks.values():
        answer_lower = check.answer.lower()

        if answer_lower == "yes":
            check.decision = "YES"
            continue

        if answer_lower == "no":
            check.decision = "NO"
            continue

        model_response = model.ask(
            [
                TextContent(
                    EVALUATE_CHEKS_PROMPT.format(
                        question=concept.question,
                        answer=check.answer,
                    )
                )
            ]
        )

        check.decision = parse_model_decision(model_response)


MULTIMODAL_EVALUATE_CHEKS_PROMPT = """
You translated a concept comparison from geometric domain to the real-world domain as follows:

Geometric domain: {original_concept}

Real world domain: 
{translation}

Now, you need to check if the queried image matches your translation 
and provides enough information to distinguish it from the other concept.
Don't focus too much on the prompt. It's just a hint for you to understand the concept better.
Provided image represents {image_concept}. 

Give your answer in format:
EVALUATION: OK
EXPLANATION: <here you can provide additional information>
or
EVALUATION: REJECTED
EXPLANATION: <here you can provide additional information>
"""


def evaluate_concept_images_with_multimodal(
    translations_directory: str,
    model: LLMMessenger,
    problem_id: int,
    translation: Translation,
    concept: Concept,
    stop_after_n_approved_images: int,
) -> int:
    concept_path = os.path.join(
        translations_directory,
        str(problem_id),
        translation.to_comparison().replace(":", ""),
        concept.concept.replace(":", ""),
    )

    translation_example = json.dumps(
        {
            "left": {
                "concept": translation.left.concept,
                "prompt": translation.left.prompt,
            },
            "right": {
                "concept": translation.right.concept,
                "prompt": translation.right.prompt,
            },
        },
        indent=4,
    )

    task_content = TextContent(
        MULTIMODAL_EVALUATE_CHEKS_PROMPT.format(
            original_concept=translation.original_concept,
            translation=translation_example,
            image_concept=concept.concept,
        )
    )

    approved_images_count = 0

    files = os.listdir(concept_path)
    files.sort(key=lambda x: os.path.getmtime(os.path.join(concept_path, x)))

    for index, img in enumerate(files):
        img_name, _ = os.path.splitext(img)
        img_path = os.path.join(concept_path, img)

        if img_name in concept.img_checks:
            if concept.img_checks[img_name].decision == "OK":
                approved_images_count += 1
            continue

        if (
            approved_images_count >= stop_after_n_approved_images
            or there_is_not_enough_imgs_to_accept_translation(
                stop_after_n_approved_images,
                approved_images_count,
                files,
                index,
            )
        ):
            break

        try:
            model_response = model.ask(
                [
                    task_content,
                    ImageContent(img_path),
                ]
            )
        except UnsupportedImageException:
            concept.img_checks[img_name] = ImageCheck("Unsupported image", "REJECTED")
            continue
        except Exception:
            model_response = model.ask(
                [
                    task_content,
                    ImageContent(img_path),
                ]
            )

        decision = parse_model_decision(model_response)
        explanation = parse_model_explanation(model_response)
        concept.img_checks[img_name] = ImageCheck(explanation, decision)

        if decision == "OK":
            approved_images_count += 1

    return approved_images_count


def there_is_not_enough_imgs_to_accept_translation(
    stop_after_n_approved_images,
    approved_images_count,
    files,
    index,
):
    return len(files) - index - 1 < stop_after_n_approved_images - approved_images_count


def parse_model_decision(model_response: str) -> str:
    return re.search(r"(.*?)EVALUATION: (.*?)$", model_response, re.MULTILINE).group(2)


def parse_model_explanation(model_response: str) -> str:
    return re.search(r"(.*?)EXPLANATION: (.*?)$", model_response, re.MULTILINE).group(2)


def move_wrong_images_to_trash(translations_directory: str):
    translation_dict_path = os.path.join(translations_directory, "translations.json")
    translation_dict = TranslationDictionary.load(translation_dict_path)

    for problem_id, translations in translation_dict.items():
        for translation in translations:
            comparison = f"{translation.left.concept} vs {translation.right.concept}"
            translation_directory = os.path.join(
                translations_directory, str(problem_id), comparison
            )

            move_wrong_concept_images_to_trash(
                translation.left,
                os.path.join(translation_directory, translation.left.concept),
            )

            move_wrong_concept_images_to_trash(
                translation.right,
                os.path.join(translation_directory, translation.right.concept),
            )


def move_wrong_concept_images_to_trash(
    concept: Concept,
    output_path: str,
):
    trash_directory = f"{output_path}/trash"

    for img_id, check in concept.img_checks.items():
        if "yes" not in check.decision.lower():
            image_path = os.path.join(output_path, f"{img_id}.jpg")
            os.makedirs(trash_directory, exist_ok=True)
            shutil.move(image_path, trash_directory)


def generate_matching_images(
    prompt: str,
    generator: ImageGenerator,
    output_path: str,
    count: int,
) -> List[str]:
    generated_images = []

    for _ in range(count):
        generated_images = generator.generate_images(prompt, 1)
        downloaded_images = download_images(generated_images, output_path, "gen_")
        generated_images.extend(downloaded_images)

    return downloaded_images


def move_from_translations_to_problem_directory(
    problem_directory: str,
    translations_directory: str,
    comparison: str,
    left: str,
    right: str,
):
    os.makedirs(problem_directory, exist_ok=True)
    os.rename(
        f"{translations_directory}/{comparison}/{left}", f"{problem_directory}/left"
    )
    os.rename(
        f"{translations_directory}/{comparison}/{right}", f"{problem_directory}/right"
    )


def parse_json_array(text: str) -> List[Translation]:
    pattern = r"(.*?)\[(.*?)\](.*?)$"

    match = re.search(pattern, text, re.DOTALL)
    json_array = f"[{match.group(2)}]" if match else None
    return (
        [Translation.from_dict(json_object) for json_object in json.loads(json_array)]
        if json_array
        else []
    )


def clear_checks(translations_directory: str):
    translation_dict_path = os.path.join(translations_directory, "translations.json")
    translation_dict = TranslationDictionary.load(translation_dict_path)

    for translations in translation_dict.values():
        for translation in translations:
            for concept in [translation.left, translation.right]:
                concept.img_checks = {}

    translation_dict.dump(translation_dict_path)


def restore_images_from_trash(translations_directory: str):
    for problem_id in os.listdir(translations_directory):
        problem_directory = f"{translations_directory}/{problem_id}"

        if not os.path.isdir(problem_directory):
            continue

        for comparison in os.listdir(problem_directory):
            for concept in os.listdir(
                f"{translations_directory}/{problem_id}/{comparison}"
            ):
                restore_concept_images_from_trash(
                    f"{translations_directory}/{problem_id}/{comparison}/{concept}"
                )


def restore_concept_images_from_trash(
    output_path: str,
):
    trash_directory = f"{output_path}/trash"

    if not os.path.exists(trash_directory):
        return

    for file in os.listdir(trash_directory):
        shutil.move(f"{trash_directory}/{file}", output_path)

    os.rmdir(trash_directory)


def remove_additional_images(translations_directory: str, files_to_keep: int):
    for problem_id in os.listdir(translations_directory):
        problem_directory = f"{translations_directory}/{problem_id}"

        if not os.path.isdir(problem_directory):
            continue

        for comparison in os.listdir(problem_directory):
            comparison_directory = os.path.join(problem_directory, comparison)
            for concept in os.listdir(comparison_directory):
                concept_directory = os.path.join(comparison_directory, concept)

                files_to_remove = [
                    os.path.join(concept_directory, file)
                    for file in os.listdir(concept_directory)[files_to_keep:]
                ]

                for file in files_to_remove:
                    os.remove(file)


def check_with_multimodal(
    problem_ids: List[int],
    translations_directory: str,
    model: LLMMessenger,
    stop_after_n_approved_images: int,
    stop_after_n_approved_translations: int,
):
    translations_dict_path = os.path.join(translations_directory, "translations.json")
    translations_dict = TranslationDictionary.load(translations_dict_path)

    for problem_id in problem_ids:
        translations = translations_dict[problem_id]

        approved_translations = 0

        for checked_translations, translation in enumerate(translations, start=1):
            if translation.evaluation != "":
                if translation.evaluation == "OK":
                    approved_translations += 1
                continue

            if (
                approved_translations >= stop_after_n_approved_translations
                or there_is_not_enough_translations(
                    stop_after_n_approved_translations,
                    translations,
                    approved_translations,
                    checked_translations,
                )
            ):
                break

            approved_images_count_left = evaluate_concept_images_with_multimodal(
                translations_directory,
                model,
                problem_id,
                translation,
                translation.left,
                stop_after_n_approved_images,
            )

            translations_dict.dump(translations_dict_path)

            approved_images_count_right = evaluate_concept_images_with_multimodal(
                translations_directory,
                model,
                problem_id,
                translation,
                translation.right,
                stop_after_n_approved_images,
            )

            translation.evaluation = (
                "OK"
                if approved_images_count_left >= stop_after_n_approved_images
                and approved_images_count_right >= stop_after_n_approved_images
                else "REJECTED"
                if approved_images_count_left <= 2 and approved_images_count_right <= 2
                else "NEEDS_MORE_IMAGES"
            )

            if translation.evaluation == "OK":
                approved_translations += 1

            translations_dict.dump(translations_dict_path)


def there_is_not_enough_translations(
    stop_after_n_approved_translations,
    translations,
    approved_translations,
    checked_translations,
):
    return (
        len(translations) - checked_translations
        < stop_after_n_approved_translations - approved_translations
    )


def complete_with_generated_photos(
    translations_directory: str,
    generator: ImageGenerator,
):
    translations_dict = TranslationDictionary.load(
        os.path.join(translations_directory, "translations.json")
    )

    for problem_id, translations in translations_dict.items():
        translation_to_correct_images = sorted(
            [
                (
                    translation.right.concept,
                    count_correct_images_in_concept(translation.right)
                    + count_correct_images_in_concept(translation.left),
                )
                for translation in translations
            ],
            key=itemgetter(0),
        )

        print(translation_to_correct_images)
