from typing import List, Tuple, Union

from layers import *
import torch
import torch.nn
import torch.nn.functional as F
import dgl.function as fn
from dgl.nn.pytorch.glob import SortPooling


class GraphCrossModule(torch.nn.Module):
    """
    Description
    -----------
    The Graph Cross Module used by Graph Cross Networks.
    This module only contains graph cross layers.

    Parameters
    ----------
    pool_ratios : Union[float, List[float]]
        The pooling ratios (for keeping nodes) for each layer.
        For example, if `pool_ratio=0.8`, 80\% nodes will be preserved.
        If a single float number is given, all pooling layers will have the
        same pooling ratio.
    in_dim : int
        The number of input node feature channels.
    out_dim : int
        The number of output node feature channels.
    hidden_dim : int
        The number of hidden node feature channels.
    cross_weight : float, optional
        The weight parameter used in graph cross layers
        Default: :obj:`1.0`
    fuse_weight : float, optional
        The weight parameter used at the end of GXN for channel fusion.
        Default: :obj:`1.0`
    """
    def __init__(self, pool_ratios:Union[float, List[float]], in_dim:int,
                 out_dim:int, hidden_dim:int, cross_weight:float=1.,
                 fuse_weight:float=1., dist:int=1, num_cross_layers:int=2):
        super(GraphCrossModule, self).__init__()
        if isinstance(pool_ratios, float):
            pool_ratios = (pool_ratios, pool_ratios)
        self.cross_weight = cross_weight
        self.fuse_weight = fuse_weight
        self.num_cross_layers = num_cross_layers

        # build network
        self.start_gcn_scale1 = GraphConvWithDropout(in_dim, hidden_dim)
        self.start_gcn_scale2 = GraphConvWithDropout(hidden_dim, hidden_dim)
        self.end_gcn = GraphConvWithDropout(2 * hidden_dim, out_dim)

        self.index_select_scale1 = IndexSelect(pool_ratios[0], hidden_dim, act="prelu", dist=dist)
        self.index_select_scale2 = IndexSelect(pool_ratios[1], hidden_dim, act="prelu", dist=dist)
        self.start_pool_s12 = GraphPool(hidden_dim)
        self.start_pool_s23 = GraphPool(hidden_dim)
        self.end_unpool_s21 = GraphUnpool(hidden_dim)
        self.end_unpool_s32 = GraphUnpool(hidden_dim)

        self.s1_l1_gcn = GraphConvWithDropout(hidden_dim, hidden_dim)
        self.s1_l2_gcn = GraphConvWithDropout(hidden_dim, hidden_dim)
        self.s1_l3_gcn = GraphConvWithDropout(hidden_dim, hidden_dim)

        self.s2_l1_gcn = GraphConvWithDropout(hidden_dim, hidden_dim)
        self.s2_l2_gcn = GraphConvWithDropout(hidden_dim, hidden_dim)
        self.s2_l3_gcn = GraphConvWithDropout(hidden_dim, hidden_dim)

        self.s3_l1_gcn = GraphConvWithDropout(hidden_dim, hidden_dim)
        self.s3_l2_gcn = GraphConvWithDropout(hidden_dim, hidden_dim)
        self.s3_l3_gcn = GraphConvWithDropout(hidden_dim, hidden_dim)

        if num_cross_layers >= 1:
            self.pool_s12_1 = GraphPool(hidden_dim, use_gcn=True)
            self.unpool_s21_1 = GraphUnpool(hidden_dim)
            self.pool_s23_1 = GraphPool(hidden_dim, use_gcn=True)
            self.unpool_s32_1 = GraphUnpool(hidden_dim)
        if num_cross_layers >= 2:
            self.pool_s12_2 = GraphPool(hidden_dim, use_gcn=True)
            self.unpool_s21_2 = GraphUnpool(hidden_dim)
            self.pool_s23_2 = GraphPool(hidden_dim, use_gcn=True)
            self.unpool_s32_2 = GraphUnpool(hidden_dim)

    def forward(self, graph, feat):
        # start of scale-1
        graph_scale1 = graph
        feat_scale1 = self.start_gcn_scale1(graph_scale1, feat)
        feat_origin = feat_scale1
        feat_scale1_neg = feat_scale1[torch.randperm(feat_scale1.size(0))] # negative samples
        logit_s1, scores_s1, select_idx_s1, non_select_idx_s1, feat_down_s1 = \
            self.index_select_scale1(graph_scale1, feat_scale1, feat_scale1_neg)
        feat_scale2, graph_scale2 = self.start_pool_s12(graph_scale1, feat_scale1,
                                                        select_idx_s1, non_select_idx_s1,
                                                        scores_s1, pool_graph=True)
        
        # start of scale-2
        feat_scale2 = self.start_gcn_scale2(graph_scale2, feat_scale2)
        feat_scale2_neg = feat_scale2[torch.randperm(feat_scale2.size(0))] # negative samples
        logit_s2, scores_s2, select_idx_s2, non_select_idx_s2, feat_down_s2 = \
            self.index_select_scale2(graph_scale2, feat_scale2, feat_scale2_neg)
        feat_scale3, graph_scale3 = self.start_pool_s23(graph_scale2, feat_scale2,
                                                        select_idx_s2, non_select_idx_s2,
                                                        scores_s2, pool_graph=True)
        
        # layer-1
        res_s1_0, res_s2_0, res_s3_0 = feat_scale1, feat_scale2, feat_scale3
        
        feat_scale1 = F.relu(self.s1_l1_gcn(graph_scale1, feat_scale1))
        feat_scale2 = F.relu(self.s2_l1_gcn(graph_scale2, feat_scale2))
        feat_scale3 = F.relu(self.s3_l1_gcn(graph_scale3, feat_scale3))

        if self.num_cross_layers >= 1:
            feat_s12_fu = self.pool_s12_1(graph_scale1, feat_scale1,
                                          select_idx_s1, non_select_idx_s1,
                                          scores_s1)
            feat_s21_fu = self.unpool_s21_1(graph_scale1, feat_scale2, select_idx_s1)
            feat_s23_fu = self.pool_s23_1(graph_scale2, feat_scale2,
                                          select_idx_s2, non_select_idx_s2,
                                          scores_s2)
            feat_s32_fu = self.unpool_s32_1(graph_scale2, feat_scale3, select_idx_s2)

            feat_scale1 = feat_scale1 + self.cross_weight * feat_s21_fu + res_s1_0
            feat_scale2 = feat_scale2 + self.cross_weight * (feat_s12_fu + feat_s32_fu) / 2 + res_s2_0
            feat_scale3 = feat_scale3 + self.cross_weight * feat_s23_fu + res_s3_0
        
        # layer-2
        feat_scale1 = F.relu(self.s1_l2_gcn(graph_scale1, feat_scale1))
        feat_scale2 = F.relu(self.s2_l2_gcn(graph_scale2, feat_scale2))
        feat_scale3 = F.relu(self.s3_l2_gcn(graph_scale3, feat_scale3))

        if self.num_cross_layers >= 2:
            feat_s12_fu = self.pool_s12_2(graph_scale1, feat_scale1,
                                          select_idx_s1, non_select_idx_s1,
                                          scores_s1)
            feat_s21_fu = self.unpool_s21_2(graph_scale1, feat_scale2, select_idx_s1)
            feat_s23_fu = self.pool_s23_2(graph_scale2, feat_scale2,
                                          select_idx_s2, non_select_idx_s2,
                                          scores_s2)
            feat_s32_fu = self.unpool_s32_2(graph_scale2, feat_scale3, select_idx_s2)

            cross_weight = self.cross_weight * 0.05
            feat_scale1 = feat_scale1 + cross_weight * feat_s21_fu
            feat_scale2 = feat_scale2 + cross_weight * (feat_s12_fu + feat_s32_fu) / 2
            feat_scale3 = feat_scale3 + cross_weight * feat_s23_fu
        
        # layer-3
        feat_scale1 = F.relu(self.s1_l3_gcn(graph_scale1, feat_scale1))
        feat_scale2 = F.relu(self.s2_l3_gcn(graph_scale2, feat_scale2))
        feat_scale3 = F.relu(self.s3_l3_gcn(graph_scale3, feat_scale3))

        # final layers
        feat_s3_out = self.end_unpool_s32(graph_scale2, feat_scale3, select_idx_s2) + feat_down_s2
        feat_s2_out = self.end_unpool_s21(graph_scale1, feat_scale2 + feat_s3_out, select_idx_s1)
        feat_agg = feat_scale1 + self.fuse_weight * feat_s2_out + self.fuse_weight * feat_down_s1
        feat_agg = torch.cat((feat_agg, feat_origin), dim=1)
        feat_agg = self.end_gcn(graph_scale1, feat_agg)

        return feat_agg, logit_s1, logit_s2


class GraphCrossNet(torch.nn.Module):
    """
    Description
    -----------
    The Graph Cross Network.

    Parameters
    ----------
    in_dim : int
        The number of input node feature channels.
    out_dim : int
        The number of output node feature channels.
    edge_feat_dim : int, optional
        The number of input edge feature channels. Edge feature
        will be passed to a Linear layer and concatenated to
        input node features. Default: :obj:`0`
    hidden_dim : int, optional
        The number of hidden node feature channels.
        Default: :obj:`96`
    pool_ratios : Union[float, List[float]], optional
        The pooling ratios (for keeping nodes) for each layer.
        For example, if `pool_ratio=0.8`, 80\% nodes will be preserved.
        If a single float number is given, all pooling layers will have the
        same pooling ratio.
        Default: :obj:`[0.9, 0.7]`
    readout_nodes : int, optional
        Number of nodes perserved in the final sort pool operation.
        Default: :obj:`30`
    conv1d_dims : List[int], optional
        The number of kernels of Conv1d operations.
        Default: :obj:`[16, 32]`
    conv1d_kws : List[int], optional
        The kernel size of Conv1d.
        Default: :obj:`[5]`
    cross_weight : float, optional
        The weight parameter used in graph cross layers
        Default: :obj:`1.0`
    fuse_weight : float, optional
        The weight parameter used at the end of GXN for channel fusion.
        Default: :obj:`1.0`
    """
    def __init__(self, in_dim:int, out_dim:int, edge_feat_dim:int=0,
                 hidden_dim:int=96, pool_ratios:Union[List[float], float]=[0.9, 0.7],
                 readout_nodes:int=30, conv1d_dims:List[int]=[16, 32],
                 conv1d_kws:List[int]=[5],
                 cross_weight:float=1., fuse_weight:float=1., dist:int=1):
        super(GraphCrossNet, self).__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.hidden_dim = hidden_dim
        self.edge_feat_dim = edge_feat_dim
        self.readout_nodes = readout_nodes
        conv1d_kws = [hidden_dim] + conv1d_kws

        if edge_feat_dim > 0:
            self.in_dim += hidden_dim
            self.e2l_lin = torch.nn.Linear(edge_feat_dim, hidden_dim)
        else:
            self.e2l_lin = None

        self.gxn = GraphCrossModule(pool_ratios, in_dim=self.in_dim, out_dim=hidden_dim,
                                    hidden_dim=hidden_dim//2, cross_weight=cross_weight,
                                    fuse_weight=fuse_weight, dist=dist)
        self.sortpool = SortPooling(readout_nodes)

        # final updates
        self.final_conv1 = torch.nn.Conv1d(1, conv1d_dims[0],
                                                kernel_size=conv1d_kws[0],
                                                stride=conv1d_kws[0])
        self.final_maxpool = torch.nn.MaxPool1d(2, 2)
        self.final_conv2 = torch.nn.Conv1d(conv1d_dims[0], conv1d_dims[1],
                                                 kernel_size=conv1d_kws[1], stride=1)
        self.final_dense_dim = int((readout_nodes - 2) / 2 + 1)
        self.final_dense_dim = (self.final_dense_dim - conv1d_kws[1] + 1) * conv1d_dims[1]

        if self.out_dim > 0:
            self.out_lin = torch.nn.Linear(self.final_dense_dim, out_dim)

        self.init_weights()
    
    def init_weights(self):
        if self.e2l_lin is not None:
            torch.nn.init.xavier_normal_(self.e2l_lin.weight)
        torch.nn.init.xavier_normal_(self.final_conv1.weight)
        torch.nn.init.xavier_normal_(self.final_conv2.weight)
        if self.out_dim > 0:
            torch.nn.init.xavier_normal_(self.out_lin.weight)

    def forward(self, graph:DGLGraph, node_feat:Tensor, edge_feat:Optional[Tensor]=None):
        num_batch = graph.batch_size
        if edge_feat is not None:
            edge_feat = self.e2l_lin(edge_feat)
            with graph.local_scope():
                graph.edata["he"] = edge_feat
                graph.update_all(fn.copy_e("he", "m"), fn.sum("m", "hn"))
                edge2node_feat = graph.ndata.pop("hn")
                node_feat = torch.cat((node_feat, edge2node_feat), dim=1)

        node_feat, logits1, logits2 = self.gxn(graph, node_feat)
        batch_sortpool_feats = self.sortpool(graph, node_feat)

        # final updates
        to_conv1d = batch_sortpool_feats.unsqueeze(1)
        conv1d_result = F.relu(self.final_conv1(to_conv1d))
        conv1d_result = self.final_maxpool(conv1d_result)
        conv1d_result = F.relu(self.final_conv2(conv1d_result))
        
        to_dense = conv1d_result.view(num_batch, -1)
        if self.out_dim > 0:
            out = F.relu(self.out_lin(to_dense))
        else:
            out = to_dense
        
        return out, logits1, logits2


class GraphClassifier(torch.nn.Module):
    """
    Description
    -----------
    Graph Classifier for graph classification.
    GXN + MLP
    """
    def __init__(self, args):
        super(GraphClassifier, self).__init__()
        self.gxn = GraphCrossNet(in_dim=args.in_dim, 
                                 out_dim=args.embed_dim,
                                 edge_feat_dim=args.edge_feat_dim,
                                 hidden_dim=args.hidden_dim,
                                 pool_ratios=args.pool_ratios,
                                 readout_nodes=args.readout_nodes,
                                 conv1d_dims=args.conv1d_dims,
                                 conv1d_kws=args.conv1d_kws,
                                 cross_weight=args.cross_weight,
                                 fuse_weight=args.fuse_weight)
        self.lin1 = torch.nn.Linear(args.embed_dim, args.final_dense_hidden_dim)
        self.lin2 = torch.nn.Linear(args.final_dense_hidden_dim, args.out_dim)
        self.dropout = args.dropout

    def forward(self, graph:DGLGraph, node_feat:Tensor, edge_feat:Optional[Tensor]=None):
        embed, logits1, logits2 = self.gxn(graph, node_feat, edge_feat)
        logits = F.relu(self.lin1(embed))
        if self.dropout > 0:
            logits = F.dropout(logits, p=self.dropout, training=self.training)
        logits = self.lin2(logits)
        return F.log_softmax(logits, dim=1), logits1, logits2
