import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification
from transformers import GenerationConfig
import time
import gc
import numpy as np
from model import EAReranker

# 初始化设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


test_words = ["hello", "test", "cool", "cat", "dog", "my", "hot", "look", "six", "one", "two", "red", "blue", "take", "the", "to"]
def memory_monitor(func):
    """测量函数执行时的峰值显存和时间的装饰器"""
    def wrapper(*args, **kwargs):
        torch.cuda.empty_cache()
        gc.collect()
        
        # 记录初始状态
        start_mem = torch.cuda.memory_allocated()
        start_time = time.time()
        print(start_mem, torch.cuda.max_memory_allocated())
        
        # 重置峰值统计
        torch.cuda.reset_peak_memory_stats()
        
        # 执行函数
        result = func(*args, **kwargs)
        
        # 计算指标
        metrics = {
            "peak_memory_MB": (torch.cuda.max_memory_allocated()) / 1024**2,
            "time_s": (time.time() - start_time)/len(test_words)
        }
        
        return metrics
    return wrapper

# lightblue 模型的辅助函数
def make_reranker_input(context, query):
    return f"<<<Query>>>\n{query}\n\n<<<Context>>>\n{context}"

def make_reranker_inference_conversation(context, query):
    system_message = "Given a query and a piece of text, output a score of 1-7 based on how related the query is to the text. 1 means least related and 7 is most related."
    return [
        {"role": "system", "content": system_message},
        {"role": "user", "content": make_reranker_input(context, query)},
    ]

def get_prob(logits, token_id):
    probs = torch.softmax(logits, dim=-1)
    return probs[0, token_id].item()
@memory_monitor
def measure_lightblue_performance(model_name, model_id, length=1):
    """专门评估lightblue模型的性能"""
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16 if device.type == "cuda" else torch.float32, device_map="auto")
    model.eval()
    idx_tokens = [tokenizer.encode(str(i))[0] for i in range(1, 8)]  # 1-7对应的token
    
    for word in test_words:
        # 准备输入
        if length == 1:
            test_text = word
        else:
            test_text = f"{word} " * length
            
        conversation = make_reranker_inference_conversation(test_text, test_text)
        prompt = tokenizer.apply_chat_template(conversation, tokenize=False)
            
        inputs = tokenizer(prompt, return_tensors="pt").to(device)
        with torch.no_grad():
            outputs = model.generate(
                    **inputs,
                    max_new_tokens=1,
                    return_dict_in_generate=True,
                    output_scores=True
            )
            
        # 计算期望分数
        logits = outputs.scores[0]
        probs = np.array([get_prob(logits, token) for token in idx_tokens])
        scores = np.arange(1, 8)
        expected_score = np.sum(probs * scores)
    del model, tokenizer
    
    return 
    
@memory_monitor
def measure_eareranker(input_dim, hidden_dim=768, length=1):
    """专门评估eareranker模型的性能"""
    model = EAReranker(input_dim=input_dim, hidden_dim=hidden_dim).to(device)
    for word in test_words:
        input_emb = torch.rand(1, input_dim, device=device)
        model(input_emb, input_emb)
    del model
    
    return 

@memory_monitor
def measure_standard_performance(model_name, model_id, length=1):
    """评估标准重排序模型的性能"""
    try:
        tokenizer = AutoTokenizer.from_pretrained(model_id)
        model = AutoModelForSequenceClassification.from_pretrained(
            model_id, 
            trust_remote_code=True
        ).to(device)
        model.eval()

        for word in test_words:
            if length == 1:
                input_text = [[word, word]]
            else:
                input_text = [[f"{word} " * length, f"{word} " * length]]
    
            if "jina" in model_id:
                scores = model.compute_score(input_text)
            else:
                inputs = tokenizer(input_text, padding=True, truncation=True, return_tensors='pt')
                inputs = {k: v.to(device) for k, v in inputs.items()}
                with torch.no_grad():
                    outputs = model(**inputs)
        
        del model, tokenizer
    except Exception as e:
        print(f"评估 {model_name} 时出错: {str(e)}")
    return 

def main():
    models_config = [
        {"name": "lb-reranker-v1.0",
         "id": "lightblue/lb-reranker-v1.0",
         "type": "lightblue"},
        {"name": "jina-reranker-v2-base-multilingual",
         "id": "jinaai/jina-reranker-v2-base-multilingual",
         "type": "standard"},
        {"name": "bge-reranker-v2-m3",
         "id": "BAAI/bge-reranker-v2-m3",
         "type": "standard"},
        {"name": "gte-multilingual-reranker-base", 
         "id": "Alibaba-NLP/gte-multilingual-reranker-base",
         "type": "standard"},
    ]
    
    all_results = {}
    for dim in [768, 896, 1024, 1536]:
        all_results[f"EAReranker_{dim}"] = {"1": measure_eareranker(dim)}
    for config in models_config:
        all_results[config["name"]] = {}
        for length in [1, 512, 1024, 4096, 8196, 64000]:
            if config["type"] == "lightblue":
                all_results[config["name"]][str(length)] = measure_lightblue_performance(config["name"], config["id"], length)
            else:
                all_results[config["name"]][str(length)] = measure_standard_performance(
                    config["name"], config["id"], length
                )
        
    
    
    for model_name, metrics in all_results.items():
        print("{:<40} {:<15.2f} {:<15.4f} {:<15.2f} {:<15.4f} {:<15.2f} {:<15.4f} {:<15.2f} {:<15.4f} {:<15.2f} {:<15.4f} {:<15.2f} {:<15.4f}".format(
            model_name,
            metrics.get("1", {}).get("peak_memory_MB", 0),
            metrics.get("1", {}).get("time_s", 0),
            metrics.get("512", {}).get("peak_memory_MB", 0),
            metrics.get("512", {}).get("time_s", 0),
            metrics.get("1024", {}).get("peak_memory_MB", 0),
            metrics.get("1024", {}).get("time_s", 0),
            metrics.get("4096", {}).get("peak_memory_MB", 0),
            metrics.get("4096", {}).get("time_s", 0),
            metrics.get("8196", {}).get("peak_memory_MB", 0),
            metrics.get("8196", {}).get("time_s", 0),
            metrics.get("64000", {}).get("peak_memory_MB", 0),
            metrics.get("64000", {}).get("time_s", 0)
        ))
    
    print("="*80)

if __name__ == "__main__":
    main()
