import torch
import torch.nn as nn
from hydra.utils import instantiate

from src.utils import get_class_from_path

# MAX_BATCH = 16384
MAX_BATCH = 2048


class NAGphormerOurs(torch.nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg

        # feature and position encoders
        self.encoder = get_class_from_path(cfg.model.feat_encoder.name)(cfg)
        self.pos_encoder = get_class_from_path(cfg.model.pos_encoder.name)(cfg)

        self.feat_dim = feat_dim = self.encoder.dim_feat
        self.pos_dim = pos_dim = self.pos_encoder.dim_pe

        self.num_hops = cfg.model.num_hops

        # transformer
        self.blocks = nn.ModuleList(
            [
                instantiate(
                    cfg.model.block,
                    feat_dim=feat_dim,
                    pos_dim=pos_dim,
                    cfg=cfg,
                    _recursive_=False,
                )
                for _ in range(cfg.model.n_layers)
            ]
        )

        # readout
        self.read_out = nn.Linear(feat_dim + pos_dim, cfg.model.dim_out)

        self.all_node_features = None
        self.all_pos_features = None

    @torch.no_grad()
    def _get_k_hop_features(self, node_feat, pos_feat, edge_index, sparse=True):
        r"""
        node_feat: [N x D_node]
        pos_feat: [N x D_num_eigen_freqs]
        edge_index: [2 x N_edges]
        """
        device = node_feat.device
        N = node_feat.size(0)
        K = self.num_hops
        x = torch.cat([node_feat, pos_feat], dim=-1)

        def create_edge_index_with_self_edges(edge_index, N):
            # Remove self-edges to that are already present
            ret = edge_index[:, edge_index[0] != edge_index[1]]
            # Add self-edges
            diag = torch.arange(N, device=ret.device)
            ret = torch.cat((ret, diag.repeat(2, 1)), dim=-1)
            return ret

        if sparse:
            # Create sparse normalized Adjacency matrix
            edge_index_w_selfedges = create_edge_index_with_self_edges(edge_index, N)
            D = torch.bincount(edge_index_w_selfedges[0], minlength=N)  # degree vector, D[i] = out-degree of each node
            D_sqrt = D.sqrt()
            edge_weights = 1 / (D_sqrt[edge_index_w_selfedges[0]] * D_sqrt[edge_index_w_selfedges[1]])
            A = torch.sparse_coo_tensor(edge_index_w_selfedges, edge_weights.float())
        else:
            # Convert edge_index to adjacency matrix
            A = torch.zeros((N, N), device=device)
            A[edge_index[0], edge_index[1]] = 1.0

            # Add self-loops to A
            diag = torch.arange(N, device=device)
            A[diag, diag] = 1.0

            # Normalize A
            D = (1 / (A.sum(dim=-1, keepdims=True) ** 0.5)) * torch.eye(N, device=device)
            A = D @ A @ D

        # Propagate
        with torch.autocast("cuda", enabled=not A.is_sparse):  # sparse matmuls are not supported by torch
            all_features = [x.clone()]
            for k in range(K):
                _next = A @ all_features[-1]
                all_features.append(_next)
            all_features = torch.stack(all_features, dim=1)
            # Stack at dim 1 so that final shape is: (n_nodes, n_hops+1, dim)
            # batch size is n_nodes, context length is n_hops + 1

        self.all_node_features = all_features[..., : node_feat.size(-1)]
        self.all_pos_features = all_features[..., node_feat.size(-1) :]
        self.hops_created = True

        print(f"Created k-hop tokens: {self.all_node_features.shape}, {self.all_pos_features.shape}")

    def forward(self, batch):
        # -- prep input
        if self.all_node_features is None:
            self._get_k_hop_features(batch["x_feat"], batch["graph_pos"], batch["edge_index"])

        x_node = self.encoder(self.all_node_features)
        x_pos = self.pos_encoder(self.all_pos_features, batch)

        x = torch.cat((x_node, x_pos), dim=-1)

        # -- transformer
        for block in self.blocks:
            if x.size(0) <= MAX_BATCH:
                x = block(x)
            else:
                # xformers goes OOM if batch size is too big
                # so we split big batches in a for-loop
                x_chunks = []
                for i in range(0, x.size(0), MAX_BATCH):
                    x_chunks.append(block(x[i : i + MAX_BATCH]))
                x = torch.cat(x_chunks, dim=0)

        # -- readout
        nodes_to_decode = batch["node_ids"]
        pred = self.read_out(x[nodes_to_decode, 0, :])
        return pred, batch["task_label"]
