import numpy as np
import torch
import json
import os
from transformers import AutoTokenizer, AutoModelForSequenceClassification

device = "auto"  # Use device_map='auto' for automatic device allocation
modelname = "PKU-Alignment/beaver-7b-v2.0-reward"

# Initialize tokenizer and model
print("Loading model...")
try:
    # First attempt to use AutoModelForSequenceClassification
    model = AutoModelForSequenceClassification.from_pretrained(
        modelname, 
        torch_dtype=torch.bfloat16, 
        device_map=device,
        trust_remote_code=True
    )
    tokenizer = AutoTokenizer.from_pretrained(modelname, trust_remote_code=True)
    print("Model loaded successfully")
except Exception as e:
    print(f"Failed to load model: {e}")
    print("Please try method 2: Use the complete independent implementation")
    exit(1)

def calculate_reward(prompt, response):
    """Calculate reward score for a single response"""
    # Construct conversation format
    conversation = f"BEGINNING OF CONVERSATION: USER: {prompt} ASSISTANT: {response}"
    
    # Encode input
    inputs = tokenizer(conversation, return_tensors='pt', truncation=True, max_length=2048)
    
    # If model is on GPU, ensure inputs are on the same device
    if next(model.parameters()).is_cuda:
        inputs = {k: v.cuda() for k, v in inputs.items()}
    
    # Get reward score
    with torch.no_grad():
        outputs = model(**inputs)
        
        # Correct extraction method for Beaver reward model
        if hasattr(outputs, 'end_scores'):
            # end_scores is the reward score we need
            reward = outputs.end_scores[0, 0].item()  # Extract scalar from tensor([[value]])
        elif hasattr(outputs, 'logits'):
            logits = outputs.logits
            if logits.dim() == 2:  # [batch_size, num_classes]
                reward = logits[0, 0].item()
            elif logits.dim() == 1:  # [num_classes]
                reward = logits[0].item()
            else:
                reward = float(logits.mean())
        else:
            # Fallback method
            if hasattr(outputs, '__getitem__') and len(outputs) > 0:
                first_output = outputs[0]
                if hasattr(first_output, 'numel') and first_output.numel() > 1:
                    reward = first_output.flatten()[0].item()
                else:
                    reward = first_output.item() if hasattr(first_output, 'item') else float(first_output)
            else:
                reward = float(outputs)
    
    return reward

def process_sample(sample):
    """Process a single sample, calculate reward scores and variance for all valid responses"""
    prompt = sample["prompt"]
    rewards = []
    
    try:
        # 1. Calculate reward for original response
        original_reward = calculate_reward(prompt, sample["original_response"])
        rewards.append(original_reward)
        
        # 2. Calculate rewards for all valid paraphrases
        for para in sample["paraphrases"]:
            if para["status"] == "valid":
                para_reward = calculate_reward(prompt, para["text"])
                rewards.append(para_reward)
        
        # 3. Normalize reward scores to [-1,1] range
        if len(rewards) > 1:
            mi = min(rewards)
            mx = max(rewards)
            if mx != mi:
                for i in range(len(rewards)):
                    rewards[i] = 2 * (rewards[i] - mi) / (mx - mi) - 1 
        
        # 4. Calculate statistical indicators
        variance = np.var(rewards)
        mean_reward = np.mean(rewards)
        min_reward = min(rewards)
        max_reward = max(rewards)
        
        return {
            "id": sample["id"],
            "rewards": rewards,
            "variance": variance,
            "mean_reward": mean_reward,
            "min_reward": min_reward,
            "max_reward": max_reward,
            "num_responses": len(rewards)
        }
    
    except Exception as e:
        print(f"Error processing sample {sample.get('id', 'unknown')}: {e}")
        return None


if __name__ == "__main__":
    file_paths = []
    all_variances = []
    
    for i, file_path in enumerate(file_paths):
        print(f"\nProcessing file {i+1}/{len(file_paths)}: {os.path.basename(file_path)}")
        
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                data = json.load(f)
        except Exception as e:
            print(f"Failed to read file: {e}")
            continue
        
        file_variances = []
        original_best_count = 0
        processed_count = 0
        
        samples = data.get('data', [])
        print(f"Total number of samples: {len(samples)}")
        
        for j, sample in enumerate(samples):
            if sample.get("valid_paraphrases_count", 0) >= 3:
                if j % 10 == 0:  # Show progress every 10 samples
                    print(f"Processing progress: {j+1}/{len(samples)}")
                
                result = process_sample(sample)
                if result is not None:
                    processed_count += 1
                    file_variances.append(result['variance'])
                    
                    print(f"Sample {result['id']} reward scores: {result['rewards']}")
                    print(f"Variance: {result['variance']:.4f}")
        
        if file_variances:
            mean_var = np.mean(file_variances)
            print(f"File {i+1} results:")
            print(f"  Number of processed samples: {processed_count}")
            print(f"  Average variance: {mean_var:.4f}")
            print(f"  Original response is best: {original_best_count}/{processed_count} ({original_best_count/processed_count*100:.1f}%)")
            all_variances.append(mean_var)
        else:
            print(f"File {i+1} has no valid results")
            all_variances.append(0.0)
    
    # Final results
    print(f"\n=== Final Results ===")
    print(f"Variances per file: {[f'{v:.4f}' for v in all_variances]}")
    if any(v > 0 for v in all_variances):
        valid_vars = [v for v in all_variances if v > 0]
        print(f"Overall average variance: {np.mean(valid_vars):.4f}")