import datetime
import os
import time
import uuid
from dataclasses import dataclass
from io import BytesIO
from pathlib import Path
from typing import Annotated, Optional, cast

import litellm
import typer
from dataset_structure import BaseChunk, CaptionChunk
from datasets import Dataset, DatasetDict, Value, load_dataset
from datasets import Image as DatasetImage
from PIL import Image
from tqdm import tqdm
from unstructured.partition.image import partition_image

from custom_colbert.utils.image_utils import get_base64_image

LITELLM_NUM_RETRIES = 3

litellm.vertex_project = os.getenv("VERTEX_PROJECT")
litellm.vertex_location = os.getenv("VERTEX_LOCATION")


@dataclass
class CompletionDataUnstructured:
    answer: str
    cost: float


def shorten_image_path(image_path: str) -> str:
    """
    Shorten the image path to make it more readable.
    """
    full_path = Path(image_path)
    enclosing_dir = full_path.parent.parent
    return full_path.relative_to(enclosing_dir).as_posix()


GENERATE_CAPTION = """ 
        You are an assistant specialized in document analysis. Given a table or a figure,
        you have to provide a detailed summary of the content in maximum 3000 characters. Your summary should 
        be qualitative and not quantitative.
        Here is the table/figure to analyze: {image}. Answer ONLY with the caption of the table/figure.
"""


def generate_caption(
    image_path: str, model: str, prompt: str, seed: Optional[int] = None, max_tokens: Optional[int] = 1048
):
    """
    Generate a caption for an image using a model and a prompt.
    """
    print(f"Generating caption for {image_path}")

    image = Image.open(image_path)

    # ------------- use litelmm to generate the caption -------------
    response = litellm.completion(
        model=model,
        max_tokens=max_tokens,
        messages=[
            {
                "role": "user",
                "content": [
                    {
                        "type": "text",
                        "text": prompt,
                    },
                    {
                        "type": "image_url",
                        "image_url": {"url": get_base64_image(image, add_url_prefix=True)},
                    },
                ],
            }
        ],
        stream=False,
        seed=seed,
    )

    # Get answers
    answer = response["choices"][0]["message"]["content"]

    # Get cost
    if "llava" in model:
        cost = 0.0  # no cost
    else:
        cost = litellm.completion_cost(response)

    return answer, cost


def generate_chunk_dataset(dataset: Dataset, model: str, prompt: str, output_dir: str):
    print(f"Generating captions for {len(dataset)} images")

    timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")
    start = time.time()
    id_list = []
    costs_list = []

    count_tables = 0
    count_figures = 0
    count_text_chunks = 0

    # Create a new dataset with the same features as the original dataset
    new_dataset = {key: [] for key in dataset.features.keys()}

    # Add new features to the dataset
    new_dataset.update({"text_description": [], "chunk_id": [], "chunk_type": [], "chunk_image": []})

    for row in tqdm(dataset, desc="Iterating over dataset rows"):
        # Sanitize the image filename
        image_path_or_url = row["image_filename"]  # FIXME: poor column naming -> it can be either a path OR a url

        # NOTE: We use uuid to avoid having filenames that are too long. This is acceptable since we will create
        # the new dataset from `row["image_filename"]`.
        image_dir = Path("data/unstructured") / output_dir / str(uuid.uuid4())

        # Save the image to a BytesIO object to be able to pass it to the partition_image function
        image = row["image"]
        image_bytes_io = BytesIO()
        image.save(image_bytes_io, format="PNG")
        image_bytes_io.seek(0)

        chunks = partition_image(
            file=cast(bytes, image_bytes_io),
            strategy="hi_res",
            hi_res_model_name="yolox",
            infer_table_structure=True,
            extract_images_in_pdf=True,
            extract_image_block_types=["Table", "Image"],
            extract_image_block_output_dir=image_dir.as_posix(),
            chunking_strategy="by_title",
            max_characters=4000,
            new_after_n_chars=3800,
        )

        # Due to chunking strategy, there are only tables and text chunks
        for chunk in tqdm(chunks, desc="Iterating over chunks"):
            if "Table" in str(type(chunk)):
                if chunk.metadata.detection_class_prob is None:
                    raise ValueError("Detection class probability is None")

                if chunk.metadata.detection_class_prob > 0.5:
                    try:
                        caption, cost = generate_caption(cast(str, chunk.metadata.image_path), model, prompt)
                        base_chunk = CaptionChunk(
                            filename=image_path_or_url,
                            type="table",
                            page=chunk.metadata.page_number,
                            timestamp=timestamp,
                            bbox=chunk.metadata.coordinates.points if chunk.metadata.coordinates is not None else None,
                            imagepath=chunk.metadata.image_path,
                            caption=caption,
                            model=model,
                            prompt=prompt,
                        ).model_dump()
                        count_tables += 1
                        costs_list.append(cost)
                    except Exception as e:
                        print(f"Error generating caption for {chunk.metadata.image_path}: {e}. Skipping.")
                        continue
                else:
                    continue

            elif "CompositeElement" in str(type(chunk)):  # text-only chunk
                base_chunk = BaseChunk(
                    filename=image_path_or_url,
                    type="text",
                    page=chunk.metadata.page_number,
                    timestamp=timestamp,
                    text=chunk.text,
                ).model_dump()
                count_text_chunks += 1

            else:
                # NOTE: This should not happen because of `extract_image_block_types=["Table", "Image"]`
                # but it does with `unstructured==0.13.2`
                print(f"Skipping chunk of type `{type(chunk)}`")
                continue

            id_list.append(str(uuid.uuid4()))

            #  ============== Update the new dataset ==============
            new_dataset["chunk_id"].append(id_list[-1])
            new_dataset["chunk_type"].append(base_chunk["type"])

            if base_chunk["type"] == "text":
                new_dataset["text_description"].append(base_chunk["text"])
                new_dataset["chunk_image"].append(None)  # None for text chunks
            elif base_chunk["type"] == "table":
                new_dataset["text_description"].append(base_chunk["caption"])
                new_dataset["chunk_image"].append(chunk.metadata.image_path)
            else:
                raise ValueError(f"Unknown chunk type: {base_chunk['type']}")

            # Add the rest of the columns
            for key in row.keys():
                if key == "prompt":
                    new_dataset[key].append(prompt)
                elif key == "model":
                    new_dataset[key].append(model)
                else:
                    new_dataset[key].append(row[key])
            # ======================================================

        # for images we need to get the extracted images from unstructured.partition.image
        figures = [str(f) for f in Path(image_dir).rglob("*.jpg") if "figure" in f.name]
        for figure in figures:
            try:
                caption, cost = generate_caption(figure, model, prompt)

                base_chunk = CaptionChunk(
                    filename=image_path_or_url,
                    type="figure",
                    page=0,
                    timestamp=timestamp,
                    bbox=None,
                    imagepath=figure,
                    caption=caption,
                    model=model,
                    prompt=prompt,
                ).model_dump()

                id_list.append(str(uuid.uuid4()))
                count_figures += 1
                costs_list.append(cost)

                #  ============== Update the new dataset ==============
                new_dataset["chunk_id"].append(id_list[-1])
                new_dataset["chunk_type"].append(base_chunk["type"])
                new_dataset["text_description"].append(base_chunk["caption"])
                new_dataset["chunk_image"].append(figure)
                # ======================================================

                for key in row.keys():
                    if key == "prompt":
                        new_dataset[key].append(prompt)
                    elif key == "model":
                        new_dataset[key].append(model)
                    else:
                        new_dataset[key].append(row[key])
            except Exception as e:
                print(f"Error generating caption for {figure}: {e}. Skipping.")
                continue

        # TODO: Remove after debug
        assert all(
            len(new_dataset[key]) == len(new_dataset["chunk_id"]) for key in new_dataset.keys()
        ), "Length mismatch: " + str({key: len(new_dataset[key]) for key in new_dataset.keys()})

    if costs_list:
        print(
            f"The generated dataset contains {count_tables} tables, {count_figures} figures, and {count_text_chunks} text chunks."
        )
        print(f"Average cost per request: {sum(costs_list) / len(costs_list):.3f}$")
        print(f"Total cost: {sum(costs_list):.3f}$")
        print(
            f"Time taken: {time.time() - start:.2f}s for {len(new_dataset['query'])} chunks.\n \
            Average time per chunk: {(time.time() - start) / len(new_dataset['query']):.2f}s"
        )

    # Convert paths to 'Image' feature type
    features = dataset.features.copy()
    # add new features
    features.update(
        {
            "text_description": Value("string"),
            "chunk_id": Value("string"),
            "chunk_type": Value("string"),
            "chunk_image": DatasetImage(),
        }
    )

    new_dataset = DatasetDict({"test": Dataset.from_dict(new_dataset, features=features)})

    print(f"Saving the dataset to huggingface in the coldoc/baseline_{output_dir} dataset")
    new_dataset["test"].push_to_hub(f"coldoc/baseline_cap_{output_dir}", split="test", private=True)


def main(dataset_name: Annotated[str, typer.Argument(help="The name of the dataset to caption")]):
    ds = cast(DatasetDict, load_dataset(f"{dataset_name}"))

    generate_chunk_dataset(
        ds["test"],
        "vertex_ai/claude-3-sonnet@20240229",
        """ You are an assistant specialized in document analysis. Given a table or a figure,
                            you have to provide a detailed summary of the content in maximum 3000 characters. Your summary should be qualitative and not quantitative.
                            Here is the table/figure to analyze: {image}. Answer ONLY with the caption of the table/figure.""",
        output_dir=dataset_name.split("/")[-1],
    )

    print("Done!")


if __name__ == "__main__":
    typer.run(main)
