from .gnnconv import GCNLayer, untrainedGCNLayer, untrainedGCONVLayer
from torch import nn
import torch.nn.functional as F
import torch
import dgl
import dgl.function as fn
from dgl.base import DGLError

class GCN(nn.Module):
    def __init__(self, args):
        super(GCN, self).__init__()
        dims = [args.d_data] + args.GCN_args['h_dims'] + [args.n_cls]
        self.dropout = args.GCN_args['dropout']
        self.gnn_layers = nn.ModuleList()
        for l in range(len(dims)-1):
            self.gnn_layers.append(GCNLayer(dims[l], dims[l+1]))

    def forward_aux(self, features):
        h = features
        for layer in self.gnn_layers[:-1]:
            h = layer.linear(h)
            h = F.relu(h)
            h = F.dropout(h, p=self.dropout, training=self.training)
        return self.gnn_layers[-1].linear(h)

    def forward(self, blocks, twp=False):
        if isinstance(blocks, torch.Tensor):
            return self.forward_aux(blocks)
        e_list = []
        h = blocks[0].srcdata['feat']
        for i, layer in enumerate(self.gnn_layers[:-1]):
            h, e = layer.forward(blocks[i], h, twp)
            h = F.relu(h)
            h = F.dropout(h, p=self.dropout, training=self.training)
            e_list.append(e)
        logits, e = self.gnn_layers[-1].forward(blocks[-1], h)
        e_list.append(e)
        if twp:
            return logits, e_list
        return logits

    def reset_params(self):
        for layer in self.gnn_layers:
            layer.reset_parameters()

class U_GCN(nn.Module):
    def __init__(self, args):
        super(U_GCN, self).__init__()
        dims = [args.d_data] + args.backbone_args['h_dims']
        self.gnn_layers = nn.ModuleList()
        for l in range(len(dims)-1):
            self.gnn_layers.append(untrainedGCNLayer(dims[l], dims[l+1], args.gain))

    def forward(self, blocks):
        seed_nodes = blocks[-1].dstdata[dgl.NID]
        h_list = []
        h = blocks[0].srcdata['feat']
        for i, layer in enumerate(self.gnn_layers):
            h = layer(blocks[i], h)
            h = F.tanh(h)
            dst_nodes = blocks[i].dstdata[dgl.NID]
            idx_map = torch.cat([torch.where(dst_nodes == seed)[0] for seed in seed_nodes])
            h_list.append(h[idx_map])
        return torch.cat(h_list, dim=1)

class U_GCONV(nn.Module):
    def __init__(self, args):
        super(U_GCONV, self).__init__()
        dims = [args.d_data] + args.backbone_args['h_dims']
        self.gnn_layers = nn.ModuleList()
        for l in range(len(dims)-1):
            self.gnn_layers.append(untrainedGCONVLayer(dims[l], dims[l+1], args.gain))

    def forward(self, blocks):
        seed_nodes = blocks[-1].dstdata[dgl.NID]
        h_list = []
        h = blocks[0].srcdata['feat']
        for i, layer in enumerate(self.gnn_layers):
            h = layer.forward_batch(blocks[i], h)
            h = F.tanh(h)
            dst_nodes = blocks[i].dstdata[dgl.NID]
            idx_map = torch.cat([torch.where(dst_nodes == seed)[0] for seed in seed_nodes])
            h_list.append(h[idx_map])
        return torch.cat(h_list, dim=1)

class SGC(nn.Module):
    def __init__(self, args, norm=None):
        super().__init__()
        self._k = args.SGC_args['k']
        self.norm = norm

    def forward(self, blocks):
        feat = blocks[0].srcdata['feat']
        if self._k != len(blocks):
            raise DGLError('The depth of the dataloader sampler is incompatible with the depth of SGC')
        for block in blocks:
            with block.local_scope():
                # compute normalization
                degs = block.out_degrees().float().clamp(min=1)
                norm = torch.pow(degs, -0.5)
                norm = norm.to(feat.device).unsqueeze(1)
                # compute (D^-1 A^k D)^k X
                feat = feat * norm
                block.srcdata['h'] = feat
                block.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
                feat = block.dstdata.pop('h')
                degs = block.in_degrees().float().clamp(min=1)
                norm = torch.pow(degs, -0.5)
                norm = norm.to(feat.device).unsqueeze(1)
                feat = feat * norm

        with blocks[-1].local_scope():
            if self.norm is not None:
                feat = self.norm(feat)

        return feat
    