import numpy as np
from transformers import AutoTokenizer, pipeline
import torch
import json
import os
from datetime import datetime
import gc


# Define models
MODELS = [
    {
        "name": "RM-Gemma-7B",
        "path": "weqweasdas/RM-Gemma-7B"
    }
]

class RewardModelEvaluator:
    def __init__(self, model_config):
        self.model_name = model_config["name"]
        self.model_path = model_config["path"]
        self.rm_pipe = None
        self.rm_tokenizer = None
        self.device = 0  # Use GPU 0
        
    def load_model(self):
        """Load the model using pipeline"""
        print(f"Loading model: {self.model_path}")
        try:
            # Load tokenizer
            self.rm_tokenizer = AutoTokenizer.from_pretrained(
                self.model_path,
                use_fast=True
            )
            if self.rm_tokenizer.pad_token is None:
                self.rm_tokenizer.pad_token = self.rm_tokenizer.eos_token
            
            # Create pipeline
            self.rm_pipe = pipeline(
                "sentiment-analysis",  # Or "text-classification"
                model=self.model_path,
                device=self.device,
                tokenizer=self.rm_tokenizer,
                model_kwargs={
                    "dtype": torch.bfloat16,
                    "low_cpu_mem_usage": True
                },
                trust_remote_code=True
            )
            
            print(f"Model {self.model_name} loaded successfully")
            return True
            
        except Exception as e:
            print(f"Failed to load model {self.model_name}: {e}")
            return False
    
    
    def calculate_reward_batch(self, texts, batch_size=4):
        """Calculate reward scores in batches"""
        all_rewards = []
        
        # Pipeline configuration
        pipe_kwargs = {
            "return_all_scores": True,
            "function_to_apply": "none",
            "batch_size": batch_size,
            "truncation": True,
            "max_length": 1024
        }
        
        try:
            # Process in batches
            pipe_outputs = self.rm_pipe(texts, **pipe_kwargs)
            
            # Extract scores
            for output in pipe_outputs:
                if isinstance(output, list) and len(output) > 0:
                    # Usually take the score of the first output
                    reward = output[0]["score"]
                else:
                    reward = output["score"] if "score" in output else 0.0
                all_rewards.append(float(reward))
            
            return all_rewards
            
        except Exception as e:
            print(f"Failed to calculate rewards in batch: {e}")
            # Fallback to single text processing
            return [self.calculate_reward_single(text) for text in texts]
    
    def calculate_reward_single(self, text):
        """Calculate reward score for a single text"""
        try:
            pipe_kwargs = {
                "return_all_scores": True,
                "function_to_apply": "none",
                "batch_size": 1,
                "truncation": True,
                "max_length": 1024
            }
            
            pipe_outputs = self.rm_pipe([text], **pipe_kwargs)
            
            if isinstance(pipe_outputs[0], list) and len(pipe_outputs[0]) > 0:
                reward = pipe_outputs[0][0]["score"]
            else:
                reward = pipe_outputs[0]["score"] if "score" in pipe_outputs[0] else 0.0
                
            return float(reward)
            
        except Exception as e:
            print(f"Failed to calculate reward: {e}")
            return 0.0
    
    def format_text_for_model(self, prompt, response):
        """Format text into the format required by the model"""
        # Construct chat format
        chat = [
            {"role": "user", "content": prompt},
            {"role": "assistant", "content": response}
        ]
        
        # Apply chat template
        try:
            formatted_text = self.rm_tokenizer.apply_chat_template(
                chat, 
                tokenize=False, 
                add_generation_prompt=False
            )
            # Remove bos_token if it exists
            if self.rm_tokenizer.bos_token:
                formatted_text = formatted_text.replace(self.rm_tokenizer.bos_token, "")
            return formatted_text
        except Exception as e:
            print(f"Failed to format text: {e}")
            # Fallback to simple format
            return f"[INST] {prompt} [/INST] {response}"
    
    def process_sample_streaming(self, sample):
        """Process a single sample in a streaming manner"""
        if sample.get("valid_paraphrases_count", 0) < 3:
            return None
            
        prompt = sample["prompt"]
        texts_to_evaluate = []
        
        # 1. Original response
        original_text = self.format_text_for_model(prompt, sample['original_response'])
        texts_to_evaluate.append(original_text)
        
        # 2. All valid paraphrases
        for para in sample.get("paraphrases", []):
            if para.get("status") == "valid":
                para_text = self.format_text_for_model(prompt, para['text'])
                texts_to_evaluate.append(para_text)
        
        # Calculate rewards in batches
        rewards = self.calculate_reward_batch(texts_to_evaluate)
        
        return {
            "id": sample["id"],
            "rewards": rewards,
            "num_responses": len(rewards)
        }
    
    def evaluate_file_streaming(self, file_path, samples):
        """Evaluate file in a streaming manner"""
        print(f"\n[{self.model_name}] Processing file: {os.path.basename(file_path)}")
        print(f"[{self.model_name}] Total number of samples: {len(samples)}")
        
        # First iteration: Calculate global statistics
        print(f"[{self.model_name}] First iteration: Calculating global statistics...")
        all_rewards = []
        valid_sample_ids = []
        
        for j, sample in enumerate(samples):
            if j % 50 == 0:
                print(f"[{self.model_name}] First iteration progress: {j+1}/{len(samples)}")
                
            sample_result = self.process_sample_streaming(sample)
            if sample_result is not None:
                all_rewards.extend(sample_result["rewards"])
                valid_sample_ids.append(sample["id"])
                
                # Regularly clean up memory
                if len(all_rewards) % 100 == 0:
                    gc.collect()
        
        if not all_rewards:
            print(f"[{self.model_name}] No valid samples")
            return self._create_empty_result(file_path)
        
        # Calculate global statistics
        all_rewards = np.array(all_rewards, dtype=float)
        global_mean = np.mean(all_rewards)
        global_std = np.std(all_rewards)
        
        print(f"[{self.model_name}] Global statistics: Mean={global_mean:.4f}, Std={global_std:.4f}")
        print(f"[{self.model_name}] Total number of rewards: {len(all_rewards)}")
        
        # Clean up large array
        del all_rewards
        gc.collect()
        
        # Second iteration: Calculate standardized results
        print(f"[{self.model_name}] Second iteration: Calculating standardized results...")
        file_variances = []
        file_rvariances = []
        sample_results = []
        
        for j, sample in enumerate(samples):
            if sample["id"] not in valid_sample_ids:
                continue
                
            if j % 50 == 0:
                print(f"[{self.model_name}] Second iteration progress: {j+1}/{len(samples)}")
            
            sample_result = self.process_sample_streaming(sample)
            if sample_result is not None:
                # Standardize rewards
                rewards = np.array(sample_result["rewards"], dtype=float)
                if global_std > 0:
                    standardized_rewards = (rewards - global_mean) / global_std
                else:
                    standardized_rewards = rewards - global_mean
                
                # Calculate variance metrics
                if len(standardized_rewards) > 1:
                    p90, p10 = np.percentile(standardized_rewards, [90, 10])
                    quantile_spread = p90 - p10
                    rvariance = float(np.var(standardized_rewards))
                else:
                    quantile_spread = 0.0
                    rvariance = 0.0
                
                result = {
                    "id": sample_result["id"],
                    "variance": float(quantile_spread),
                    "mean_reward": float(np.mean(standardized_rewards)),
                    "min_reward": float(np.min(standardized_rewards)),
                    "max_reward": float(np.max(standardized_rewards)),
                    "num_responses": len(standardized_rewards),
                    "rvariance": rvariance
                }
                
                file_variances.append(result['variance'])
                file_rvariances.append(result['rvariance'])
                sample_results.append(result)
                
                # Regular cleanup
                if len(sample_results) % 50 == 0:
                    gc.collect()
        
        return {
            "file_name": os.path.basename(file_path),
            "processed_samples": len(sample_results),
            "global_mean": float(global_mean),
            "global_std": float(global_std),
            "mean_variance": float(np.mean(file_variances)) if file_variances else 0.0,
            "mean_rvariance": float(np.mean(file_rvariances)) if file_rvariances else 0.0,
            "sample_results": sample_results
        }
    
    def _create_empty_result(self, file_path):
        """Create empty result"""
        return {
            "file_name": os.path.basename(file_path),
            "processed_samples": 0,
            "global_mean": 0.0,
            "global_std": 0.0,
            "mean_variance": 0.0,
            "mean_rvariance": 0.0,
            "sample_results": []
        }

def load_file_lazily(file_path):
    """Lazily load the file to avoid loading all data into memory at once"""
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
            samples = data.get('data', [])
            print(f"Loaded file: {os.path.basename(file_path)}, Number of samples: {len(samples)}")
            return samples
    except Exception as e:
        print(f"Failed to read file {file_path}: {e}")
        return []

def main():
    file_paths = []
    
    # Store all results
    all_results = {
        "timestamp": datetime.now().isoformat(),
        "models": {},
        "summary": {}
    }
    
    # Process each model
    for model_idx, model_config in enumerate(MODELS):
        print(f"\n{'='*50}")
        print(f"Processing model {model_idx+1}/{len(MODELS)}: {model_config['name']}")
        print(f"{'='*50}")
        
        evaluator = RewardModelEvaluator(model_config)
        
        # Load model
        if not evaluator.load_model():
            print(f"Skipping model {model_config['name']}")
            continue
        
        model_results = {
            "model_name": model_config["name"],
            "model_path": model_config["path"],
            "files": {},
            "overall_stats": {}
        }
        
        all_variances = []
        all_rvariances = []
        
        # Process each file one by one
        for file_path in file_paths:
            print(f"\nProcessing file: {os.path.basename(file_path)}")
            
            # Lazily load current file
            samples = load_file_lazily(file_path)
            if not samples:
                continue
            
            # Process file in streaming mode
            file_result = evaluator.evaluate_file_streaming(file_path, samples)
            model_results["files"][os.path.basename(file_path)] = file_result
            
            if file_result["mean_variance"] > 0:
                all_variances.append(file_result["mean_variance"])
                all_rvariances.append(file_result["mean_rvariance"])
            
            # Clean up memory after processing a file
            del samples
            gc.collect()
            torch.cuda.empty_cache()
            
            print(f"File {os.path.basename(file_path)} processed successfully")
        
        # Calculate overall model statistics
        if all_variances:
            model_results["overall_stats"] = {
                "mean_variance_across_files": float(np.mean(all_variances)),
                "mean_rvariance_across_files": float(np.mean(all_rvariances)),
                "total_processed_files": len(all_variances)
            }
        
        # Save model results
        all_results["models"][model_config["name"]] = model_results
        
        print(f"\n[{model_config['name']}] Completed!")
        if all_variances:
            print(f"Average variance: {np.mean(all_variances):.4f}")
            print(f"Average rvariance: {np.mean(all_rvariances):.4f}")
    
    # Calculate cross-model comparison
    summary_stats = {}
    for model_name, model_data in all_results["models"].items():
        if "overall_stats" in model_data and model_data["overall_stats"]:
            summary_stats[model_name] = {
                "mean_variance": model_data["overall_stats"]["mean_variance_across_files"],
                "mean_rvariance": model_data["overall_stats"]["mean_rvariance_across_files"]
            }
    
    all_results["summary"] = {
        "model_comparison": summary_stats,
        "files_processed": [os.path.basename(f) for f in file_paths]
    }
    
    # Save results
    output_file = f"reward_model_evaluation_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
    try:
        with open(output_file, 'w', encoding='utf-8') as f:
            json.dump(all_results, f, ensure_ascii=False, indent=2)
        print(f"\nResults saved to: {output_file}")
    except Exception as e:
        print(f"Failed to save results: {e}")
    
    # Print final summary
    print(f"\n{'='*60}")
    print("Final Summary")
    print(f"{'='*60}")
    for model_name, stats in summary_stats.items():
        print(f"{model_name}:")
        print(f"  Average variance: {stats['mean_variance']:.4f}")
        print(f"  Average rvariance: {stats['mean_rvariance']:.4f}")

if __name__ == "__main__":
    main()