import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_scatter import scatter, scatter_mean, scatter_add
import torch_scatter

from pooling_set import *
from equiv_set import *

from layers import *
import pdb


class HyperLayer(nn.Module):
    def __init__(self, proc_type_V2E, pooling_type_V2E, proc_type_E2V, pooling_type_E2V, args):
        super(HyperLayer, self).__init__()
        self.alpha = args.restart_alpha
        self.dropout = args.dropout
        self.mlp3_layers = args.MLP3_num_layers

        self.normalization = args.normalization
        input_norm = args.deepset_input_norm


        self.V2EConvs = SetLayer(d_in=args.MLP_hidden, 
                                        d_out = args.MLP_hidden, 
                                        num_layers=args.MLP_num_layers, 
                                        d_hid = args.MLP_hidden,
                                        proc_type = proc_type_V2E,
                                        pooling_type = pooling_type_V2E,
                                        args=args)

        self.E2VConvs = SetLayer(d_in=args.MLP_hidden, 
                                        d_out = args.MLP_hidden, 
                                        num_layers=args.MLP_num_layers, 
                                        d_hid = args.MLP_hidden,
                                        proc_type = proc_type_E2V,
                                        pooling_type = pooling_type_E2V,
                                        args=args)

        if self.mlp3_layers > 0:
            self.W = MLP(args.MLP_hidden, args.MLP_hidden, args.MLP_hidden, self.mlp3_layers,
                dropout=self.dropout, Normalization=self.normalization, InputNorm=input_norm)
        else:
            self.W = nn.Identity()


    def forward(self, x, x0, edge_index, reversed_edge_index, data):
        h = self.V2EConvs(x, None, edge_index, data, 'V2E')

        x = self.E2VConvs(h, x, reversed_edge_index, data, 'E2V')
       

        if x.shape[0] < x0.shape[0]:
            dif = x0.shape[0] - x.shape[0]
            pad = torch.zeros((dif, x.shape[1])).to(x.device)
            x = torch.cat((x, pad), dim=0)

        x = self.alpha * x + (1-self.alpha)*x0
        x = self.W(x)

        return x, h

    def reset_parameters(self):
        self.V2EConvs.reset_parameters()
        self.E2VConvs.reset_parameters()
        if self.mlp3_layers > 0:
            self.W.reset_parameters()

class SetLayer(nn.Module):
    def __init__(self, d_in, d_out, num_layers, d_hid, proc_type, pooling_type, args):
        super(SetLayer, self).__init__()
        self.pooling_type = pooling_type
        if proc_type == 'MLP':
            self.proc = MLP_DS(d_in = d_in, 
                                d_out = d_out, 
                                num_layers = num_layers, 
                                hidden_layer_size = d_hid,
                                norm_type=args.normalization,
                                args = args)
            pool_d_in = d_out
        elif proc_type == 'Id':
            self.proc = Identity()
            pool_d_in = d_in
        elif proc_type == 'SAB':
            ln = (args.normalization=='ln')
            self.proc = SAB(d_in, d_out, args.heads, ln=ln)
            pool_d_in = d_out
        
        elif proc_type == 'ISAB':
            ln = (args.normalization=='ln')
            self.proc = ISAB(d_in, d_out, args.heads, num_inds=args.isab_num_inds, ln=ln)
            pool_d_in = d_out


        
        if pooling_type == 'DeepSet':
            self.pooling = GAP(use_mlp = True,
                                d_in = pool_d_in, 
                                d_out = d_out, 
                                num_layers = num_layers, 
                                hidden_layer_size = d_hid,
                                norm_type=args.normalization,
                                args=args)

        elif pooling_type == 'FPSWE':
            self.pooling = FPSWE_pool(d_in = pool_d_in, 
                                num_anchors = args.apprepset_n_anchors, 
                                num_projections = d_out,
                                anch_freeze=True,
                                out_type=args.fpswe_out)
        elif pooling_type == 'LPSWE':
            self.pooling = FPSWE_pool(d_in = pool_d_in, 
                                num_anchors = args.apprepset_n_anchors, 
                                num_projections = d_out,
                                anch_freeze=False,
                                out_type=args.fpswe_out)

        elif pooling_type == 'PMA':
            self.pooling = PMAPool(pool_d_in, d_hid,
                                d_out, num_layers, heads=args.heads, 
                                dropout=args.dropout)
                  
        self.dropout = args.dropout

    
    def forward(self, x, h, edge_index, data, name):
        x, x_nodes = self.proc(x, edge_index, data, name)
        x_inp = x
        x = x_inp

        x = self.pooling(x, edge_index, data, name)
        return x

    def reset_parameters(self):
        self.proc.reset_parameters()
        self.pooling.reset_parameters()




