import numpy as np
from transformers import AutoTokenizer, pipeline
import torch
import json
import os

device = 0  
modelname = "weqweasdas/RM-Mistral-7B"


rm_tokenizer = AutoTokenizer.from_pretrained(modelname)
rm_pipe = pipeline(
    "sentiment-analysis",
    model=modelname,
    device=device,
    tokenizer=rm_tokenizer,
    model_kwargs={"torch_dtype": torch.bfloat16}
)

pipe_kwargs = {
    "return_all_scores": True,
    "function_to_apply": "none",
    "batch_size": 1
}

def calculate_reward(prompt, response):
    conv = [
        {"role": "user", "content": prompt}, 
        {"role": "assistant", "content": response}
    ]
    
    # Format the conversation
    conv_formatted = rm_tokenizer.apply_chat_template(conv, tokenize=False, add_generation_prompt=False)
    
    # Remove the potential duplicate bos token
    if rm_tokenizer.bos_token is not None and conv_formatted.startswith(rm_tokenizer.bos_token):
        conv_formatted = conv_formatted[len(rm_tokenizer.bos_token):]
    
    # Get reward using pipeline
    pipe_outputs = rm_pipe([conv_formatted], **pipe_kwargs)
    reward = pipe_outputs[0][0]["score"]
    
    return reward

def process_sample(sample):
    prompt = sample["prompt"]
    rewards = []
    original_reward = calculate_reward(prompt, sample["original_response"])
    rewards.append(original_reward)
    for para in sample["paraphrases"]:
        if para["status"] == "valid":
            para_reward = calculate_reward(prompt, para["text"])
            rewards.append(para_reward)
    mi = min(rewards)
    mx = max(rewards)
    for i in range(len(rewards)):
        if mx != mi:
            rewards[i] = 2 * (rewards[i] - mi) / (mx - mi) - 1 
    variance = np.var(rewards)
    mean_reward = np.mean(rewards)
    min_reward = min(rewards)
    max_reward = max(rewards)
    
    return {
        "id": sample["id"],
        "rewards": rewards,
        "variance": variance,
        "mean_reward": mean_reward,
        "min_reward": min_reward,
        "max_reward": max_reward,
        "num_responses": len(rewards)
    }


if __name__ == "__main__":
    file_paths = []
    all_variances = []
    
    for i, file_path in enumerate(file_paths):
        
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                data = json.load(f)
        except Exception as e:
            continue
        
        file_variances = []
        original_best_count = 0
        processed_count = 0
        
        samples = data.get('data', [])
        
        for j, sample in enumerate(samples):
            if sample.get("valid_paraphrases_count", 0) >= 3:
                
                
                result = process_sample(sample)
                if result is not None:
                    processed_count += 1
                    file_variances.append(result['variance'])
                    
                print(result["rewards"])
                print(result["variance"])
        
        if file_variances:
            mean_var = np.mean(file_variances)
            print(f"  {mean_var:.4f}")
            
            all_variances.append(mean_var)
        else:
            all_variances.append(0.0)
    
    print(f"{[f'{v:.4f}' for v in all_variances]}")
    if any(v > 0 for v in all_variances):
        valid_vars = [v for v in all_variances if v > 0]