import numpy as np
from transformers import AutoTokenizer, AutoModel  # 改为AutoModel
import torch
import json
import os
device = "cuda:0"
cnt = 0
modelname = "openbmb/Eurus-RM-7b"

model = AutoModel.from_pretrained(
    modelname,
    torch_dtype=torch.bfloat16,
    device_map=device,
    trust_remote_code=True, 
    use_safetensors=True,
)
tokenizer = AutoTokenizer.from_pretrained(modelname)

def calculate_reward(prompt, response):
    formatted_text = f"[INST] {prompt} [/INST] {response}"
    inputs = tokenizer(
        formatted_text, 
        return_tensors="pt",
        truncation=True,
        max_length=2048
    ).to(device)
    
    with torch.no_grad():
        reward = model(**inputs).item()
        return reward

def process_sample(sample):
    prompt = sample["prompt"]
    rewards = []

    original_reward = calculate_reward(prompt, sample["original_response"])
    rewards.append(original_reward)

    for para in sample["paraphrases"]:
        if para["status"] == "valid":
            para_reward = calculate_reward(prompt, para["text"])
            rewards.append(para_reward)
    
    print(rewards)
    
    variance = np.var(rewards)
    mean_reward = np.mean(rewards)
    min_reward = min(rewards)
    max_reward = max(rewards)
    
    if max_reward == rewards[0]:
        global cnt
        cnt += 1
        
    rrewards = []
    rrewards = rewards.copy() 
    mi = min(rewards)
    mx = max(rewards)
    for i in range(len(rewards)):
        if mx != mi:
            rrewards[i] = (rrewards[i] - mi) / (mx - mi)
    
    return {
        "id": sample["id"],
        "rewards": rewards,
        "variance": variance,
        "mean_reward": mean_reward,
        "min_reward": min_reward,
        "max_reward": max_reward,
        "num_responses": len(rewards),
        "rvariance": np.var(rrewards)
    }


if __name__ == "__main__":
    file_paths = []
    all_variances=[]
    for i, file_path in enumerate(file_paths):
        print(f"\n{i+1}/{len(file_paths)}: {os.path.basename(file_path)}")
        
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                data = json.load(f)
        except Exception as e:
            print(f"{e}")
            continue
        
        file_variances = []
        original_best_count = 0
        processed_count = 0
        
        samples = data.get('data', [])
        print(f"{len(samples)}")
        
        for j, sample in enumerate(samples):
            if sample.get("valid_paraphrases_count", 0) >= 3:
                if j % 10 == 0:  
                    print(f"{j+1}/{len(samples)}")
                    
                try:
                    result = process_sample(sample)
                    if result is not None:
                        processed_count += 1
                        file_variances.append(result['variance'])
                        
                    if j % 10 == 0:
                        print(f"var = {result['variance']}")
                        print(f"rvar = {result['rvariance']}")
                except Exception as e:
                    continue
        
        if file_variances:
            mean_var = np.mean(file_variances)
            print(f"mean_var:{mean_var:.4f}")
        
            all_variances.append(mean_var)
        else:
            all_variances.append(0.0)
    
    print(f"var=: {[f'{v:.4f}' for v in all_variances]}")
    if any(v > 0 for v in all_variances):
        valid_vars = [v for v in all_variances if v > 0]
        print(f"mean: {np.mean(valid_vars):.4f}")