from functools import partial
import json
import sys
import time
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
import lightning as L
import mteb
from mteb.model_meta import ModelMeta

from tqdm import tqdm
import random

# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))

from litgpt import Tokenizer
from litgpt.config import Config
from litgpt.retrieval_model import PSLM
from litgpt.utils import check_valid_checkpoint_dir, get_default_supported_precision, load_checkpoint

# BEGIN HECK
# as a hack we need to be able to get utils from the main training script
# so we add the repo root to the python path
import sys
import os

repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(repo_root)

# this captures things like the scale_lr function and stuff
from train import *

from eval.mteb_task_prompts import (
    TASK_LIST_STS,
    TASK_LIST,
    task2prefix_short,
    task2prefix_long,
    TASK_LIST_RETRIEVAL,
    TASK_LIST_CLUSTERING,
    TASK_LIST_PAIR_CLASSIFICATION,
    TASK_LIST_RERANKING,
    TASK_LIST_CLASSIFICATION,
)

# setting seed for reproducibility
torch.manual_seed(42)
random.seed(42)

TASKS_BY_CATEGORY = {
    "retrieval": TASK_LIST_RETRIEVAL,
    "retrieval_tiny": ["NFCorpus", "ArguAna", "SciFact"],
    "mteb_subset": [
        "NFCorpus",
        "ArguAna",
        "SciFact",  # Retrieval tasks
        "StackOverflowDupQuestions",
        "SciDocsRR",  # Reranking tasks
        "BiorxivClusteringS2S",
        "MedrxivClusteringS2S",
        "TwentyNewsgroupsClustering",  # Clustering tasks
        "SprintDuplicateQuestions",  # Pair classification tasks
        "Banking77Classification",
        "EmotionClassification",
        "MassiveIntentClassification",  # Classification tasks
        "STS17",
        "SICK-R",
        "STSBenchmark",  # STS tasks
    ],
    "msmarco": ["MSMARCO"],
    "sts": TASK_LIST_STS,
    "clustering": TASK_LIST_CLUSTERING,
    "pair_classification": TASK_LIST_PAIR_CLASSIFICATION,
    "reranking": TASK_LIST_RERANKING,
    "classification": TASK_LIST_CLASSIFICATION,
    "all": TASK_LIST,
    "user_specified": None,  # requires the second task_list arg
}


def get_detailed_instruct(task_description: str) -> str:
    if not task_description:
        return ""

    return "Instruct: {}\nQuery: ".format(task_description)


def medi2_instruct(task_description: str) -> str:
    if not task_description:
        return ""

    return "{} ".format(task_description)


NAME_TO_FUNC = {
    "default": get_detailed_instruct,
    "medi2": medi2_instruct,
}


@torch.inference_mode()
def main(
    model_path: str = "/path/to/model",
    checkpoint_dir: Path = None,
    max_seq_length: int = None,
    no_instruction: bool = False,
    prefix_add_eos: bool = False,
    suffix_add_eos: bool = False,
    task: str = "retrieval_tiny",
    task_list: str = None,
    task_locks: bool = False,
    include_long_prompt: bool = False,
    include_meta_tokens: bool = False,
    overwrite_results: bool = True,
    result_dir: str = None,
    prompt_style: str = "sys_query_doc",
    batch_size: int = 1,
    pooling_method: str = "lasttoken",
    num_few_shot: int = 0,
) -> None:
    # if task is user_specified, parse as comma sep list
    if task == "user_specified":
        assert task_list is not None, "task_list must be provided when task is user_specified"
        tasks = task_list.split(",")
    else:
        tasks = TASKS_BY_CATEGORY[task]

    print(f"Running {len(tasks)} tasks:\n", json.dumps(tasks, indent=4))

    if task_locks:
        print("Task locks enabled")
        # we'll use a simple system where we create a lockfile with the task name in place of where the results json will go
        # then each process looks for it task name and if it exists skips it
    else:
        print("Task locks disabled")
    
    tasks = mteb.get_tasks(tasks=tasks)
    evaluation = mteb.MTEB(tasks=tasks, task_langs=["en"])

    print("Loading model from:", model_path)

    run_config_path = model_path + "/run_config.json"
    model_config_path = model_path + "/model_config.json"

    run_config = json.load(open(run_config_path, "r"))

    model_config = Config.from_file(Path(model_config_path))
    model_config.structured_init = False
    model_config.structured_init_for_wte = False
    model_config.structured_init_olmo_variant = False
    model_config.strategy = "ddp"  # TODO: It's a placeholder to avoid error, need to be fixed

    max_seq_length = run_config["block_size"] if max_seq_length is None else max_seq_length

    tokenizer = Tokenizer(run_config["tokenizer_path"])
    print("====== Model args: ======")
    print("Tokenizer path:", run_config["tokenizer_path"])
    print("suffix_is_prefix:", run_config["suffix_is_prefix"])
    print("flip_rope_embedding_suffix:", run_config["flip_rope_embedding_suffix"])
    print("add_suf_pre_tokens:", run_config["add_suf_pre_tokens"])
    print("nope_pos_embeddings:", run_config["nope_pos_embedding"])
    model = PSLM(
        model_config,
        objective=None,
        tokenizer=tokenizer,
        suffix_is_prefix=run_config["suffix_is_prefix"],
        flip_rope_embedding_suffix=run_config["flip_rope_embedding_suffix"],
        add_suf_pre_tokens=run_config["add_suf_pre_tokens"],
        nope_pos_embeddings=run_config["nope_pos_embedding"],
    )

    checkpoint = torch.load(checkpoint_dir, map_location=torch.device("cpu"))
    model.load_state_dict(checkpoint["model"])

    if run_config["suffix_is_prefix"]:
        assert model.prefix_model == model.suffix_model

    for name, param in model.named_parameters():
        print(name, param.size(), param.dtype, param.device)
        print(param)
        break
    model = model.to(torch.bfloat16).to("cuda:0")
    model.mteb_model_meta = ModelMeta(
        name=str(checkpoint_dir).split("/")[-3],
        revision=os.path.join(
            f"prompt_style_{prompt_style}_pooling_method_{pooling_method}_num_few_shot_{num_few_shot}_no_instruction_{no_instruction}_include_long_prompt_{include_long_prompt}_include_meta_tokens_{include_meta_tokens}",
            checkpoint_dir.name.strip(".pth"),
        ),
        release_date=None,
        languages=None,
    )
    if no_instruction:
        sub_dir_name = f"wo_instruction"
    else:
        if include_long_prompt:
            sub_dir_name = f"w_long_prompt"
        else:
            sub_dir_name = f"w_short_prompt"
        if include_meta_tokens:
            sub_dir_name += "_meta_tokens"

    start_time = time.time()
    for t in tqdm(tasks, desc="Evaluating tasks", total=len(tasks)):
        task_name, task_type = t.metadata_dict["name"], t.metadata_dict["type"]
        if not no_instruction:
            document_prefix = ""
            query_prefix = ""
            short_prefixes = task2prefix_short[task_name]
            long_prefix = task2prefix_long[task_name]
            if not include_long_prompt:
                query_prefix = short_prefixes["query"]
                document_prefix = short_prefixes["document"]
                if include_meta_tokens:
                    query_prefix = "<s>" + f"<QUERY> " + "{text} <QUERY>"
                    document_prefix = "<DOC> {text} <DOC>"
                else:
                    query_prefix = query_prefix + ": {text}"
                    document_prefix = document_prefix + ": {text}"
            else:
                query_prefix = task2prefix_long[task_name]
                document_prefix = query_prefix
                if include_meta_tokens:
                    if prompt_style == "query_doc":
                        query_prefix = "<s>" + f"<QUERY> " + "{text} <QUERY>"
                        document_prefix = "<DOC> {text} <DOC>"
                    elif prompt_style == "long_prefix_query_doc":
                        is_symmetric = short_prefixes["query"] == short_prefixes["document"]
                        doc_long_prompt = query_prefix if is_symmetric else ""
                        query_prefix = f"<s> {long_prefix} <QUERY> " + "{text} <QUERY>"
                        document_prefix = "<DOC> {text} <DOC>"
                    elif prompt_style == "sys_query_doc":
                        is_symmetric = short_prefixes["query"] == short_prefixes["document"]
                        doc_long_prompt = query_prefix if is_symmetric else ""
                        query_prefix = f"<s><SYS> {long_prefix} <SYS> <QUERY> " + "{text} <QUERY>"
                        document_prefix = "<DOC> {text} <DOC>"
                    elif prompt_style == "sys_query_doc_sys":
                        is_symmetric = short_prefixes["query"] == short_prefixes["document"]
                        doc_long_prompt = query_prefix if is_symmetric else ""
                        query_prefix = f"<s><SYS> {long_prefix} <SYS> <QUERY> " + "{text} <QUERY>"
                        document_prefix = "<DOC> {text} <DOC> " + f"<SYS> {long_prefix} <SYS>"
                    elif prompt_style == "prefix_query_doc_prefix":
                        query_prefix = f"<s> {short_prefixes['query'].title()} <QUERY> " + "{text} <QUERY>"
                        document_prefix = "<DOC> {text} <DOC> " + f"{short_prefixes['document'].title()}"
                    elif prompt_style == "task_query_doc_task":
                        query_prefix = f"<s><TASK> {short_prefixes['query'].title()} <TASK> <QUERY> " + "{text} <QUERY>"
                        document_prefix = "<DOC> {text} <DOC> " + f"<TASK> {short_prefixes['document'].title()} <TASK>"
                    elif prompt_style == "sys_task_query_doc":
                        is_symmetric = short_prefixes["query"] == short_prefixes["document"]
                        doc_long_prompt = query_prefix if is_symmetric else ""
                        query_prefix = (
                            f"<s><SYS> {long_prefix} <SYS> <TASK> {short_prefixes['query'].title()} <TASK> <QUERY> "
                            + "{text} <QUERY>"
                        )
                        document_prefix = "<DOC> {text} <DOC>"
                    elif prompt_style == "sys_task_query_doc_task":
                        is_symmetric = short_prefixes["query"] == short_prefixes["document"]
                        doc_long_prompt = query_prefix if is_symmetric else ""
                        query_prefix = (
                            f"<s><SYS> {long_prefix} <SYS> <TASK> {short_prefixes['query'].title()} <TASK> <QUERY> "
                            + "{text} <QUERY>"
                        )
                        document_prefix = "<DOC> {text} <DOC> " + f"<TASK> {short_prefixes['document'].title()} <TASK>"
                    elif prompt_style == "sys_task_query_doc_task_sys":
                        is_symmetric = short_prefixes["query"] == short_prefixes["document"]
                        doc_long_prompt = query_prefix if is_symmetric else ""
                        query_prefix = (
                            f"<s><SYS> {long_prefix} <SYS> <TASK> {short_prefixes['query'].title()} <TASK> <QUERY> "
                            + "{text} <QUERY>"
                        )
                        document_prefix = (
                            "<DOC> {text} <DOC> "
                            + f"<TASK> {short_prefixes['document'].title()} <TASK> <SYS> {long_prefix} <SYS>"
                        )
                    elif prompt_style == "fineweb":
                        is_symmetric = short_prefixes["query"] == short_prefixes["document"]
                        doc_long_prompt = query_prefix if is_symmetric else ""
                        query_prefix = f"<s> {short_prefixes['query'].title()}: " +  "{text}"
                        document_prefix = f"{short_prefixes['document'].title()}: " + "{text}"
                    elif prompt_style == "instruction_question_answer":
                        query_prefix = f"<s> Instruction: {long_prefix}\n" + "Question: {text}\nAnswer:"
                        document_prefix = "{text}"
                    elif prompt_style == "question_excerpt_doc":
                        query_prefix = f'<s> Instruction: {long_prefix}'+'\nQuestion: {text}\nThe answer to this question can be found in the following excerpt: "...'
                        document_prefix = '"{text}'
                    elif prompt_style == "phase_2":  # phase 2 and phase 3 uses same prompt style
                        query_prefix = f"<s> {short_prefixes['query'].title()}: " + "<QUERY> {text} <QUERY>".strip()
                        document_prefix = "<DOC> {text} <DOC> " + f"{short_prefixes['document'].title()}: ".strip()
                    else:
                        raise ValueError(f"Unknown prompt_style: {prompt_style}")
                else:
                    is_symmetric = short_prefixes["query"] == short_prefixes["document"]
                    query_prefix = long_prefix + ": {text}"
                    document_prefix = query_prefix if is_symmetric else ""

            model.encode_queries = partial(model.encode_queries, instruction=query_prefix)
            model.encode_corpus = partial(model.encode_corpus, instruction=document_prefix)

            # for single sequence tasks like clustering, classification, etc. where we don't have a query/document
            model.encode = partial(model.encode, instruction=query_prefix, encoding_mode="prefix")
            # model.encode = partial(model.encode, instruction=document_prefix, encoding_mode="prefix")
            # model.encode = partial(model.encode, instruction=query_prefix, encoding_mode="suffix")
            # model.encode = partial(model.encode, instruction=document_prefix, encoding_mode="suffix")

        eval_splits = ["test" if task_name not in ["MSMARCO"] else "dev"]
        evaluation = mteb.MTEB(tasks=[task_name], task_langs=["en"])

        if num_few_shot > 0:
            # HACK: passing the few shot examples to `encode` so that we can create in-context examples from mteb dataset
            evaluation.tasks[0].load_data()
            sampled_pairs = dict(
                random.sample(list(evaluation.tasks[0].relevant_docs[eval_splits[0]].items()), num_few_shot * 2)
            )
            few_shot_examples = []
            for k, v in sampled_pairs.items():
                query = evaluation.tasks[0].queries[eval_splits[0]][k]
                doc = evaluation.tasks[0].corpus[eval_splits[0]][list(v)[0]]  # picking the first relevant doc
                doc = doc["title"] + " " + doc["text"] if "title" in doc else doc["text"]
                few_shot_examples.append((query, doc))

            # rewriting the query and document prefix to include the few shot examples
            query_prefix = "{few_shot}" + query_prefix
            document_prefix = document_prefix + "{few_shot}"

            model.encode_queries = partial(
                model.encode_queries, instruction=query_prefix, few_shot_examples=few_shot_examples[:num_few_shot]
            )
            model.encode_corpus = partial(
                model.encode_corpus, instruction=document_prefix, few_shot_examples=few_shot_examples[num_few_shot:]
            )
            model.encode = partial(model.encode, num_few_shot=num_few_shot)
        print(
            f"{'#'*80}\nEvaluating task: {task_name}\nWriting to {result_dir}/{model.mteb_model_meta.name}/{model.mteb_model_meta.revision}\n{'#'*80}"
        )
        # here is where we do the lockfile check
        # write the hostname and human readable datetime to the lockfile for human inspection
        if task_locks:
            lockfile = Path(
                f"{result_dir}/{model.mteb_model_meta.name}/{model.mteb_model_meta.revision}/{task_name}.lock"
            )
            if lockfile.exists():
                print(
                    f"Lockfile {lockfile} exists, skipping task on {os.uname().nodename} {time.strftime('%Y-%m-%d %H:%M:%S')}"
                )
                got_lock = False
                continue
            else:
                lockfile.touch()
                with open(lockfile, "w") as f:
                    f.write(f"{os.uname().nodename}@{time.strftime('%Y-%m-%d %H:%M:%S')}\n")
                got_lock = True
        else:
            lockfile = None
            got_lock = False

        results = evaluation.run(
            model,
            # output_folder=f"results/{sub_dir_name}" if result_dir is None else result_dir,
            output_folder=f"results/" if result_dir is None else result_dir,
            eval_splits=eval_splits,
            verbosity=2,
            encode_kwargs={
                "batch_size": batch_size,
                "max_length": max_seq_length,
                "add_eos": False,
                "pooling_method": pooling_method,
            },
            # top_k=args.top_k,
            overwrite_results=overwrite_results,
            # save_corpus_embeddings=True
        )
        # remove the lockfile
        if task_locks and got_lock:
            if lockfile.exists():
                lockfile.unlink()
                print(f"Lockfile {lockfile} released!")
            else:
                print(f"{'#'*80}\nWARN: Expected lockfile {lockfile} does not exist even though got_lock=={got_lock}, skipping deletion :/\n{'#'*80}")

        print(task_name)
        print("\t", results)
        if len(results):
            results_dict = results[0].to_dict()["scores"][eval_splits[0]][0]
            try:
                print("main_score =>", results_dict["main_score"])
            except KeyError:
                print("could not find main_score")
                continue
        print()


    print(f"======= Evaluation Time (in minutes): {(time.time() - start_time) / 60} =======")
    print("======= Evaluation Done =======")


if __name__ == "__main__":
    from jsonargparse import CLI

    # torch.set_float32_matmul_precision("high")
    CLI(main)
