import numpy as np
import torch
import torch.nn.functional as F
from torch import nn


def create_expander_graph(num_nodes, degree, algorithm="Random-d", device="cpu"):
    """
    Create an expander graph using different algorithms.

    Args:
        num_nodes (int): Number of nodes in the graph
        degree (int): Degree of each node in the expander graph
        algorithm (str): Algorithm to use ('Random-d', 'Random-d2', 'Hamiltonian')
        device (str): Device to place the tensor on

    Returns:
        torch.Tensor: Edge indices of shape [2, num_edges] for the expander graph
    """
    if algorithm == "Random-d":
        return create_random_d_expander(num_nodes, degree, device)
    elif algorithm == "Random-d2":
        return create_random_d2_expander(num_nodes, degree, device)
    elif algorithm == "Hamiltonian":
        return create_hamiltonian_expander(num_nodes, degree, device)
    else:
        raise ValueError(f"Unknown expander algorithm: {algorithm}")


def create_random_d_expander(num_nodes, degree, device="cpu"):
    """Create a random d-regular graph."""
    edge_list = []

    for node in range(num_nodes):
        # Sample neighbors without replacement, excluding self
        possible_neighbors = list(range(num_nodes))
        possible_neighbors.remove(node)

        # Ensure we don't exceed the number of possible neighbors
        actual_degree = min(degree, len(possible_neighbors))
        neighbors = np.random.choice(possible_neighbors, actual_degree, replace=False)

        for neighbor in neighbors:
            edge_list.append([node, neighbor])

    if len(edge_list) == 0:
        return torch.empty((2, 0), dtype=torch.long, device=device)

    edge_index = torch.tensor(edge_list, dtype=torch.long, device=device).t()
    return edge_index


def create_random_d2_expander(num_nodes, degree, device="cpu"):
    """Create a Random-d2 expander graph (variation of random regular graph)."""
    edge_list = []

    for node in range(num_nodes):
        # Create connections to nodes at fixed distances
        for i in range(1, degree + 1):
            neighbor = (node + i) % num_nodes
            if neighbor != node:
                edge_list.append([node, neighbor])

        # Add some random connections
        possible_neighbors = list(range(num_nodes))
        possible_neighbors.remove(node)

        num_random = max(0, degree - degree)  # Add some randomness
        if num_random > 0 and len(possible_neighbors) >= num_random:
            random_neighbors = np.random.choice(
                possible_neighbors, num_random, replace=False
            )
            for neighbor in random_neighbors:
                edge_list.append([node, neighbor])

    if len(edge_list) == 0:
        return torch.empty((2, 0), dtype=torch.long, device=device)

    edge_index = torch.tensor(edge_list, dtype=torch.long, device=device).t()
    return edge_index


def create_hamiltonian_expander(num_nodes, degree, device="cpu"):
    """Create a Hamiltonian-based expander graph."""
    edge_list = []

    for node in range(num_nodes):
        for k in range(1, degree // 2 + 1):
            # Forward connections
            neighbor_forward = (node + k) % num_nodes
            if neighbor_forward != node:
                edge_list.append([node, neighbor_forward])

            # Backward connections
            neighbor_backward = (node - k) % num_nodes
            if neighbor_backward != node:
                edge_list.append([node, neighbor_backward])

    if len(edge_list) == 0:
        return torch.empty((2, 0), dtype=torch.long, device=device)

    edge_index = torch.tensor(edge_list, dtype=torch.long, device=device).t()
    return edge_index


class ExphormerSparseAttention(nn.Module):
    """
    Exphormer sparse attention mechanism using original graph edges + expander graph edges.

    Args:
        d_model (int): Model dimension
        nhead (int): Number of attention heads
        exp_degree (int): Degree of expander graph
        exp_algorithm (str): Expander graph algorithm ('Random-d', 'Random-d2', 'Hamiltonian')
        dropout (float): Dropout probability
    """

    def __init__(
        self, d_model, nhead, exp_degree=5, exp_algorithm="Random-d", dropout=0.1
    ):
        super().__init__()
        self.d_model = d_model
        self.nhead = nhead
        self.exp_degree = exp_degree
        self.exp_algorithm = exp_algorithm
        self.head_dim = d_model // nhead

        assert d_model % nhead == 0, "d_model must be divisible by nhead"

        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(dropout)
        self.scale = self.head_dim**-0.5

    def forward(self, x, edge_index):
        """
        Forward pass of sparse attention.

        Args:
            x (torch.Tensor): Node features [num_nodes, d_model]
            edge_index (torch.Tensor): Original graph edges [2, num_edges]

        Returns:
            torch.Tensor: Output features [num_nodes, d_model]
        """
        num_nodes, d_model = x.shape
        device = x.device

        # Create combined edge index (original + expander)
        exp_edge_index = create_expander_graph(
            num_nodes, self.exp_degree, self.exp_algorithm, device
        )

        # Combine original and expander edges
        if edge_index.size(1) > 0 and exp_edge_index.size(1) > 0:
            combined_edge_index = torch.cat([edge_index, exp_edge_index], dim=1)
        elif edge_index.size(1) > 0:
            combined_edge_index = edge_index
        else:
            combined_edge_index = exp_edge_index

        # Remove duplicate edges and self-loops
        combined_edge_index = torch.unique(combined_edge_index, dim=1)

        # Project to Q, K, V
        Q = self.q_proj(x).view(num_nodes, self.nhead, self.head_dim)
        K = self.k_proj(x).view(num_nodes, self.nhead, self.head_dim)
        V = self.v_proj(x).view(num_nodes, self.nhead, self.head_dim)

        # Compute sparse attention
        out = self._sparse_attention(Q, K, V, combined_edge_index)

        # Reshape and project output
        out = out.view(num_nodes, d_model)
        out = self.out_proj(out)

        return out

    def _sparse_attention(self, Q, K, V, edge_index):
        """Compute sparse attention over the given edges."""
        num_nodes, nhead, head_dim = Q.shape

        if edge_index.size(1) == 0:
            return torch.zeros_like(Q)

        # Get source and target nodes
        src, tgt = edge_index[0], edge_index[1]

        # Compute attention scores for connected nodes only
        q_i = Q[tgt]  # [num_edges, nhead, head_dim]
        k_j = K[src]  # [num_edges, nhead, head_dim]
        v_j = V[src]  # [num_edges, nhead, head_dim]

        # Attention scores
        scores = (q_i * k_j).sum(dim=-1) * self.scale  # [num_edges, nhead]

        # Apply softmax normalization per target node and head
        scores_normalized = torch.zeros_like(scores)
        for head in range(self.nhead):
            for node in range(num_nodes):
                # Find edges pointing to this node
                mask = tgt == node
                if mask.sum() > 0:
                    node_scores = scores[mask, head]
                    node_scores_norm = F.softmax(node_scores, dim=0)
                    scores_normalized[mask, head] = node_scores_norm

        scores_normalized = self.dropout(scores_normalized)

        # Apply attention weights
        out = torch.zeros_like(Q)
        for head in range(self.nhead):
            # Weighted sum of values
            weighted_v = (
                v_j[:, head, :] * scores_normalized[:, head : head + 1]
            )  # [num_edges, head_dim]

            # Aggregate by target nodes
            out[:, head, :] = torch.zeros(num_nodes, head_dim, device=Q.device)
            out[:, head, :].index_add_(0, tgt, weighted_v)

        return out


class ExphormerLayer(nn.Module):
    """Single Exphormer transformer layer."""

    def __init__(
        self, d_model, nhead, dim_feedforward, exp_degree, exp_algorithm, dropout
    ):
        super().__init__()
        self.self_attn = ExphormerSparseAttention(
            d_model, nhead, exp_degree, exp_algorithm, dropout
        )

        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

        self.dropout = nn.Dropout(dropout)
        self.activation = nn.ReLU()

    def forward(self, x, edge_index):
        """Forward pass through a single Exphormer layer."""
        # Self-attention with residual connection
        x2 = self.self_attn(x, edge_index)
        x = x + self.dropout(x2)
        x = self.norm1(x)

        # Feedforward with residual connection
        x2 = self.linear2(self.dropout(self.activation(self.linear1(x))))
        x = x + self.dropout(x2)
        x = self.norm2(x)

        return x


class Exphormer(nn.Module):
    """
    Exphormer: Sparse Graph Transformer using expander graphs.

    This model uses sparse attention patterns based on:
    1. Original graph edges
    2. Expander graph edges

    Args:
        in_features (int): Input feature dimension
        d_model (int): Model dimension
        nhead (int): Number of attention heads
        dim_feedforward (int): Feedforward dimension
        num_layers (int): Number of transformer layers
        exp_degree (int): Degree of expander graph
        exp_algorithm (str): Expander algorithm ('Random-d', 'Random-d2', 'Hamiltonian')
        dropout (float): Dropout probability
    """

    def __init__(
        self,
        in_features,
        d_model,
        nhead,
        dim_feedforward,
        num_layers,
        exp_degree=5,
        exp_algorithm="Random-d",
        dropout=0.1,
    ):
        super().__init__()
        self.d_model = d_model
        self.embedding_dim = d_model

        # Input projection to handle dimension mismatch
        self.input_proj = (
            nn.Linear(in_features, d_model) if in_features != d_model else nn.Identity()
        )

        self.layers = nn.ModuleList(
            [
                ExphormerLayer(
                    d_model, nhead, dim_feedforward, exp_degree, exp_algorithm, dropout
                )
                for _ in range(num_layers)
            ]
        )

        self.norm = nn.LayerNorm(d_model)

    def forward(self, data):
        """
        Forward pass through Exphormer.

        Args:
            data (Data): PyTorch Geometric Data object

        Returns:
            torch.Tensor: Node embeddings [num_nodes, d_model]
        """
        x = data.x
        edge_index = data.edge_index

        # Project input to d_model dimensions
        x = self.input_proj(x)

        for layer in self.layers:
            x = layer(x, edge_index)

        x = self.norm(x)
        return x
