import numpy as np
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import json
import os

# Set a mirror for Hugging Face, if needed. For submission, you might want to remove this.
os.environ['HF_ENDPOINT']= "https://hf-mirror.com"
device = "cuda:0"
modelname = "Skywork/Skywork-Reward-Llama-3.1-8B"

model = AutoModelForSequenceClassification.from_pretrained(
    modelname,
    torch_dtype=torch.bfloat16,
    device_map=device,
    attn_implementation="flash_attention_2",
    num_labels=1,
)
tokenizer = AutoTokenizer.from_pretrained(modelname)

def calculate_reward(prompt, response):
    """Calculate the reward score for a single response"""
    # Create the model input format
    conv = [
        {"role": "user", "content": prompt}, 
        {"role": "assistant", "content": response}
    ]
    
    # Format and tokenize the conversations
    conv_formatted = tokenizer.apply_chat_template(conv, tokenize=False)
    
    # These two lines remove the potential duplicate bos token
    if tokenizer.bos_token is not None and conv_formatted.startswith(tokenizer.bos_token):
        conv_formatted = conv_formatted[len(tokenizer.bos_token):]
    
    conv_tokenized = tokenizer(conv_formatted, return_tensors="pt").to(device)
    
    with torch.no_grad():
        outputs = model(**conv_tokenized)
        return outputs.logits[0][0].item()

def collect_all_rewards(samples):
    """First pass: Collect all rewards to calculate global mean and standard deviation"""
    all_rewards = []
    sample_rewards_list = []  # Store rewards for each sample for the second pass
    
    print("First pass: Collecting all rewards...")
    for j, sample in enumerate(samples):
        if sample.get("valid_paraphrases_count", 0) >= 3:
            if j % 10 == 0:
                print(f"Collection progress: {j+1}/{len(samples)}")
            
            prompt = sample["prompt"]
            rewards = []
            
            # 1. Original response
            rewards.append(calculate_reward(prompt, sample["original_response"]))
            
            # 2. All valid paraphrases
            for para in sample.get("paraphrases", []):
                if para.get("status") == "valid":
                    rewards.append(calculate_reward(prompt, para["text"]))
            
            sample_rewards_list.append({
                "id": sample["id"],
                "rewards": rewards,
                "sample": sample
            })
            all_rewards.extend(rewards)
    
    # Calculate global statistics
    all_rewards = np.array(all_rewards, dtype=float)
    global_mean = np.mean(all_rewards)
    global_std = np.std(all_rewards)
    
    print(f"Global statistics: Mean={global_mean:.4f}, Std Dev={global_std:.4f}")
    print(f"Total number of rewards: {len(all_rewards)}")
    
    return sample_rewards_list, global_mean, global_std

def process_sample_with_global_stats(sample_data, global_mean, global_std):
    """Second pass: Standardize and calculate variance using global statistics"""
    rewards = np.array(sample_data["rewards"], dtype=float)
    
    # Standardize using global mean and standard deviation
    standardized_rewards = (rewards - global_mean) / global_std
    
    # Calculate quantile spread (P90 - P10)
    p90, p10 = np.percentile(standardized_rewards, [90, 10])
    quantile_spread = p90 - p10
    
    return {
        "id": sample_data["id"],
        "rewards": standardized_rewards.tolist(),
        "variance": float(quantile_spread),      # Quantile spread after standardization
        "mean_reward": float(np.mean(standardized_rewards)),
        "min_reward": float(np.min(standardized_rewards)),
        "max_reward": float(np.max(standardized_rewards)),
        "num_responses": len(standardized_rewards),
        "rvariance": float(np.var(standardized_rewards))  # Variance after standardization
    }

if __name__ == "__main__":
    # Replace these with the paths to your data files
    file_paths = []
    all_variances = []
    all_rvariances = []
    
    for i, file_path in enumerate(file_paths):
        print(f"\nProcessing file {i+1}/{len(file_paths)}: {os.path.basename(file_path)}")
        
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                data = json.load(f)
        except Exception as e:
            print(f"Failed to read file: {e}")
            continue
        
        samples = data.get('data', [])
        print(f"Total samples: {len(samples)}")
        
        # First pass: Collect all rewards and calculate global statistics
        sample_rewards_list, global_mean, global_std = collect_all_rewards(samples)
        
        if not sample_rewards_list:
            print(f"File {i+1} has no valid samples")
            all_variances.append(0.0)
            all_rvariances.append(0.0)
            continue
        
        # Second pass: Calculate standardized variance using global statistics
        print("Second pass: Calculating standardized variance...")
        file_variances = []
        file_rvariances = []
        processed_count = 0
        
        for j, sample_data in enumerate(sample_rewards_list):
            if j % 10 == 0:
                print(f"Standardization progress: {j+1}/{len(sample_rewards_list)}")
            
            result = process_sample_with_global_stats(sample_data, global_mean, global_std)
            if result is not None:
                processed_count += 1
                file_variances.append(result['variance'])
                file_rvariances.append(result['rvariance'])
                
                if j % 10 == 0:
                    print(f"var = {result['variance']:.4f}")
                    print(f"rvar = {result['rvariance']:.4f}")
        
        if file_variances:
            mean_var = np.mean(file_variances)
            mean_rvar = np.mean(file_rvariances)
            print(f"File {i+1} results:")
            print(f"  Processed samples: {processed_count}")
            print(f"  Mean variance (quantile spread): {mean_var:.4f}")
            print(f"  Mean rvariance (variance): {mean_rvar:.4f}")
            all_variances.append(mean_var)
            all_rvariances.append(mean_rvar)
        else:
            print(f"File {i+1} has no valid results")
            all_variances.append(0.0)
            all_rvariances.append(0.0)
    
    # Final results
    print(f"\n=== Final Results ===")
    print(f"Variance per file (quantile spread): {[f'{v:.4f}' for v in all_variances]}")
    print(f"R-variance per file (variance): {[f'{v:.4f}' for v in all_rvariances]}")
    
    if any(v > 0 for v in all_variances):
        valid_vars = [v for v in all_variances if v > 0]
        valid_rvars = [v for v in all_rvariances if v > 0]
        print(f"Overall mean variance: {np.mean(valid_vars):.4f}")
        print(f"Overall mean r-variance: {np.mean(valid_rvars):.4f}")
