import os
import argparse
import numpy as np
import torch
import torch.nn as nn
import warnings

from evaluate_embedding import evaluate_embedding
from losses import get_positive_expectation, get_negative_expectation
from tu_datasets import load
warnings.filterwarnings("ignore")


class GCNLayer(nn.Module):
    def __init__(self, in_ft, out_ft, bias=True):
        super(GCNLayer, self).__init__()
        self.fc = nn.Linear(in_ft, out_ft, bias=False)
        self.act = nn.PReLU()

        if bias:
            self.bias = nn.Parameter(torch.FloatTensor(out_ft))
            self.bias.data.fill_(0.0)
        else:
            self.register_parameter('bias', None)

        for m in self.modules():
            self.weights_init(m)

    def weights_init(self, m):
        if isinstance(m, nn.Linear):
            torch.nn.init.xavier_uniform_(m.weight.data)
            if m.bias is not None:
                m.bias.data.fill_(0.0)

    def forward(self, feat, adj):
        feat = self.fc(feat)
        out = torch.bmm(adj, feat)
        if self.bias is not None:
            out += self.bias
        return self.act(out)


class GCN(nn.Module):
    def __init__(self, in_ft, out_ft, num_layers):
        super(GCN, self).__init__()
        n_h = out_ft
        self.layers = []
        self.num_layers = num_layers
        self.layers.append(GCNLayer(in_ft, n_h).cuda())
        for __ in range(num_layers - 1):
            self.layers.append(GCNLayer(n_h, n_h).cuda())

    def forward(self, feat, adj, mask):
        h_1 = self.layers[0](feat, adj)
        h_1g = torch.sum(h_1, 1)
        for idx in range(self.num_layers - 1):
            h_1 = self.layers[idx + 1](h_1, adj)
            h_1g = torch.cat((h_1g, torch.sum(h_1, 1)), -1)
        return h_1, h_1g


class MLP(nn.Module):
    def __init__(self, in_ft, out_ft):
        super(MLP, self).__init__()
        self.ffn = nn.Sequential(
            nn.Linear(in_ft, out_ft),
            nn.PReLU(),
            nn.Linear(out_ft, out_ft),
            nn.PReLU(),
            nn.Linear(out_ft, out_ft),
            nn.PReLU()
        )
        self.linear_shortcut = nn.Linear(in_ft, out_ft)

    def forward(self, x):
        return self.ffn(x) + self.linear_shortcut(x)


class MVGRL(nn.Module):
    def __init__(self, n_in, n_h, num_layers):
        super(MVGRL, self).__init__()
        self.mlp1 = MLP(1 * n_h, n_h)
        self.mlp2 = MLP(num_layers * n_h, n_h)
        self.gnn1 = GCN(n_in, n_h, num_layers)
        self.gnn2 = GCN(n_in, n_h, num_layers)

    def forward(self, adj, diff, feat, mask):
        lv1, gv1 = self.gnn1(feat, adj, mask)
        lv2, gv2 = self.gnn2(feat, diff, mask)

        lv1 = self.mlp1(lv1)
        lv2 = self.mlp1(lv2)

        gv1 = self.mlp2(gv1)
        gv2 = self.mlp2(gv2)

        return lv1, gv1, lv2, gv2

    def embed(self, feat, adj, diff, mask):
        __, gv1, __, gv2 = self.forward(adj, diff, feat, mask)
        return (gv1 + gv2).detach()


def local_global_loss_(l_enc, g_enc, batch, measure, mask):
    num_graphs = g_enc.shape[0]
    num_nodes = l_enc.shape[0]
    max_nodes = num_nodes // num_graphs

    pos_mask = torch.zeros((num_nodes, num_graphs)).cuda()
    neg_mask = torch.ones((num_nodes, num_graphs)).cuda()
    msk = torch.ones((num_nodes, num_graphs)).cuda()
    for nodeidx, graphidx in enumerate(batch):
        pos_mask[nodeidx][graphidx] = 1.
        neg_mask[nodeidx][graphidx] = 0.
    for idx, m in enumerate(mask):
        msk[idx * max_nodes + m: idx * max_nodes + max_nodes, idx] = 0.

    res = torch.mm(l_enc, g_enc.t()) * msk
    E_pos = get_positive_expectation(res * pos_mask, measure, average=False).sum()
    E_pos = E_pos / num_nodes
    E_neg = get_negative_expectation(res * neg_mask, measure, average=False).sum()
    E_neg = E_neg / (num_nodes * (num_graphs - 1))
    return E_neg - E_pos


def global_global_loss_(g1_enc, g2_enc, measure):
    num_graphs = g1_enc.shape[0]

    pos_mask = torch.zeros((num_graphs, num_graphs)).cuda()
    neg_mask = torch.ones((num_graphs, num_graphs)).cuda()
    for graphidx in range(num_graphs):
        pos_mask[graphidx][graphidx] = 1.
        neg_mask[graphidx][graphidx] = 0.

    res = torch.mm(g1_enc, g2_enc.t())

    E_pos = get_positive_expectation(res * pos_mask, measure, average=False).sum()
    E_pos = E_pos / num_graphs
    E_neg = get_negative_expectation(res * neg_mask, measure, average=False).sum()
    E_neg = E_neg / (num_graphs * (num_graphs - 1))
    return E_neg - E_pos


def arg_parse():
    parser = argparse.ArgumentParser(description='GcnInformax Arguments.')
    parser.add_argument('--dataset', help='Dataset')
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--prior', dest='prior', action='store_const',  const=True, default=False)

    parser.add_argument('--aug', type=str, default='dnodes')
    parser.add_argument('--epochs', type=int, default=40)
    parser.add_argument('--lr', type=float, default=0.001, help='Learning rate.')
    parser.add_argument('--wd', type=float, default=0.0)
    parser.add_argument('--patience', type=int, default=20)
    parser.add_argument('--num_layers', type=int, default=5, help='Number of graph convolution layers before each pooling')
    parser.add_argument('--hidden_dim', type=int, default=512, help='')
    return parser.parse_args()


def train(args):
    batch_size = 64
    adj, diff, feat, labels, num_nodes = load(args.dataset)

    feat = torch.FloatTensor(feat).cuda()
    diff = torch.FloatTensor(diff).cuda()
    adj = torch.FloatTensor(adj).cuda()
    labels = torch.LongTensor(labels).cuda()

    ft_size = feat[0].shape[1]
    max_nodes = feat[0].shape[0]

    model = MVGRL(ft_size, args.hidden_dim, args.num_layers).cuda()
    optimiser = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)

    cnt_wait = 0
    best = 1e9
    itr = (adj.shape[0] // batch_size) + 1
    for epoch in range(args.epochs):
        epoch_loss = 0.0
        train_idx = np.arange(adj.shape[0])
        np.random.shuffle(train_idx)

        for idx in range(0, len(train_idx), batch_size):
            model.train()
            optimiser.zero_grad()

            batch = train_idx[idx: idx + batch_size]
            mask = num_nodes[idx: idx + batch_size]

            lv1, gv1, lv2, gv2 = model(adj[batch], diff[batch], feat[batch], mask)

            lv1 = lv1.view(batch.shape[0] * max_nodes, -1)
            lv2 = lv2.view(batch.shape[0] * max_nodes, -1)

            batch = torch.LongTensor(np.repeat(np.arange(batch.shape[0]), max_nodes)).cuda()

            loss1 = local_global_loss_(lv1, gv2, batch, 'JSD', mask)
            loss2 = local_global_loss_(lv2, gv1, batch, 'JSD', mask)
            loss = loss1 + loss2
            epoch_loss += loss
            loss.backward()
            optimiser.step()
        epoch_loss /= itr

        if epoch_loss < best:
            best = epoch_loss
            best_t = epoch
            cnt_wait = 0
            torch.save(model.state_dict(), f'mvgrl_{args.dataset}.pkl')
        else:
            cnt_wait += 1

        if cnt_wait == args.patience:
            break

    model.load_state_dict(torch.load(f'mvgrl-{args.dataset}.pkl'))
    os.remove(f'mvgrl-{args.dataset}.pkl')
    
    features = feat.cuda()
    adj = adj.cuda()
    diff = diff.cuda()
    labels = labels.cuda()
    
    model.eval()
    embeds = model.embed(features, adj, diff, num_nodes)
    res = evaluate_embedding(embeds, labels)
    print(f'{res[0]} +- {res[1]}')


if __name__ == '__main__':
    args = arg_parse()
    print('args')
    print('---------------------------')
    train(args)

    # layers = [2, 8, 12]
    # batch = [32, 64, 128, 256]
    # epoch = [20, 40, 100]
    # ds = ['MUTAG', 'PTC_MR', 'IMDB-BINARY', 'IMDB-MULTI', 'REDDIT-BINARY', 'REDDIT-MULTI-5K']
    # seeds = [123, 132, 321, 312, 231]