#! /usr/bin/env python
# -*- coding: utf-8 -*-
# vim:fenc=utf-8
#
# Copyright © 2021 
#
# Distributed under terms of the MIT license.

"""
This script contains all models in our paper.
"""

import torch
import utils

import torch.nn as nn
import torch.nn.functional as F

from torch_geometric.nn.conv import MessagePassing, GCNConv, GATConv
from layers import *

import math 

from torch_scatter import scatter, scatter_mean, scatter_add
from torch_geometric.utils import softmax

from pooling_set import *
from equiv_set import *
from layers_set import *




class SetHNN(nn.Module):
    def __init__(self, args, norm=None):
        super(SetHNN, self).__init__()
        """
        args should contain the following:
        V_in_dim, V_enc_hid_dim, V_dec_hid_dim, V_out_dim, V_enc_num_layers, V_dec_num_layers
        E_in_dim, E_enc_hid_dim, E_dec_hid_dim, E_out_dim, E_enc_num_layers, E_dec_num_layers
        All_num_layers,dropout
        !!! V_in_dim should be the dimension of node features
        !!! E_out_dim should be the number of classes (for classification)
        """

#         Now set all dropout the same, but can be different
        self.All_num_layers = args.All_num_layers
        self.aggr = args.aggregate
        self.NormLayer = args.normalization
        self.InputNorm = args.deepset_input_norm
        self.GPR = args.GPR
        self.LearnMask = args.LearnMask
        self.sharing = args.sharing

        self.readout_type = args.readout_type

        self.layers = nn.ModuleList()
        self.dropout = args.dropout
        self.input_dropout = args.input_dropout
    
        self.task_type = args.task_type
        self.output_type = args.output_type

        num_features = args.num_features

        self.pooling_type = args.pooling_type
        self.proc_type_V2E = args.proc_type
        self.proc_type_E2V = args.proc_type

        self.pooling_type_V2E = args.pooling_type
        self.pooling_type_E2V = args.pooling_type
       
        # if args.dname == 'coauthor_cora':
        #     self.pooling_type_E2V = 'DeepSet'

        self.lin = nn.Linear(num_features, args.MLP_hidden)
        self.layers.append(HyperLayer(proc_type_V2E = self.proc_type_V2E,
                                        pooling_type_V2E = self.pooling_type_V2E,
                                        proc_type_E2V = self.proc_type_E2V,
                                        pooling_type_E2V = self.pooling_type_E2V,
                                        args=args))

        
        for _ in range(self.All_num_layers-1):
            self.layers.append(HyperLayer(proc_type_V2E = self.proc_type_V2E,
                                        pooling_type_V2E = self.pooling_type_V2E,
                                        proc_type_E2V = self.proc_type_E2V,
                                        pooling_type_E2V = self.pooling_type_E2V,
                                        args=args))

        self.classifier = MLP(in_channels=args.MLP_hidden,
                            hidden_channels=args.Classifier_hidden,
                            out_channels=args.num_classes,
                            num_layers=args.Classifier_num_layers,
                            dropout=self.dropout,
                            Normalization=self.NormLayer,
                            InputNorm=False)
        self.alpha = args.restart_alpha
        

    def reset_parameters(self):
        self.lin.reset_parameters()
        for layer in self.layers:
            layer.reset_parameters()

        self.classifier.reset_parameters()


    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        cidx = edge_index[1].min()
        edge_index[1] -= cidx  # make sure we do not waste memory
        reversed_edge_index = torch.stack(
            [edge_index[1], edge_index[0]], dim=0)

        x = F.dropout(x, p=self.input_dropout, training=self.training) # Input dropout
        x = self.lin(x)
        x0 = x.clone()
        for i, _ in enumerate(self.layers):
            x = F.dropout(x, p=self.dropout, training=self.training)
            idx = 0 if self.sharing else i
            x, hedge_x = self.layers[idx](x, x0, edge_index, reversed_edge_index, data)
            x = F.relu(x)

        split_idx_dict = None

        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.classifier(x).squeeze(-1) 
        return x, split_idx_dict














