import torch
import torch.nn as nn
from hydra.utils import instantiate

# MAX_BATCH = 16384
MAX_BATCH = 2048


class NAGphormer(torch.nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg

        feat_dim = cfg.model.dim_in
        pos_dim = cfg.model.num_eigenvecs

        self.read_in = nn.Linear(
            feat_dim + pos_dim,
            cfg.model.hidden_dim,
        )

        self.num_hops = cfg.model.num_hops

        # transformer
        self.blocks = nn.ModuleList(
            [
                instantiate(
                    cfg.model.block,
                    feat_dim=cfg.model.hidden_dim,
                    pos_dim=0,
                    cfg=cfg,
                    _recursive_=False,
                )
                for _ in range(cfg.model.n_layers)
            ]
        )

        # readout
        self.read_out = nn.Linear(cfg.model.hidden_dim, cfg.model.dim_out)

        self.all_features = None

    def _get_k_hop_features(self, node_feat, pos_feat, edge_index, sparse=True):
        r"""
        node_feat: [N x D_node]
        pos_feat: [N x D_num_eigen_freqs]
        edge_index: [2 x N_edges]
        """
        device = node_feat.device
        N = node_feat.size(0)
        K = self.num_hops
        x = torch.cat([node_feat, pos_feat], dim=-1)

        def create_edge_index_with_self_edges(edge_index, N):
            # Remove self-edges to that are already present
            ret = edge_index[:, edge_index[0] != edge_index[1]]
            # Add self-edges
            diag = torch.arange(N, device=ret.device)
            ret = torch.cat((ret, diag.repeat(2, 1)), dim=-1)
            return ret

        if sparse:
            # Create sparse normalized Adjacency matrix
            edge_index_w_selfedges = create_edge_index_with_self_edges(edge_index, N)
            D = torch.bincount(edge_index_w_selfedges[0], minlength=N)  # degree vector, D[i] = out-degree of each node
            D_sqrt = D.sqrt()
            edge_weights = 1 / (D_sqrt[edge_index_w_selfedges[0]] * D_sqrt[edge_index_w_selfedges[1]])
            A = torch.sparse_coo_tensor(edge_index_w_selfedges, edge_weights.float())
        else:
            # Convert edge_index to adjacency matrix
            A = torch.zeros((N, N), device=device)
            A[edge_index[0], edge_index[1]] = 1.0

            # Add self-loops to A
            diag = torch.arange(N, device=device)
            A[diag, diag] = 1.0

            # Normalize A
            D = (1 / (A.sum(dim=-1, keepdims=True) ** 0.5)) * torch.eye(N, device=device)
            A = D @ A @ D

        self.all_features = torch.zeros((N, K + 1, x.size(1)), device=device)
        self.all_features[:, 0, :] = x
        # Propagate
        with torch.autocast("cuda", enabled=not A.is_sparse):  # sparse matmuls are not supported by torch
            for k in range(1, K):
                self.all_features[:, k, :] = A @ self.all_features[:, k - 1, :]

        print(f"Created k-hop tokens: {self.all_features.shape}")

    def forward(self, batch):
        # -- prep input
        if self.all_features is None:
            self._get_k_hop_features(batch["x_feat"], batch["graph_pos"], batch["edge_index"])

        x = self.read_in(self.all_features)

        # -- transformer
        for block in self.blocks:
            if x.size(0) <= MAX_BATCH:
                x = block(x)
            else:
                # xformers goes OOM if batch size is too big
                # so we split big batches in a for-loop
                x_chunks = []
                for i in range(0, x.size(0), MAX_BATCH):
                    x_chunks.append(block(x[i : i + MAX_BATCH]))
                x = torch.cat(x_chunks, dim=0)

        # -- readout
        nodes_to_decode = batch["node_ids"]
        pred = self.read_out(x[nodes_to_decode, 0, :])
        return pred, batch["task_label"]
