"""
Group Relative Policy Optimization (GRPO) Trainer.

GRPO is simpler than PPO - it doesn't require a critic network.
Instead, it uses group-relative baselines computed from the batch.

This implementation is based on DeepSeekMath and is well-suited
for code generation tasks.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from dataclasses import dataclass, field
from typing import Optional, List, Dict, Any, Tuple
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
import logging

from ..utils.megatron_bridge import initialize_cublas, with_cublas_retry

logger = logging.getLogger(__name__)


@dataclass
class GRPOConfig:
    """Configuration for GRPO algorithm."""
    learning_rate: float = 1e-5
    clip_epsilon: float = 0.2
    kl_coef: float = 0.1  # KL penalty coefficient
    entropy_coef: float = 0.01
    max_grad_norm: float = 1.0
    group_size: int = 8  # Number of completions per prompt for baseline
    normalize_advantages: bool = True
    use_kl_penalty: bool = True
    # LR scheduling
    use_lr_scheduler: bool = True
    warmup_ratio: float = 0.05  # 5% warmup
    min_lr: float = 1e-6  # Minimum learning rate for cosine annealing
    total_steps: int = 10000  # Total training steps (for scheduler)


@dataclass
class GRPOExperience:
    """Single experience for GRPO training."""
    prompt: str
    response: str
    reward: float
    old_log_prob: float
    difficulty: int = 0
    policy_version: int = 0


@dataclass
class GRPOBatch:
    """Batch of experiences for GRPO training."""
    prompts: List[str]
    responses: List[str]
    rewards: torch.Tensor
    old_log_probs: torch.Tensor
    difficulties: List[int]
    advantages: Optional[torch.Tensor] = None

    def to(self, device: torch.device) -> 'GRPOBatch':
        """Move tensors to device."""
        return GRPOBatch(
            prompts=self.prompts,
            responses=self.responses,
            rewards=self.rewards.to(device),
            old_log_probs=self.old_log_probs.to(device),
            difficulties=self.difficulties,
            advantages=self.advantages.to(device) if self.advantages is not None else None,
        )


class GRPOTrainer:
    """
    GRPO Trainer for code generation.

    Key differences from PPO:
    - No critic network (uses group-relative baseline)
    - Advantage = (reward - mean_reward) / std_reward
    - KL penalty added to loss instead of subtracted from reward
    - Single update per batch (no multiple epochs)
    """

    def __init__(
        self,
        model: nn.Module,
        tokenizer: Any,
        config: Optional[GRPOConfig] = None,
        device: str = "cuda",
        ref_model: Optional[nn.Module] = None,
    ):
        """
        Initialize GRPO trainer.

        Args:
            model: Policy model
            tokenizer: Tokenizer for the model
            config: GRPOConfig
            device: Device to use
            ref_model: Reference model for KL penalty (optional)
        """
        self.model = model
        self.tokenizer = tokenizer
        self.config = config or GRPOConfig()
        self.device = device
        self.ref_model = ref_model

        # Initialize CUBLAS for this device to prevent CUBLAS_STATUS_NOT_INITIALIZED
        if device == "cuda" or (isinstance(device, str) and device.startswith("cuda")):
            initialize_cublas()
            logger.info("GRPO Trainer: CUBLAS initialized")

        # Optimizer
        self.optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=self.config.learning_rate,
        )

        # Learning rate scheduler with warmup + cosine annealing
        self.scheduler = None
        self.warmup_steps = 0
        if self.config.use_lr_scheduler and self.config.total_steps > 0:
            self.warmup_steps = max(50, int(self.config.total_steps * self.config.warmup_ratio))
            # Use cosine annealing after warmup
            self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                self.optimizer,
                T_max=max(1, self.config.total_steps - self.warmup_steps),
                eta_min=self.config.min_lr,
            )
            logger.info(f"LR scheduler initialized: warmup={self.warmup_steps} steps, "
                       f"cosine decay to {self.config.min_lr}")

        # Metrics
        self.metrics_history: List[Dict[str, float]] = []
        self.total_updates = 0

    def compute_group_advantages(
        self,
        rewards: torch.Tensor,
        group_size: Optional[int] = None,
    ) -> torch.Tensor:
        """
        Compute group-relative advantages.

        For each group of completions:
        advantage = (reward - mean_reward) / (std_reward + eps)

        Args:
            rewards: Tensor of rewards [batch_size]
            group_size: Size of each group (default: use config)

        Returns:
            Advantages tensor
        """
        group_size = group_size or self.config.group_size

        if len(rewards) < group_size:
            # Not enough for groups, use global normalization
            mean_reward = rewards.mean()
            std_reward = rewards.std() + 1e-8
            advantages = (rewards - mean_reward) / std_reward
        else:
            # Compute group-wise advantages
            num_groups = len(rewards) // group_size
            advantages = torch.zeros_like(rewards)

            for i in range(num_groups):
                start = i * group_size
                end = start + group_size
                group_rewards = rewards[start:end]

                mean_r = group_rewards.mean()
                std_r = group_rewards.std() + 1e-8

                advantages[start:end] = (group_rewards - mean_r) / std_r

            # Handle remaining samples
            remaining = len(rewards) % group_size
            if remaining > 0:
                start = num_groups * group_size
                remaining_rewards = rewards[start:]
                mean_r = remaining_rewards.mean()
                std_r = remaining_rewards.std() + 1e-8
                advantages[start:] = (remaining_rewards - mean_r) / std_r

        if self.config.normalize_advantages:
            advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        return advantages

    def compute_log_probs(
        self,
        prompts: List[str],
        responses: List[str],
    ) -> torch.Tensor:
        """
        Compute log probabilities for given prompt-response pairs.

        Args:
            prompts: List of prompts
            responses: List of responses

        Returns:
            Log probabilities tensor
        """
        full_texts = [p + r for p, r in zip(prompts, responses)]

        inputs = self.tokenizer(
            full_texts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=1024,
        ).to(self.device)

        # Get prompt lengths
        prompt_inputs = self.tokenizer(
            prompts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=512,
        )
        prompt_lengths = prompt_inputs.attention_mask.sum(dim=1)

        # Forward pass with CUBLAS error handling
        # Use model's dtype dynamically instead of hardcoded float16 to avoid precision mismatch
        model_dtype = next(self.model.parameters()).dtype
        use_autocast = model_dtype in (torch.float16, torch.bfloat16)

        @with_cublas_retry
        def _forward_pass():
            if use_autocast:
                with torch.amp.autocast(device_type='cuda', dtype=model_dtype):
                    return self.model(
                        input_ids=inputs.input_ids,
                        attention_mask=inputs.attention_mask,
                        return_dict=True,
                    )
            else:
                return self.model(
                    input_ids=inputs.input_ids,
                    attention_mask=inputs.attention_mask,
                    return_dict=True,
                )

        outputs = _forward_pass()

        logits = outputs.logits
        log_probs_all = F.log_softmax(logits.float(), dim=-1)

        # Compute log probs for response tokens only
        batch_size = inputs.input_ids.shape[0]
        total_log_probs = torch.zeros(batch_size, device=self.device)

        for i in range(batch_size):
            start_idx = prompt_lengths[i].item()
            end_idx = inputs.attention_mask[i].sum().item()

            for j in range(start_idx, end_idx - 1):
                next_token = inputs.input_ids[i, j + 1]
                total_log_probs[i] += log_probs_all[i, j, next_token]

        return total_log_probs

    def compute_kl_divergence(
        self,
        prompts: List[str],
        responses: List[str],
        new_log_probs: torch.Tensor,
    ) -> torch.Tensor:
        """
        Compute KL divergence from reference model.

        Args:
            prompts: List of prompts
            responses: List of responses
            new_log_probs: Log probs from current policy

        Returns:
            KL divergence per sample
        """
        if self.ref_model is None:
            return torch.zeros_like(new_log_probs)

        with torch.no_grad():
            ref_log_probs = self.compute_log_probs_with_model(
                self.ref_model, prompts, responses
            )

        # KL(new || ref) approximation
        kl = new_log_probs - ref_log_probs
        return kl

    def compute_log_probs_with_model(
        self,
        model: nn.Module,
        prompts: List[str],
        responses: List[str],
    ) -> torch.Tensor:
        """Compute log probs using a specific model."""
        full_texts = [p + r for p, r in zip(prompts, responses)]

        inputs = self.tokenizer(
            full_texts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=1024,
        ).to(self.device)

        prompt_inputs = self.tokenizer(
            prompts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=512,
        )
        prompt_lengths = prompt_inputs.attention_mask.sum(dim=1)

        # Forward pass with CUBLAS error handling
        # Use model's dtype dynamically instead of hardcoded float16 to avoid precision mismatch
        model_dtype = next(model.parameters()).dtype
        use_autocast = model_dtype in (torch.float16, torch.bfloat16)
        
        # Determine device type for autocast
        device_type = 'cuda' if torch.cuda.is_available() and 'cuda' in str(self.device) else 'cpu'

        @with_cublas_retry
        def _forward_pass():
            if use_autocast:
                with torch.amp.autocast(device_type=device_type, dtype=model_dtype):
                    return model(
                        input_ids=inputs.input_ids,
                        attention_mask=inputs.attention_mask,
                        return_dict=True,
                    )
            else:
                return model(
                    input_ids=inputs.input_ids,
                    attention_mask=inputs.attention_mask,
                    return_dict=True,
                )

        outputs = _forward_pass()

        logits = outputs.logits
        log_probs_all = F.log_softmax(logits.float(), dim=-1)

        batch_size = inputs.input_ids.shape[0]
        total_log_probs = torch.zeros(batch_size, device=self.device)

        for i in range(batch_size):
            start_idx = prompt_lengths[i].item()
            end_idx = inputs.attention_mask[i].sum().item()

            for j in range(start_idx, end_idx - 1):
                next_token = inputs.input_ids[i, j + 1]
                total_log_probs[i] += log_probs_all[i, j, next_token]

        return total_log_probs

    def train_step(
        self,
        batch: GRPOBatch,
        importance_weights: Optional[torch.Tensor] = None,
    ) -> Dict[str, float]:
        """
        Perform one GRPO update step.

        Args:
            batch: GRPOBatch of experiences
            importance_weights: Optional weights for off-policy correction

        Returns:
            Dictionary of training metrics
        """
        self.model.train()
        batch = batch.to(self.device)

        # Compute advantages if not provided
        if batch.advantages is None:
            advantages = self.compute_group_advantages(batch.rewards)
        else:
            advantages = batch.advantages

        # Apply importance weights if provided
        if importance_weights is not None:
            importance_weights = importance_weights.to(self.device)
            advantages = advantages * importance_weights

        # Compute new log probs
        new_log_probs = self.compute_log_probs(batch.prompts, batch.responses)

        # Compute ratio
        ratio = torch.exp(new_log_probs - batch.old_log_probs)

        # Clipped surrogate loss
        clipped_ratio = torch.clamp(
            ratio,
            1 - self.config.clip_epsilon,
            1 + self.config.clip_epsilon,
        )

        policy_loss = -torch.min(
            ratio * advantages,
            clipped_ratio * advantages,
        ).mean()

        # KL penalty
        if self.config.use_kl_penalty:
            kl = self.compute_kl_divergence(batch.prompts, batch.responses, new_log_probs)
            kl_loss = self.config.kl_coef * kl.mean()
        else:
            kl_loss = torch.tensor(0.0, device=self.device)
            kl = torch.zeros_like(new_log_probs)

        # Total loss
        loss = policy_loss + kl_loss

        # Backward pass
        self.optimizer.zero_grad()
        loss.backward()

        # Gradient clipping
        grad_norm = torch.nn.utils.clip_grad_norm_(
            self.model.parameters(),
            self.config.max_grad_norm,
        )

        self.optimizer.step()
        self.total_updates += 1

        # Learning rate scheduling
        current_lr = self.optimizer.param_groups[0]['lr']
        if self.scheduler is not None:
            if self.total_updates < self.warmup_steps:
                # Linear warmup
                warmup_factor = self.total_updates / max(1, self.warmup_steps)
                for param_group in self.optimizer.param_groups:
                    param_group['lr'] = self.config.learning_rate * warmup_factor
                current_lr = self.optimizer.param_groups[0]['lr']
            else:
                # Cosine annealing
                self.scheduler.step()
                current_lr = self.optimizer.param_groups[0]['lr']

        # Compute metrics
        with torch.no_grad():
            approx_kl = (batch.old_log_probs - new_log_probs).mean().item()
            clip_frac = ((ratio - 1).abs() > self.config.clip_epsilon).float().mean().item()

        metrics = {
            "policy_loss": policy_loss.item(),
            "kl_loss": kl_loss.item(),
            "total_loss": loss.item(),
            "approx_kl": approx_kl,
            "clip_fraction": clip_frac,
            "grad_norm": grad_norm.item() if isinstance(grad_norm, torch.Tensor) else grad_norm,
            "mean_reward": batch.rewards.mean().item(),
            "mean_advantage": advantages.mean().item(),
            "std_advantage": advantages.std().item(),
            "learning_rate": current_lr,
        }

        self.metrics_history.append(metrics)
        return metrics

    def get_gradient_magnitude(self) -> float:
        """Get the magnitude of gradients (for ACB learning signal)."""
        total_norm = 0.0
        for p in self.model.parameters():
            if p.grad is not None:
                total_norm += p.grad.data.norm(2).item() ** 2
        return total_norm ** 0.5

    def save_checkpoint(self, path: str):
        """Save trainer checkpoint."""
        torch.save({
            "model_state_dict": self.model.state_dict(),
            "optimizer_state_dict": self.optimizer.state_dict(),
            "config": self.config,
            "metrics_history": self.metrics_history,
            "total_updates": self.total_updates,
        }, path)

    def load_checkpoint(self, path: str):
        """Load trainer checkpoint."""
        checkpoint = torch.load(path, map_location=self.device)
        self.model.load_state_dict(checkpoint["model_state_dict"])
        self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        self.metrics_history = checkpoint.get("metrics_history", [])
        self.total_updates = checkpoint.get("total_updates", 0)


def create_grpo_batch(experiences: List[GRPOExperience]) -> GRPOBatch:
    """Create a GRPOBatch from a list of experiences."""
    prompts = [e.prompt for e in experiences]
    responses = [e.response for e in experiences]
    rewards = torch.tensor([e.reward for e in experiences], dtype=torch.float32)
    old_log_probs = torch.tensor([e.old_log_prob for e in experiences], dtype=torch.float32)
    difficulties = [e.difficulty for e in experiences]

    return GRPOBatch(
        prompts=prompts,
        responses=responses,
        rewards=rewards,
        old_log_probs=old_log_probs,
        difficulties=difficulties,
    )


if __name__ == "__main__":
    print("Testing GRPO components...")

    config = GRPOConfig()
    print(f"GRPO Config: {config}")

    # Test experience creation
    exp = GRPOExperience(
        prompt="def add(a, b):\n    ",
        response="return a + b",
        reward=1.0,
        old_log_prob=-2.5,
        difficulty=3,
    )
    print(f"Created experience: {exp}")

    # Test batch creation
    experiences = [exp] * 8
    batch = create_grpo_batch(experiences)
    print(f"Created batch with {len(batch.prompts)} experiences")

    # Test advantage computation (without model)
    rewards = torch.tensor([1.0, 0.5, 0.0, 1.0, 0.5, 0.0, 0.8, 0.3])

    class MockTrainer:
        config = GRPOConfig()

        def compute_group_advantages(self, rewards, group_size=None):
            group_size = group_size or self.config.group_size
            mean_r = rewards.mean()
            std_r = rewards.std() + 1e-8
            return (rewards - mean_r) / std_r

    trainer = MockTrainer()
    advantages = trainer.compute_group_advantages(rewards)
    print(f"Advantages: {advantages}")
    print(f"Advantage mean: {advantages.mean():.4f}, std: {advantages.std():.4f}")

    print("\nAll GRPO tests passed!")
