import numpy as np
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import json
import os

device = "cuda:0"

model_configs = [
    "LxzGordon/URM-LLaMa-3.1-8B",
    "Skywork/Skywork-Reward-Llama-3.1-8B",
    "Ray2333/GRM-llama3-8B-distill"  
]

class RewardModel:
    def __init__(self, model_name):
        
        self.model_name = model_name
        print(f"{model_name}")
        if(model_name=="LxzGordon/URM-LLaMa-3.1-8B"):
            self.model = AutoModelForSequenceClassification.from_pretrained(
                model_name,
                torch_dtype=torch.bfloat16,
                device_map=device,
                attn_implementation="flash_attention_2",
                num_labels=10,
            )
        else:
            self.model = AutoModelForSequenceClassification.from_pretrained(
                model_name,
                torch_dtype=torch.bfloat16,
                device_map=device,
                attn_implementation="flash_attention_2",
                num_labels=1,
            )
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    def calculate_reward(self, prompt, response):
        conv = [
            {"role": "user", "content": prompt}, 
            {"role": "assistant", "content": response}
        ]
        
        conv_formatted = self.tokenizer.apply_chat_template(conv, tokenize=False)
        
        if self.tokenizer.bos_token is not None and conv_formatted.startswith(self.tokenizer.bos_token):
            conv_formatted = conv_formatted[len(self.tokenizer.bos_token):]
        
        conv_tokenized = self.tokenizer(conv_formatted, return_tensors="pt").to(device)
        
        with torch.no_grad():
            outputs = self.model(**conv_tokenized)
            return outputs.logits[0][0].item()

def collect_all_rewards_for_dataset(samples, reward_model):
    all_rewards = []
    
    for i, sample in enumerate(samples):
        if i % 5 == 0:
            print(f"{i+1}/{len(samples)}")
            
        prompt = sample["prompt"]
        
        reward = reward_model.calculate_reward(prompt, sample["original_response"])
        all_rewards.append(reward)
        for para in sample.get("paraphrases", []):
            if para.get("status") == "valid":
                reward = reward_model.calculate_reward(prompt, para["text"])
                all_rewards.append(reward)
    
    return np.array(all_rewards)

def process_sample_with_global_stats(sample, reward_model, global_mean, global_std):
    prompt = sample["prompt"]
    raw_rewards = []
    normalized_rewards = []

    reward = reward_model.calculate_reward(prompt, sample["original_response"])
    raw_rewards.append(reward)
    normalized_rewards.append((reward - global_mean) / global_std)

    for para in sample.get("paraphrases", []):
        if para.get("status") == "valid":
            reward = reward_model.calculate_reward(prompt, para["text"])
            raw_rewards.append(reward)
            normalized_rewards.append((reward - global_mean) / global_std)

    normalized_rewards = np.array(normalized_rewards, dtype=float)

    p90, p10 = np.percentile(normalized_rewards, [90, 10])
    quantile_spread = p90 - p10

    return {
        "id": sample["id"],
        "raw_rewards": raw_rewards,
        "normalized_rewards": normalized_rewards.tolist(),
        "variance": float(quantile_spread),  
        "mean_reward": float(np.mean(normalized_rewards)),
        "min_reward": float(np.min(normalized_rewards)),
        "max_reward": float(np.max(normalized_rewards)),
        "num_responses": len(normalized_rewards),
        "std_variance": float(np.var(normalized_rewards))  
    }

def process_file_with_multiple_models(file_path, target_samples=200):

    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
    except Exception as e:
        return None
    
    samples = data.get('data', [])
    valid_samples = [s for s in samples if s.get("valid_paraphrases_count", 0) >= 3]
    
    print(f"{len(samples)}, {len(valid_samples)}")
    
    if len(valid_samples) == 0:
        return None
    
    model_rankings = []
    model_results = []
    
    for i, model_config in enumerate(model_configs):
        print(f"\n{i+1}/{len(model_configs)}: {model_config}")
        
        reward_model = RewardModel(model_config)

        all_rewards = collect_all_rewards_for_dataset(valid_samples, reward_model)
        global_mean = np.mean(all_rewards)
        global_std = np.std(all_rewards)
        
        print(f"{global_mean:.4f}, {global_std:.4f}")
        

        sample_results = []
        for j, sample in enumerate(valid_samples):
            if j % 50 == 0:
                print(f"{j+1}/{len(valid_samples)}")
            
            result = process_sample_with_global_stats(sample, reward_model, global_mean, global_std)
            sample_results.append((j, result))
        
        model_results.append(sample_results)
        
        sample_variances = [(j, result['variance']) for j, result in sample_results]
        sample_variances.sort(key=lambda x: x[1], reverse=True)
        rankings = {idx: rank for rank, (idx, _) in enumerate(sample_variances)}
        model_rankings.append(rankings)
        
        variances = [result['variance'] for _, result in sample_results]
        print(f"{i+1} {np.mean(variances):.4f}")
        print(f"{i+1} {np.min(variances):.4f}, {np.max(variances):.4f}]")
        
        del reward_model
        torch.cuda.empty_cache()
    
    total_rankings = []
    for i in range(len(valid_samples)):
        total_rank = sum(model_rankings[j][i] for j in range(len(model_configs)))
        total_rankings.append((i, total_rank))
    
    total_rankings.sort(key=lambda x: x[1])
    selected_indices = [idx for idx, _ in total_rankings[:min(target_samples, len(total_rankings))]]
    selected_samples = []
    for idx in selected_indices:
        sample = valid_samples[idx].copy()

        sample['rm_results'] = {}
        for j, model_config in enumerate(model_configs):
            model_name = model_config.split('/')[-1]
            sample['rm_results'][model_name] = {
                'variance': model_results[j][idx][1]['variance'],
                'ranking': model_rankings[j][idx]
            }
        sample['total_ranking'] = sum(model_rankings[j][idx] for j in range(len(model_configs)))
        selected_samples.append(sample)
    
    print(f"\n{len(selected_samples)} ")
    
    return {
        'data': selected_samples,
        'metadata': {
            'original_count': len(samples),
            'valid_count': len(valid_samples),
            'selected_count': len(selected_samples),
            'models_used': model_configs,
            'selection_method': 'sum_of_rankings'
        }
    }

if __name__ == "__main__":
    file_paths = []
    
    for i, file_path in enumerate(file_paths):
        print(f"\n{'='*60}")
        print(f"{i+1}/{len(file_paths)}: {os.path.basename(file_path)}")
        print(f"{'='*60}")
        
        result = process_file_with_multiple_models(file_path, target_samples=200)
        
        if result:
            base_name = os.path.splitext(file_path)[0]
            output_path = f"{base_name}_filtered.json"
            
            with open(output_path, 'w', encoding='utf-8') as f:
                json.dump(result, f, ensure_ascii=False, indent=2)
            
            print(f"\n{output_path}")
            print(f"{result['metadata']['original_count']}")
            print(f"{result['metadata']['valid_count']}")  
            print(f"{result['metadata']['selected_count']}")
        else:
            print(f"{file_path}")
    