import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
import numpy as np
from typing import Dict, List, Tuple
from models import BaseLanguageModel

class ValueNetwork(nn.Module):
    def __init__(self, hidden_size: int = 768):
        super().__init__()
        self.value_head = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.Linear(hidden_size // 2, 1)
        )
    
    def forward(self, hidden_states):
        return self.value_head(hidden_states.mean(dim=1))

class PPOTrainer:
    def __init__(self, model: BaseLanguageModel, training_config: TrainingConfig):
        
        self.model = model
        self.clip_epsilon = training_config.clip_epsilon
        self.value_loss_coef = training_config.value_loss_coef
        self.entropy_coef = training_config.entropy_coef
        self.max_grad_norm = training_config.max_grad_norm
        self.ppo_epochs = training_config.ppo_epochs
        self.gamma = training_config.gamma
        self.lam = training_config.lam
        
        # Add value network
        hidden_size = getattr(model.model.config, 'hidden_size', 768)
        self.value_network = ValueNetwork(hidden_size).to(model.device)
        
        # Optimizers
        self.policy_optimizer = torch.optim.AdamW(model.model.parameters(), lr=training_config.learning_rate)
        self.value_optimizer = torch.optim.AdamW(self.value_network.parameters(), lr=training_config.learning_rate * 3)
        
        self.training_stats = {
            'policy_loss': [],
            'value_loss': [],
            'entropy_loss': [],
            'total_loss': [],
            'clip_fraction': []
        }
    
    def compute_advantages(self, rewards: torch.Tensor, values: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        
        advantages = torch.zeros_like(rewards)
        returns = torch.zeros_like(rewards)
        
        last_gae = 0
        last_return = 0
        
        for t in reversed(range(len(rewards))):
            if t == len(rewards) - 1:
                next_value = 0
            else:
                next_value = values[t + 1]
            
            delta = rewards[t] + self.gamma * next_value - values[t]
            advantages[t] = delta + self.gamma * self.lam * last_gae
            returns[t] = rewards[t] + self.gamma * last_return
            
            last_gae = advantages[t]
            last_return = returns[t]
        
        # Normalize advantages
        if advantages.std() > 1e-8:
            advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
        
        return advantages, returns
    
    def get_action_log_probs(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, 
                           action_ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        
        with torch.no_grad():
            outputs = self.model.model(input_ids=input_ids, attention_mask=attention_mask, 
                                     output_hidden_states=True)
            logits = outputs.logits
            hidden_states = outputs.hidden_states[-1]
        
        # Get log probabilities for actions
        log_probs = F.log_softmax(logits, dim=-1)
        
        # Select log probs for actual actions taken
        action_log_probs = torch.gather(log_probs, -1, action_ids.unsqueeze(-1)).squeeze(-1)
        
        # Compute entropy
        probs = F.softmax(logits, dim=-1)
        entropy = -torch.sum(probs * log_probs, dim=-1)
        
        return action_log_probs.mean(dim=-1), entropy.mean(dim=-1), hidden_states
    
    def ppo_update(self, rollout_batch: List[Dict]) -> Dict[str, float]:
        
        total_policy_loss = 0
        total_value_loss = 0
        total_entropy_loss = 0
        total_clip_fraction = 0
        num_updates = 0
        
        for epoch in range(self.ppo_epochs):
            for rollout in rollout_batch:
                
                # Prepare batch data
                input_ids = rollout['input_ids'].unsqueeze(0).to(self.model.device)
                attention_mask = rollout['attention_mask'].unsqueeze(0).to(self.model.device)
                action_ids = rollout['action_ids'].to(self.model.device)
                old_log_probs = rollout['old_log_probs'].to(self.model.device)
                rewards = rollout['rewards'].to(self.model.device)
                
                # Get current policy outputs
                new_log_probs, entropy, hidden_states = self.get_action_log_probs(
                    input_ids, attention_mask, action_ids
                )
                
                # Get value estimates
                values = self.value_network(hidden_states).squeeze(-1)
                
                # Ensure tensors have the same shape
                if values.dim() == 0:
                    values = values.unsqueeze(0)
                if rewards.dim() == 0:
                    rewards = rewards.unsqueeze(0)
                
                # Compute advantages and returns
                advantages, returns = self.compute_advantages(rewards, values)
                
                # PPO policy loss
                if old_log_probs.dim() == 0:
                    old_log_probs = old_log_probs.unsqueeze(0)
                if new_log_probs.dim() == 0:
                    new_log_probs = new_log_probs.unsqueeze(0)
                
                ratio = torch.exp(new_log_probs - old_log_probs)
                
                surrogate1 = ratio * advantages
                surrogate2 = torch.clamp(ratio, 1 - self.clip_epsilon, 1 + self.clip_epsilon) * advantages
                
                policy_loss = -torch.min(surrogate1, surrogate2).mean()
                
                # Value loss
                value_loss = F.mse_loss(values, returns)
                
                # Entropy loss
                if entropy.dim() == 0:
                    entropy = entropy.unsqueeze(0)
                entropy_loss = -entropy.mean()
                
                # Combined loss
                total_loss = (policy_loss + 
                             self.value_loss_coef * value_loss + 
                             self.entropy_coef * entropy_loss)
                
                # Policy update
                self.policy_optimizer.zero_grad()
                total_loss.backward(retain_graph=True)
                torch.nn.utils.clip_grad_norm_(self.model.model.parameters(), self.max_grad_norm)
                self.policy_optimizer.step()
                
                # Value update
                self.value_optimizer.zero_grad()
                value_loss.backward()
                torch.nn.utils.clip_grad_norm_(self.value_network.parameters(), self.max_grad_norm)
                self.value_optimizer.step()
                
                # Statistics
                clip_fraction = ((ratio > 1 + self.clip_epsilon) | (ratio < 1 - self.clip_epsilon)).float().mean()
                
                total_policy_loss += policy_loss.item()
                total_value_loss += value_loss.item()
                total_entropy_loss += entropy_loss.item()
                total_clip_fraction += clip_fraction.item()
                num_updates += 1
        
        # Store training stats
        if num_updates > 0:
            avg_policy_loss = total_policy_loss / num_updates
            avg_value_loss = total_value_loss / num_updates
            avg_entropy_loss = total_entropy_loss / num_updates
            avg_clip_fraction = total_clip_fraction / num_updates
        else:
            avg_policy_loss = avg_value_loss = avg_entropy_loss = avg_clip_fraction = 0.0
        
        self.training_stats['policy_loss'].append(avg_policy_loss)
        self.training_stats['value_loss'].append(avg_value_loss)
        self.training_stats['entropy_loss'].append(avg_entropy_loss)
        self.training_stats['clip_fraction'].append(avg_clip_fraction)
        
        return {
            'policy_loss': avg_policy_loss,
            'value_loss': avg_value_loss,
            'entropy_loss': avg_entropy_loss,
            'clip_fraction': avg_clip_fraction
        }
    
    def collect_rollout(self, prompts: List[str], max_new_tokens: int = 100) -> List[Dict]:
        
        rollout_data = []
        
        for prompt in prompts:
            try:
                # Tokenize prompt
                inputs = self.model.tokenizer(
                    prompt, return_tensors="pt", truncation=True, 
                    max_length=self.model.config.max_length, padding=True
                ).to(self.model.device)
                
                # Simple generation without tracking scores (for compatibility)
                with torch.no_grad():
                    outputs = self.model.model.generate(
                        **inputs,
                        max_new_tokens=max_new_tokens,
                        do_sample=True,
                        temperature=0.7,
                        pad_token_id=self.model.tokenizer.eos_token_id
                    )
                
                # Extract generated tokens
                generated_ids = outputs[0][inputs['input_ids'].shape[1]:]
                
                # Create a simple log probability estimate
                # In a real implementation, this would come from the generation process
                log_prob_estimate = torch.tensor(-2.0)  # Reasonable log prob estimate
                
                generated_text = self.model.tokenizer.decode(generated_ids, skip_special_tokens=True)
                
                rollout_data.append({
                    'input_ids': inputs['input_ids'][0],
                    'attention_mask': inputs['attention_mask'][0],
                    'action_ids': generated_ids,
                    'old_log_probs': log_prob_estimate,
                    'generated_text': generated_text
                })
                
            except Exception as e:
                print(f"Error in rollout collection: {e}")
                # Create fallback data
                empty_tensor = torch.tensor([], dtype=torch.long).to(self.model.device)
                rollout_data.append({
                    'input_ids': empty_tensor,
                    'attention_mask': empty_tensor,
                    'action_ids': empty_tensor,
                    'old_log_probs': torch.tensor(0.0),
                    'generated_text': ""
                })
        
        return rollout_data
    
    def get_training_stats(self) -> Dict[str, List[float]]:
        return self.training_stats.copy()