import json
import time
from collections import defaultdict
from pathlib import Path
from typing import Annotated, Any, Dict, List, Tuple, cast
from dataclasses import asdict
import timeit

import huggingface_hub
import torch
import torch.nn.functional as F
import typer
from datasets import Dataset, load_dataset
from dotenv import load_dotenv
from FlagEmbedding import BGEM3FlagModel
from safetensors.torch import save_file
from torch import Tensor
from tqdm import tqdm
from transformers import AutoImageProcessor, AutoModel, AutoTokenizer, AutoProcessor

from custom_colbert.trainer.retrieval_evaluator import CustomEvaluator
from custom_colbert.models.paligemma_colbert_architecture import ColPali
from custom_colbert.interpretability.processor import ColPaliProcessor
from custom_colbert.interpretability.vit_configs import VIT_CONFIG
from peft import LoraConfig, PeftConfig

from custom_colbert.utils.iter_utils import batched
from custom_colbert.utils.torch_utils import get_torch_device
from PIL import Image   

BATCH_SIZE = 4
OUTDIR_MEASURE_LATENCY = Path("outputs/measure_latency/")

DEVICE = get_torch_device()
print(f"Using DEVICE {DEVICE}")

OUTPUT_MEASURE_EMB_SIZE = Path("outputs/measure_emb_size/")
DEBUG_SAVE_TENSOR = False
if DEBUG_SAVE_TENSOR:
    print("DEBUG_SAVE_TENSOR is set to `True`: the first embeddings will be saved on disk and the script will stop.")

load_dotenv(override=True)


def average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
    last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
    return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]


def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0]
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)


# Define function to compute query embeddings
def compute_query_embeddings(queries: List[str], model_name: str, model, tokenizer, processor, text_model, DEVICE: str):
    """
    text_model just available for nomic
    
    """
    if "e5" in model_name:
        query_texts = ["query: " + query for query in queries]
        qs = compute_embeddings_e5(query_texts, model, tokenizer)

    elif "bge" in model_name:
        qs = model.encode(queries, batch_size=1, max_length=512)["dense_vecs"]

    elif "jina" in model_name:
        qs = model.encode_text(queries, batch_size=1)

    elif "nomic" in model_name:
        query_texts = ["search_query: " + query for query in queries]
        encoded_input = tokenizer(queries, padding=True, truncation=True, return_tensors="pt").to(DEVICE)
        with torch.no_grad():
            qs = text_model(**encoded_input)
        qs = mean_pooling(qs, encoded_input["attention_mask"])
        qs = F.layer_norm(qs, normalized_shape=(qs.shape[1],))
        qs = F.normalize(qs, p=2, dim=1)
    
    elif "colpali" in model_name:
        #####TODO : embed query with colpali 
        processed_queries =  processor.process_text(queries).to(DEVICE)
        with torch.no_grad():
            qs = model.forward(**asdict(processed_queries))  # type: ignore

    elif "siglip" in model_name:
        ####TODO : embed query with siglip
        inputs_queries = processor(text=queries, return_tensors="pt", padding="max_length", truncation = True).to(DEVICE)
        qs = model.get_text_features(**inputs_queries)

    else:
        raise ValueError(f"Model {model_name} not supported")

    return torch.tensor(qs).to(DEVICE)


# Define function to compute passage embeddings
def compute_passage_embeddings(batch_passages: List[str|Image.Image], model_name, model, tokenizer, processor, vision_model, DEVICE):
    if "e5" in model_name:
        passage_texts = ["passage: " + passage for passage in batch_passages]
        ps = compute_embeddings_e5(passage_texts, model, tokenizer)

    elif "bge" in model_name:
        ps = model.encode(batch_passages, batch_size=BATCH_SIZE, max_length=512)["dense_vecs"]

    elif "jina" in model_name:
        ps = model.encode_image(batch_passages, batch_size=BATCH_SIZE)

    elif "nomic" in model_name:
        vision_inputs = processor(batch_passages, return_tensors="pt", padding=True, truncation=True).to(DEVICE)
        with torch.no_grad():
            ps = vision_model(**vision_inputs).last_hidden_state
            ps = F.normalize(ps[:, 0], p=2, dim=1)
    elif "colpali" in model_name:
        # ensure batch_passages is a list of Image.Image
        batch_passages = list(batch_passages)
        input_image_processed = processor.process_image(batch_passages, add_special_prompt=True).to(DEVICE)
        with torch.no_grad():
            ps = model.forward(**asdict(input_image_processed))  # type: ignore


    elif "siglip" in model_name:
        input_image_processed = processor(images = batch_passages, return_tensors="pt", padding = True).to(DEVICE)
        with torch.no_grad():
            ps = model.get_image_features(**input_image_processed)
    else:
        raise ValueError(f"Model {model_name} not supported")

    if DEBUG_SAVE_TENSOR:
        savedir = OUTPUT_MEASURE_EMB_SIZE / model_name.replace("/", "_")
        savedir.mkdir(parents=True, exist_ok=True)
        savepath = savedir / f"doc_embedding_{0}.pt"
        ps_torch = torch.tensor(ps).to(torch.float32)
        torch.save(ps_torch, savepath)
        print(f"Document embedding saved to `{savepath}`")
        exit()

    return torch.tensor(ps).to(DEVICE)


def compute_embeddings_e5(inputs: List[str], model, tokenizer, batch_size=BATCH_SIZE) -> Tensor:
    embeddings = []
    model.to(DEVICE)
    model.eval()

    for i in range(0, len(inputs), batch_size):
        batch_texts = inputs[i : i + batch_size]
        batch_dict = tokenizer(batch_texts, max_length=512, padding=True, truncation=True, return_tensors="pt")
        batch_dict = {k: v.to(DEVICE) for k, v in batch_dict.items()}

        with torch.no_grad():
            outputs = model(**batch_dict)
            batch_embeddings = average_pool(outputs.last_hidden_state, batch_dict["attention_mask"])
            batch_embeddings = F.normalize(batch_embeddings, p=2, dim=1)
            embeddings.append(batch_embeddings)

    return torch.cat(embeddings, dim=0)


def log_metrics(metrics, dataset_name, log_file="metrics.json"):
    if Path(log_file).exists():
        loaded_metrics = json.load(open(log_file))
        loaded_metrics[dataset_name] = metrics
        with open(log_file, "w") as f:
            json.dump(loaded_metrics, f)
    else:
        with open(log_file, "w") as f:
            json.dump({dataset_name: metrics}, f)


def evaluate_model(ds: Dataset, model_name: str, ocr_only: bool = False) -> Tuple[Dict[str, Any], Dict[str, Any]]:
    """
    Evaluate a model on a dataset
    """
    evaluator = CustomEvaluator(is_multi_vector=False)
    queries = list(set(ds["query"]))

    # remove None queries
    if None in queries:
        queries.remove(None)

    page_retrieving = True if ("jina" in model_name) or \
                            ("nomic" in model_name) or \
                            ("colpali" in model_name) or \
                            ("siglip" in model_name) else False

    if page_retrieving:
        passages = ds["image"]
        # hash the image to get the doc_id
        # passages2filename = {str(idx): image_filename for idx, image_filename in enumerate(ds["image_filename"])}

    elif ocr_only:
        text_descriptions = defaultdict(list)

        for record in tqdm(ds, desc="Filtering OCR chunks"):
            if record["chunk_type"] == "text":
                text_descriptions[record["image_filename"]].append(record["text_description"])

        # concatenate all text chunks for each image

        passages = [" ".join(descriptions) for descriptions in text_descriptions.values()]
        passages2filename = {
            passage: image_filename for passage, image_filename in zip(passages, text_descriptions.keys())
        }
    else:
        passages = ds["text_description"]
        passages2filename = {
            passage: image_filename for passage, image_filename in zip(ds["text_description"], ds["image_filename"])
        }

    print(f"Number of queries: {len(queries)}, Number of passages: {len(passages)}")

    list_ps = []

    # ========================   Load model   ========================
    model = None
    tokenizer = None
    processor = None
    vision_model = None
    text_model = None

    if "e5" in model_name:
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModel.from_pretrained(model_name).to(DEVICE)
    elif "bge" in model_name:
        model = BGEM3FlagModel(model_name, use_fp16=True)
    elif "jina" in model_name:
        model = AutoModel.from_pretrained(model_name, trust_remote_code=True).to(DEVICE)
    elif "nomic" in model_name:
        processor = AutoImageProcessor.from_pretrained("nomic-ai/nomic-embed-vision-v1.5")
        vision_model = AutoModel.from_pretrained("nomic-ai/nomic-embed-vision-v1.5", trust_remote_code=True).to(DEVICE)
        text_model = AutoModel.from_pretrained("nomic-ai/nomic-embed-text-v1.5", trust_remote_code=True).to(DEVICE)
        tokenizer = AutoTokenizer.from_pretrained("nomic-ai/nomic-embed-text-v1.5", trust_remote_code=True)
    elif "colpali" in model_name:
        #####TODO
        #raise NotImplementedError("Colbert evaluation not implemented yet here")
        model_path = "google/paligemma-3b-mix-448"
        lora_path = "coldoc/paligemma-3b-mix-448"

        # Load the model and LORA adapter
        model = cast(ColPali, ColPali.from_pretrained(model_path, device_map=DEVICE))
        vit_config = VIT_CONFIG[model_path]

        peft_config = cast(LoraConfig, PeftConfig.from_pretrained(lora_path))

        # Load the Lora adapter into the model
        # Note:`add_adapter` is used to create a new adapter while `load_adapter` is used to load an existing adapter
        model.load_adapter(lora_path, adapter_name="colpali", device_map=DEVICE)
        if model.active_adapters() != ["colpali"]:
            raise ValueError(f"Incorrect adapters loaded: {model.active_adapters()}")
        print(f"Loaded model from {model_path} and LORA from {lora_path}")

        # Load the processor
        processor = ColPaliProcessor.from_pretrained(model_path)
        print("Loaded custom processor")

    elif "siglip" in model_name:
        ####TODO
        processor = AutoProcessor.from_pretrained(model_name)
        model = AutoModel.from_pretrained(model_name).to(DEVICE)
    else:
        raise ValueError(f"Model {model_name} not supported")



    # ========================   Compute embeddings   ========================
    print("Computing query embeddings")
    start = time.time()
    qs = compute_query_embeddings(queries, model_name, model, tokenizer, processor, text_model, DEVICE)
    query_time = time.time() - start

    print("Computing passage embeddings")

    start = time.time()
    count = 0
    for passage in tqdm(
        batched(passages, n=BATCH_SIZE), desc="Computing embeddings", total=len(passages) // BATCH_SIZE
    ):
        ps = compute_passage_embeddings(passage, model_name, model, tokenizer, processor, vision_model, DEVICE)
        list_ps.append(ps)

    passage_time = time.time() - start
    
    # Concatenate along the appropriate dimension
    ps = torch.cat(list_ps, dim=0)

    repeats = 1000
        ##############################  Temp code to comment out  ##############################
    # if "colpali" in model_name:
    #     index_lookup_time = timeit.timeit('torch.einsum("bnd,csd->bcns", qs, ps).max(dim=3)[0].sum(dim=2).argmax()', 
    #                 setup="import torch; qs = torch.randn(1, 15, 128).to('cuda'); ps = torch.randn(1000, 1024, 128).to('cuda')", number=repeats)
    #     print(f"Runtime Colbert (s): {index_lookup_time/repeats}")
        
    #     raise NotImplementedError("Colbert evaluation not implemented yet here")

    # else: 
    #     index_lookup_time = timeit.timeit('torch.einsum("bd,cd->bc", qs, ps).argmax()', 
    #                 setup="import torch; qs = torch.randn(1, 1024).to('cuda'); ps = torch.randn(1000, 1024).to('cuda')", number=repeats)

    #     print(f"Runtime Biencoder (s): {index_lookup_time/repeats}")

    #     times = {
    #     "query_encoding": query_time / len(queries),
    #     "passage_time": passage_time / len(passages),
    #     "index_lookup_time": index_lookup_time / repeats,
    #     }
    #     for time_name, time_value in times.items():
    #         print(f"{time_name} : {time_value}")

    #     raise NotImplementedError("Biencoder evaluation not implemented yet here")

        ##############################  Temp code to comment out  ##############################

    scores = torch.einsum("bd,cd->bc", qs, ps)

    # ================ VECTOR STORE EVALUATION =================
    # vector_store_times = []
    # start = time.time()
    # for idx in tqdm(range((1 + len(ps)) // BATCH_SIZE), desc="Saving embeddings"):
    #     savepath = OUTDIR_MEASURE_LATENCY / f"doc_embedding_colpali_{idx}.pt"
    #     savepath.parent.mkdir(parents=True, exist_ok=True)
    #     save_file({"output_images": ps[idx]}, filename=savepath)
    #     print(f"Embeddings saved to `{savepath}`")

    # vector_store_times.append(time.time() - start)
    # print(f"Average vector store time: {sum(vector_store_times)/(len(vector_store_times) * BATCH_SIZE)}")
    # ==========================================================

    assert scores.shape == (
        len(queries),
        len(passages),
    ), f"Scores shape is {scores.shape} instead of {(len(queries), len(passages))}"

    # ========================   Compute metrics   ============================
    print("Computing metrics")
    relevant_docs = {}
    results = {}

    queries2filename = {query: image_filename for query, image_filename in zip(ds["query"], ds["image_filename"])}

    # TODO : match index to get the scores when image_embedding
    if page_retrieving:
        passages2filename = {docidx: image_filename for docidx, image_filename in enumerate(ds["image_filename"])}
        # same code as in custom_colbert/utils/train_custom_colbert_models.py for eval
        for query, score_per_query in zip(queries, scores):
            relevant_docs[query] = {queries2filename[query]: 1}  # Ok

            for docidx, score in enumerate(score_per_query):
                filename = passages2filename[docidx]
                score_passage = float(score.item())

                if query in results:
                    results[query][filename] = max(results[query].get(filename, 0), score_passage)
                else:
                    results[query] = {filename: score_passage}

    else:
        for query, score in zip(queries, scores):
            relevant_docs[query] = {queries2filename[query]: 1}

            for idx, passage in enumerate(passages):
                filename = passages2filename[passage]
                score_passage = float(score[idx].item())

                if query in results:
                    results[query][filename] = max(results[query].get(filename, 0), score_passage)
                else:
                    results[query] = {filename: score_passage}

    assert len(results) == len(queries), f"Results length is {len(results)} instead of {len(queries)}"

    metrics = evaluator.compute_metrics(relevant_docs, results)

    times = {
        "query_encoding": query_time / len(queries),
        "passage_time": passage_time / len(passages),
        "index_lookup_time": index_lookup_time,
    }

    for metric, value in metrics.items():
        print(f"{metric} : {value}")

    for time_name, time_value in times.items():
        print(f"{time_name} : {time_value}")
    return metrics, times


def evaluate_dataset(dataset_name, model_name, text_chunk_only=False, ocr_only=False):
    if text_chunk_only:
        text_string = "text"
    elif ocr_only:
        text_string = "ocr"
    else:
        text_string = ""

    metric_log = f"data/evaluation_results/metrics/metrics_{model_name.replace('/', '_')}_{text_string}.json"
    time_log = f"data/evaluation_results/times/times_{model_name.replace('/', '_')}_{text_string}.json"

    ds = cast(Dataset, load_dataset(dataset_name, split="test"))

    if "chunk_type" in ds.column_names and text_chunk_only:
        ds = ds.filter(lambda x: x["chunk_type"] == "text")

    assert "query" in ds.column_names, "The dataset should have a query column"
    assert "image_filename" in ds.column_names, "The dataset should have a image_filename column"

    metrics, times = evaluate_model(ds, model_name, ocr_only)

    print(f"Finished evaluation of {model_name} on {dataset_name}")

    log_metrics(metrics, dataset_name, metric_log)
    log_metrics(times, dataset_name, time_log)


def main(
    model_name: Annotated[str, typer.Option(help="model name to use for evaluation")],
    dataset_name: Annotated[str, typer.Option(help="dataset on hugging face to evaluate")] = "",
    collection_name: Annotated[str, typer.Option(help="collection name to use for evaluation")] = "",
    text_only: Annotated[bool, typer.Option(help="If True, only text chunks will be used for evaluation")] = False,
    ocr_only: Annotated[bool, typer.Option(help="If True, only ocr chunks will be used for evaluation")] = False,
):
    # print parameters
    print(f"Model name: {model_name}")
    print(f"Dataset name: {dataset_name}")
    print(f"Text only: {text_only}")
    print(f"OCR only: {ocr_only}")

    if dataset_name != "":
        evaluate_dataset(dataset_name, model_name, text_only, ocr_only)
    elif collection_name != "":
        collection = huggingface_hub.get_collection(collection_name)
        datasets = collection.items
        for dataset in datasets:
            print(f"\n---------------------------\nEvaluating {dataset.item_id}")
            evaluate_dataset(dataset.item_id, model_name, text_only, ocr_only)
    else:
        raise ValueError("Please provide a dataset name or collection name")


if __name__ == "__main__":
    typer.run(main)
