import numpy as np
from transformers import AutoTokenizer, AutoModelForSequenceClassification, BitsAndBytesConfig
import torch
import json
import os
from datetime import datetime
import gc


class RewardModelEvaluator:
    def __init__(self, model_config):
        
        self.model_name = model_config["name"]
        self.model_path = model_config["path"]
        self.model = None
        self.tokenizer = None
        self.device = "cuda:0"
        
    def load_model(self):
        try:
            quantization_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=torch.bfloat16,
                bnb_4bit_use_double_quant=True,
            )
            
            self.model = AutoModelForSequenceClassification.from_pretrained(
                self.model_path,
                quantization_config=quantization_config, 
                device_map="auto",  
                torch_dtype=torch.bfloat16,
                attn_implementation="flash_attention_2",
                use_safetensors=True,
                trust_remote_code=True,
                low_cpu_mem_usage=True,  
            )
            
            self.tokenizer = AutoTokenizer.from_pretrained(
                self.model_path, 
                trust_remote_code=True,
                use_fast=True 
            )
            if self.tokenizer.pad_token is None:
                self.tokenizer.pad_token = self.tokenizer.eos_token
                
            return True
        except Exception as e:
            return False
    
    def unload_model(self):

        if self.model is not None:
            del self.model
            self.model = None
        if self.tokenizer is not None:
            del self.tokenizer
            self.tokenizer = None
        torch.cuda.empty_cache()
        gc.collect()

    
    def calculate_reward_batch(self, texts, batch_size=4):
        all_rewards = []
        
        for i in range(0, len(texts), batch_size):
            batch_texts = texts[i:i+batch_size]
            

            inputs = self.tokenizer(
                batch_texts,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=1024,  
                return_attention_mask=True
            ).to(self.device)
            
            with torch.no_grad():
                outputs = self.model(**inputs)
                rewards = outputs.logits.squeeze(-1).cpu().float().numpy()
                
                if rewards.ndim == 0:  
                    rewards = [float(rewards)]
                else:
                    rewards = rewards.tolist()
                    
                all_rewards.extend(rewards)

            del inputs, outputs
            torch.cuda.empty_cache()
            
        return all_rewards
    
    def calculate_reward_single(self, text):
        try:
            inputs = self.tokenizer(
                text,
                return_tensors="pt",
                truncation=True,
                max_length=1024,
                padding=False  
            ).to(self.device)
            
            with torch.no_grad():
                outputs = self.model(**inputs)
                logits = outputs.logits
                
                if logits.dim() > 1:
                    if logits.size(-1) == 1:
                        reward = logits.squeeze().cpu().float().item()
                    else:
                        reward = logits[0, -1].cpu().float().item()
                else:
                    reward = logits.cpu().float().item()
            

            del inputs, outputs
            torch.cuda.empty_cache()
            
            return reward
            
        except Exception as e:

            return 0.0
    
    def process_sample_streaming(self, sample):
        if sample.get("valid_paraphrases_count", 0) < 3:
            return None
            
        prompt = sample["prompt"]
        rewards = []
        
        text = f"[INST] {prompt} [/INST] {sample['original_response']}"
        rewards.append(self.calculate_reward_single(text))
        
        for para in sample.get("paraphrases", []):
            if para.get("status") == "valid":
                text = f"[INST] {prompt} [/INST] {para['text']}"
                rewards.append(self.calculate_reward_single(text))
        
        return {
            "id": sample["id"],
            "rewards": rewards,
            "num_responses": len(rewards)
        }
    
    def evaluate_file_streaming(self, file_path, samples):

        

        all_rewards = []
        valid_sample_ids = []
        
        for j, sample in enumerate(samples):
                
            sample_result = self.process_sample_streaming(sample)
            if sample_result is not None:
                all_rewards.extend(sample_result["rewards"])
                valid_sample_ids.append(sample["id"])

                if len(all_rewards) % 100 == 0:
                    gc.collect()
        
        if not all_rewards:
            print(f"[{self.model_name}] 无有效样�?")
            return self._create_empty_result(file_path)

        all_rewards = np.array(all_rewards, dtype=float)
        global_mean = np.mean(all_rewards)
        global_std = np.std(all_rewards)
        
        print(f"[{self.model_name}] {global_mean:.4f}{global_std:.4f}")
        print(f"[{self.model_name}] {len(all_rewards)}")
        

        del all_rewards
        gc.collect()

        print(f"[{self.model_name}] 第二次遍历：计算标准化结�?...")
        file_variances = []
        file_rvariances = []
        sample_results = []
        
        for j, sample in enumerate(samples):
            if sample["id"] not in valid_sample_ids:
                continue
                
            if j % 50 == 0:
                print(f"[{self.model_name}] 第二次遍历进?: {j+1}/{len(samples)}")
            
            sample_result = self.process_sample_streaming(sample)
            if sample_result is not None:
                rewards = np.array(sample_result["rewards"], dtype=float)
                if global_std > 0:
                    standardized_rewards = (rewards - global_mean) / global_std
                else:
                    standardized_rewards = rewards - global_mean
                
                if len(standardized_rewards) > 1:
                    p90, p10 = np.percentile(standardized_rewards, [90, 10])
                    quantile_spread = p90 - p10
                    rvariance = float(np.var(standardized_rewards))
                else:
                    quantile_spread = 0.0
                    rvariance = 0.0
                
                result = {
                    "id": sample_result["id"],
                    "variance": float(quantile_spread),
                    "mean_reward": float(np.mean(standardized_rewards)),
                    "min_reward": float(np.min(standardized_rewards)),
                    "max_reward": float(np.max(standardized_rewards)),
                    "num_responses": len(standardized_rewards),
                    "rvariance": rvariance
                }
                
                file_variances.append(result['variance'])
                file_rvariances.append(result['rvariance'])
                sample_results.append(result)
                
                if len(sample_results) % 50 == 0:
                    gc.collect()
        
        return {
            "file_name": os.path.basename(file_path),
            "processed_samples": len(sample_results),
            "global_mean": float(global_mean),
            "global_std": float(global_std),
            "mean_variance": float(np.mean(file_variances)) if file_variances else 0.0,
            "mean_rvariance": float(np.mean(file_rvariances)) if file_rvariances else 0.0,
            "sample_results": sample_results
        }
    
    def _create_empty_result(self, file_path):
        return {
            "file_name": os.path.basename(file_path),
            "processed_samples": 0,
            "global_mean": 0.0,
            "global_std": 0.0,
            "mean_variance": 0.0,
            "mean_rvariance": 0.0,
            "sample_results": []
        }
