"""
Megatron Bridge: Utilities for integrating Megatron-LM optimizations
and resolving CUBLAS issues in distributed/GPU training.

This module provides:
1. Proper CUBLAS initialization and handle management
2. Fused kernel operations from Megatron (when available)
3. Graceful fallbacks to PyTorch native ops
4. Memory-efficient attention implementations
"""

import os
import logging
import warnings
from typing import Optional, Tuple, Callable, Any
from functools import wraps
from contextlib import contextmanager

import torch
import torch.nn as nn
import torch.nn.functional as F

logger = logging.getLogger(__name__)

# Track initialization state
_CUBLAS_INITIALIZED = False
_MEGATRON_AVAILABLE = False
_APEX_AVAILABLE = False
_FLASH_ATTN_AVAILABLE = False


def _check_megatron():
    """Check if Megatron-LM is available."""
    global _MEGATRON_AVAILABLE
    try:
        from megatron.core import tensor_parallel
        from megatron.core.fusions import fused_layer_norm
        _MEGATRON_AVAILABLE = True
        logger.info("Megatron-LM core detected")
    except ImportError:
        try:
            # Try older megatron import path
            import megatron
            _MEGATRON_AVAILABLE = True
            logger.info("Megatron-LM (legacy) detected")
        except ImportError:
            _MEGATRON_AVAILABLE = False
    return _MEGATRON_AVAILABLE


def _check_apex():
    """Check if NVIDIA Apex is available."""
    global _APEX_AVAILABLE
    try:
        from apex.normalization import FusedLayerNorm, FusedRMSNorm
        from apex.transformer.functional import fused_softmax
        _APEX_AVAILABLE = True
        logger.info("NVIDIA Apex detected")
    except ImportError:
        _APEX_AVAILABLE = False
    return _APEX_AVAILABLE


def _check_flash_attn():
    """Check if Flash Attention is available."""
    global _FLASH_ATTN_AVAILABLE
    try:
        from flash_attn import flash_attn_func
        _FLASH_ATTN_AVAILABLE = True
        logger.info("Flash Attention detected")
    except ImportError:
        _FLASH_ATTN_AVAILABLE = False
    return _FLASH_ATTN_AVAILABLE


def initialize_cublas(device: Optional[int] = None) -> bool:
    """
    Initialize CUBLAS handles to prevent CUBLAS_STATUS_NOT_INITIALIZED errors.

    This should be called early in training to warm up CUBLAS and avoid
    initialization issues during forward/backward passes.

    Args:
        device: CUDA device index (default: current device)

    Returns:
        True if initialization succeeded
    """
    global _CUBLAS_INITIALIZED

    if not torch.cuda.is_available():
        logger.warning("CUDA not available, skipping CUBLAS initialization")
        return False

    if _CUBLAS_INITIALIZED:
        return True

    try:
        if device is None:
            device = torch.cuda.current_device()

        with torch.cuda.device(device):
            # Create tensors to force CUBLAS initialization
            # This triggers the CUBLAS handle creation
            logger.info(f"Initializing CUBLAS on device {device}...")

            # Warmup with small matmul operations - start with float32 which is most stable
            for dtype in [torch.float32]:  # Only warmup with float32 to avoid CUBLAS issues
                try:
                    a = torch.randn(64, 64, dtype=dtype, device=f'cuda:{device}')
                    b = torch.randn(64, 64, dtype=dtype, device=f'cuda:{device}')
                    c = torch.matmul(a, b)
                    del a, b, c
                    logger.info(f"CUBLAS warmup successful for {dtype}")
                except Exception as e:
                    logger.warning(f"CUBLAS warmup failed for {dtype}: {e}")

            # Warmup cuDNN with convolution
            try:
                x = torch.randn(1, 64, 8, 8, device=f'cuda:{device}')
                conv = nn.Conv2d(64, 64, 3, padding=1).to(f'cuda:{device}')
                y = conv(x)
                del x, y, conv
            except Exception as e:
                logger.warning(f"cuDNN warmup failed: {e}")

            # Sync and clear cache
            torch.cuda.synchronize(device)
            torch.cuda.empty_cache()

        _CUBLAS_INITIALIZED = True
        logger.info(f"CUBLAS initialized successfully on device {device}")
        return True

    except Exception as e:
        logger.error(f"CUBLAS initialization failed: {e}")
        return False


def initialize_all_devices() -> bool:
    """Initialize CUBLAS on all available CUDA devices."""
    if not torch.cuda.is_available():
        return False

    success = True
    for i in range(torch.cuda.device_count()):
        if not initialize_cublas(i):
            success = False
    return success


@contextmanager
def cublas_workspace_config(size_mb: int = 32):
    """
    Context manager to configure CUBLAS workspace size.

    Larger workspace can improve performance but uses more memory.
    Smaller workspace can help with CUBLAS_STATUS_ALLOC_FAILED errors.

    Args:
        size_mb: Workspace size in MB
    """
    if not torch.cuda.is_available():
        yield
        return

    # Set workspace limit via environment variable before CUBLAS operations
    old_val = os.environ.get('CUBLAS_WORKSPACE_CONFIG', None)
    os.environ['CUBLAS_WORKSPACE_CONFIG'] = f':4096:8'  # Deterministic config

    try:
        yield
    finally:
        if old_val is None:
            os.environ.pop('CUBLAS_WORKSPACE_CONFIG', None)
        else:
            os.environ['CUBLAS_WORKSPACE_CONFIG'] = old_val


def get_fused_layer_norm() -> type:
    """
    Get the best available fused layer norm implementation.

    Priority:
    1. Apex FusedLayerNorm (fastest)
    2. Megatron fused layer norm
    3. PyTorch native LayerNorm (fallback)

    Returns:
        LayerNorm class to use
    """
    if _APEX_AVAILABLE:
        try:
            from apex.normalization import FusedLayerNorm
            logger.debug("Using Apex FusedLayerNorm")
            return FusedLayerNorm
        except ImportError:
            pass

    if _MEGATRON_AVAILABLE:
        try:
            from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
            logger.debug("Using Megatron FusedLayerNorm")
            return FusedLayerNorm
        except ImportError:
            pass

    logger.debug("Using PyTorch native LayerNorm")
    return nn.LayerNorm


def get_fused_rms_norm() -> type:
    """
    Get the best available fused RMS norm implementation.

    Returns:
        RMSNorm class to use
    """
    if _APEX_AVAILABLE:
        try:
            from apex.normalization import FusedRMSNorm
            logger.debug("Using Apex FusedRMSNorm")
            return FusedRMSNorm
        except ImportError:
            pass

    if _MEGATRON_AVAILABLE:
        try:
            from megatron.core.fusions.fused_layer_norm import FusedRMSNorm
            logger.debug("Using Megatron FusedRMSNorm")
            return FusedRMSNorm
        except ImportError:
            pass

    # Fallback: Custom RMSNorm
    class RMSNorm(nn.Module):
        """Root Mean Square Layer Normalization."""
        def __init__(self, hidden_size: int, eps: float = 1e-6):
            super().__init__()
            self.weight = nn.Parameter(torch.ones(hidden_size))
            self.variance_epsilon = eps

        def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
            variance = hidden_states.pow(2).mean(-1, keepdim=True)
            hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
            return self.weight * hidden_states

    logger.debug("Using custom RMSNorm (no fused kernel)")
    return RMSNorm


class MegatronAttentionBridge(nn.Module):
    """
    Bridge module that selects the best available attention implementation.

    Handles CUBLAS issues by:
    1. Trying Flash Attention first (uses custom CUDA kernels, avoids CUBLAS)
    2. Falling back to Megatron fused attention
    3. Final fallback to memory-efficient attention or standard SDPA
    4. Ultimate fallback to eager attention
    """

    def __init__(
        self,
        hidden_size: int,
        num_attention_heads: int,
        num_kv_heads: Optional[int] = None,
        attention_dropout: float = 0.0,
        use_flash_attn: bool = True,
        use_megatron: bool = True,
    ):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_attention_heads = num_attention_heads
        self.num_kv_heads = num_kv_heads or num_attention_heads
        self.head_dim = hidden_size // num_attention_heads
        self.attention_dropout = attention_dropout

        # Determine best implementation
        self.impl = self._select_implementation(use_flash_attn, use_megatron)
        logger.info(f"MegatronAttentionBridge using: {self.impl}")

    def _select_implementation(self, use_flash_attn: bool, use_megatron: bool) -> str:
        """Select the best available attention implementation."""
        if use_flash_attn and _FLASH_ATTN_AVAILABLE:
            return "flash_attn"

        if use_megatron and _MEGATRON_AVAILABLE:
            try:
                from megatron.core.transformer.attention import SelfAttention
                return "megatron"
            except ImportError:
                pass

        # Check if SDPA is available and working
        if hasattr(F, 'scaled_dot_product_attention'):
            # Test SDPA to see if it works
            try:
                with torch.no_grad():
                    q = torch.randn(1, 1, 4, 8, device='cuda' if torch.cuda.is_available() else 'cpu')
                    k = torch.randn(1, 1, 4, 8, device='cuda' if torch.cuda.is_available() else 'cpu')
                    v = torch.randn(1, 1, 4, 8, device='cuda' if torch.cuda.is_available() else 'cpu')
                    _ = F.scaled_dot_product_attention(q, k, v)
                return "sdpa"
            except Exception as e:
                logger.warning(f"SDPA test failed: {e}")

        return "eager"

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        is_causal: bool = False,
    ) -> torch.Tensor:
        """
        Forward pass with automatic implementation selection.

        Args:
            query: (batch, heads, seq, head_dim)
            key: (batch, kv_heads, seq, head_dim)
            value: (batch, kv_heads, seq, head_dim)
            attention_mask: Optional attention mask
            is_causal: Whether to use causal masking

        Returns:
            Attention output (batch, heads, seq, head_dim)
        """
        if self.impl == "flash_attn":
            return self._flash_attention(query, key, value, attention_mask, is_causal)
        elif self.impl == "megatron":
            return self._megatron_attention(query, key, value, attention_mask, is_causal)
        elif self.impl == "sdpa":
            return self._sdpa_attention(query, key, value, attention_mask, is_causal)
        else:
            return self._eager_attention(query, key, value, attention_mask, is_causal)

    def _flash_attention(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        attention_mask: Optional[torch.Tensor],
        is_causal: bool,
    ) -> torch.Tensor:
        """Flash Attention implementation."""
        from flash_attn import flash_attn_func

        # Flash attention expects (batch, seq, heads, head_dim)
        q = query.transpose(1, 2)
        k = key.transpose(1, 2)
        v = value.transpose(1, 2)

        out = flash_attn_func(
            q, k, v,
            dropout_p=self.attention_dropout if self.training else 0.0,
            causal=is_causal,
        )

        # Back to (batch, heads, seq, head_dim)
        return out.transpose(1, 2)

    def _megatron_attention(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        attention_mask: Optional[torch.Tensor],
        is_causal: bool,
    ) -> torch.Tensor:
        """Megatron attention implementation."""
        try:
            from megatron.core.transformer.dot_product_attention import DotProductAttention
            # Use Megatron's optimized attention
            # Note: This is a simplified bridge - full integration would require
            # more Megatron infrastructure
            attn = DotProductAttention(
                config=None,  # Would need MegatronConfig
                layer_number=1,
                attn_mask_type='causal' if is_causal else 'padding',
            )
            return attn(query, key, value, attention_mask)
        except Exception as e:
            logger.warning(f"Megatron attention failed: {e}, falling back to SDPA")
            return self._sdpa_attention(query, key, value, attention_mask, is_causal)

    def _sdpa_attention(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        attention_mask: Optional[torch.Tensor],
        is_causal: bool,
    ) -> torch.Tensor:
        """PyTorch SDPA implementation with CUBLAS error handling."""
        try:
            return F.scaled_dot_product_attention(
                query, key, value,
                attn_mask=attention_mask,
                dropout_p=self.attention_dropout if self.training else 0.0,
                is_causal=is_causal,
            )
        except RuntimeError as e:
            if "CUBLAS" in str(e) or "cublas" in str(e):
                logger.warning(f"CUBLAS error in SDPA: {e}, falling back to eager")
                return self._eager_attention(query, key, value, attention_mask, is_causal)
            raise

    def _eager_attention(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        attention_mask: Optional[torch.Tensor],
        is_causal: bool,
    ) -> torch.Tensor:
        """Eager (manual) attention implementation - most compatible."""
        batch_size, num_heads, seq_len, head_dim = query.shape

        # Compute attention scores
        scale = head_dim ** -0.5
        attn_weights = torch.matmul(query, key.transpose(-2, -1)) * scale

        # Apply causal mask
        if is_causal:
            causal_mask = torch.triu(
                torch.ones(seq_len, seq_len, dtype=torch.bool, device=query.device),
                diagonal=1
            )
            attn_weights = attn_weights.masked_fill(causal_mask, float('-inf'))

        # Apply attention mask
        if attention_mask is not None:
            attn_weights = attn_weights + attention_mask

        # Softmax and dropout
        attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
        if self.training and self.attention_dropout > 0:
            attn_weights = F.dropout(attn_weights, p=self.attention_dropout)

        # Compute output
        return torch.matmul(attn_weights, value)


class CUBLASSafeLinear(nn.Linear):
    """
    Linear layer with CUBLAS error handling and automatic fallback.

    Wraps nn.Linear to catch CUBLAS errors and retry with different
    configurations or fall back to CPU computation if needed.
    """

    def __init__(self, *args, max_retries: int = 2, **kwargs):
        super().__init__(*args, **kwargs)
        self.max_retries = max_retries
        self._cublas_failures = 0

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        for attempt in range(self.max_retries + 1):
            try:
                return super().forward(input)
            except RuntimeError as e:
                error_str = str(e).lower()
                if 'cublas' in error_str or 'cuda' in error_str:
                    self._cublas_failures += 1
                    if attempt < self.max_retries:
                        logger.warning(
                            f"CUBLAS error in Linear (attempt {attempt + 1}): {e}"
                        )
                        # Try to recover
                        torch.cuda.synchronize()
                        torch.cuda.empty_cache()
                        continue
                    else:
                        # Last resort: CPU fallback
                        logger.warning("CUBLAS failed, computing on CPU")
                        device = input.device
                        result = F.linear(
                            input.cpu(),
                            self.weight.cpu(),
                            self.bias.cpu() if self.bias is not None else None
                        )
                        return result.to(device)
                raise


def with_cublas_retry(func: Callable) -> Callable:
    """
    Decorator to add CUBLAS error handling and retry logic to functions.

    Usage:
        @with_cublas_retry
        def my_cuda_function(x):
            return torch.matmul(x, x.T)
    """
    @wraps(func)
    def wrapper(*args, **kwargs):
        max_retries = kwargs.pop('_cublas_max_retries', 2)

        for attempt in range(max_retries + 1):
            try:
                return func(*args, **kwargs)
            except RuntimeError as e:
                error_str = str(e).lower()
                if 'cublas' in error_str:
                    if attempt < max_retries:
                        logger.warning(
                            f"CUBLAS error (attempt {attempt + 1}/{max_retries}): {e}"
                        )
                        torch.cuda.synchronize()
                        torch.cuda.empty_cache()
                        # Brief pause to let GPU recover
                        import time
                        time.sleep(0.1)
                        continue
                raise
        return func(*args, **kwargs)

    return wrapper


class MegatronBridgeConfig:
    """Configuration for Megatron bridge optimizations."""

    def __init__(
        self,
        use_fused_layer_norm: bool = True,
        use_fused_attention: bool = True,
        use_flash_attention: bool = True,
        cublas_workspace_mb: int = 32,
        enable_cublas_tf32: bool = True,
        enable_cudnn_benchmark: bool = True,
        enable_cudnn_deterministic: bool = False,
    ):
        self.use_fused_layer_norm = use_fused_layer_norm
        self.use_fused_attention = use_fused_attention
        self.use_flash_attention = use_flash_attention
        self.cublas_workspace_mb = cublas_workspace_mb
        self.enable_cublas_tf32 = enable_cublas_tf32
        self.enable_cudnn_benchmark = enable_cudnn_benchmark
        self.enable_cudnn_deterministic = enable_cudnn_deterministic


def configure_cuda_backends(config: Optional[MegatronBridgeConfig] = None):
    """
    Configure CUDA backends for optimal performance and stability.

    Args:
        config: Configuration object (uses defaults if None)
    """
    if not torch.cuda.is_available():
        logger.warning("CUDA not available, skipping backend configuration")
        return

    if config is None:
        config = MegatronBridgeConfig()

    # TF32 for faster matrix multiplication on Ampere+ GPUs
    if config.enable_cublas_tf32:
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
        logger.info("Enabled TF32 for CUBLAS and cuDNN")

    # cuDNN settings
    torch.backends.cudnn.benchmark = config.enable_cudnn_benchmark
    torch.backends.cudnn.deterministic = config.enable_cudnn_deterministic
    logger.info(
        f"cuDNN: benchmark={config.enable_cudnn_benchmark}, "
        f"deterministic={config.enable_cudnn_deterministic}"
    )

    # Set memory fraction to avoid OOM issues
    if hasattr(torch.cuda, 'set_per_process_memory_fraction'):
        # Leave some memory for CUBLAS workspace
        try:
            torch.cuda.set_per_process_memory_fraction(0.95)
        except Exception as e:
            logger.warning(f"Could not set memory fraction: {e}")


def setup_megatron_bridge(
    config: Optional[MegatronBridgeConfig] = None,
    device: Optional[int] = None,
) -> dict:
    """
    Full setup for Megatron bridge optimizations.

    This should be called at the start of training to:
    1. Initialize CUBLAS
    2. Check for available optimizations (Megatron, Apex, Flash Attn)
    3. Configure CUDA backends

    Args:
        config: Configuration object
        device: CUDA device to initialize

    Returns:
        Dictionary with available features
    """
    logger.info("Setting up Megatron bridge...")

    # Check available libraries
    _check_megatron()
    _check_apex()
    _check_flash_attn()

    # Configure backends
    configure_cuda_backends(config)

    # Initialize CUBLAS
    if device is not None:
        initialize_cublas(device)
    else:
        initialize_all_devices()

    features = {
        'megatron_available': _MEGATRON_AVAILABLE,
        'apex_available': _APEX_AVAILABLE,
        'flash_attn_available': _FLASH_ATTN_AVAILABLE,
        'cublas_initialized': _CUBLAS_INITIALIZED,
        'fused_layer_norm': get_fused_layer_norm().__name__,
        'fused_rms_norm': get_fused_rms_norm().__name__,
    }

    logger.info(f"Megatron bridge setup complete: {features}")
    return features


# Convenience function for common use case
def safe_matmul(
    a: torch.Tensor,
    b: torch.Tensor,
    out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """
    CUBLAS-safe matrix multiplication with automatic error recovery.

    Args:
        a: First tensor
        b: Second tensor
        out: Optional output tensor

    Returns:
        Result of matmul(a, b)
    """
    try:
        if out is not None:
            return torch.matmul(a, b, out=out)
        return torch.matmul(a, b)
    except RuntimeError as e:
        if 'cublas' in str(e).lower():
            logger.warning(f"CUBLAS error in matmul: {e}, retrying...")
            torch.cuda.synchronize()
            torch.cuda.empty_cache()
            if out is not None:
                return torch.matmul(a, b, out=out)
            return torch.matmul(a, b)
        raise


if __name__ == "__main__":
    # Test the bridge
    print("Testing Megatron Bridge...")

    features = setup_megatron_bridge()
    print(f"\nAvailable features: {features}")

    if torch.cuda.is_available():
        print("\nTesting safe_matmul...")
        a = torch.randn(100, 100, device='cuda')
        b = torch.randn(100, 100, device='cuda')
        c = safe_matmul(a, b)
        print(f"safe_matmul result shape: {c.shape}")

        print("\nTesting MegatronAttentionBridge...")
        attn = MegatronAttentionBridge(
            hidden_size=512,
            num_attention_heads=8,
        )
        q = torch.randn(2, 8, 32, 64, device='cuda')
        k = torch.randn(2, 8, 32, 64, device='cuda')
        v = torch.randn(2, 8, 32, 64, device='cuda')
        out = attn(q, k, v, is_causal=True)
        print(f"Attention output shape: {out.shape}")

        print("\nMegatron Bridge tests passed!")
    else:
        print("\nCUDA not available, skipping GPU tests")
