import torch
import torch.nn.functional as F
from torch import nn

try:
    from torch_scatter import scatter_add, scatter_softmax
except ImportError:
    scatter_softmax = None
    scatter_add = None


_EXPANDER_CACHE = {}


def create_expander_graph_cached(num_nodes, degree, algorithm="Random-d", device="cpu", layer_idx=0):
    cache_key = (num_nodes, degree, algorithm, str(device), layer_idx)

    if cache_key not in _EXPANDER_CACHE:
        if algorithm == "Random-d":
            edge_index = create_random_d_expander_vectorized(num_nodes, degree, device)
        elif algorithm == "Random-d2":
            edge_index = create_random_d2_expander_vectorized(num_nodes, degree, device)
        elif algorithm == "Hamiltonian":
            edge_index = create_hamiltonian_expander_vectorized(
                num_nodes, degree, device
            )
        else:
            raise ValueError(f"Unknown expander algorithm: {algorithm}")

        _EXPANDER_CACHE[cache_key] = edge_index

    return _EXPANDER_CACHE[cache_key]


def create_random_d_expander_vectorized(num_nodes, degree, device="cpu"):
    if num_nodes <= 1:
        return torch.empty((2, 0), dtype=torch.long, device=device)

    actual_degree = min(degree, num_nodes - 1)

    all_nodes = torch.arange(num_nodes, device=device)
    all_edges = []

    for node in range(num_nodes):
        candidates = torch.cat([all_nodes[:node], all_nodes[node + 1 :]])

        if len(candidates) > actual_degree:
            perm = torch.randperm(len(candidates), device=device)[:actual_degree]
            neighbors = candidates[perm]
        else:
            neighbors = candidates

        sources = torch.full((len(neighbors),), node, device=device)
        edges = torch.stack([sources, neighbors])
        all_edges.append(edges)

    if not all_edges:
        return torch.empty((2, 0), dtype=torch.long, device=device)

    return torch.cat(all_edges, dim=1)


def create_random_d2_expander_vectorized(num_nodes, degree, device="cpu"):
    if num_nodes <= 1:
        return torch.empty((2, 0), dtype=torch.long, device=device)

    all_edges = []

    nodes = torch.arange(num_nodes, device=device)

    for i in range(1, degree + 1):
        sources = nodes
        targets = (nodes + i) % num_nodes

        mask = sources != targets
        if mask.any():
            edges = torch.stack([sources[mask], targets[mask]])
            all_edges.append(edges)

    if not all_edges:
        return torch.empty((2, 0), dtype=torch.long, device=device)

    return torch.cat(all_edges, dim=1)


def create_hamiltonian_expander_vectorized(num_nodes, degree, device="cpu"):
    if num_nodes <= 1:
        return torch.empty((2, 0), dtype=torch.long, device=device)

    all_edges = []
    nodes = torch.arange(num_nodes, device=device)

    for k in range(1, degree // 2 + 1):
        forward_targets = (nodes + k) % num_nodes
        backward_targets = (nodes - k) % num_nodes

        forward_mask = nodes != forward_targets
        backward_mask = nodes != backward_targets

        if forward_mask.any():
            forward_edges = torch.stack(
                [nodes[forward_mask], forward_targets[forward_mask]]
            )
            all_edges.append(forward_edges)

        if backward_mask.any():
            backward_edges = torch.stack(
                [nodes[backward_mask], backward_targets[backward_mask]]
            )
            all_edges.append(backward_edges)

    if not all_edges:
        return torch.empty((2, 0), dtype=torch.long, device=device)

    return torch.cat(all_edges, dim=1)


class ExphormerSparseAttentionOptimized(nn.Module):
    def __init__(
        self, d_model, nhead, exp_degree=5, exp_algorithm="Random-d", dropout=0.1, layer_idx=0
    ):
        super().__init__()
        self.d_model = d_model
        self.nhead = nhead
        self.exp_degree = exp_degree
        self.exp_algorithm = exp_algorithm
        self.layer_idx = layer_idx
        self.head_dim = d_model // nhead

        assert d_model % nhead == 0, "d_model must be divisible by nhead"

        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(dropout)
        self.scale = self.head_dim**-0.5

    def forward(self, x, edge_index):
        num_nodes, d_model = x.shape
        device = x.device

        exp_edge_index = create_expander_graph_cached(
            num_nodes, self.exp_degree, self.exp_algorithm, device, self.layer_idx
        )

        if edge_index.size(1) > 0 and exp_edge_index.size(1) > 0:
            combined_edge_index = torch.cat([edge_index, exp_edge_index], dim=1)
        elif edge_index.size(1) > 0:
            combined_edge_index = edge_index
        else:
            combined_edge_index = exp_edge_index

        combined_edge_index = torch.unique(combined_edge_index, dim=1)

        Q = self.q_proj(x).view(num_nodes, self.nhead, self.head_dim)
        K = self.k_proj(x).view(num_nodes, self.nhead, self.head_dim)
        V = self.v_proj(x).view(num_nodes, self.nhead, self.head_dim)

        out = self._sparse_attention_optimized(Q, K, V, combined_edge_index)

        out = out.view(num_nodes, d_model)
        out = self.out_proj(out)

        return out

    def _sparse_attention_optimized(self, Q, K, V, edge_index):
        num_nodes, nhead, head_dim = Q.shape

        if edge_index.size(1) == 0:
            return torch.zeros_like(Q)

        src, tgt = edge_index[0], edge_index[1]

        q_i = Q[tgt]
        k_j = K[src]
        v_j = V[src]

        scores = (q_i * k_j).sum(dim=-1) * self.scale

        if scatter_softmax is not None:
            scores_normalized = torch.zeros_like(scores)
            for head in range(nhead):
                scores_normalized[:, head] = scatter_softmax(
                    scores[:, head], tgt, dim=0, dim_size=num_nodes
                )
        else:
            scores_normalized = self._manual_softmax_normalization(
                scores, tgt, num_nodes
            )

        scores_normalized = self.dropout(scores_normalized)

        out = torch.zeros_like(Q)

        if scatter_add is not None:
            for head in range(nhead):
                weighted_v = v_j[:, head, :] * scores_normalized[:, head : head + 1]
                out[:, head, :] = scatter_add(
                    weighted_v, tgt, dim=0, dim_size=num_nodes
                )
        else:
            for head in range(nhead):
                weighted_v = v_j[:, head, :] * scores_normalized[:, head : head + 1]
                out[:, head, :] = torch.zeros(num_nodes, head_dim, device=Q.device)
                out[:, head, :].index_add_(0, tgt, weighted_v)

        return out

    def _manual_softmax_normalization(self, scores, tgt, num_nodes):
        nhead = scores.size(1)
        scores_normalized = torch.zeros_like(scores)

        for node in range(num_nodes):
            mask = tgt == node
            if mask.any():
                node_scores = scores[mask]
                node_scores_norm = F.softmax(node_scores, dim=0)
                scores_normalized[mask] = node_scores_norm

        return scores_normalized


class ExphormerLayerOptimized(nn.Module):
    def __init__(
        self, d_model, nhead, dim_feedforward, exp_degree, exp_algorithm, dropout, layer_idx=0
    ):
        super().__init__()
        self.self_attn = ExphormerSparseAttentionOptimized(
            d_model, nhead, exp_degree, exp_algorithm, dropout, layer_idx
        )

        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

        self.dropout = nn.Dropout(dropout)
        self.activation = nn.ReLU()

    def forward(self, x, edge_index):
        x2 = self.self_attn(x, edge_index)
        x = x + self.dropout(x2)
        x = self.norm1(x)

        x2 = self.linear2(self.dropout(self.activation(self.linear1(x))))
        x = x + self.dropout(x2)
        x = self.norm2(x)

        return x


class ExphormerOptimized(nn.Module):
    def __init__(
        self,
        in_features,
        d_model,
        nhead,
        dim_feedforward,
        num_layers,
        exp_degree=5,
        exp_algorithm="Random-d",
        dropout=0.1,
    ):
        super().__init__()
        self.d_model = d_model
        self.embedding_dim = d_model

        self.input_proj = (
            nn.Linear(in_features, d_model) if in_features != d_model else nn.Identity()
        )

        self.layers = nn.ModuleList(
            [
                ExphormerLayerOptimized(
                    d_model, nhead, dim_feedforward, exp_degree, exp_algorithm, dropout, layer_idx=i
                )
                for i in range(num_layers)
            ]
        )

        self.norm = nn.LayerNorm(d_model)

    def forward(self, data):
        x = data.x
        edge_index = data.edge_index

        x = self.input_proj(x)

        for layer in self.layers:
            x = layer(x, edge_index)

        x = self.norm(x)
        return x


def clear_expander_cache():
    global _EXPANDER_CACHE
    _EXPANDER_CACHE.clear()
