from functools import partial
import json
import sys
import time
from pathlib import Path
from typing import Any, Dict, List, Tuple

from datasets import load_dataset
import numpy as np
from collections import defaultdict
from tqdm import tqdm
import faiss
import pytrec_eval
from mair import print_results, trec_eval

import torch
import lightning as L

# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))
print(sys.path)
from litgpt import Tokenizer
from litgpt.config import Config
from litgpt.retrieval_model import PSLM
from litgpt.multiple_negative_ranking_loss import cos_sim

# BEGIN HECK
# as a hack we need to be able to get utils from the main training script
# so we add the repo root to the python path
import sys
import os

# repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
# sys.path.append(repo_root)

# Adding parser arguments for cli
import argparse
from sentence_transformers import SentenceTransformer

parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="our_model", help="Model to evaluate")
parser.add_argument("--output_file", type=str, default="output.md", help="Output file to save the results")
args = parser.parse_args()

import torch
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_math_sdp(True)

def eval_embedding(model, tokenizer, tasks, instruct=True, model_name='our_model'):
    output_dict = defaultdict(list)
    for task in tasks:
        if task in output_dict:
            continue
        data = load_dataset('MAIR-Bench/MAIR-Queries', task)
        docs = load_dataset('MAIR-Bench/MAIR-Docs', task)
        for split in data:
            doc_split = 'docs' if split == 'queries' else split.replace('_queries', '_docs')
            doc_content = [item['doc'] for item in docs[doc_split]]

            # prefix_suffix_model.encode_corpus = partial(prefix_suffix_model.encode_corpus, instruction="{text}")
            # prefix_suffix_model.encode_corpus = partial(prefix_suffix_model.encode_corpus, instruction="<DOC> {text} <DOC>")
            # prefix_suffix_model.encode_corpus = partial(prefix_suffix_model.encode_corpus, instruction="<DOC> {text} <DOC> Search_Document:")
            # prefix_suffix_model.encode_corpus = partial(prefix_suffix_model.encode_corpus, instruction="<DOC> {text} <DOC> <TASK> Search_Document <TASK> " + f"<SYS> {data[split]['instruction'][0]} <SYS>")

            if model_name == 'nomic':
                pass
            elif model_name == 'our_model':
                prefix_suffix_model.encode_corpus = partial(prefix_suffix_model.encode_corpus, instruction="{text}")
                doc_embedding = model.encode_corpus(doc_content, batch_size=256, max_length=2048, add_eos=False, add_bos=False, pooling_method="lasttoken")
            elif model_name == 'e5':
                doc_embedding = model.encode(doc_content, batch_size=1)
            elif model_name == 'gte_qwen':
                doc_embedding = model.encode(doc_content)
            else:
            #   doc_embedding = model.encode(doc_content, batch_size=32, show_progress_bar=True, max_length=2048)
                raise ValueError("Model name not recognized")

            doc_embedding = np.asarray(doc_embedding, dtype=np.float32)

            dim = doc_embedding.shape[1]
            index = faiss.index_factory(dim, "Flat", faiss.METRIC_INNER_PRODUCT)
            index.add(doc_embedding)

            query_embedding = []
            for item in data[split]:
                # prefix_suffix_model.encode_queries = partial(prefix_suffix_model.encode_queries, instruction="{text}")
                # prefix_suffix_model.encode_queries = partial(prefix_suffix_model.encode_queries, instruction="<s> <QUERY> {text} <QUERY>")
                # prefix_suffix_model.encode_queries = partial(prefix_suffix_model.encode_queries, instruction="<s> Search_Query: <QUERY> {text} <QUERY>")
                # prefix_suffix_model.encode_queries = partial(prefix_suffix_model.encode_queries, instruction=f"<s><SYS> {item['instruction']} <SYS> <TASK> Search_Query <TASK> " + "<QUERY> {text} <QUERY>")

                if instruct:
                  # if model_name == 'nomic':
                  #   query_embedding.append(nomic_embed(model,tokenizer, item['query']))
                  # else:
                  # query_embedding.append(model.encode(item['query'], prompt=item['instruction']))
                    if model_name == 'e5':
                        query_embedding.append(model.encode(item['query'], prompt=f"Instruct: {item['instruction']}\nQuery: "))
                        # query_embedding.append(model.encode(item['query'], prompt_name="web_search_query", prompt=item['instruction']))
                        # query_embedding.append(model.encode(item['query'], prompt_name="web_search_query"))
                    elif model_name == 'gte_qwen':
                        query_embedding.append(model.encode(item['query'], prompt=f"Instruct: {item['instruction']}\nQuery: "))
                        # query_embedding.append(model.encode(item['query'], prompt_name="query", prompt=item['instruction']))
                        # query_embedding.append(model.encode(item['query'], prompt_name="query"))
                    elif model_name == 'our_model':
                        # prefix_suffix_model.encode_queries = partial(prefix_suffix_model.encode_queries, instruction=f"Instruction: {item['instruction']} " + "Question: {text}\nAnswer:")
                        prefix_suffix_model.encode_queries = partial(prefix_suffix_model.encode_queries, instruction="{text}")
                        query_embedding.append(model.encode_queries(item['query'], batch_size=512, max_length=2048, add_eos=False, add_bos=False, pooling_method="lasttoken"))
                else:
                    if model_name == 'nomic':
                      pass
                    elif model_name == 'our_model':
                      query_embedding.append(model.encode_queries(item['query'], batch_size=512, max_length=2048, add_eos=False, add_bos=False, pooling_method="lasttoken"))
                      # doc_embedding = model.encode_corpus(item['query'])
                    else:
                    #   query_embedding.append(model.encode(item['query']))
                        raise ValueError("Model name not recognized")
            query_embedding = np.asarray(query_embedding, dtype=np.float32)
            distance, rank = index.search(query_embedding, 100)

            qrels = {}
            for item in data[split]:
                qrels[item['qid']] = {str(x['id']): int(x['score']) for x in item['labels']}
            results = {}
            for item, rk, ds in zip(data[split], rank, distance):
                results[item['qid']] = {}
                for r, d in zip(rk, ds):
                    results[item['qid']][str(docs[doc_split][int(r)]['id'])] = float(d)
            eval_results = trec_eval(qrels, results, k_values=(1, 5, 10, 100))
            output_dict[task + '/' + split].append(
                {'task': task, 'split': split, 'eval_results': eval_results, 'size': len(data[split]),
                 'results': results})
            print(task + '/' + split, eval_results)
    print_results(output_dict)
    return output_dict

if args.model == "our_model":
    root_dir = "/XXXX-36/XXXX-22/retrieval-pretrained-01"
    # root_dir = "/XXXX-30/XXXX-29/XXXX-31/proj-shared/language_models/retrieval/jwk_ckpts/pythia-1.4b-retr-32k_w_meta_mb2-wb2048-grp256_keep368640_pad2blckTrue_1-1-16_128N_peak6e-04_cosine_min6e-05"
    # root_dir = "/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/output/v4_fineweb_100b_pythia-160m-retr-32k_w_meta_mb2-wb2048-grp1024_keep_368k_negs_128N_truncate_normal"
    # root_dir = "/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/output/axonn_nomic_finetune_phase3_pt_step_130k_fineweb_100b_lr_2e-5_w_lockstep_wb_tgrp_8_negs_4096_mean_pool_v4_pythia-160m-retr-32k_w_meta_truncate_normal_mb16-wb2048-grp1-1-8_32_16N_max_steps_2484_max_seq_512"
    run_config_path = root_dir + "/run_config.json"
    model_config_path = root_dir + "/model_config.json"
    checkpoint_dir = root_dir + "/lit_ckpts/step-00120000_ckpt.pth"
    # checkpoint_dir = root_dir + "/combined_ckpts/step-00100000_ckpt.pth"
    # checkpoint_dir = root_dir + "/combined_ckpts/step-00072000_ckpt.pth"
    # checkpoint_dir = root_dir + "/combined_ckpts/step-00002483_ckpt.pth"
    tokenizer_path = "/XXXX-36/XXXX-22/retrieval-pretrained-01" + "/ret_meta_tokens"

    run_config = json.load(open(run_config_path, "r"))

    model_config = Config.from_file(Path(model_config_path))
    model_config.structured_init = False
    model_config.structured_init_for_wte = False
    model_config.structured_init_olmo_variant = False
    model_config.strategy = "ddp"  # TODO: It's a placeholder to avoid error, need to be fixed
    model_config.attn_impl = "sdpa"

    max_seq_length = run_config["block_size"]

    prefix_suffix_tokenizer = Tokenizer(tokenizer_path)
    print("====== Model args: ======")
    print("Tokenizer path:", tokenizer_path)
    print("suffix_is_prefix:", run_config["suffix_is_prefix"])
    print("flip_rope_embedding_suffix:", run_config["flip_rope_embedding_suffix"])
    print("add_suf_pre_tokens:", run_config["add_suf_pre_tokens"])
    print("nope_pos_embeddings:", run_config["nope_pos_embedding"])
    prefix_suffix_model = PSLM(
        model_config,
        objective=None,
        tokenizer=prefix_suffix_tokenizer,
        suffix_is_prefix=run_config["suffix_is_prefix"],
        flip_rope_embedding_suffix=run_config["flip_rope_embedding_suffix"],
        add_suf_pre_tokens=run_config["add_suf_pre_tokens"],
        nope_pos_embeddings=run_config["nope_pos_embedding"],
    )

    checkpoint = torch.load(checkpoint_dir, map_location=torch.device("cpu"), weights_only=False)
    prefix_suffix_model.load_state_dict(checkpoint["model"])

    if run_config["suffix_is_prefix"]:
        assert prefix_suffix_model.prefix_model == prefix_suffix_model.suffix_model

    for name, param in prefix_suffix_model.named_parameters():
        print(name, param.size(), param.dtype, param.device)
        print(param)
        break
    prefix_suffix_model = prefix_suffix_model.to(torch.bfloat16).to("cuda:0")
    # prefix_suffix_model = prefix_suffix_model.to("cuda:0")
    _model = prefix_suffix_model
    results = eval_embedding(_model, tokenizer = None, tasks = ['IFEval'], instruct = True, model_name = 'our_model')
    # we will now save the table produced by `print_results(results)` in a .md file


# # Adding instruction to the query
# prefix_suffix_model.encode_queries = partial(prefix_suffix_model.encode_queries, instruction="<s><QUERY> {text} <QUERY>")
# prefix_suffix_model.encode_corpus = partial(prefix_suffix_model.encode_corpus, instruction="<DOC> {text} <DOC>")
elif args.model == "e5":
    print("Loading e5 model")
    _model = SentenceTransformer("intfloat/e5-mistral-7b-instruct").to(torch.float32).to("cuda:0")
    # _model = SentenceTransformer("intfloat/e5-mistral-7b-instruct").to(torch.bfloat16).to("cuda:0")
    results = eval_embedding(_model, tokenizer = None, tasks = ['IFEval'], instruct = True, model_name = 'e5')
elif args.model == "gte_qwen":
    print("Loading gte-qwen model")
    _model = SentenceTransformer("Alibaba-NLP/gte-Qwen2-1.5B-instruct", trust_remote_code=True).to(torch.float32).to("cuda:0")
    # _model = SentenceTransformer("Alibaba-NLP/gte-Qwen2-1.5B-instruct", trust_remote_code=True).to(torch.bfloat16).to("cuda:0")
    results = eval_embedding(_model, tokenizer = None, tasks = ['IFEval'], instruct = True, model_name = 'gte_qwen')

with open(args.output_file, "w") as f:
    f.write("```markdown\n")
    f.write(print_results(results))
    f.write("```\n")

# code
# results = eval_embedding(prefix_suffix_model, tokenizer = None, tasks = ['RepoBench', 'SWE-Bench-Lite', 'FoodAPI', 'HuggingfaceAPI', 'PytorchAPI', 'SpotifyAPI', 'TMDB', 'TensorAPI', 'ToolBench', 'WeatherAPI', 'APPS', 'CodeSearchNet', 'HumanEval-X', 'LeetCode', 'MBPP', 'Conala', 'TLDR', 'CodeEditSearch'], instruct = True, model_name = 'our_model')
# legal
# results = eval_embedding(prefix_suffix_model, tokenizer = None, tasks = ['BillSum', 'AILA2019-Case', 'GerDaLIR', 'LeCaRDv2', 'AILA2019-Statutes', 'BSARD', 'LegalQuAD', 'REGIR-EU2UK', 'REGIR-UK2EU', 'TREC-Legal_2011', 'CUAD'], instruct = True, model_name = 'our_model')
# Finance
# results = eval_embedding(prefix_suffix_model, tokenizer = None, tasks = ['ConvFinQA', 'Apple', 'FinQA', 'FinanceBench', 'HC3Finance', 'TAT-DQA', 'Trade-the-event', 'FiQA'], instruct = True, model_name = 'our_model')
# Academic
# results = eval_embedding(prefix_suffix_model, tokenizer = None, tasks = ['LitSearch', 'FairRanking_2020', 'ProofWiki_Reference', 'Stacks_Reference', 'Stein_Reference', 'Trench_Reference', 'TAD', 'TAS2', 'SciDocs', 'ProofWiki_Proof', 'Stacks_Proof', 'Stein_Proof', 'Trench_Proof', 'SciFact', 'Competition-Math', 'StackMathQA'], instruct = True, model_name = 'our_model')
# results = eval_embedding(prefix_suffix_model, tokenizer = None, tasks = ['Core17', 'News21', 'Robust04', 'InstructIR', 'NevIR'], instruct = False, model_name = 'our_model')
# results = eval_embedding(prefix_suffix_model, tokenizer = None, tasks = ['IFEval'], instruct = True, model_name = 'our_model')
# results = eval_embedding(_model, tokenizer = None, tasks = ['IFEval'], instruct = True, model_name = 'e5')
# results = eval_embedding(_model, tokenizer = None, tasks = ['IFEval'], instruct = True, model_name = 'gte_qwen')