import __init__
import torch
from gcn_lib.sparse.torch_vertex import GENConv,SemiGCNConv,SGConv
from gcn_lib.sparse.torch_nn import norm_layer
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
import logging

def filter_edges(edge_index, y, train_idx):
    device = edge_index.device
    train_mask = torch.zeros(y.size(0), dtype=torch.bool).to(device)
    train_mask[train_idx] = True
    nodes_in_train = train_mask[edge_index]
    edges_in_train = nodes_in_train[0] & nodes_in_train[1]
    labels_1 = y[edge_index[0]]
    labels_2 = y[edge_index[1]]
    label_match = (labels_1 == labels_2).squeeze()
    edge_mask = ~(edges_in_train & ~label_match)
    filtered_edge_index = edge_index[:, edge_mask]

    return filtered_edge_index.to(device)


def dropedge( edge_index, scale):
    num_edges = edge_index.size(1)
    num_drop = int(num_edges * scale)
    drop_indices = torch.randperm(num_edges)[:num_drop]
    mask = torch.ones(num_edges, dtype=torch.bool)
    mask[drop_indices] = False
    edge_index_dropped = edge_index[:, mask]

    return edge_index_dropped


class DeeperGCN(torch.nn.Module):
    def __init__(self, args):
        super(DeeperGCN, self).__init__()

        self.num_layers = args.num_layers
        self.dropout = args.dropout
        self.block = args.block
        # self.mode = args.mode
        self.scale = args.scale
        self.checkpoint_grad = False

        in_channels = args.in_channels
        hidden_channels = args.hidden_channels
        num_tasks = args.num_tasks
        conv = args.conv
        aggr = args.gcn_aggr

        t = args.t
        self.learn_t = args.learn_t
        p = args.p
        self.learn_p = args.learn_p
        y = args.y
        self.learn_y = args.learn_y

        self.msg_norm = args.msg_norm
        learn_msg_scale = args.learn_msg_scale

        self.norm = args.norm # different norm 
        mlp_layers = args.mlp_layers

        if aggr in ['softmax_sg', 'softmax', 'power'] and self.num_layers > 7:
            self.checkpoint_grad = True
            self.ckp_k = self.num_layers // 2

        print('The number of layers {}'.format(self.num_layers),
              'Aggregation method {}'.format(aggr),
              'block: {}'.format(self.block))

        if self.block == 'res+':
            print('LN/BN->ReLU->GraphConv->Res')
        elif self.block == 'res':
            print('GraphConv->LN/BN->ReLU->Res')
        elif self.block == 'dense':
            raise NotImplementedError('To be implemented')
        elif self.block == "plain":
            print('GraphConv->LN/BN->ReLU')
        elif self.block == "sgc_plain":
            print('GraphConv')
        else:
            raise Exception('Unknown block Type')

        self.gcns = torch.nn.ModuleList()
        self.norms = torch.nn.ModuleList()
        
        self.node_features_encoder = torch.nn.Linear(in_channels, hidden_channels)
        self.node_pred_linear = torch.nn.Linear(hidden_channels, num_tasks)

        for layer in range(self.num_layers):

            if conv == 'gen':
                gcn = GENConv(hidden_channels, hidden_channels,
                              aggr=aggr,
                              t=t, learn_t=self.learn_t,
                              p=p, learn_p=self.learn_p,
                              y=y, learn_y=self.learn_y,
                              msg_norm=self.msg_norm, learn_msg_scale=learn_msg_scale,
                              norm=self.norm, mlp_layers=mlp_layers)
            elif conv == 'gcn':
                gcn = SemiGCNConv(in_channels, hidden_channels)
            elif conv == 'sgc':
                gcn = SGConv(in_channels, num_tasks, K=self.num_layers, 
                        mode=self.mode, scale=self.scale)
            else:
                raise Exception('Unknown Conv Type')

            self.gcns.append(gcn)
            if conv=='sgc' or self.norm=='drop':
                pass
            else:
                self.norms.append(norm_layer(self.norm, hidden_channels, self.scale))

    def forward(self,  x, edge_index, y=None, train_idx=None):
        # breakpoint()
        h = self.node_features_encoder(x)
        h0 = h
        if self.block == 'res+':

            h = self.gcns[0](h, edge_index)

            if self.checkpoint_grad:

                for layer in range(1, self.num_layers):
                    h1 = self.norms[layer - 1](h)
                    h2 = F.relu(h1)
                    h2 = F.dropout(h2, p=self.dropout, training=self.training)

                    if layer % self.ckp_k != 0:
                        res = checkpoint(self.gcns[layer], h2, edge_index)
                        h = res + h
                    else:
                        h = self.gcns[layer](h2, edge_index) + h

            else:
                for layer in range(1, self.num_layers):
                    h1 = self.norms[layer - 1](h)
                    h2 = F.relu(h1)
                    h2 = F.dropout(h2, p=self.dropout, training=self.training)
                    h = self.gcns[layer](h2, edge_index) + h

            h = F.relu(self.norms[self.num_layers - 1](h))
            h = F.dropout(h, p=self.dropout, training=self.training)

        elif self.block == 'res':
            if self.norm == 'none':
                h = F.relu(self.gcns[0](h, edge_index))
            else:
                h = F.relu(self.norms[0](self.gcns[0](h, edge_index)))
            h = F.dropout(h, p=self.dropout, training=self.training)

            for layer in range(1, self.num_layers):
                h1 = self.gcns[layer](h, edge_index)
                if self.norm == 'none':
                    h2 = h1
                else:
                    h2 = self.norms[layer](h1)
                h = (1-self.scale)*F.relu(h2) + self.scale*h
                h = F.dropout(h, p=self.dropout, training=self.training)

        elif self.block == 'dense':
            raise NotImplementedError('To be implemented')

        elif self.block == 'plain':
            
            if self.norm == 'none':
                h = F.relu(self.gcns[0](h, edge_index))
            elif self.norm == 'contra':
                h = F.relu(self.norms[0](self.gcns[0](h, edge_index), edge_index))
            elif self.norm == 'label':
                # h = F.relu(self.gcns[0](h, edge_index))
                h = F.relu(self.norms[0](self.gcns[0](h, edge_index)))
                edge_index = filter_edges(edge_index, y, train_idx)
            elif self.norm == 'sign':
                h = F.relu(self.norms[0](self.gcns[0](h, edge_index), h0))
            elif self.norm == 'drop':
                edge_index = dropedge(edge_index, self.scale)
                h = F.relu(self.gcns[0](h, edge_index))
            else:
                h = F.relu(self.norms[0](self.gcns[0](h, edge_index)))
            h = F.dropout(h, p=self.dropout, training=self.training)
            

            for layer in range(1, self.num_layers):
                h1 = self.gcns[layer](h, edge_index)
                if self.norm == 'none' or 'drop':
                    h2 = h1
                elif self.norm == 'contra' :
                    h2 = self.norms[layer](h1, edge_index)
                elif self.norm == 'label':
                    # h2 = h1
                    h2 = self.norms[layer](h1)
                elif self.norm == 'sign':
                    h2 = self.norms[layer](h1,h0)
                else:
                    h2 = self.norms[layer](h1)
                h = F.relu(h2)
                h = F.dropout(h, p=self.dropout, training=self.training)
                if self.norm in ['label', 'sign']:
                    h = (1-self.scale)*h1 + self.scale*h

        elif self.block == 'sgc_plain':

            h = self.gcns[0](h, edge_index, y, train_idx)
            h = F.dropout(h, p=self.dropout, training=self.training)

            # for layer in range(1, self.num_layers):
            #     h1 = self.gcns[layer](h, edge_index)
            #     h2 = self.norms[layer](h1)
            #     h = F.relu(h2)
            #     h = F.dropout(h, p=self.dropout, training=self.training)
        else:
            raise Exception('Unknown block Type')
        if self.block == 'sgc_plain':
            pass
        else:
            h = self.node_pred_linear(h)

        return torch.log_softmax(h, dim=-1)

    def print_params(self, epoch=None, final=False):

        if self.learn_t:
            ts = []
            for gcn in self.gcns:
                ts.append(gcn.t.item())
            if final:
                print('Final t {}'.format(ts))
            else:
                logging.info('Epoch {}, t {}'.format(epoch, ts))

        if self.learn_p:
            ps = []
            for gcn in self.gcns:
                ps.append(gcn.p.item())
            if final:
                print('Final p {}'.format(ps))
            else:
                logging.info('Epoch {}, p {}'.format(epoch, ps))

        if self.learn_y:
            ys = []
            for gcn in self.gcns:
                ys.append(gcn.sigmoid_y.item())
            if final:
                print('Final sigmoid(y) {}'.format(ys))
            else:
                logging.info('Epoch {}, sigmoid(y) {}'.format(epoch, ys))

        if self.msg_norm:
            ss = []
            for gcn in self.gcns:
                ss.append(gcn.msg_norm.msg_scale.item())
            if final:
                print('Final s {}'.format(ss))
            else:
                logging.info('Epoch {}, s {}'.format(epoch, ss))

