#!/usr/bin/env python
# -*- encoding: utf-8 -*-

import os
import sys
import torch
import argparse
from mteb import MTEB
from transformers import AutoModel
from collections import OrderedDict
from cmteb_dres_model import CmtebDRESModel

project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, project_root)

task2instruction = {
    "STS": {
        "AFQMC": "为金融领域句子生成表示",
        "ATEC": "为金融领域客服场景的句子生成表示",
        "BQ": "为银行客服对话中的句子生成表示",
        "LCQMC": "为通用领域句子生成表示",
        "PAWSX": "为英文翻译得到的中文句子生成表示",
        "QBQTC": "为搜索引擎的搜索文本生成表示",
        "STSB": "为简短的通用领域句子生成表示",
        "STS22": "为新闻报道文本生成表示",
    },
    "IR": {
        "CmedqaRetrieval": "为医学问题生成表示，用于匹配相关医学文本",
        "CovidRetrieval": "为新冠疫情相关问题生成表示，用于匹配相关文本",
        "DuRetrieval": "为用户日常搜索问题生成表示，用于匹配网页内容",
        "EcomRetrieval": "为用户购物搜索问题生成表示，用于匹配商品属性文本",
        "MedicalRetrieval": "为患者症状描述文本生成表示，用于匹配诊疗方案或医学指南",
        "MMarcoRetrieval": "为简短搜索问题生成表示，用于匹配相关长文本网页内容",
        "T2Retrieval": "为通用领域问题生成表示，用于匹配相关文本",
        "VideoRetrieval": "为简短用户问题生成表示，用于匹配视频标题或视频描述",
    }
}


def build_instruction_prefix(model_name: str, instruction: str = "") -> str:
    model_name_lower = model_name.lower()
    is_minicpm_or_e5 = "minicpm" in model_name_lower or "e5" in model_name_lower
    is_bge = "bge" in model_name_lower

    prefix = ""
    if instruction:
        prefix = f"Instruction: {instruction} Query: "

    if is_minicpm_or_e5:
        prefix = f"<s>{prefix}"
    elif is_bge:
        prefix = f"[CLS]{prefix}"

    return prefix


def build_passage_instruction_prefix(model_name: str) -> str:
    model_name_lower = model_name.lower()
    is_minicpm_or_e5 = "minicpm" in model_name_lower or "e5" in model_name_lower
    is_bge = "bge" in model_name_lower

    prefix = ""
    if is_minicpm_or_e5:
        prefix = "<s>"
    elif is_bge:
        prefix = "[CLS]Passage: "

    return prefix


def load_model(args):
    base_model_path = args.base_model_path
    checkpoint_path = args.checkpoint_path

    if base_model_path != checkpoint_path:
        print("=" * 60)
        print("Mode: Evaluating fine-tuned CHECKPOINT")
        print(f"Base model for architecture: {base_model_path}")
        print(f"Checkpoint for weights:      {checkpoint_path}")
        print("=" * 60)

        print(f"Step 1: Creating a PURE AutoModel shell from '{base_model_path}' with correct dtype...")
        model_shell = AutoModel.from_pretrained(
            base_model_path,
            trust_remote_code=True
        )

        print(f"Step 2: Loading state_dict from checkpoint '{checkpoint_path}'...")
        state_dict = torch.load(os.path.join(checkpoint_path, "pytorch_model.bin"), map_location="cpu", weights_only=True)
        
        print("Step 3: Cleaning state_dict keys by stripping the 'model.' prefix...")
        cleaned_state_dict = OrderedDict()
        for k, v in state_dict.items():
            if k.startswith('model.'):
                cleaned_state_dict[k[6:]] = v
            else:
                cleaned_state_dict[k] = v

        print("Step 4: Loading the cleaned state_dict into the PURE AutoModel shell...")
        missing_keys, unexpected_keys = model_shell.load_state_dict(cleaned_state_dict, strict=False)
        print(f"State_dict loaded. Missing keys: {missing_keys}, Unexpected keys: {unexpected_keys}")

        print("Step 5: Initializing the evaluation model (CmtebDRESModel)...")
        model = CmtebDRESModel(
            model_object=model_shell,
            model_name_or_path=checkpoint_path,
            pooling_method=args.pooling_method,
            normalize_embeddings=True,
            max_length=args.max_length,
            batch_size=args.batch_size,
            gpu_id=args.gpu_id,
        )
    else:
        print("=" * 60)
        print("Mode: Evaluating RAW model (base_model_path == checkpoint_path)")
        print(f"Loading raw model directly from: {base_model_path}")
        print("=" * 60)

        model = CmtebDRESModel(
            model_name_or_path=base_model_path,
            pooling_method=args.pooling_method,
            normalize_embeddings=True,
            max_length=args.max_length,
            batch_size=args.batch_size,
            gpu_id=args.gpu_id,
        )

    return model


def set_instruction(model, task, args):
    print("========== use task instruction ==========")

    if task in task2instruction["IR"]:
        ir_instruction_for_query = task2instruction["IR"][task]

        model.ir_instruction_for_query = build_instruction_prefix(
            model_name=args.base_model_path,
            instruction=ir_instruction_for_query
        )

        model.ir_instruction_for_passage = build_passage_instruction_prefix(
            model_name=args.base_model_path
        )

        print(f"query_instruction: {model.ir_instruction_for_query}")
        print(f"passage_instruction: {model.ir_instruction_for_passage}")
    elif task in task2instruction["STS"]:
        sts_instruction = task2instruction["STS"][task]

        model.sts_instruction = build_instruction_prefix(
            model_name=args.base_model_path,
            instruction=sts_instruction
        )

        print(f"sts_instruction: {model.sts_instruction}")
    return model


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--base_model_path', default=None, type=str)
    parser.add_argument('--checkpoint_path', default=None, type=str)
    parser.add_argument('--pooling_method', default="mean", type=str)
    parser.add_argument('--max_length', default=1024, type=int)
    parser.add_argument('--batch_size', default=16, type=int)
    parser.add_argument('--gpu_id', default=0, type=int)

    parser.add_argument('--task_names', default="DuRetrieval", type=str)
    parser.add_argument('--use_task_instruction', action="store_true", default=True, help="use task prompt")
    return parser.parse_args()


def main():
    args = get_args()
    print(f"Args: {args}")

    model = load_model(args)

    ir_tasks = ["VideoRetrieval", "EcomRetrieval", "CmedqaRetrieval", "MMarcoRetrieval",
                "MedicalRetrieval", "CovidRetrieval", "DuRetrieval", "T2Retrieval"]
    
    sts_tasks = ['AFQMC', 'STS22', 'ATEC', 'BQ', 'LCQMC', 'PAWSX', 'STSB', 'QBQTC']
    task_names = ir_tasks + sts_tasks

    print(task_names)

    for task in task_names:
        print(f"========== {task} ==========")

        model.sts_instruction = None
        model.ir_instruction_for_query = None
        model.ir_instruction_for_passage = None
        
        if args.use_task_instruction:
            model = set_instruction(model, task, args)

        evaluation = MTEB(tasks=[task], task_langs=['zh', 'zh-CN'])
        exp_name, ckpt_name = args.checkpoint_path.split('/')[-2:]

        task_type = "IR" if task in task2instruction["IR"] else "STS"
        output_folder = f"result/{task_type}/{exp_name}_{ckpt_name}"
        print(f"output folder: {output_folder}")

        try:
            evaluation.run(model, output_folder=output_folder)
        except:
            pass


if __name__ == '__main__':
    main()
