import torch
from torch import nn
from torch_geometric.nn import global_mean_pool
import torch.nn.functional as F

from src.models.gnn_classes.message_passing import MessagePassingLayer
from src.models.model_utils import ACTIVATION_MAPPING, POOLING_MAPPING


def build_cayley_edge_index(num_nodes):
    # Find smallest n s.t. n^3 >= num_nodes (approximation for |SL(2, Z_n)|)
    n = 2
    while n**3 < num_nodes:
        n += 1
    # Generate elements of SL(2, Z_n)
    els = [(a,b,c,d)
           for a in range(n) for b in range(n)
           for c in range(n) for d in range(n)
           if (a*d - b*c) % n == 1]
    idx = {els[i]: i for i in range(len(els))}
    N = len(els)
    # Generators
    gens = [(1,1,0,1), (1,0,1,1)]
    rows, cols = [], []
    for g in els:
        a,b,c,d = g
        i = idx[g]
        for sx, sy, sz, sw in gens:
            na = (a*sx + b*sz) % n
            nb = (a*sy + b*sw) % n
            nc = (c*sx + d*sz) % n
            nd = (c*sy + d*sw) % n
            j = idx.get((na,nb,nc,nd))
            if j is not None:
                rows.append(i)
                cols.append(j)
    # Truncate to desired size
    if N > num_nodes:
        mask = list(range(num_nodes))
        mask_set = set(mask)
        filtered = [(r,c) for r,c in zip(rows,cols) if r in mask_set and c in mask_set]
        rows, cols = zip(*filtered) if filtered else ([], [])
        N = num_nodes
    return torch.tensor([rows, cols], dtype=torch.long), N

# Utility: Build batched block-diagonal expander edge_index

def build_batch_expander_edge_index(batch):
    # batch: tensor of shape [total_nodes], values in [0, num_graphs)
    total_nodes = batch.size(0)
    exp_rows, exp_cols = [], []
    for graph_id in batch.unique().tolist():
        # global node indices for this graph
        node_idx = (batch == graph_id).nonzero(as_tuple=False).view(-1)
        Ni = node_idx.size(0)
        if Ni == 0: continue
        # local expander edges
        (edge_idx, N_local) = build_cayley_edge_index(Ni)
        # adjust only if truncated produces fewer nodes
        #assert N_local >= Ni
        rows_local, cols_local = edge_idx
        # map to global indices
        rows_global = node_idx[rows_local]
        cols_global = node_idx[cols_local]
        exp_rows.append(rows_global)
        exp_cols.append(cols_global)
    if exp_rows:
        exp_rows = torch.cat(exp_rows, dim=0)
        exp_cols = torch.cat(exp_cols, dim=0)
    else:
        exp_rows = torch.tensor([], dtype=torch.long)
        exp_cols = torch.tensor([], dtype=torch.long)
    return torch.stack([exp_rows, exp_cols], dim=0)


# Define the GCN model
class ExpanderGNN(nn.Module):
    def __init__(self,
                 output_dim: int,
                 final_activation: torch.functional,
                 msg_passing_method: str="gcn",
                 hidden_dims: list[int] = None,
                 hidden_dim: int = None,
                 n_message_passings: int = None,
                 dropout=0.0,
                 norm: str = None,
                 pooling: str = None,
                 softmax_function: "str" = "softmax"):
        super(ExpanderGNN, self).__init__()

        if hidden_dims is None and hidden_dim is not None:
            hidden_dims = [hidden_dim] * n_message_passings

        # Embedding layer
        self.embedding = nn.LazyLinear(out_features=hidden_dims[0])
        self.conv_layers = nn.ModuleList()
        for i in range(1, len(hidden_dims)):
            self.conv_layers.append(MessagePassingLayer(input_dim=hidden_dims[i - 1],
                                                        output_dim=hidden_dims[i],
                                                        msg_passing_type=msg_passing_method,
                                                        dropout=dropout,
                                                        norm=(norm if (i < len(hidden_dims)-1) else None)))

        self.readout = nn.Linear(hidden_dims[-1], output_dim)
        self.dropout_rate = dropout
        self.softmax_function = softmax_function
        self.pooling = POOLING_MAPPING[pooling]
        self.activation = ACTIVATION_MAPPING[final_activation]

    def forward(self, x, edge_index, edge_attr=None, batch=None):
        # Apply embedding layer
        x = self.embedding(x)

        exp_adj =  build_batch_expander_edge_index(torch.zeros(x.size(0), dtype=torch.long))
        device = x.device
        edge_index_exp = exp_adj.to(device)

        # Apply message passing layers
        for i, conv in enumerate(self.conv_layers):
            # use cayley adjacency list for every second layer
            if i % 2 == 1:
                x = conv(x, edge_index_exp)
            else:
                x = conv(x, edge_index)

        # Safe final node representations in evaluation mode
        if not self.training:
            x_L = x

        if self.pooling is not None:
            x = self.pooling(x=x, batch=batch)

        x = self.activation(self.readout(x))

        output = x if self.training else (x, x_L)

        return output
