import argparse
import glob
import json
import os
import random
import time
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Tuple, cast

import litellm
from datasets import Dataset, DatasetDict, Features, Image, IterableDataset, Value
from dotenv import load_dotenv
from litellm.caching import Cache
from model_names import LLMS_NAME, VERTEX
from PIL import Image as PILImage
from tqdm import tqdm

from custom_colbert.utils.image_utils import get_base64_image
from custom_colbert.utils.pdf_utils import convert_all_pdfs_to_images
from scripts.generate_data.custom_llm_prices import CUSTOM_COSTS_PER_TOKEN
from scripts.generate_data.prompts import PROMPT_QUESTION_ANSWER

load_dotenv(override=True)

random.seed(42)

LITELLM_NUM_RETRIES = 3

LITELLM_CACHEDIR = Path("data/litellm_cache/")
LITELLM_CACHEDIR.mkdir(parents=True, exist_ok=True)
cache = Cache(type="disk", disk_cache_dir=LITELLM_CACHEDIR)

litellm.vertex_project = os.getenv("VERTEX_PROJECT")
litellm.vertex_location = os.getenv("VERTEX_LOCATION")

if litellm.vertex_project is None or litellm.vertex_location is None:
    raise ValueError("Please set the VERTEX_PROJECT and VERTEX_LOCATION environment variables.")

print(f"Vertex project: {litellm.vertex_project}")
print(f"Vertex location: {litellm.vertex_location}")

litellm.drop_params = True

# NOTE: Uncomment the line below to see the logs of the LLM calls.
# litellm.set_verbose = True


@dataclass
class CompletionData:
    answer: str
    cost: float


def shorten_image_path(image_path: str) -> str:
    """
    Shorten the image path to make it more readable.
    """
    full_path = Path(image_path)
    enclosing_dir = full_path.parent.parent
    return full_path.relative_to(enclosing_dir).as_posix()


def generate_qa_from_page(
    image_path: str, model: str, max_tokens: int, prompt: str = PROMPT_QUESTION_ANSWER, seed: int = 42
) -> Tuple[str, float]:
    """
    Generate questions and answers from an image.
    """
    # If the image is already in the cache, return the answer and the cost.
    # NOTE: the cache key is the image path to avoid having the entire base64 image in the cache key.
    cached_data = cast(dict | None, cache.get_cache(cache_key=image_path))

    if cached_data is not None:
        print(f"\nLoading cached completion for {shorten_image_path(image_path)}...")
        try:
            cached_completion_data = CompletionData(**cached_data)
        except Exception as e:
            raise ValueError(f"Error loading cached completion data: {e}")
        return cached_completion_data.answer, cached_completion_data.cost

    # Otherwise, generate the question and answer from the image:
    image = PILImage.open(image_path)

    if model in CUSTOM_COSTS_PER_TOKEN:
        input_cost_per_token = CUSTOM_COSTS_PER_TOKEN[model]["input_cost_per_token"]
        output_cost_per_token = CUSTOM_COSTS_PER_TOKEN[model]["output_cost_per_token"]
    else:
        input_cost_per_token = None
        output_cost_per_token = None

    response = litellm.completion(
        model=model,
        max_tokens=max_tokens,
        messages=[
            {
                "role": "user",
                "content": [
                    {
                        "type": "text",
                        "text": prompt,
                    },
                    {
                        "type": "image_url",
                        "image_url": {"url": get_base64_image(image, add_url_prefix=True)},
                    },
                ],
            }
        ],
        input_cost_per_token=input_cost_per_token,
        output_cost_per_token=output_cost_per_token,
        stream=False,
        num_retries=LITELLM_NUM_RETRIES,
        seed=seed,
    )

    answer = cast(str, response["choices"][0]["message"]["content"])

    cost = litellm.completion_cost(
        response,
        custom_cost_per_token=CUSTOM_COSTS_PER_TOKEN[model] if model in CUSTOM_COSTS_PER_TOKEN else None,
    )

    # Cache the answer and the cost.
    print("\nCaching completion...\n")
    current_completion = CompletionData(answer=answer, cost=cost)
    # NOTE: `cache_data` expects a JSON serializable object. However, the cached_data is unserialized when loaded.
    cache.add_cache(cache_key=image_path, result=json.dumps(asdict(current_completion)))

    return answer, cost


def generate_dataset_from_img_folder(
    path_to_folder: str, n_samples: int = 1, model_name: str = "haiku", vertex_ai: bool = True
) -> IterableDataset:
    """
    Generate questions and answers from a folder containing pdf files and by themes.

    Args:
    - path_to_folder (str): path to the folder containing the pdf files
    - n_samples (int): number of images to sample in each subfolder

    Returns:
    - ds (DatasetDict): a dataset containing the questions and answers generated from the pdf files

    structure of the dataset:
    - query (str): the question generated from the image
    - image (PIL.Image): the image
    - image_filename (str): the path to the image
    - answer (str): the answer to the question
    - page (str): the page number
    - model (str): the model used to generate the question
    - prompt (str): the prompt used to generate the question
    - source (str): the source of the image

    """
    img_files = glob.glob(os.path.join(path_to_folder, "**/*.jpg"))
    print(f"Number of images found: {len(img_files)}")

    export_filepath = (Path("data") / "sampled_files" / Path(path_to_folder).name).with_suffix(".txt")

    # NOTE: Uncomment to sample n_samples images and to export the list of sampled files to a text file
    sampled_files = random.sample(img_files, n_samples)
    export_filepath = (Path("data") / "sampled_files" / Path(path_to_folder).name).with_suffix(".txt")
    export_filepath.parent.mkdir(parents=True, exist_ok=True)
    with open(export_filepath, "w") as f:
        for item in sampled_files:
            f.write("%s\n" % item)

    # NOTE: Uncomment to load the list of sampled files from a text file
    # with open(export_filepath, "r") as f:
    #     sampled_files = f.read().splitlines()

    print(f"Number of images sampled: {len(sampled_files)}")

    costs_list = []

    if vertex_ai:
        model = VERTEX[model_name]
    else:
        model = LLMS_NAME[model_name]

    # Create a Dataset from the dictionary
    features = Features(
        {
            "query": Value("string"),
            "image": Image(),
            "image_filename": Value("string"),
            "answer": Value("string"),
            "page": Value("string"),
            "model": Value("string"),
            "prompt": Value("string"),
            "source": Value("string"),
        }
    )

    # Create an empty text file that will contain the file paths of the corrupted pdf files
    dirpath_corrupted = Path("data/failed_llm_calls.txt")
    dirpath_corrupted.parent.mkdir(parents=True, exist_ok=True)

    def gen():
        with dirpath_corrupted.open("w") as f_corrupted:
            with tqdm(total=len(sampled_files)) as pbar:
                for image_path in sampled_files:
                    pbar.set_description(f"Processing {shorten_image_path(image_path)}")
                    answer, cost = generate_qa_from_page(image_path, model=model, max_tokens=1000)
                    costs_list.append(cost)

                    try:
                        answer = json.loads(answer)

                        for qa in answer["questions"]:
                            pil_image = PILImage.open(image_path)

                            yield {
                                "query": qa["question"],
                                "image": pil_image,
                                "image_filename": image_path,
                                "answer": qa["answer"],
                                "page": os.path.basename(image_path).split(".")[0].split("_")[-1],
                                "model": model_name,
                                "prompt": PROMPT_QUESTION_ANSWER,
                                "source": "pdf",
                            }

                    except Exception as e:
                        print(f"Error processing {image_path}: {e}")
                        f_corrupted.write(image_path)
                        f_corrupted.write("\n")

                    pbar.update(1)

    # Create the dataset from the generator
    ds = cast(IterableDataset, Dataset.from_generator(gen, features=features))

    if len(costs_list) > 0:
        print(f"Average cost: {sum(costs_list) / len(costs_list)}")
        print(f"Total cost: {sum(costs_list)}")

    return ds


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Generate questions and answers from pdf images.")
    parser.add_argument(
        "--pdf_folder",
        type=str,
        help="Path to the folder containing the pdf files.",
        default="data/test/pages_extracted/healthcare_industry_test",
    )
    parser.add_argument("--convert_to_images", action="store_true", help="Convert the pdf files to images.")
    parser.add_argument("--n_samples", type=int, help="Number of pdf files to sample in each subfolder.", default=30)
    parser.add_argument(
        "--model_name",
        type=str,
        help="Name of the model to use for generating questions and answers.",
        default="sonnet",
    )
    parser.add_argument("--split_name", type=str, help="Name of the split to save the dataset.")
    parser.add_argument("--hub_dataset_name", type=str, help="Name of the dataset to push to the hub.")
    parser.add_argument("--vertex_ai", action="store_true", help="Use Vertex AI to generate questions and answers.")

    args = parser.parse_args()

    start = time.time()

    if args.convert_to_images:
        print("Converting pdf files to images...")
        convert_all_pdfs_to_images(args.pdf_folder, n_samples=args.n_samples)
    else:
        print("Generating questions and answers from pdf images...")
        ds = generate_dataset_from_img_folder(
            args.pdf_folder, n_samples=args.n_samples, model_name=args.model_name, vertex_ai=args.vertex_ai
        )

        ds_dict = DatasetDict({args.split_name: ds})
        ds_dict.push_to_hub(args.hub_dataset_name, private=True)

    print(f"Time taken: {time.time() - start}")
