from pathlib import Path
from typing import Annotated, cast

import torch
import typer
from datasets import Dataset, load_dataset
from dotenv import load_dotenv
from safetensors.torch import save_file
from tqdm import trange
import time
import json

from custom_colbert.interpretability.processor import ColPaliProcessor
from custom_colbert.models.paligemma_colbert_architecture import ColPali
from custom_colbert.utils.torch_utils import get_torch_device

load_dotenv(override=True)
device = get_torch_device()


OUTDIR_MEASURE_LATENCY = Path("outputs/measure_latency/")


def main(
    n_iter: Annotated[int, typer.Option(help="Number of iterations")],
    dataset: Annotated[str, typer.Option(help="Dataset name")] = "coldoc/shiftproject_test",
    batch_size: Annotated[int, typer.Option(help="Batch size")] = 4,
):
    model_path = "google/paligemma-3b-mix-448"
    lora_path = "coldoc/paligemma-3b-mix-448"

    # Load the model and LORA adapter
    model = cast(ColPali, ColPali.from_pretrained(model_path, device_map=device))

    # Load the Lora adapter into the model
    # Note:`add_adapter` is used to create a new adapter while `load_adapter` is used to load an existing adapter
    model.load_adapter(lora_path, adapter_name="colpali", device_map=device)
    if model.active_adapters() != ["colpali"]:
        raise ValueError(f"Incorrect adapters loaded: {model.active_adapters()}")
    print(f"Loaded model from {model_path} and LORA from {lora_path}")

    # Load the processor
    processor = ColPaliProcessor.from_pretrained(model_path)
    print("Loaded custom processor")

    # Load dataset
    ds = cast(Dataset, load_dataset(dataset, split="test"))
    print("Dataset loaded")

    image_encoding_times =[]
    vector_store_times = []

    for idx in trange(n_iter):
        ds_ = ds.take(batch_size)
        # Preprocess the inputs
        # NOTE: the text is not used, but needed for the processor to return a batched output
        input_processed_images = processor.processor(
            images=ds_["image"], text=ds_["query"], return_tensors="pt", padding="longest"
        ).to(device)

        input_processed_images["input_ids"] = input_processed_images["input_ids"][
            :, : processor.processor.image_seq_length
        ]
        input_processed_images["pixel_values"] = input_processed_images["pixel_values"][
            :, : processor.processor.image_seq_length
        ]
        input_processed_images["attention_mask"] = input_processed_images["attention_mask"][
            :, : processor.processor.image_seq_length
        ]

        if not (
            input_processed_images["input_ids"].shape[0]
            == input_processed_images["pixel_values"].shape[0]
            == input_processed_images["attention_mask"].shape[0]
            == batch_size
        ):
            raise ValueError("Batch size mismatch")

        # Forward pass
        start = time.time()
        with torch.no_grad():
            output_images = model.forward(**input_processed_images)  # (batch_size, n_patch_x * n_patch_y, hidden_dim)

        image_encoding_times.append(time.time() - start)
        # Save the embeddings as safetensors

        start = time.time()
        savepath = OUTDIR_MEASURE_LATENCY / f"doc_embedding_colpali_{idx}.pt"
        savepath.parent.mkdir(parents=True, exist_ok=True)
        save_file({"output_images": output_images}, filename=savepath)
        print(f"Embeddings saved to `{savepath}`")
        vector_store_times.append(time.time() - start)

        ds = ds.skip(batch_size)

    times = {
        "image_encoding_times": image_encoding_times,
        "vector_store_times": vector_store_times,
    }

    with open(OUTDIR_MEASURE_LATENCY / "times.json", "w") as f:
        json.dump(times, f)

    print(f'Average image encoding time: {sum(image_encoding_times)/(len(image_encoding_times)* batch_size)}')
    print(f'Average vector store time: {sum(vector_store_times)/(len(vector_store_times) * batch_size)}')
if __name__ == "__main__":
    typer.run(main)
