from typing import Optional, Tuple
import os
import sys

import torch
from torch import Tensor, nn
from torch_geometric.nn.conv.gcn_conv import gcn_norm
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.nn.inits import zeros
from torch_geometric.typing import OptTensor
from torch_geometric.utils import add_self_loops, coalesce
from torch_sparse import SparseTensor

# Global variable to store the loaded extension
_CUDA_EXT = None

# Dynamic compilation and import of CUDA extension
def _load_cuda_extension():
    """Dynamically compile and load the CUDA extension for soft median operations."""
    global _CUDA_EXT
    
    # Return cached extension if already loaded
    if _CUDA_EXT is not None:
        return _CUDA_EXT
        
    try:
        from torch.utils.cpp_extension import load
        import os
        
        # Auto-setup CUDA environment for A100 GPUs
        print("Setting up CUDA environment...")
        
        # Set correct CUDA paths for CUDA 12.5
        cuda_home = '/usr/local/cuda-12.9'
        os.environ['CUDA_HOME'] = cuda_home
        os.environ['CUDA_ROOT'] = cuda_home
        
        # CRITICAL: Force PyTorch to use the correct nvcc
        nvcc_path = '/usr/local/cuda-12.5/bin/nvcc'
        os.environ['NVCC'] = nvcc_path
        
        # Set TORCH_CUDA_ARCH_LIST for A100 (Ampere) architecture
        os.environ['TORCH_CUDA_ARCH_LIST'] = '8.0;7.5;7.0'
        
        # Update PATH to use CUDA 12.5 nvcc (put it FIRST)
        current_path = os.environ.get('PATH', '')
        cuda_bin_path = '/usr/local/cuda-12.5/bin'
        if cuda_bin_path not in current_path:
            os.environ['PATH'] = f"{cuda_bin_path}:{current_path}"
        
        # Update LD_LIBRARY_PATH for CUDA 12.5
        current_ld_path = os.environ.get('LD_LIBRARY_PATH', '')
        cuda_lib_path = '/usr/local/cuda-12.5/lib64'
        if cuda_lib_path not in current_ld_path:
            os.environ['LD_LIBRARY_PATH'] = f"{cuda_lib_path}:{current_ld_path}"
        
        print(f"Using CUDA_HOME: {os.environ['CUDA_HOME']}")
        print(f"Using NVCC: {os.environ['NVCC']}")
        print(f"Using TORCH_CUDA_ARCH_LIST: {os.environ['TORCH_CUDA_ARCH_LIST']}")
        
        # Verify nvcc version before compilation
        import subprocess
        try:
            result = subprocess.run([nvcc_path, '--version'], 
                                  capture_output=True, text=True, check=True)
            print(f"NVCC version check: {result.stdout.split('release')[1].split(',')[0].strip()}")
        except Exception as e:
            print(f"Warning: Could not verify nvcc version: {e}")
        
        # Get the directory where this file is located
        current_dir = os.path.dirname(os.path.abspath(__file__))
        
        # Define source files
        sources = [
            os.path.join(current_dir, 'ops.cpp'),
            os.path.join(current_dir, 'ops_kernel.cu'),
        ]
        
        # Optimized compilation for A100 with CUDA 12.5
        soft_median_cuda_ops = load(
            name='soft_median_cuda_ops',
            sources=sources,
            extra_cflags=['-O3'],
            extra_cuda_cflags=[
                '-O3',
                '--use_fast_math',
                '--expt-relaxed-constexpr',
                # A100 (Ampere) architecture - primary target
                '-gencode=arch=compute_80,code=sm_80',
                # Include Volta and Turing for compatibility
                '-gencode=arch=compute_70,code=sm_70',
                '-gencode=arch=compute_75,code=sm_75',
            ],
            verbose=True  # Enable verbose to see compilation details
        )
        
        print("CUDA extension compiled successfully!")
        _CUDA_EXT = soft_median_cuda_ops.dimmedian_idx
        return _CUDA_EXT
        
    except Exception as e:
        print(f"Failed to compile CUDA extension: {e}")
        raise RuntimeError(
            f"CUDA extension compilation failed: {e}. "
            "Please ensure you have a compatible CUDA installation, "
            "appropriate NVCC version, and PyTorch with CUDA support."
        )

class SoftMedianConv(nn.Module):
    r"""The graph convolutional operator with soft
    median aggregation from the `"Robustness of Graph Neural Networks
    at Scale" <https://arxiv.org/abs/2110.14038>`_ paper (NeurIPS'21)

    Parameters
    ----------
    in_channels : int
        dimensions of int samples
    out_channels : int
        dimensions of output samples
    temperature : float, optional
        temperature parameter for softmax weighting, by default 1.0
    cached : bool, optional
        whether the layer will cache
        the computation of :math:`(\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}}
        \mathbf{\hat{D}}^{-1/2})` and sorted edges on first execution,
        and will use the cached version for further executions,
        by default False
    add_self_loops : bool, optional
        whether to add self-loops to the input graph, by default True
    normalize : bool, optional
        whether to compute symmetric normalization
        coefficients on the fly, by default False
        (Note: SoftMedian has built-in normalization, external normalization is typically not needed)
    row_normalize : bool, optional
        whether to perform row-normalization on the fly, by default False
        (Note: SoftMedian has built-in normalization, external normalization is typically not needed)
    bias : bool, optional
        whether to use bias in the layers, by default True

    Raises
    ------
    RuntimeWarning
        if the local CUDA extension compilation fails

    Note
    ----
    The input edges must be sorted for :meth:`dimmedian_idx`
    from the local CUDA extension

    See also
    --------
    :class:`greatx.nn.models.supervised.SoftMedianGCN`
    """
    _cached_edges: Optional[Tuple[Tensor, Tensor]] = None

    def __init__(self, in_channels: int, out_channels: int,
                 temperature: float = 1.0, cached: bool = False, 
                 add_self_loops: bool = True, normalize: bool = False, 
                 row_normalize: bool = False, bias: bool = True):

        super().__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.temperature = temperature
        self.cached = cached
        self.add_self_loops = add_self_loops
        self.normalize = normalize
        self.row_normalize = row_normalize

        self.lin = Linear(in_channels, out_channels, bias=False,
                          weight_initializer='glorot')

        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

    def reset_parameters(self):
        self.lin.reset_parameters()
        zeros(self.bias)

    def cache_clear(self):
        """Clear cached inputs or intermediate results."""
        self._cached_edges = None
        return self

    def forward(self, x: Tensor, edge_index: Tensor,
                edge_weight: OptTensor = None) -> Tensor:
        """"""

        x = self.lin(x)

        if self._cached_edges is not None:
            edge_index, edge_weight = self._cached_edges
        else:
            # NOTE: we do not support Dense adjacency matrix here
            if isinstance(edge_index, SparseTensor):
                row, col, edge_weight = edge_index.coo()
                edge_index = torch.stack([row, col], dim=0)

            if self.add_self_loops:
                edge_index, edge_weight = add_self_loops(
                    edge_index, edge_weight, num_nodes=x.size(0))

            # Skip GCN-style normalization - SoftMedian has its own normalization
            # if self.normalize:
            #     edge_index, edge_weight = gcn_norm(...)

            if edge_weight is None:
                edge_weight = x.new_ones(edge_index.size(1))

            edge_index, edge_weight = coalesce(edge_index, edge_weight)

            # cache edges
            if self.cached:
                self._cached_edges = edge_index, edge_weight

        x = soft_median_reduce(x, edge_index, edge_weight, self.temperature)

        if self.bias is not None:
            x = x + self.bias

        return x

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}({self.in_channels}, '
                f'{self.out_channels}, temperature={self.temperature})')


def soft_median_reduce(x: Tensor, edge_index: Tensor,
                       edge_weight: Tensor, temperature: float = 1.0) -> Tensor:
    """weighted dimension-wise Soft Median aggregation"""
    import math
    from torch_scatter import scatter_softmax, scatter
    
    # Lazy load the CUDA extension only when this function is called
    dimmedian_idx = _load_cuda_extension()
    
    assert edge_weight is not None
    row, col = edge_index
    N, D = x.size()
    
    # Step 1: Find dimension-wise median indices
    median_idx = dimmedian_idx(x, row, col, edge_weight, N)
    col_idx = torch.arange(D, device=row.device).view(1, -1).expand(N, D)
    x_median = x[median_idx, col_idx]  # "x bar" in the paper
    
    # Step 2: Calculate distances from median to neighbors
    diff = x_median[row] - x[col]  # Shape: [num_edges, D]
    dist = diff.norm(dim=-1)  # Shape: [num_edges], this is "c" in the paper
    
    # Step 3: Calculate soft weights using temperature scaling
    logits = -dist / (temperature * math.sqrt(D))
    weights = scatter_softmax(logits, row, dim=0)  # Softmax per source node
    
    # Step 4: Create weighted adjacency matrix values (s * a in the paper)
    A_weighted_values = weights * edge_weight
    
    # Step 5: Simple normalization (following GB version approach)
    row_sum_orig = scatter(edge_weight, row, dim=0, dim_size=N, reduce='sum')
    row_sum_weighted = scatter(A_weighted_values, row, dim=0, dim_size=N, reduce='sum')
    
    # Avoid division by zero - use 1e-8 for numerical stability
    normalizers = row_sum_orig / (row_sum_weighted + 1e-8)
    
    # Step 6: Apply normalization and aggregate
    final_weights = A_weighted_values * normalizers[row]
    
    # Step 7: Aggregate features using scatter
    weighted_features = x[col] * final_weights.unsqueeze(1)  # Shape: [num_edges, D]
    x_out = scatter(weighted_features, row, dim=0, dim_size=N, reduce='sum')
    
    return x_out