# -*- coding: utf-8 -*-
import numpy as np
from transformers import AutoTokenizer, AutoModelForSequenceClassification, BitsAndBytesConfig
import torch
import json
import os
from datetime import datetime
import gc

# Set a mirror for Hugging Face, if needed. For submission, you might want to remove this.
os.environ['HF_ENDPOINT'] = "https://hf-mirror.com"
device = "cuda:0"

# Define the models to be evaluated
MODELS = [
    {
        "name":"Skywork-Reward-V2-Qwen3-8B",
        "path":"Skywork/Skywork-Reward-V2-Qwen3-8B"
    }
]

class RewardModelEvaluator:
    def __init__(self, model_config):
        self.model_name = model_config["name"]
        self.model_path = model_config["path"]
        self.model = None
        self.tokenizer = None
        
    def load_model(self):
        print(f"Loading model: {self.model_path}")
        try:
            quantization_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=torch.bfloat16,
                bnb_4bit_use_double_quant=True,
            )
            
            self.model = AutoModelForSequenceClassification.from_pretrained(
                self.model_path,
                quantization_config=quantization_config,  # 4-bit quantization
                device_map="auto",                       # Auto-assign device
                torch_dtype=torch.bfloat16,
                attn_implementation="flash_attention_2",
                use_safetensors=True,
                trust_remote_code=True,
                low_cpu_mem_usage=True,                  # Reduce CPU memory usage
            )
            
            self.tokenizer = AutoTokenizer.from_pretrained(
                self.model_path, 
                trust_remote_code=True,
                use_fast=True                            # Use fast tokenizer
            )
            if self.tokenizer.pad_token is None:
                self.tokenizer.pad_token = self.tokenizer.eos_token
                
            print(f"Model {self.model_name} loaded successfully")
            return True
        except Exception as e:
            print(f"Failed to load model {self.model_name}: {e}")
            return False
    
    def unload_model(self):
        """Unload model to free memory"""
        if self.model is not None:
            del self.model
            self.model = None
        if self.tokenizer is not None:
            del self.tokenizer
            self.tokenizer = None
        torch.cuda.empty_cache()
        gc.collect()
        print(f"Model {self.model_name} has been unloaded.")
    
    def calculate_reward_batch(self, texts, batch_size=4):
        """Calculate reward scores in batch - reduces inference calls"""
        all_rewards = []
        
        for i in range(0, len(texts), batch_size):
            batch_texts = texts[i:i+batch_size]
            
            # Batch tokenization
            inputs = self.tokenizer(
                batch_texts,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=1024,  # Reduced to 1024 to lower VRAM usage
                return_attention_mask=True
            ).to(device)
            
            with torch.no_grad():
                outputs = self.model(**inputs)
                rewards = outputs.logits.squeeze(-1).cpu().float().numpy()
                
                if rewards.ndim == 0:  # single sample
                    rewards = [float(rewards)]
                else:
                    rewards = rewards.tolist()
                    
                all_rewards.extend(rewards)
            
            # Clear VRAM
            del inputs, outputs
            torch.cuda.empty_cache()
            
        return all_rewards
    
    def calculate_reward_single(self, text):
        """Calculate reward score for a single instance - safest method"""
        try:
            inputs = self.tokenizer(
                text,
                return_tensors="pt",
                truncation=True,
                max_length=1024,
                padding=False  # Do not use padding
            ).to(device)
            
            with torch.no_grad():
                outputs = self.model(**inputs)
                logits = outputs.logits
                
                # Handle different output formats
                if logits.dim() > 1:
                    if logits.size(-1) == 1:
                        reward = logits.squeeze().cpu().float().item()
                    else:
                        # Take the logits of the last token, which is standard practice for reward models.
                        reward = logits[0, -1].cpu().float().item()
                else:
                    reward = logits.cpu().float().item()
            
            # Clear VRAM
            del inputs, outputs
            torch.cuda.empty_cache()
            
            return reward
            
        except Exception as e:
            print(f"Failed to calculate reward: {e}")
            return 0.0
    
    def process_sample_streaming(self, sample):
        """Stream-process a single sample, using individual processing to avoid padding issues"""
        if sample.get("valid_paraphrases_count", 0) < 3:
            return None
            
        prompt = sample["prompt"]
        rewards = []
        
        # 1. Original response
        text = f"[INST] {prompt} [/INST] {sample['original_response']}"
        rewards.append(self.calculate_reward_single(text))
        
        # 2. All valid paraphrases
        for para in sample.get("paraphrases", []):
            if para.get("status") == "valid":
                text = f"[INST] {prompt} [/INST] {para['text']}"
                rewards.append(self.calculate_reward_single(text))
        
        return {
            "id": sample["id"],
            "rewards": rewards,
            "num_responses": len(rewards)
        }
    
    def evaluate_file_streaming(self, file_path, samples):
        """Stream-evaluate a file - two passes without storing all intermediate nodes in memory"""
        print(f"\n[{self.model_name}] Processing file: {os.path.basename(file_path)}")
        print(f"[{self.model_name}] Total samples: {len(samples)}")
        
        # First pass: calculate global statistics...
        print(f"[{self.model_name}] First pass: calculating global statistics...")
        all_rewards = []
        valid_sample_ids = []
        
        for j, sample in enumerate(samples):
            if j % 50 == 0:
                print(f"[{self.model_name}] First pass progress: {j+1}/{len(samples)}")
                
            sample_result = self.process_sample_streaming(sample)
            if sample_result is not None:
                all_rewards.extend(sample_result["rewards"])
                valid_sample_ids.append(sample["id"])
                
                # Periodically clear memory
                if len(all_rewards) % 100 == 0:
                    gc.collect()
        
        if not all_rewards:
            print(f"[{self.model_name}] No valid samples found.")
            return self._create_empty_result(file_path)
        
        # Calculate global statistics
        all_rewards = np.array(all_rewards, dtype=float)
        global_mean = np.mean(all_rewards)
        global_std = np.std(all_rewards)
        
        print(f"[{self.model_name}] Global stats: mean={global_mean:.4f}, std={global_std:.4f}")
        print(f"[{self.model_name}] Total number of rewards: {len(all_rewards)}")
        
        # Clear large data structures
        del all_rewards
        gc.collect()
        
        # Second pass: calculate standardized results...
        print(f"[{self.model_name}] Second pass: calculating standardized results...")
        file_variances = []
        file_rvariances = []
        sample_results = []
        
        for j, sample in enumerate(samples):
            if sample["id"] not in valid_sample_ids:
                continue
                
            if j % 50 == 0:
                print(f"[{self.model_name}] Second pass progress: {j+1}/{len(samples)}")
            
            sample_result = self.process_sample_streaming(sample)
            if sample_result is not None:
                # Standardize rewards
                rewards = np.array(sample_result["rewards"], dtype=float)
                if global_std > 0:
                    standardized_rewards = (rewards - global_mean) / global_std
                else:
                    standardized_rewards = rewards - global_mean
                
                # Calculate variance metrics
                if len(standardized_rewards) > 1:
                    p90, p10 = np.percentile(standardized_rewards, [90, 10])
                    quantile_spread = p90 - p10
                    rvariance = float(np.var(standardized_rewards))
                else:
                    quantile_spread = 0.0
                    rvariance = 0.0
                
                result = {
                    "id": sample_result["id"],
                    "variance": float(quantile_spread),
                    "mean_reward": float(np.mean(standardized_rewards)),
                    "min_reward": float(np.min(standardized_rewards)),
                    "max_reward": float(np.max(standardized_rewards)),
                    "num_responses": len(standardized_rewards),
                    "rvariance": rvariance
                }
                
                file_variances.append(result['variance'])
                file_rvariances.append(result['rvariance'])
                sample_results.append(result)
                
                # Periodic cleanup
                if len(sample_results) % 50 == 0:
                    gc.collect()
        
        return {
            "file_name": os.path.basename(file_path),
            "processed_samples": len(sample_results),
            "global_mean": float(global_mean),
            "global_std": float(global_std),
            "mean_variance": float(np.mean(file_variances)) if file_variances else 0.0,
            "mean_rvariance": float(np.mean(file_rvariances)) if file_rvariances else 0.0,
            "sample_results": sample_results
        }
    
    def _create_empty_result(self, file_path):
        """Create an empty result"""
        return {
            "file_name": os.path.basename(file_path),
            "processed_samples": 0,
            "global_mean": 0.0,
            "global_std": 0.0,
            "mean_variance": 0.0,
            "mean_rvariance": 0.0,
            "sample_results": []
        }

def load_file_lazily(file_path):
    """Lazily load the file to avoid loading all data into memory at once"""
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
            samples = data.get('data', [])
            print(f"Loading file: {os.path.basename(file_path)}, Sample count: {len(samples)}")
            return samples
    except Exception as e:
        print(f"Failed to read file {file_path}: {e}")
        return []

def main():
    # Replace these with your actual file paths or a directory path
    file_paths = []
    
    # Store all results
    all_results = {
        "timestamp": datetime.now().isoformat(),
        "models": {},
        "summary": {}
    }
    
    # Process each model
    for model_idx, model_config in enumerate(MODELS):
        print(f"\n{'='*50}")
        print(f"Processing model {model_idx+1}/{len(MODELS)}: {model_config['name']}")
        print(f"{'='*50}")
        
        evaluator = RewardModelEvaluator(model_config)
        
        # Load model
        if not evaluator.load_model():
            print(f"Skipping model {model_config['name']}")
            continue
        
        model_results = {
            "model_name": model_config["name"],
            "model_path": model_config["path"],
            "files": {},
            "overall_stats": {}
        }
        
        all_variances = []
        all_rvariances = []
        
        # Process files one by one without pre-loading
        for file_path in file_paths:
            print(f"\nProcessing file: {os.path.basename(file_path)}")
            
            # Lazily load the current file
            samples = load_file_lazily(file_path)
            if not samples:
                continue
            
            # Stream-process the file
            file_result = evaluator.evaluate_file_streaming(file_path, samples)
            model_results["files"][os.path.basename(file_path)] = file_result
            
            if file_result["mean_variance"] > 0:
                all_variances.append(file_result["mean_variance"])
                all_rvariances.append(file_result["mean_rvariance"])
            
            # Clean up memory after processing a file
            del samples
            gc.collect()
            torch.cuda.empty_cache()
            
            print(f"File {os.path.basename(file_path)} processing complete")
        
        # Calculate overall model statistics
        if all_variances:
            model_results["overall_stats"] = {
                "mean_variance_across_files": float(np.mean(all_variances)),
                "mean_rvariance_across_files": float(np.mean(all_rvariances)),
                "total_processed_files": len(all_variances)
            }
        
        # Save model results
        all_results["models"][model_config["name"]] = model_results
        
        # Unload model to free memory
        evaluator.unload_model()
        
        print(f"\n[{model_config['name']}] Complete!")
        if all_variances:
            print(f"Average variance: {np.mean(all_variances):.4f}")
            print(f"Average rvariance: {np.mean(all_rvariances):.4f}")
    
    # Code for cross-model comparison and saving results remains unchanged...
    summary_stats = {}
    for model_name, model_data in all_results["models"].items():
        if "overall_stats" in model_data and model_data["overall_stats"]:
            summary_stats[model_name] = {
                "mean_variance": model_data["overall_stats"]["mean_variance_across_files"],
                "mean_rvariance": model_data["overall_stats"]["mean_rvariance_across_files"]
            }
    
    all_results["summary"] = {
        "model_comparison": summary_stats,
        "files_processed": [os.path.basename(f) for f in file_paths]
    }
    
    # Save results
    output_file = f"reward_model_evaluation_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
    try:
        with open(output_file, 'w', encoding='utf-8') as f:
            json.dump(all_results, f, ensure_ascii=False, indent=2)
        print(f"\nResults saved to: {output_file}")
    except Exception as e:
        print(f"Failed to save results: {e}")
    
    # Print final summary
    print(f"\n{'='*60}")
    print("Final Summary")
    print(f"{'='*60}")
    for model_name, stats in summary_stats.items():
        print(f"{model_name}:")
        print(f"  Average variance: {stats['mean_variance']:.4f}")
        print(f"  Average rvariance: {stats['mean_rvariance']:.4f}")

if __name__ == "__main__":
    main()