import glob
import json
import os
import random
from pathlib import Path
from typing import Annotated

import typer
from datasets import Dataset, DatasetDict, Features, Image, Value, load_dataset
from generate_pdfs_questions import generate_qa_from_page
from model_names import LLMS_NAME, VERTEX
from PIL import Image as PILImage
from prompts import PROMPT_QUESTION_ANSWER_FRENCH
from tqdm import tqdm

random.seed(42)

MAX_LEN_DATASET_SHIFT = 20


def generate_n_questions_from_dir(dataset_name: str, pdf_pages_dir: str, n: int, model_name="gpt4o"):
    """
    Generate n questions from a directory of pdfs.
    """
    print(f"Generating {n} questions from {pdf_pages_dir}")
    image_paths = glob.glob(f"{pdf_pages_dir}/*.jpg")
    print(f"Found {len(image_paths)} images in {pdf_pages_dir}")

    if len(image_paths) == 0:
        print(f"No images found in {pdf_pages_dir}")
        return
    elif n > len(image_paths):
        print(f"Number of images is less than {n}. Generating questions for all images.")
        n = len(image_paths)
    # take n random images
    random_images = random.sample(image_paths, n)

    cost_list = []
    output_file = Path(f"data/shift_dataset/qa_pairs/qa_pairs_{dataset_name}.txt")
    output_file.parent.mkdir(parents=True, exist_ok=True)

    with open(output_file, "a", encoding="utf-8") as f:
        for random_image in tqdm(random_images, desc="Generating QA pairs", total=n):
            # generate a QA pair
            model = VERTEX[model_name] if model_name in VERTEX else LLMS_NAME[model_name]
            print(f"Generating QA with model {model}")

            answer, cost = generate_qa_from_page(
                random_image, model=model, max_tokens=1000, prompt=PROMPT_QUESTION_ANSWER_FRENCH
            )
            cost_list.append(cost)
            try:
                answer = json.loads(answer)

                # save answers to a text file
                f.write(f"\n**Image:** {random_image}\n")
                f.write(f"\n**Questions:** {json.dumps(answer, ensure_ascii=False)}\n")

            except Exception as e:
                print(f"Error processing {random_image}: {e}")

    print(f"Questions saved to {output_file}")
    print(f"Average cost: {sum(cost_list) / len(cost_list)}")


def get_shift_dataset(text_log: str, max_len_dataset=20):
    """
    Generate a dataset of images and captions from a directory of pdfs.
    """

    features = Features(
        {
            "query": Value("string"),
            "image": Image(),
            "image_filename": Value("string"),
            "answer": Value("string"),
            "page": Value("string"),
            "model": Value("string"),
            "prompt": Value("string"),
            "source": Value("string"),
        }
    )

    dirpath_corrupted = Path("data/failed_llm_calls.txt")
    dirpath_corrupted.parent.mkdir(parents=True, exist_ok=True)

    dataset_dict = {key: [] for key in features.keys()}
    images_path = []
    qas = []

    i = 0

    # open the text log file
    with open(text_log, "r") as f:
        lines = f.readlines()

    # Process each line
    for line in lines:
        line = line.strip()
        if line.startswith("**Image:**"):
            # Extract the image path and append it to the images_path list
            image_path = line.split("**Image:**")[1].strip()
            images_path.append(image_path)
        elif line.startswith("**Questions:**"):
            # Extract the questions and answers and parse the JSON
            qas_json_str = line.split("**Questions:**")[1].strip()
            try:
                parsed_qas = json.loads(qas_json_str)
                qas.append(parsed_qas)
            except json.JSONDecodeError as e:
                print(f"Error decoding JSON: {e}")

    assert len(images_path) == len(qas), "Number of images and questions do not match"

    while len(dataset_dict["query"]) < max_len_dataset and i < len(images_path):
        print("\n---------------------------------------------\n")

        pil_image = PILImage.open(images_path[i])
        pil_image.show()

        # replace single quotes with double quotes
        answer = qas[i]

        for j, question in enumerate(answer["questions"], start=1):
            print(f"\nQuestion {j}: {question['question']}")
            print(f"Answer {j}: {question['answer']}")

        user_input = (
            input(
                "\nWhich question do you want to use? (1-3)?\nType '0' if you want to skip this image.\nType 'e' to enter your own question-answer pair\n\nYour input: "
            )
            .strip()
            .lower()
        )
        if user_input == "0":
            i += 1
            continue

        elif user_input.isdigit() and 0 < int(user_input) <= 3:
            question = answer["questions"][int(user_input) - 1]

            dataset_dict["query"].append(question["question"])
            dataset_dict["image"].append(pil_image)
            dataset_dict["image_filename"].append(images_path[i])
            dataset_dict["answer"].append(question["answer"])
            dataset_dict["page"].append(os.path.basename(images_path[i]).split(".")[0].split("_")[-1])
            dataset_dict["model"].append("")
            dataset_dict["prompt"].append(PROMPT_QUESTION_ANSWER_FRENCH)
            dataset_dict["source"].append("pdf")

            print(f"Generated {len(dataset_dict['query'])} rows")
            i += 1

        elif user_input == "e":
            question = input("Enter your question: ")
            answer = input("Enter the answer: ")

            dataset_dict["query"].append(question)
            dataset_dict["image"].append(pil_image)
            dataset_dict["image_filename"].append(images_path[i])
            dataset_dict["answer"].append([answer])
            dataset_dict["page"].append(os.path.basename(images_path[i]).split(".")[0].split("_")[-1])
            dataset_dict["model"].append("")
            dataset_dict["prompt"].append(PROMPT_QUESTION_ANSWER_FRENCH)
            dataset_dict["source"].append("pdf")

            print(f"Generated {len(dataset_dict['query'])} rows")
            i += 1

        else:
            print("Invalid input. Skipping image.")
            i += 1
            continue

    ds = {"test": Dataset.from_dict(dataset_dict, features=features)}

    return DatasetDict(ds)


def concat_datasets():
    """
    Concatenate all the datasets generated from the PDFs
    and push the concatenated dataset to the Hugging Face hub.
    """
    datasets = []
    dataset_names = [
        "approvisionnement_petrolier",
        "decarbonner_sante",
        "aviation",
        "cartographie_transition",
        "rapport_avancement",
    ]

    for dataset_name in dataset_names:
        try:
            ds = load_dataset(f"coldoc/temp_shiftproject_{dataset_name}_test")
            datasets.append(ds["test"])
            print(f"Loaded dataset: {dataset_name}")

        except Exception as e:
            print(f"Failed to load dataset {dataset_name}: {e}")
            continue

    if datasets:
        concatenated_dataset = Dataset.from_dict(
            {key: sum([ds[key] for ds in datasets], []) for key in datasets[0].features}
        )
        concatenated_dataset = DatasetDict({"test": concatenated_dataset})
        concatenated_dataset.push_to_hub("coldoc/shiftproject_test")
        print("Concatenated dataset pushed to Hugging Face hub successfully.")
    else:
        print("No datasets were loaded. Concatenation not performed.")


def main(
    dataset_name: Annotated[str, typer.Argument(help="The name of the dataset to validate")],
    generate_questions: Annotated[bool, typer.Option(help="Generate questions from the PDFs")] = False,
):
    # def main(dataset_name: str, generate_questions: bool, max_len_dataset: int, model_name: str, base_path: str, datasets: dict[str, str
    model_name = "sonnet"
    base_path = "data/scrapped_pdfs_split/pages_extracted/shift_project_test"
    datasets = {
        "approvisionnement_petrolier": "Approvisionnement-petrolier-futur-de-lUE_Shift-Project_Mai-2021_RAPPORT-COMPLET.pdf",
        "decarbonner_sante": "Decarboner-la-sante-pour-soigner-durablement.pdf",
        "aviation": "TSP_AVIATION_RAPPORT_211116.pdf",
        "cartographie_transition": "tsp-_cartographie_de_la_transition-rapport_final.pdf",
        "rapport_avancement": "TSP-PTEF-V1-Rapport-dAvancement.pdf",
    }

    if dataset_name in datasets:
        data_path = f"{base_path}/{datasets[dataset_name]}"
        if generate_questions:
            generate_n_questions_from_dir(dataset_name, data_path, 200, model_name)
        else:
            ds = get_shift_dataset(
                f"data/shift_dataset/qa_pairs/qa_pairs_{dataset_name}.txt", max_len_dataset=MAX_LEN_DATASET_SHIFT
            )
            ds.push_to_hub(f"coldoc/temp_shiftproject_{dataset_name}_test")

    elif dataset_name == "concat":
        concat_datasets()


if __name__ == "__main__":
    typer.run(main)
