import torch.nn as nn
import torch
from stmodels.layers import DenseNet, ContinuousConv
import torch.nn.functional as F
import numpy as np


class SpatialInducedConv(nn.Module):
    def __init__(self, embed_size, ker_input_size, ker_embed_size, level_sizes, level_num, depth=1):
        super(SpatialInducedConv, self).__init__()
        
        self.embed_size = embed_size
        self.level_num = level_num
        self.depth = depth
        self.points = np.cumsum(np.insert(np.array(level_sizes), 0, 0))
        self.points_total = self.points[-1]

        self.conv_down_list = []
        for l in range(1, level_num):
            ker_width_l = ker_embed_size // (2 ** l)
            kernel_l = DenseNet([ker_input_size, ker_width_l, embed_size ** 2], nn.GELU)
            self.conv_down_list.append(ContinuousConv(embed_size, embed_size, kernel_l, aggr='mean', bias=False))
        self.conv_down_list = nn.ModuleList(self.conv_down_list)

        self.conv_list = []
        for l in range(level_num):
            ker_width_l = ker_embed_size // (2 ** l)
            kernel_l = DenseNet([ker_input_size, ker_width_l, ker_width_l, embed_size ** 2], nn.GELU)
            self.conv_list.append(ContinuousConv(embed_size, embed_size, kernel_l, aggr='mean', bias=False))
        self.conv_list = nn.ModuleList(self.conv_list)

        self.conv_up_list = []
        for l in range(1, level_num):
            ker_width_l = ker_embed_size // (2 ** l)
            kernel_l = DenseNet([ker_input_size, ker_width_l, embed_size ** 2], nn.GELU)
            self.conv_up_list.append(ContinuousConv(embed_size, embed_size, kernel_l, aggr='mean', bias=False))
        self.conv_up_list = nn.ModuleList(self.conv_up_list)

    
    def forward(self, embedding, batch, t):
        edge_index_mid, edge_index_down, edge_index_up = batch.edge_index
        edge_attr_mid, edge_attr_down, edge_attr_up = batch.edge_attr
        edge_range_mid, edge_range_down, edge_range_up = batch.edge_range
        batch_size, node_num, embedding_size = embedding.shape
        _, __, ___, attr_dim = edge_attr_mid.shape
        edge_attr_mid, edge_attr_down, edge_attr_up = edge_attr_mid[:,[t]], edge_attr_down[:,[t]], edge_attr_up[:,[t]]
        edge_attr_mid, edge_attr_down, edge_attr_up = edge_attr_mid.reshape(batch_size, -1, attr_dim),\
            edge_attr_down.reshape(batch_size, -1, attr_dim), edge_attr_up.reshape(batch_size, -1, attr_dim)
        x = embedding
        res = torch.clone(embedding)

        for dep in range(self.depth):
            # downward
            for l in range(self.level_num-1):
                x = res + self.conv_down_list[l](
                    x, edge_index_down[:,edge_range_down[l,0]:edge_range_down[l,1]],
                    edge_attr_down[...,edge_range_down[l,0]:edge_range_down[l,1],:]
                    )
                x = F.relu(x)
                res = x

            #midpass
            for l in reversed(range(self.level_num)):
                x = res + self.conv_list[l](
                    x, edge_index_mid[:,edge_range_mid[l,0]:edge_range_mid[l,1]], 
                    edge_attr_mid[...,edge_range_mid[l,0]:edge_range_mid[l,1],:]
                    )
                x = F.relu(x)
                res = x

            #upward
            for l in reversed(range(self.level_num-2)):
                x = res + self.conv_up_list[l](
                    x, edge_index_up[:,edge_range_up[l,0]:edge_range_up[l,1]], 
                    edge_attr_up[...,edge_range_up[l,0]:edge_range_up[l,1],:]
                    )
                x = F.relu(x)
                res = x

        return x