import numpy as np
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import json
import os
os.environ['HF_ENDPOINT'] = "https://hf-mirror.com"
device = "cuda:0"
# modelname = "allenai/Llama-3.1-8B-Instruct-RM-RB2"
cnt = 0
# Load model
modelname = "Ray2333/GRM-llama3-8B-distill"


model = AutoModelForSequenceClassification.from_pretrained(
    modelname,
    torch_dtype=torch.bfloat16,
    device_map=device,
    attn_implementation="flash_attention_2",
    num_labels=1,
    use_safetensors=True,
)
tokenizer = AutoTokenizer.from_pretrained(modelname)

def calculate_reward(prompt, response):
    """Calculate reward score for a single response"""
    # Create model input format: [PROMPT] + [RESPONSE]
    conv = [
        {"role": "user", "content": prompt}, 
        {"role": "assistant", "content": response}
    ]
    
    # Format and tokenize the conversations
    conv_formatted = tokenizer.apply_chat_template(conv, tokenize=False)
    
    # These two lines remove the potential duplicate bos token
    if tokenizer.bos_token is not None and conv_formatted.startswith(tokenizer.bos_token):
        conv_formatted = conv_formatted[len(tokenizer.bos_token):]
    
    conv_tokenized = tokenizer(conv_formatted, return_tensors="pt").to(device)
    
    with torch.no_grad():
        outputs = model(**conv_tokenized)
        return outputs.logits[0][0].item()

mi = 1e9
mx = -1e9

def process_sample(sample, op: int):
    """
    Process a single sample and return statistical results including "quantile difference (P90-P10)".
    The parameter op is consistent with the old code but not used here; you can extend it as needed.
    """
    prompt = sample["prompt"]
    rewards = []

    # 1. Original response
    rewards.append(calculate_reward(prompt, sample["original_response"]))

    # 2. All valid paraphrases
    for para in sample.get("paraphrases", []):
        if para.get("status") == "valid":
            rewards.append(calculate_reward(prompt, para["text"]))


    rewards = np.array(rewards, dtype=float)

    if op == 0:
        global mi, mx
        for i in range(len(rewards)):
            if mx != mi:
                # reward = rewards[i]
                rewards[i] = 2 * (rewards[i] - mi) / (mx - mi) - 1 
        return {
            "id": sample["id"],
            "rewards": rewards.tolist(),
            "variance": float(np.var(rewards)),   # Use quantile difference instead of original variance
            "mean_reward": float(np.mean(rewards)),
            "min_reward": float(np.min(rewards)),
            "max_reward": float(np.max(rewards)),
            "num_responses": len(rewards),
            "rvariance": float(np.var(rewards))   # Original variance retained in rvariance
        }
    # 3. Quantile difference (P90 - P10)
    p90, p10 = np.percentile(rewards, [90, 10])
    quantile_spread = p90 - p10

    # 4. Optional: Standardization (relative indicator)
    # Uncomment the following line if you want to divide by overall standard deviation or mean
    quantile_spread /= np.std(rewards)   # or /= np.mean(rewards)

    if op == 1:
        print(rewards)
        print(p90, p10)
        print(quantile_spread)
        print(np.std(rewards))
    
    
    if op == 2:
        mi = min(mi, np.min(rewards))
        mx = max(mx, np.max(rewards))
        
    return {
        "id": sample["id"],
        "rewards": rewards.tolist(),
        "variance": float(quantile_spread),   # Use quantile difference instead of original variance
        "mean_reward": float(np.mean(rewards)),
        "min_reward": float(np.min(rewards)),
        "max_reward": float(np.max(rewards)),
        "num_responses": len(rewards),
        "rvariance": float(np.var(rewards))   # Original variance retained in rvariance
    }

if __name__ == "__main__":
    file_paths = []
    all_variances = []
    for i, file_path in enumerate(file_paths):
        print(f"\nProcessing file {i+1}/{len(file_paths)}: {os.path.basename(file_path)}")
        
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                data = json.load(f)
        except Exception as e:
            print(f"Failed to read file: {e}")
            continue
        
        file_variances = []
        original_best_count = 0
        processed_count = 0
        
        samples = data.get('data', [])
        print(f"Total number of samples: {len(samples)}")
        
        for j, sample in enumerate(samples):
            if sample.get("valid_paraphrases_count", 0) >= 3:
                if j % 10 == 0:  # Show progress every 10 samples
                    print(f"Processing progress: {j+1}/{len(samples)}")
                result = process_sample(sample, 2)
        
        for j, sample in enumerate(samples):
            if sample.get("valid_paraphrases_count", 0) >= 3:
                if j % 10 == 0:  # Show progress every 10 samples
                    print(f"Processing progress: {j+1}/{len(samples)}")
                    
                
                result = process_sample(sample, 0)
                if result is not None:
                    processed_count += 1
                    file_variances.append(result['variance'])
                    
                if j % 10 == 0:
                    print(f"var = {result['variance']}")
                    print(f"rvar = {result['rvariance']}")
                
        
        if file_variances:
            mean_var = np.mean(file_variances)
            print(f"File {i+1} results:")
            print(f"  Number of processed samples: {processed_count}")
            print(f"  Average variance: {mean_var:.4f}")
            print(f"  Original response is best: {original_best_count}/{processed_count} ({original_best_count/processed_count*100:.1f}%)")
            all_variances.append(mean_var)
        else:
            print(f"File {i+1} has no valid results")
            all_variances.append(0.0)
    
    # Final results
    print(f"Variances per file: {[f'{v:.4f}' for v in all_variances]}")
    if any(v > 0 for v in all_variances):
        valid_vars = [v for v in all_variances if v > 0]
        print(f"Overall average variance: {np.mean(valid_vars):.4f}")