"""
Universal GPU-based Adam optimizer for contextual bandit learning.

This module implements a memory-efficient Adam optimizer with configurable
norm constraints for learning tool embeddings in the contextual bandit setting.
"""

import gc
from enum import Enum
from typing import Optional

import torch


class NormConstraintType(Enum):
    """Types of norm constraints for embeddings."""

    NONE = "none"  # No norm constraints
    UNIT_SPHERE = (
        "unit_sphere"  # Normalize vectors with norm > 1 to unit sphere
    )
    UNIT_VECTOR = "unit_vector"  # Normalize all vectors to unit length


class UniversalBanditOptimizer:
    """
    Universal GPU-based Adam optimizer with configurable norm constraints.

    Implements the Adam optimization algorithm with memory-efficient half-precision
    states and configurable norm constraints on learned embeddings.
    """

    def __init__(
        self,
        num_arms: int,
        embedding_dim: int,
        device: torch.device,
        learning_rate: float = 1e-6,
        beta1: float = 0.9,
        beta2: float = 0.999,
        epsilon: float = 1e-8,
        norm_constraint: NormConstraintType = NormConstraintType.UNIT_SPHERE,
    ):
        """
        Initialize the universal Adam optimizer.

        Args:
            num_arms: Number of arms (tools) in the bandit
            embedding_dim: Dimension of the embedding vectors
            device: PyTorch device (CPU or CUDA)
            learning_rate: Base learning rate
            beta1: Exponential decay rate for first moment estimates
            beta2: Exponential decay rate for second moment estimates
            epsilon: Small constant for numerical stability
            norm_constraint: Type of norm constraint to apply
        """
        self.learning_rate = learning_rate
        self.beta1 = beta1
        self.beta2 = beta2
        self.epsilon = epsilon
        self.device = device
        self.norm_constraint = norm_constraint

        # Initialize optimizer states with half precision for memory efficiency
        self.first_moment = torch.zeros(
            (num_arms, embedding_dim), dtype=torch.float16, device=device
        )
        self.second_moment = torch.zeros(
            (num_arms, embedding_dim), dtype=torch.float16, device=device
        )
        self.step_count = 0

    def step(
        self,
        parameters: torch.Tensor,
        gradients: torch.Tensor,
        current_lr: Optional[float] = None,
    ) -> torch.Tensor:
        """
        Perform one optimization step with configurable norm constraint.

        Args:
            parameters: Current parameter tensor [num_arms, embedding_dim]
            gradients: Gradient tensor [num_arms, embedding_dim]
            current_lr: Current learning rate (if None, uses base learning rate)

        Returns:
            Updated parameter tensor
        """
        self.step_count += 1
        lr = current_lr if current_lr is not None else self.learning_rate

        # Convert gradients to half precision for memory efficiency
        gradients_half = gradients.half()

        # Update biased first and second moment estimates
        self.first_moment.mul_(self.beta1).add_(
            gradients_half, alpha=1 - self.beta1
        )
        self.second_moment.mul_(self.beta2).addcmul_(
            gradients_half, gradients_half, value=1 - self.beta2
        )

        # Compute bias correction terms
        bias_correction1 = 1 - self.beta1**self.step_count
        bias_correction2 = 1 - self.beta2**self.step_count

        # Compute bias-corrected moments (convert back to float32)
        first_moment_corrected = self.first_moment.float() / bias_correction1
        second_moment_corrected = self.second_moment.float() / bias_correction2

        # Apply Adam update
        parameters.data.addcdiv_(
            first_moment_corrected,
            torch.sqrt(second_moment_corrected) + self.epsilon,
            value=-lr,
        )

        # Apply norm constraint based on configuration
        self._apply_norm_constraint(parameters)

        # Clean up intermediate tensors to save memory
        del gradients_half, first_moment_corrected, second_moment_corrected

        return parameters

    def _apply_norm_constraint(self, parameters: torch.Tensor) -> None:
        """Apply the configured norm constraint to parameters."""
        if self.norm_constraint == NormConstraintType.NONE:
            # No constraints
            pass
        elif self.norm_constraint == NormConstraintType.UNIT_SPHERE:
            # Normalize vectors with norm > 1 to unit sphere
            with torch.no_grad():
                # Compute L2 norm for each arm (row-wise)
                norms = torch.norm(parameters.data, dim=1, keepdim=True)
                # Only normalize if norm > 1 (clamp to minimum of 1)
                norms = torch.clamp(norms, min=1)
                # Normalize each arm's embedding
                parameters.data.div_(norms)
        elif self.norm_constraint == NormConstraintType.UNIT_VECTOR:
            # Normalize all vectors to unit length
            with torch.no_grad():
                # Compute L2 norm for each arm (row-wise)
                norms = torch.norm(parameters.data, dim=1, keepdim=True)
                # Avoid division by zero
                norms = torch.clamp(norms, min=1e-8)
                # Normalize each arm's embedding to unit length
                parameters.data.div_(norms)
        else:
            raise ValueError(
                f"Unknown norm constraint type: {self.norm_constraint}"
            )

    def get_memory_usage(self) -> float:
        """Get memory usage of optimizer states in MB."""
        if self.device.type == "cuda":
            return (
                self.first_moment.element_size() * self.first_moment.numel()
                + self.second_moment.element_size() * self.second_moment.numel()
            ) / (1024**2)
        return 0.0

    def reset(self):
        """Reset optimizer state (useful for multiple runs)."""
        self.first_moment.zero_()
        self.second_moment.zero_()
        self.step_count = 0

    def set_norm_constraint(self, norm_constraint: NormConstraintType):
        """Change the norm constraint type during training."""
        self.norm_constraint = norm_constraint

    def get_norm_constraint(self) -> NormConstraintType:
        """Get the current norm constraint type."""
        return self.norm_constraint


def get_gpu_memory_usage() -> float:
    """Get current GPU memory usage in MB."""
    if torch.cuda.is_available():
        return torch.cuda.memory_allocated() / (1024**2)
    return 0.0


def cleanup_gpu_memory():
    """Clean up GPU memory by running garbage collection and clearing cache."""
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()


# Convenience functions for creating optimizers with common configurations
def create_unit_sphere_optimizer(
    num_arms: int,
    embedding_dim: int,
    device: torch.device,
    learning_rate: float = 1e-3,
    **kwargs,
) -> UniversalBanditOptimizer:
    """Create optimizer with unit sphere constraint (norm <= 1)."""
    return UniversalBanditOptimizer(
        num_arms=num_arms,
        embedding_dim=embedding_dim,
        device=device,
        learning_rate=learning_rate,
        norm_constraint=NormConstraintType.UNIT_SPHERE,
        **kwargs,
    )


def create_unit_vector_optimizer(
    num_arms: int,
    embedding_dim: int,
    device: torch.device,
    learning_rate: float = 1e-3,
    **kwargs,
) -> UniversalBanditOptimizer:
    """Create optimizer with unit vector constraint (norm = 1)."""
    return UniversalBanditOptimizer(
        num_arms=num_arms,
        embedding_dim=embedding_dim,
        device=device,
        learning_rate=learning_rate,
        norm_constraint=NormConstraintType.UNIT_VECTOR,
        **kwargs,
    )


def create_unconstrained_optimizer(
    num_arms: int,
    embedding_dim: int,
    device: torch.device,
    learning_rate: float = 1e-3,
    **kwargs,
) -> UniversalBanditOptimizer:
    """Create optimizer with no norm constraints."""
    return UniversalBanditOptimizer(
        num_arms=num_arms,
        embedding_dim=embedding_dim,
        device=device,
        learning_rate=learning_rate,
        norm_constraint=NormConstraintType.NONE,
        **kwargs,
    )
