import __init__
import time
import torch
import torch.nn as nn
from gcn_lib.sparse.torch_nn import norm_layer
import torch.nn.functional as F
import logging
import eff_gcn_modules.rev.memgcn as memgcn
from eff_gcn_modules.rev.rev_layer import GENBlock, GENNonlinearBlock
import copy

from ipdb import set_trace as stc


class RevGCN(torch.nn.Module):
    def __init__(self, args):
        super(RevGCN, self).__init__()

        self.num_layers = args.num_layers
        self.dropout = args.dropout
        self.group = args.group

        hidden_channels = args.hidden_channels
        num_tasks = args.num_tasks
        aggr = args.gcn_aggr

        t = args.t
        self.learn_t = args.learn_t
        p = args.p
        self.learn_p = args.learn_p
        y = args.y
        self.learn_y = args.learn_y

        self.msg_norm = args.msg_norm
        learn_msg_scale = args.learn_msg_scale

        conv_encode_edge = args.conv_encode_edge
        norm = args.norm
        mlp_layers = args.mlp_layers
        node_features_file_path = args.nf_path

        self.use_one_hot_encoding = args.use_one_hot_encoding

        self.gcns = torch.nn.ModuleList()
        self.last_norm = norm_layer(norm, hidden_channels)

        for layer in range(self.num_layers):
            Fms = nn.ModuleList()
            fm = GENBlock(hidden_channels//self.group, hidden_channels//self.group,
                          aggr=aggr,
                          t=t, learn_t=self.learn_t,
                          p=p, learn_p=self.learn_p,
                          y=y, learn_y=self.learn_y,
                          msg_norm=self.msg_norm,
                          learn_msg_scale=learn_msg_scale,
                          encode_edge=conv_encode_edge,
                          edge_feat_dim=hidden_channels,
                          norm=norm, mlp_layers=mlp_layers)

            for i in range(self.group):
                if i == 0:
                    Fms.append(fm)
                else:
                    Fms.append(copy.deepcopy(fm))


            invertible_module = memgcn.GroupAdditiveCoupling(Fms,
                                                             group=self.group)


            gcn = memgcn.InvertibleModuleWrapper(fn=invertible_module,
                                                 keep_input=False)

            self.gcns.append(gcn)

        self.node_features = torch.load(node_features_file_path).to(f'cuda:{args.device}')

        if self.use_one_hot_encoding:
            self.node_one_hot_encoder = torch.nn.Linear(8, 8)
            self.node_features_encoder = torch.nn.Linear(8 * 2, hidden_channels)
        else:
            self.node_features_encoder = torch.nn.Linear(8, hidden_channels)

        self.edge_encoder = torch.nn.Linear(8, hidden_channels)

        self.node_pred_linear = torch.nn.Linear(hidden_channels, num_tasks)

    def forward(self, x, node_index, edge_index, edge_attr, epoch=-1):
        
        t = time.perf_counter()
        print(f'Part 1 begin...', flush=True, end=' ')
        node_features_1st = self.node_features[node_index]

        if self.use_one_hot_encoding:
            node_features_2nd = self.node_one_hot_encoder(x)
            # concatenate
            node_features = torch.cat((node_features_1st, node_features_2nd), dim=1)
        else:
            node_features = node_features_1st

        h = self.node_features_encoder(node_features)

        edge_emb = self.edge_encoder(edge_attr)
        edge_emb = torch.cat([edge_emb]*self.group, dim=-1)

        m = torch.zeros_like(h).bernoulli_(1 - self.dropout)
        mask = m.requires_grad_(False) / (1 - self.dropout)

        print(f'Done! [{time.perf_counter() - t:.2f}s]')

        t = time.perf_counter()
        print(f'Part 2 begin...', flush=True, end=' ')
        h = self.gcns[0](h, edge_index, mask, edge_emb)

        for layer in range(1, self.num_layers):
            h = self.gcns[layer](h, edge_index, mask, edge_emb)

        print(f'Done! [{time.perf_counter() - t:.2f}s]')
        # h_norelu = h.cpu().detach()
        t = time.perf_counter()
        print(f'Part 3 begin...', flush=True, end=' ')
        h = F.relu(self.last_norm(h))
        # h_relu = h.cpu().detach()
        h = F.dropout(h, p=self.dropout, training=self.training)
        print(f'Done! [{time.perf_counter() - t:.2f}s]')

        return self.node_pred_linear(h), None, None


    def print_params(self, epoch=None, final=False):

        if self.learn_t:
            ts = []
            for gcn in self.gcns:
                ts.append(gcn.t.item())
            if final:
                print('Final t {}'.format(ts))
            else:
                logging.info('Epoch {}, t {}'.format(epoch, ts))

        if self.learn_p:
            ps = []
            for gcn in self.gcns:
                ps.append(gcn.p.item())
            if final:
                print('Final p {}'.format(ps))
            else:
                logging.info('Epoch {}, p {}'.format(epoch, ps))

        if self.learn_y:
            ys = []
            for gcn in self.gcns:
                ys.append(gcn.sigmoid_y.item())
            if final:
                print('Final sigmoid(y) {}'.format(ys))
            else:
                logging.info('Epoch {}, sigmoid(y) {}'.format(epoch, ys))

        if self.msg_norm:
            ss = []
            for gcn in self.gcns:
                ss.append(gcn.msg_norm.msg_scale.item())
            if final:
                print('Final s {}'.format(ss))
            else:
                logging.info('Epoch {}, s {}'.format(epoch, ss))


class RevGCNpyg(torch.nn.Module):
    def __init__(self, args):
        super().__init__()

        self.num_layers = args.num_layers
        self.dropout = args.dropout
        self.group = args.group

        hidden_channels = args.hidden_channels
        num_tasks = args.num_tasks
        aggr = args.gcn_aggr

        t = args.t
        self.learn_t = args.learn_t
        p = args.p
        self.learn_p = args.learn_p
        y = args.y
        self.learn_y = args.learn_y

        self.msg_norm = args.msg_norm
        learn_msg_scale = args.learn_msg_scale

        conv_encode_edge = args.conv_encode_edge
        norm = args.norm
        mlp_layers = args.mlp_layers
        node_features_file_path = args.nf_path

        self.use_one_hot_encoding = args.use_one_hot_encoding

        self.gcns = torch.nn.ModuleList()
        self.last_norm = norm_layer(norm, hidden_channels)

        for layer in range(self.num_layers):
            Fms = nn.ModuleList()
            fm = GENBlock(hidden_channels//self.group, hidden_channels//self.group,
                          aggr=aggr,
                          t=t, learn_t=self.learn_t,
                          p=p, learn_p=self.learn_p,
                          y=y, learn_y=self.learn_y,
                          msg_norm=self.msg_norm,
                          learn_msg_scale=learn_msg_scale,
                          encode_edge=conv_encode_edge,
                          edge_feat_dim=hidden_channels,
                          norm=norm, mlp_layers=mlp_layers)

            for i in range(self.group):
                if i == 0:
                    Fms.append(fm)
                else:
                    Fms.append(copy.deepcopy(fm))


            invertible_module = memgcn.GroupAdditiveCoupling(Fms,
                                                             group=self.group)


            gcn = memgcn.InvertibleModuleWrapper(fn=invertible_module,
                                                 keep_input=False)

            self.gcns.append(gcn)

        self.node_features = torch.load(node_features_file_path)

        if self.use_one_hot_encoding:
            self.node_one_hot_encoder = torch.nn.Linear(8, 8)
            self.node_features_encoder = torch.nn.Linear(8 * 2, hidden_channels)
        else:
            self.node_features_encoder = torch.nn.Linear(8, hidden_channels)

        self.edge_encoder = torch.nn.Linear(8, hidden_channels)

        self.node_pred_linear = torch.nn.Linear(hidden_channels, num_tasks)

    def forward(self, x, edge_index, edge_attr, epoch=-1):
        
        node_features = x
        # node_features_1st = self.node_features[node_index].to(x.device)

        # if self.use_one_hot_encoding:
        #     node_features_2nd = self.node_one_hot_encoder(x)
        #     # concatenate
        #     node_features = torch.cat((node_features_1st, node_features_2nd), dim=1)
        # else:
        #     node_features = node_features_1st

        h = self.node_features_encoder(node_features)

        edge_emb = self.edge_encoder(edge_attr)
        edge_emb = torch.cat([edge_emb]*self.group, dim=-1)

        m = torch.zeros_like(h).bernoulli_(1 - self.dropout)
        mask = m.requires_grad_(False) / (1 - self.dropout)

        h = self.gcns[0](h, edge_index, mask, edge_emb)

        for layer in range(1, self.num_layers):
            h = self.gcns[layer](h, edge_index, mask, edge_emb)

        # h_norelu = h.cpu().detach()
        h = F.relu(self.last_norm(h))
        # h_relu = h.cpu().detach()
        h = F.dropout(h, p=self.dropout, training=self.training)

        return self.node_pred_linear(h), None, None


    def print_params(self, epoch=None, final=False):

        if self.learn_t:
            ts = []
            for gcn in self.gcns:
                ts.append(gcn.t.item())
            if final:
                print('Final t {}'.format(ts))
            else:
                logging.info('Epoch {}, t {}'.format(epoch, ts))

        if self.learn_p:
            ps = []
            for gcn in self.gcns:
                ps.append(gcn.p.item())
            if final:
                print('Final p {}'.format(ps))
            else:
                logging.info('Epoch {}, p {}'.format(epoch, ps))

        if self.learn_y:
            ys = []
            for gcn in self.gcns:
                ys.append(gcn.sigmoid_y.item())
            if final:
                print('Final sigmoid(y) {}'.format(ys))
            else:
                logging.info('Epoch {}, sigmoid(y) {}'.format(epoch, ys))

        if self.msg_norm:
            ss = []
            for gcn in self.gcns:
                ss.append(gcn.msg_norm.msg_scale.item())
            if final:
                print('Final s {}'.format(ss))
            else:
                logging.info('Epoch {}, s {}'.format(epoch, ss))




class JKRevGCN(torch.nn.Module):
    def __init__(self, args):
        super().__init__()

        self.num_layers = args.num_layers
        # self.W = torch.nn.Linear(self.num_linear)
        
        self.weight = torch.nn.Parameter(torch.randn(self.num_layers))

        self.dropout = args.dropout
        self.group = args.group

        hidden_channels = args.hidden_channels
        teacher_channels = args.teacher_channels
        num_tasks = args.num_tasks
        aggr = args.gcn_aggr

        t = args.t
        self.learn_t = args.learn_t
        p = args.p
        self.learn_p = args.learn_p
        y = args.y
        self.learn_y = args.learn_y

        self.msg_norm = args.msg_norm
        learn_msg_scale = args.learn_msg_scale

        conv_encode_edge = args.conv_encode_edge
        norm = args.norm
        mlp_layers = args.mlp_layers
        node_features_file_path = args.nf_path

        self.use_one_hot_encoding = args.use_one_hot_encoding

        self.gcns = torch.nn.ModuleList()
        self.last_norm = norm_layer(norm, teacher_channels)

        for layer in range(self.num_layers):
            Fms = nn.ModuleList()
            fm = GENBlock(hidden_channels//self.group, hidden_channels//self.group,
                          aggr=aggr,
                          t=t, learn_t=self.learn_t,
                          p=p, learn_p=self.learn_p,
                          y=y, learn_y=self.learn_y,
                          msg_norm=self.msg_norm,
                          learn_msg_scale=learn_msg_scale,
                          encode_edge=conv_encode_edge,
                          edge_feat_dim=hidden_channels,
                          norm=norm, mlp_layers=mlp_layers)

            for i in range(self.group):
                if i == 0:
                    Fms.append(fm)
                else:
                    Fms.append(copy.deepcopy(fm))


            invertible_module = memgcn.GroupAdditiveCoupling(Fms,
                                                             group=self.group)


            gcn = memgcn.InvertibleModuleWrapper(fn=invertible_module,
                                                 keep_input=False)

            self.gcns.append(gcn)

        self.node_features = torch.load(node_features_file_path).to(args.device)

        if self.use_one_hot_encoding:
            self.node_one_hot_encoder = torch.nn.Linear(8, 8)
            self.node_features_encoder = torch.nn.Linear(8 * 2, hidden_channels)
        else:
            self.node_features_encoder = torch.nn.Linear(8, hidden_channels)

        self.edge_encoder = torch.nn.Linear(8, hidden_channels)

        self.node_pred_linear = torch.nn.Linear(hidden_channels, num_tasks)
        # self.classifier = torch.nn.Linear(hidden_channels * self.num_layers, num_tasks)
        self.W_jk = torch.nn.Linear(hidden_channels * (self.num_layers + 1), teacher_channels)
        self.classifier = torch.nn.Linear(teacher_channels, num_tasks)

    def forward(self, x, node_index, edge_index, edge_attr, epoch=-1):
        
        t = time.perf_counter()
        print(f'Part 1 begin...', flush=True, end=' ')
        layer_outputs = []
        node_features_1st = self.node_features[node_index]

        if self.use_one_hot_encoding:
            node_features_2nd = self.node_one_hot_encoder(x)
            # concatenate
            node_features = torch.cat((node_features_1st, node_features_2nd), dim=1)
        else:
            node_features = node_features_1st

        h = self.node_features_encoder(node_features)
        layer_outputs.append(h.clone())
        
        edge_emb = self.edge_encoder(edge_attr)
        edge_emb = torch.cat([edge_emb]*self.group, dim=-1)

        m = torch.zeros_like(h).bernoulli_(1 - self.dropout)
        mask = m.requires_grad_(False) / (1 - self.dropout)

        print(f'Done! [{time.perf_counter() - t:.2f}s]')

        t = time.perf_counter()
        print(f'Part 2 begin...', flush=True, end=' ')
        for layer in range(self.num_layers):
            h = self.gcns[layer](h, edge_index, mask, edge_emb)
            layer_outputs.append(h.clone())
        print(f'Done! [{time.perf_counter() - t:.2f}s]')

        # embs = self.weight[0] * layer_outputs[0]
        # for i in range(1, self.num_layers):
        #     embs += self.weight[i] * layer_outputs[i]
        t = time.perf_counter()
        print(f'Part 3 begin...', flush=True, end=' ')
        h = torch.cat(layer_outputs, dim=1)
        h = self.W_jk(h)
        print(f'Done! [{time.perf_counter() - t:.2f}s]')
        t = time.perf_counter()
        print(f'Part 4 begin...', flush=True, end=' ')
        h_norelu = h.clone()
        h = F.relu(self.last_norm(h))
        h = F.dropout(h, p=self.dropout, training=self.training)
        h_relu = h.clone()
        print(f'Done! [{time.perf_counter() - t:.2f}s]')

        return self.classifier(h), h_norelu, h_relu


    def print_params(self, epoch=None, final=False):

        if self.learn_t:
            ts = []
            for gcn in self.gcns:
                ts.append(gcn.t.item())
            if final:
                print('Final t {}'.format(ts))
            else:
                logging.info('Epoch {}, t {}'.format(epoch, ts))

        if self.learn_p:
            ps = []
            for gcn in self.gcns:
                ps.append(gcn.p.item())
            if final:
                print('Final p {}'.format(ps))
            else:
                logging.info('Epoch {}, p {}'.format(epoch, ps))

        if self.learn_y:
            ys = []
            for gcn in self.gcns:
                ys.append(gcn.sigmoid_y.item())
            if final:
                print('Final sigmoid(y) {}'.format(ys))
            else:
                logging.info('Epoch {}, sigmoid(y) {}'.format(epoch, ys))

        if self.msg_norm:
            ss = []
            for gcn in self.gcns:
                ss.append(gcn.msg_norm.msg_scale.item())
            if final:
                print('Final s {}'.format(ss))
            else:
                logging.info('Epoch {}, s {}'.format(epoch, ss))


class JKNonlinearRevGCN(torch.nn.Module):
    def __init__(self, args):
        super().__init__()

        self.num_layers = args.num_layers
        self.num_linear, self.num_nonlinear = args.num_linear, args.num_nonlinear
        # self.W = torch.nn.Linear(self.num_linear)
        
        self.weight = torch.nn.Parameter(torch.randn(self.num_layers))

        self.dropout = args.dropout
        self.group = args.group

        hidden_channels = args.hidden_channels
        teacher_channels = args.teacher_channels
        num_tasks = args.num_tasks
        aggr = args.gcn_aggr

        t = args.t
        self.learn_t = args.learn_t
        p = args.p
        self.learn_p = args.learn_p
        y = args.y
        self.learn_y = args.learn_y

        self.msg_norm = args.msg_norm
        learn_msg_scale = args.learn_msg_scale

        conv_encode_edge = args.conv_encode_edge
        norm = args.norm
        mlp_layers = args.mlp_layers
        node_features_file_path = args.nf_path

        self.use_one_hot_encoding = args.use_one_hot_encoding

        self.gcns = torch.nn.ModuleList()
        # self.nonlinears = torch.nn.ModuleList()
        self.last_norm = norm_layer(norm, teacher_channels)

        for layer in range(self.num_layers + self.num_nonlinear):
            Fms = nn.ModuleList()
            fm = GENNonlinearBlock(
                hidden_channels//self.group,
                hidden_channels//self.group,
                self.num_layers,
                self.num_nonlinear,
                layer,
                aggr=aggr,
                t=t, learn_t=self.learn_t,
                p=p, learn_p=self.learn_p,
                y=y, learn_y=self.learn_y,
                msg_norm=self.msg_norm,
                learn_msg_scale=learn_msg_scale,
                encode_edge=conv_encode_edge,
                edge_feat_dim=hidden_channels,
                norm=norm, mlp_layers=mlp_layers,
            )

            for i in range(self.group):
                if i == 0:
                    Fms.append(fm)
                else:
                    Fms.append(copy.deepcopy(fm))

            invertible_module = memgcn.GroupAdditiveCouplingNonlinear(
                Fms,
                group=self.group,
            )

            gcn = memgcn.InvertibleModuleWrapperNonlinear(
                fn=invertible_module,
                keep_input=False,
            )

            self.gcns.append(gcn)

        # for layer in range(self.num_nonlinear):
        #     Fms_nonlinear = nn.ModuleList()
        #     fm_nonlinear = NonlinearBlock(norm=norm, in_channels=hidden_channels//self.group)

        #     for i in range(self.group):
        #         if i == 0:
        #             Fms_nonlinear.append(fm_nonlinear)
        #         else:
        #             Fms_nonlinear.append(copy.deepcopy(fm_nonlinear))
            
        #     invertible_module_nonlinear = memgcn.GroupAdditiveCoupling(Fms_nonlinear,
        #                                                                group=self.group)
        #     nonlinear = memgcn.InvertibleModuleWrapper(fn=invertible_module_nonlinear,
        #                                                keep_input=False)
        #     self.nonlinears.append(nonlinear)

        self.node_features = torch.load(node_features_file_path).to(args.device)

        if self.use_one_hot_encoding:
            self.node_one_hot_encoder = torch.nn.Linear(8, 8)
            self.node_features_encoder = torch.nn.Linear(8 * 2, hidden_channels)
        else:
            self.node_features_encoder = torch.nn.Linear(8, hidden_channels)

        self.edge_encoder = torch.nn.Linear(8, hidden_channels)

        self.node_pred_linear = torch.nn.Linear(hidden_channels, num_tasks)
        # self.classifier = torch.nn.Linear(hidden_channels * self.num_layers, num_tasks)
        self.W_jk = torch.nn.Linear(hidden_channels * (self.num_layers + self.num_nonlinear + 1), teacher_channels)
        self.classifier = torch.nn.Linear(teacher_channels, num_tasks)

    def forward(self, x, node_index, edge_index, edge_attr, epoch=-1):
        
        layer_outputs = []
        node_features_1st = self.node_features[node_index]

        if self.use_one_hot_encoding:
            node_features_2nd = self.node_one_hot_encoder(x)
            # concatenate
            node_features = torch.cat((node_features_1st, node_features_2nd), dim=1)
        else:
            node_features = node_features_1st

        h = self.node_features_encoder(node_features)
        layer_outputs.append(h.clone())
        
        edge_emb = self.edge_encoder(edge_attr)
        edge_emb = torch.cat([edge_emb]*self.group, dim=-1)

        m = torch.zeros_like(h).bernoulli_(1 - self.dropout)
        mask = m.requires_grad_(False) / (1 - self.dropout)

        for layer in range(self.num_layers + self.num_nonlinear):
            h = self.gcns[layer](h, edge_index, mask, edge_emb)
            layer_outputs.append(h.clone())

        # for layer in range(self.num_nonlinear):
        #     h = self.nonlinears[layer](h, edge_index, mask, edge_emb)
        #     layer_outputs.append(h.clone())


        # embs = self.weight[0] * layer_outputs[0]
        # for i in range(1, self.num_layers):
        #     embs += self.weight[i] * layer_outputs[i]
        h = torch.cat(layer_outputs, dim=1)
        h = self.W_jk(h)
        h_norelu = h.clone()
        h = F.relu(self.last_norm(h))
        h = F.dropout(h, p=self.dropout, training=self.training)
        h_relu = h.clone()

        return self.classifier(h), h_norelu, h_relu


    def print_params(self, epoch=None, final=False):

        if self.learn_t:
            ts = []
            for gcn in self.gcns:
                ts.append(gcn.t.item())
            if final:
                print('Final t {}'.format(ts))
            else:
                logging.info('Epoch {}, t {}'.format(epoch, ts))

        if self.learn_p:
            ps = []
            for gcn in self.gcns:
                ps.append(gcn.p.item())
            if final:
                print('Final p {}'.format(ps))
            else:
                logging.info('Epoch {}, p {}'.format(epoch, ps))

        if self.learn_y:
            ys = []
            for gcn in self.gcns:
                ys.append(gcn.sigmoid_y.item())
            if final:
                print('Final sigmoid(y) {}'.format(ys))
            else:
                logging.info('Epoch {}, sigmoid(y) {}'.format(epoch, ys))

        if self.msg_norm:
            ss = []
            for gcn in self.gcns:
                ss.append(gcn.msg_norm.msg_scale.item())
            if final:
                print('Final s {}'.format(ss))
            else:
                logging.info('Epoch {}, s {}'.format(epoch, ss))


class MLP(torch.nn.Module):
    def __init__(self, args):
        super().__init__()
        self.layers = nn.ModuleList()
        self.hidden_channels = args.hidden_channels
        self.num_tasks = args.num_tasks
        self.num_layers = args.num_layers
        self.dropout = args.dropout

        node_features_file_path = args.nf_path
        self.node_features = torch.load(node_features_file_path).to(args.device)

        self.layers.append(nn.Linear(8, self.hidden_channels))
        for i in range(1, self.num_layers):
            self.layers.append(nn.Linear(self.hidden_channels, self.hidden_channels))
        self.layers.append(nn.Linear(self.hidden_channels, self.num_tasks))

    def forward(self, x, node_index, edge_index, edge_attr, epoch=-1):
        h = self.node_features[node_index]
        for i in range(self.num_layers - 1):
            h = self.layers[i](h)
            h = F.relu(h)
            h = F.dropout(h, p=self.dropout, training=self.training)
        outputs = self.layers[-1](h)

        return outputs