import json
from transformers import AutoProcessor
import sys 
import os 
current_file_path = os.path.dirname(os.path.abspath(__file__))
module_path = os.path.join(current_file_path, "../")
sys.path.append(module_path)
from models.qwen2_vl import Qwen2VLRetForConditionalGeneration # 作者的源代码,有问题吧，为什么拿这个模型评估
# 该使用 finetune 模型来进行评估吧,但是使用这个模型来评估效果会奇差，预训练也是如此
from models.qwen2_vl_finetune import Qwen2VLRetFinetuneForConditionalGeneration
import torch
# 不确定是否需要适配
import torch_npu                              # 适配 npu
from torch_npu.contrib import transfer_to_npu # 适配 npu
import argparse
from dataset.datasets_mbeir import QueryDataset, CandidateDataset
from collators.mbeir_eval import MbeirQueryDataCollator, MbeirCandidateDataCollator
from torch.utils.data import DataLoader 
import torch.nn.functional as F 
from accelerate import Accelerator
import accelerate
DATASET_QUERY_NUM_UPPER_BOUND = 500000
DATASET_CAN_NUM_UPPER_BOUND = 10000000
# 导入自定义的工具函数 debug --------------------------------------------------------------------------------
from utils import (
    rank0_print, find_all_linear_names, safe_save_model_for_hf_trainer,
    get_peft_state_maybe_zero_3, TrainerWithCustomSampler
)
import time
# --------------------------------------------------------------------------------------------------------

def unhash_qid(hashed_qid):
    dataset_id = hashed_qid // DATASET_QUERY_NUM_UPPER_BOUND
    data_within_id = hashed_qid % DATASET_QUERY_NUM_UPPER_BOUND
    return f"{dataset_id}:{data_within_id}"

def unhash_did(hashed_did):
    dataset_id = hashed_did // DATASET_CAN_NUM_UPPER_BOUND
    data_within_id = hashed_did % DATASET_CAN_NUM_UPPER_BOUND
    return f"{dataset_id}:{data_within_id}"

def load_qrel(filename):
    qrel = {}
    qid_to_taskid = {}
    with open(filename, "r") as f:
        for line in f:
            query_id, _, doc_id, relevance_score, task_id = line.strip().split()
            if int(relevance_score) > 0:  # Assuming only positive relevance scores indicate relevant documents
                if query_id not in qrel:
                    qrel[query_id] = []
                qrel[query_id].append(doc_id)
                if query_id not in qid_to_taskid:
                    qid_to_taskid[query_id] = task_id
    print(f"Retriever: Loaded {len(qrel)} queries from {filename}")
    print(
        f"Retriever: Average number of relevant documents per query: {sum(len(v) for v in qrel.values()) / len(qrel):.2f}"
    )
    return qrel, qid_to_taskid

def compute_recall_at_k(relevant_docs, retrieved_indices, k):
    if not relevant_docs:
        return 0.0 # Return 0 if there are no relevant documents

    # Get the set of indices for the top k retrieved documents
    top_k_retrieved_indices_set = set(retrieved_indices[:k])

    # Convert the relevant documents to a set
    relevant_docs_set = set(relevant_docs)

    # Check if there is an intersection between relevant docs and top k retrieved docs
    # If there is, we return 1, indicating successful retrieval; otherwise, we return 0
    if relevant_docs_set.intersection(top_k_retrieved_indices_set):
        return 1.0
    else:
        return 0.0


def eval(args):
    original_model_id = args.original_model_id
    model_id = args.model_id
    
    model = Qwen2VLRetForConditionalGeneration.from_pretrained(
        model_id, 
        torch_dtype=torch.bfloat16, 
        low_cpu_mem_usage=True, 
    )
    # 处理模型的配置项--------------------------------------------------------------------------------
    model.mean_pooling = args.model_mean_pooling == "True"
    model.use_bi_atten = args.model_use_bi_atten == "True"
    model.use_latent_atten = args.model_use_latent_atten == "True"
    model.use_instruction_mask = args.model_use_instruction_mask == "True"
    
    # ------------------------------------------------------------------------------------------------

    # processor is not changed so we still load from the original model repo
    processor = AutoProcessor.from_pretrained(original_model_id)

    tokenizer = processor.tokenizer 
    tokenizer.model_max_length = args.model_max_length
    
    # 为每个 token 创建独立配置项 --------------------------------------------------------------------------------
    def add_embed_token(tokenizer, model, emb_token="<emb>"):
        emb_tokens = [emb_token]
        num_new_tokens = tokenizer.add_tokens(emb_tokens)
        assert len(emb_tokens) == num_new_tokens
        model.resize_token_embeddings(len(tokenizer))
        token_id = tokenizer.convert_tokens_to_ids(emb_token)
        if emb_token == "<instruction_start>":
            model.config.instruction_start_token_id = token_id
        elif emb_token == "<instruction_end>":
            model.config.instruction_end_token_id = token_id
        else:
            model.config.emb_token_id = token_id  # 默认通用 token
    add_embed_token(tokenizer, model)
    
    if model.use_instruction_mask:
        add_embed_token(tokenizer, model, emb_token="<instruction_start>")
        add_embed_token(tokenizer, model, emb_token="<instruction_end>")
    # ------------------------------------------------------------------------------------------------

    query_dataset = QueryDataset(
        query_data_path=args.query_data_path, 
        cand_pool_path=args.query_cand_pool_path,
        instructions_path=args.instructions_path,
        image_path_prefix=args.image_path_prefix,
        use_instruction_token=(args.query_dataset_use_instruction_token == "True"),  # 数据集是否使用指令 token
        has_instruction= (args.query_dataset_has_instruction == "True"),             # 数据集是否有指令
        prompt_index=int(args.query_prompt_index),                                   # 数据集的咒语索引
    )

    cand_dataset = CandidateDataset(
        query_data_path=args.query_data_path, 
        cand_pool_path=args.cand_pool_path,
        instructions_path=args.instructions_path,
        image_path_prefix=args.image_path_prefix,                     #  debug
        prompt_index=int(args.query_prompt_index),  # 数据集的咒语索引


    )

    query_data_collator = MbeirQueryDataCollator(tokenizer=tokenizer, processor=processor, \
                                                has_instruction=query_dataset.has_instruction, \
                                                use_instruction_token=query_dataset.use_instruction_token)
    cand_data_collator = MbeirCandidateDataCollator(tokenizer=tokenizer, processor=processor)
    
    query_dataloader = DataLoader(query_dataset, batch_size=16, num_workers=8, shuffle=False, collate_fn=query_data_collator)
    candidate_dataloader = DataLoader(cand_dataset, batch_size=16, num_workers=8, shuffle=False, collate_fn=cand_data_collator)

    accelerator = Accelerator(mixed_precision='bf16')
    device = accelerator.device 
    is_main_process = accelerator.is_main_process
    
    # 打印模型的信息--------------------------------------------------------------------------------
    # query_data_collator 的信息--------------------------------------------------------------------------------
    rank0_print("query_data_collator 的 has_instruction 是：",query_data_collator.has_instruction)
    rank0_print("query_data_collator 的 use_instruction_token 是：",query_data_collator.use_instruction_token)
    # 打印 query_dataset 的信息--------------------------------------------------------------------------------
    rank0_print("query_dataset 的长度是：",len(query_dataset))  
    rank0_print("query_dataset 的咒语是：",query_dataset.prompt)
    rank0_print("cand_dataset 使用的咒语： ",cand_dataset.prompt)
    rank0_print("query_dataset 是否使用指令是：",query_dataset.has_instruction)
    rank0_print("query_dataset 是否使用指令 token 是：",query_dataset.use_instruction_token)
    rank0_print("模型初始化完成————————————————————————————————————————————————————————————————————————")
    rank0_print("mean_pooling: ",model.mean_pooling ,"use_bi_atten: ",model.use_bi_atten)
    rank0_print("use_latent_atten: ",model.use_latent_atten)
    rank0_print("use_instruction_mask: ",model.use_instruction_mask,type(model.use_instruction_mask))
    rank0_print( "use_latent_atten: ",model.use_latent_atten,type(model.use_latent_atten))
    rank0_print("model.config.emb_token_id: ",model.config.emb_token_id) # debug
    rank0_print("模型的类名是：",model.__class__.__name__)
    rank0_print("脚本的运行时间是: ",time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
    rank0_print("脚本的参数是: ",args)
    # ------------------------------------------------------------------------------------------------
    model.eval()

    def tensors_to_device(data, device, dtype=model.dtype):
        for key in data.keys():
            if isinstance(data[key], torch.Tensor):
                if key == 'pixel_values':
                    data[key] = data[key].to(device).to(dtype)
                else:
                    data[key] = data[key].to(device)
        return data 

    query_features = []
    query_ids = []
    candidate_features = []
    candidate_ids = []

    from tqdm import tqdm 
    with torch.no_grad():
        candidate_batch_times = 0 # 收集候选集前10个batch 的时间
        query_batch_times = 0     # 收集查询集前10个batch 的时间
        query_dataloader, candidate_dataloader, model = accelerator.prepare(query_dataloader, candidate_dataloader, model)

        for batch_idx,batch in enumerate(tqdm(query_dataloader, disable=not is_main_process)):
            if batch_idx == 0:
                start_time = time.time()
                rank0_print("1"*50)
            batch = tensors_to_device(batch, device)
            # 处理查询集数据，此处必须将 use_instruction_mask 设置为 还原为原来的值
            model.use_instruction_mask = (args. model_use_instruction_mask == "True")
            query_embed, batch_query_ids, _ = model(**batch, inference=True)
            query_embed = F.normalize(query_embed, dim=-1)
            query_embed = accelerator.gather_for_metrics(query_embed)
            batch_query_ids = accelerate.utils.gather_object(batch_query_ids)[:len(query_embed)]
            query_ids.extend(batch_query_ids)
            query_features.append(query_embed.cpu())  # 替换原append语句
            if batch_idx == 99:
                query_batch_times = time.time() - start_time
                print(f"查询集前 100 个batch的时间: {query_batch_times/60.00} min")
                query_batch_times = query_batch_times * len(query_dataloader) / 100
                print(f"查询集所有 batch的时间: {query_batch_times/3600.00} h")

        for batch_idx,batch in enumerate(tqdm(candidate_dataloader, disable=not is_main_process)):
            if batch_idx == 0:
                start_time = time.time()
            batch = tensors_to_device(batch, device)
            # 开始进行推理 rank0_print("batch",batch)
            # 处理候选池数据，此处必须将 use_instruction_mask 设置为 False
            model.use_instruction_mask = False
            candidate_embed, _, batch_candidate_ids = model(**batch, inference=True)
            candidate_embed = F.normalize(candidate_embed, dim=-1)
            candidate_embed = accelerator.gather_for_metrics(candidate_embed)
            batch_candidate_ids = accelerator.gather_for_metrics(batch_candidate_ids)[:len(candidate_embed)]
            candidate_ids.extend(batch_candidate_ids)
            candidate_features.append(candidate_embed.cpu())  # 替换原append语句
            if batch_idx == 99:
                candidate_batch_times = time.time() - start_time
                print(f"候选集前 100 个 batch 的时间: {candidate_batch_times/60.00} min")
                candidate_batch_times = candidate_batch_times * len(candidate_dataloader) / 100
                print(f"候选集所有 batch 的时间: {candidate_batch_times/3600.00} h")


        
    model = model.to("cpu")  # 先将模型移出NPU
    del model
    accelerator.free_memory()  # 关键！释放加速器持有的资源
    torch.npu.empty_cache()
    query_features = torch.cat(query_features, dim=0).to(device)
    candidate_features = torch.cat(candidate_features, dim=0).to(device)
    
    if is_main_process:
        # Adjust the order according to ids 
        import numpy as np 

        index = []
        scores = []
        for i in range(len(query_features)):
            query_feature = query_features[i:i+1]
            score = query_feature @ candidate_features.T # (1, num_candidate)
            topk_score, topk_indexes = torch.topk(score, k=50, dim=-1)
            topk_indexes = topk_indexes.squeeze().tolist()
            index.append(topk_indexes)
            scores.append(topk_score.tolist())

        cand_names = np.array([[unhash_did(candidate_ids[item]) for item in row] for row in index])
        query_names = [unhash_qid(item) for item in query_ids]


        save_dir_name = args.save_dir_name
        if not os.path.exists(save_dir_name):
            os.makedirs(save_dir_name)
        save_name = args.qrels_path.split('/')[-1].replace('_qrels.txt', '')
        model_name = args.model_id.split('/')[-1]
        save_name = f"{save_name}_{model_name}"
        with open(f"{save_dir_name}/{save_name}_query_names.json", 'w') as f:
            json.dump(query_names, f, indent=2)
        with open(f"{save_dir_name}/{save_name}_cand_names.json", 'w') as f:
            json.dump(cand_names.tolist(), f, indent=2)
        with open(f"{save_dir_name}/{save_name}_scores.json", 'w') as f:
            json.dump(scores, f, indent=2)
        torch.save(query_features.cpu(), f"{save_dir_name}/{save_name}_query_features.pth")
        torch.save(candidate_features.cpu(), f"{save_dir_name}/{save_name}_candidate_features.pth")
        with open(f"{save_dir_name}/{save_name}_query_ids.json", 'w') as f:
            json.dump(query_ids, f, indent=2)
        with open(f"{save_dir_name}/{save_name}_candidate_ids.json", 'w') as f:
            json.dump(candidate_ids, f, indent=2)
        

        qrel, qid_to_taskid = load_qrel(args.qrels_path)

        k_lists = [1, 5, 10, 50]
        res = {}

        for k in k_lists:
            res[f'recall_{k}'] = []

        for ind, query_name in enumerate(tqdm(query_names)):
            relevant_docs = qrel[query_name]
            retrieved_indices_for_qid = cand_names[ind]
            for k in k_lists:
                recall_at_k = compute_recall_at_k(relevant_docs, retrieved_indices_for_qid, k)
                res[f'recall_{k}'].append(recall_at_k)

        for k in k_lists:
            print(f"recall_at_{k} = {sum(res[f'recall_{k}']) / len(res[f'recall_{k}'])}")

        model_name = model_id.split('/')[-1]
        with open(f"{save_dir_name}/{model_name}_results.txt", 'a') as f:
            f.write(args.qrels_path + '\n')
            for k in k_lists:
                f.write(f"recall_at_{k} = {sum(res[f'recall_{k}']) / len(res[f'recall_{k}'])}" + '\n')

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--query_data_path', type=str)
    parser.add_argument('--cand_pool_path', type=str)
    parser.add_argument('--instructions_path', type=str)
    parser.add_argument('--qrels_path', type=str)
    parser.add_argument('--model_max_length', type=int, default=1024)
    parser.add_argument('--original_model_id', type=str)
    parser.add_argument('--model_id', type=str)
    parser.add_argument('--query_cand_pool_path', type=str)
    parser.add_argument('--save_dir_name', type=str)
    parser.add_argument('--image_path_prefix', type=str,default="./data/M-BEIR/") # image_path_prefix = "./data/M-BEIR/"  # debug
    parser.add_argument('--query_dataset_has_instruction', type=str)
    parser.add_argument('--query_dataset_use_instruction_token', type=str)
    parser.add_argument('--model_mean_pooling', type=str)
    parser.add_argument('--model_use_bi_atten', type=str)
    parser.add_argument('--model_use_latent_atten', type=str)
    parser.add_argument('--model_use_instruction_mask', type=str)
    parser.add_argument('--query_prompt_index', type=str, default="1")
    args = parser.parse_args()
    eval(args)

