import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AdamW, get_linear_schedule_with_warmup

class GRPOS_Trainer:
    def __init__(self, policy_model, ref_model, shadow_deployer, dhmr_measurer, config):

        self.policy_model = policy_model
        self.ref_model = ref_model
        self.shadow = shadow_deployer
        self.measurer = dhmr_measurer
        self.config = config
        self.device = config['device']
        
        self.optimizer = AdamW(
            policy_model.parameters(),
            lr=config['lr'],
            weight_decay=config['weight_decay']
        )
        
        self.scheduler = get_linear_schedule_with_warmup(
            self.optimizer,
            num_warmup_steps=config['warmup_steps'],
            num_training_steps=config['total_steps']
        )
    
    def compute_advantage(self, rewards, hardness_coeff):
        rewards = torch.tensor(rewards)
        mean = rewards.mean()
        std = rewards.std() + 1e-8
        advantages = (rewards - mean) / std
        return advantages * hardness_coeff
    
    def kl_penalty(self, log_probs, ref_log_probs):

        kl_div = torch.exp(log_probs) * (log_probs - ref_log_probs)
        return kl_div.mean()
    
    def train_step(self, batch):
  
        prompts = [item['prompt'] for item in batch]
        responses = [self.policy_model.generate(p, **self.config['gen_params']) for p in prompts]
        
 
        demo_responses = [item['response'] for item in batch]
        alpha_D = self.measurer.compute_data_hardness(demo_responses, responses)
        
        rewards = []
        for category, prompt, resp in zip(batch['categories'], prompts, responses):
            r_demo = self.shadow.get_reward(category, [prompt], [demo_responses])
            r_gen = self.shadow.get_reward(category, [prompt], [resp])
            rewards.append(r_demo - r_gen)
        alpha_M = self.measurer.compute_model_responsiveness(rewards)
        
 
        hardness_coeff = alpha_D * alpha_M
        
        advantages = self.compute_advantage(rewards, hardness_coeff)
        
        log_probs = self.policy_model.get_log_probs(responses)
        with torch.no_grad():
            ref_log_probs = self.ref_model.get_log_probs(responses)
        

        ratios = torch.exp(log_probs - ref_log_probs)
        clipped_ratios = torch.clamp(ratios, 1-self.config['eps'], 1+self.config['eps'])
        
        policy_loss = -torch.min(ratios * advantages, clipped_ratios * advantages).mean()
        kl_penalty = self.config['beta'] * self.kl_penalty(log_probs, ref_log_probs)
        
        total_loss = policy_loss + kl_penalty
        
    
        self.optimizer.zero_grad()
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.policy_model.parameters(), 1.0)
        self.optimizer.step()
        self.scheduler.step()
        
        return total_loss.item()

    def train(self, train_dataset):
        dataloader = DataLoader(
            train_dataset,
            batch_size=self.config['batch_size'],
            shuffle=True
        )
        
        for epoch in range(self.config['epochs']):
            progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}")
            for batch in progress_bar:
                loss = self.train_step(batch)
                progress_bar.set_postfix({'loss': f"{loss:.4f}"})