
import datetime
import os
import time
import uuid
from dataclasses import dataclass
from io import BytesIO
from pathlib import Path
from typing import Annotated, Optional, cast
import re
import io
import contextlib
import time
import json

import litellm
import typer
from dataset_structure import BaseChunk, CaptionChunk
from datasets import Dataset, DatasetDict, Value, load_dataset
from datasets import Image as DatasetImage
from PIL import Image
from tqdm import tqdm
from typing import Dict
from unstructured.partition.image import partition_image

from custom_colbert.utils.image_utils import get_base64_image

LITELLM_NUM_RETRIES = 3

litellm.vertex_project = os.getenv("VERTEX_PROJECT")
litellm.vertex_location = os.getenv("VERTEX_LOCATION")


@dataclass
class CompletionDataUnstructured:
    answer: str
    cost: float


def shorten_image_path(image_path: str) -> str:
    """
    Shorten the image path to make it more readable.
    """
    full_path = Path(image_path)
    enclosing_dir = full_path.parent.parent
    return full_path.relative_to(enclosing_dir).as_posix()


GENERATE_CAPTION = """ 
        You are an assistant specialized in document analysis. Given a table or a figure,
        you have to provide a detailed summary of the content in maximum 3000 characters. Your summary should 
        be qualitative and not quantitative.
        Here is the table/figure to analyze: {image}. Answer ONLY with the caption of the table/figure.
"""

def get_elapsed_time_from_captured_output(captured_output: str) -> Dict[str, float]:
    # Define the regex pattern to match the task names and their elapsed times
    pattern = r"(?P<task>.*): Elapsed time: (?P<time>\d+\.\d+) seconds"

    # Find all matches in the input string
    matches = re.findall(pattern, captured_output)

    # Create a dictionary from the matches
    elapsed_time = {cast(str, task).strip(): float(time) for task, time in matches}

    return elapsed_time


def generate_caption(
    image_path: str, model: str, prompt: str, seed: Optional[int] = None, max_tokens: Optional[int] = 1048
):
    """
    Generate a caption for an image using a model and a prompt.
    """
    print(f"Generating caption for {image_path}")

    image = Image.open(image_path)

    # ------------- use litelmm to generate the caption -------------
    response = litellm.completion(
        model=model,
        max_tokens=max_tokens,
        messages=[
            {
                "role": "user",
                "content": [
                    {
                        "type": "text",
                        "text": prompt,
                    },
                    {
                        "type": "image_url",
                        "image_url": {"url": get_base64_image(image, add_url_prefix=True)},
                    },
                ],
            }
        ],
        stream=False,
        seed=seed,
    )

    # Get answers
    answer = response["choices"][0]["message"]["content"]

    # Get cost
    if "llava" in model:
        cost = 0.0  # no cost
    else:
        cost = litellm.completion_cost(response)

    return answer, cost


def generate_chunk_dataset(dataset: Dataset, model: str, prompt: str, output_dir: str):
    print(f"Generating captions for {len(dataset)} images")

    timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")
    start = time.time()
    id_list = []
    costs_list = []

    count_tables = 0
    count_figures = 0
    count_text_chunks = 0

    # Create a new dataset with the same features as the original dataset
    new_dataset = {key: [] for key in dataset.features.keys()}

    # Add new features to the dataset
    new_dataset.update({"text_description": [], "chunk_id": [], "chunk_type": [], "chunk_image": []})

    times = {'Layout Detection' : [], 'OCR' : [], 'Saving chunks': [], 'Generating Captions': []}

    file_count = 0
    for row in tqdm(dataset, desc="Iterating over dataset rows"):
        file_count += 1
        # Sanitize the image filename
        image_path_or_url = row["image_filename"]  # FIXME: poor column naming -> it can be either a path OR a url

        # NOTE: We use uuid to avoid having filenames that are too long. This is acceptable since we will create
        # the new dataset from `row["image_filename"]`.
        image_dir = Path("data/unstructured") / output_dir / str(uuid.uuid4())

        # Save the image to a BytesIO object to be able to pass it to the partition_image function
        image = row["image"]
        image_bytes_io = BytesIO()
        image.save(image_bytes_io, format="PNG")
        image_bytes_io.seek(0)

        output_capture = io.StringIO()
        with contextlib.redirect_stdout(output_capture):
            chunks = partition_image(
                file=cast(bytes, image_bytes_io),
                strategy="hi_res",
                hi_res_model_name="yolox",
                infer_table_structure=True,
                extract_images_in_pdf=True,
                extract_image_block_types=["Table", "Image"],
                extract_image_block_output_dir=image_dir.as_posix(),
                chunking_strategy="by_title",
                max_characters=4000,
                new_after_n_chars=3800,
            )

        captured_output = output_capture.getvalue()
        elapsed_time = get_elapsed_time_from_captured_output(captured_output)

        times['Layout Detection'].append(elapsed_time['Layout Detection'])
        times['OCR'].append(elapsed_time['OCR'])
        times['Saving chunks'].append(elapsed_time['Saving chunks'])

        # Due to chunking strategy, there are only tables and text chunks
        for chunk in tqdm(chunks, desc="Iterating over chunks"):
            if "Table" in str(type(chunk)):
                if chunk.metadata.detection_class_prob is None:
                    raise ValueError("Detection class probability is None")

                if chunk.metadata.detection_class_prob > 0.5:
                    try:
                        start = time.time()
                        caption, cost = generate_caption(cast(str, chunk.metadata.image_path), model, prompt)
                        times['Generating Captions'].append(time.time() - start)
                        base_chunk = CaptionChunk(
                            filename=image_path_or_url,
                            type="table",
                            page=chunk.metadata.page_number,
                            timestamp=timestamp,
                            bbox=chunk.metadata.coordinates.points if chunk.metadata.coordinates is not None else None,
                            imagepath=chunk.metadata.image_path,
                            caption=caption,
                            model=model,
                            prompt=prompt,
                        ).model_dump()
                        count_tables += 1
                        costs_list.append(cost)
                    except Exception as e:
                        print(f"Error generating caption for {chunk.metadata.image_path}: {e}. Skipping.")
                        continue
                else:
                    continue

            elif "CompositeElement" in str(type(chunk)):  # text-only chunk
                base_chunk = BaseChunk(
                    filename=image_path_or_url,
                    type="text",
                    page=chunk.metadata.page_number,
                    timestamp=timestamp,
                    text=chunk.text,
                ).model_dump()
                count_text_chunks += 1

            else:
                # NOTE: This should not happen because of `extract_image_block_types=["Table", "Image"]`
                # but it does with `unstructured==0.13.2`
                print(f"Skipping chunk of type `{type(chunk)}`")
                continue

            id_list.append(str(uuid.uuid4()))

            #  ============== Update the new dataset ==============
            new_dataset["chunk_id"].append(id_list[-1])
            new_dataset["chunk_type"].append(base_chunk["type"])

            if base_chunk["type"] == "text":
                new_dataset["text_description"].append(base_chunk["text"])
                new_dataset["chunk_image"].append(None)  # None for text chunks
            elif base_chunk["type"] == "table":
                new_dataset["text_description"].append(base_chunk["caption"])
                new_dataset["chunk_image"].append(chunk.metadata.image_path)
            else:
                raise ValueError(f"Unknown chunk type: {base_chunk['type']}")

            # Add the rest of the columns
            for key in row.keys():
                if key == "prompt":
                    new_dataset[key].append(prompt)
                elif key == "model":
                    new_dataset[key].append(model)
                else:
                    new_dataset[key].append(row[key])
            # ======================================================

        # for images we need to get the extracted images from unstructured.partition.image
        figures = [str(f) for f in Path(image_dir).rglob("*.jpg") if "figure" in f.name]
        for figure in figures:
            try:
                start = time.time()
                caption, cost = generate_caption(figure, model, prompt)
                times['Generating Captions'].append(time.time() - start)

                base_chunk = CaptionChunk(
                    filename=image_path_or_url,
                    type="figure",
                    page=0,
                    timestamp=timestamp,
                    bbox=None,
                    imagepath=figure,
                    caption=caption,
                    model=model,
                    prompt=prompt,
                ).model_dump()

                id_list.append(str(uuid.uuid4()))
                count_figures += 1
                costs_list.append(cost)

                #  ============== Update the new dataset ==============
                new_dataset["chunk_id"].append(id_list[-1])
                new_dataset["chunk_type"].append(base_chunk["type"])
                new_dataset["text_description"].append(base_chunk["caption"])
                new_dataset["chunk_image"].append(figure)
                # ======================================================

                for key in row.keys():
                    if key == "prompt":
                        new_dataset[key].append(prompt)
                    elif key == "model":
                        new_dataset[key].append(model)
                    else:
                        new_dataset[key].append(row[key])
            except Exception as e:
                print(f"Error generating caption for {figure}: {e}. Skipping.")
                continue
        
        # TODO: Remove after debug
        # assert all(
        #     len(new_dataset[key]) == len(new_dataset["chunk_id"]) for key in new_dataset.keys()
        # ), "Length mismatch: " + str({key: len(new_dataset[key]) for key in new_dataset.keys()})

        if file_count == 100:
            break
    #save the times as a json file
    with open(f"data/evaluation_results/speed_{output_dir}_times.json", "w") as f:
        json.dump(times, f)

    if costs_list:
        print(
            f"The generated dataset contains {count_tables} tables, {count_figures} figures, and {count_text_chunks} text chunks."
        )
        print(f"Average cost per request: {sum(costs_list) / len(costs_list):.3f}$")
        print(f"Total cost: {sum(costs_list):.3f}$")
        print(
            f"Time taken: {time.time() - start:.2f}s for {len(new_dataset['query'])} chunks.\n \
            Average time per chunk: {(time.time() - start) / len(new_dataset['query']):.2f}s"
        )

def main(dataset_name: Annotated[str, typer.Argument(help="The name of the dataset to caption")]):
    ds = cast(DatasetDict, load_dataset(f"{dataset_name}"))

    generate_chunk_dataset(
        ds["test"],
        "vertex_ai/claude-3-sonnet@20240229",
        """ You are an assistant specialized in document analysis. Given a table or a figure,
                            you have to provide a detailed summary of the content in maximum 3000 characters. Your summary should be qualitative and not quantitative.
                            Here is the table/figure to analyze: {image}. Answer ONLY with the caption of the table/figure.""",
        output_dir=dataset_name.split("/")[-1],
    )

    print("Done!")


if __name__ == "__main__":
    typer.run(main)
