import numpy as np
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import json
import os
device = "cuda:0"
modelname = "LxzGordon/URM-LLaMa-3.1-8B"

model = AutoModelForSequenceClassification.from_pretrained(
    modelname,
    device_map='auto',
    trust_remote_code=True,
)

tokenizer = AutoTokenizer.from_pretrained(
    modelname,
    use_fast=True
)
def calculate_reward(prompt, response):
    conv = [
        {"role": "user", "content": prompt}, 
        {"role": "assistant", "content": response}
    ]
    
    # Format and tokenize the conversations
    conv_formatted = tokenizer.apply_chat_template(conv, tokenize=False)
    
    # These two lines remove the potential duplicate bos token
    if tokenizer.bos_token is not None and conv_formatted.startswith(tokenizer.bos_token):
        conv_formatted = conv_formatted[len(tokenizer.bos_token):]
    
    conv_tokenized = tokenizer(conv_formatted, return_tensors="pt").to(device)
    
    with torch.no_grad():
        outputs = model(**conv_tokenized)
        return outputs.logits[0][0].item()


def process_sample(sample,op:int):
    prompt = sample["prompt"]
    rewards = []
    rewards.append(calculate_reward(prompt, sample["original_response"]))
    for para in sample.get("paraphrases", []):
        if para.get("status") == "valid":
            rewards.append(calculate_reward(prompt, para["text"]))

    rewards = np.array(rewards, dtype=float)

    p90, p10 = np.percentile(rewards, [90, 10])
    quantile_spread = p90 - p10

    quantile_spread /= np.std(rewards)   # or /= np.mean(rewards)

    if op==1:
        print(rewards)
        print(p90,p10)
        print(quantile_spread)
        print(np.std(rewards))
    
    return {
        "id": sample["id"],
        "rewards": rewards.tolist(),
        "variance": float(quantile_spread),   
        "mean_reward": float(np.mean(rewards)),
        "min_reward": float(np.min(rewards)),
        "max_reward": float(np.max(rewards)),
        "num_responses": len(rewards),
        "rvariance": float(np.var(rewards))  
    }
if __name__ == "__main__":
    file_paths = []
    all_variances=[]
    for i, file_path in enumerate(file_paths):
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                data = json.load(f)
        except Exception as e:
            continue
        
        file_variances = []
        original_best_count = 0
        processed_count = 0
        
        samples = data.get('data', [])
        
        for j, sample in enumerate(samples):
            if sample.get("valid_paraphrases_count", 0) >= 3:       
                result = process_sample(sample,0)               
                if result is not None:
                    processed_count += 1
                    file_variances.append(result['variance'])
                    
                
        
        if file_variances:
            mean_var = np.mean(file_variances)
            print(f" mean_var: {mean_var:.4f}")
            all_variances.append(mean_var)
        else:
            all_variances.append(0.0)
    
    print(f"{[f'{v:.4f}' for v in all_variances]}")
    if any(v > 0 for v in all_variances):
        valid_vars = [v for v in all_variances if v > 0]
        print(f"{np.mean(valid_vars):.4f}")