import argparse, time
import os
import numpy as np
import networkx as nx
import math
import random
import copy
from tqdm import tqdm
from typing import List

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data.dataset import Dataset
from torch.utils.data.dataloader import DataLoader

import dgl
from dgl import DGLGraph
import dgl.function as fn
from dgl.data import register_data_args, load_data
from dgl.dataloading import MultiLayerFullNeighborSampler, MultiLayerNeighborSampler
from dgl.dataloading import NodeDataLoader
from dgl.nn import EdgeWeightNorm
from dgl.nn.pytorch import GraphConv, SGConv

from ogb.nodeproppred import DglNodePropPredDataset, Evaluator
from sklearn import preprocessing as sk_prep
from gcn_aggr import GraphConvAGGR, ChebnetIIProp_SP


def print_params(model, string):
    print(f'----------- {string} ----------')
    # Calculate total number of parameters
    total_params = sum(p.numel() for p in model.encoder.parameters())
    # Calculate the size in KB (assuming 4 bytes per parameter for float32)
    params_in_kb = total_params * 4 / 1024
    print(f'Total parameters: {params_in_kb:.2f} KB')
    print('-----------------------------------')



class GCN(nn.Module):
    def __init__(self, g, in_feats, n_hidden, n_classes, n_layers, activation, dropout, bias = True, weight=True):
        super(GCN, self).__init__()
        self.g = g
        self.layers = nn.ModuleList()
        self.bns = torch.nn.ModuleList()
        self.res_linears = nn.ModuleList()
        self.layers.append(GraphConv(in_feats, n_hidden, weight = weight, bias = bias, activation=activation))
        self.bns.append(torch.nn.BatchNorm1d(n_hidden, momentum = 0.01))
        self.res_linears.append(torch.nn.Linear(in_feats, n_hidden))
        for i in range(1, n_layers - 1):
            self.layers.append(GraphConv(n_hidden, n_hidden, weight=weight, bias=bias, activation=activation))
            self.bns.append(torch.nn.BatchNorm1d(n_hidden, momentum = 0.01))
            self.res_linears.append(torch.nn.Linear(n_hidden, n_hidden))
        self.layers.append(GraphConv(n_hidden, n_classes))
        self.res_linears.append(torch.nn.Identity())
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, blocks):
        collect = []
        h = blocks[0].srcdata['feat']
        h = self.dropout(h)
        num_output_nodes = blocks[-1].num_dst_nodes()
        collect.append(h[:num_output_nodes])
        for l, (layer, block) in enumerate(zip(self.layers, blocks)):
            h_res = h[:block.num_dst_nodes()]
            h = layer(block, h)
            h = self.dropout(h)
            collect.append(h[:num_output_nodes])
            h += self.res_linears[l](h_res)
        return collect[-1]


class GCNAGGR(nn.Module):
    def __init__(self, g, in_feats, n_hidden, n_layers, activation, dropout):
        super(GCNAGGR, self).__init__()
        self.g = g
        self.layers = nn.ModuleList()
        self.bns = torch.nn.ModuleList()
        # self.res_linears = nn.ModuleList()
        # self.res_linears.append(torch.nn.Linear(in_feats, in_feats))
        for i in range(0, n_layers - 1):
            self.layers.append(GraphConvAGGR(in_feats, activation=activation))
            self.bns.append(torch.nn.BatchNorm1d(n_hidden, momentum = 0.01))
            # self.res_linears.append(torch.nn.Linear(n_hidden, n_hidden))
        self.layers.append(GraphConvAGGR(in_feats, activation=activation))
        # self.res_linears.append(torch.nn.Identity())
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, blocks):
        collect = []
        h = blocks[0].srcdata['feat']
        h = self.dropout(h)
        num_output_nodes = blocks[-1].num_dst_nodes()
        collect.append(h[:num_output_nodes])
        for l, (layer, block) in enumerate(zip(self.layers, blocks)):
            # h_res = h[:block.num_dst_nodes()]
            h = layer(block, h)
            h = self.dropout(h)
            collect.append(h[:num_output_nodes])
            # h += self.res_linears[l](h_res)
        return collect[-1]


class ChebNetAGGR(nn.Module):
    def __init__(self, g, in_feats, n_hidden, n_layers, activation, dropout, num_hop):
        super(ChebNetAGGR, self).__init__()
        self.g = g
        self.layers = nn.ModuleList()
        self.bns = torch.nn.ModuleList()
        for i in range(0, n_layers - 1):
            self.layers.append(ChebnetIIProp_SP(in_feats, num_hop))
            self.bns.append(torch.nn.BatchNorm1d(n_hidden, momentum = 0.01))
        self.layers.append(ChebnetIIProp_SP(in_feats, num_hop))
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, blocks):
        collect = []
        h = blocks[0].srcdata['feat']
        h = self.dropout(h)
        num_output_nodes = blocks[-1].num_dst_nodes()
        collect.append(h[:num_output_nodes])
        for l, (layer, block) in enumerate(zip(self.layers, blocks)):
            h = layer(block, h)
            h = self.dropout(h)
            collect.append(h[:num_output_nodes])
        return collect[-1]


class Encoder(nn.Module):
    def __init__(self, g, in_feats, n_hidden, n_layers, dropout, gnn_encoder, k = 1):
        super(Encoder, self).__init__()
        self.g = g
        self.gnn_encoder = gnn_encoder
        if gnn_encoder == 'gcn':
            activation = nn.PReLU(n_hidden)
            self.conv = GCN(g, in_feats, n_hidden, n_hidden, n_layers, activation, dropout)
        elif gnn_encoder == 'sgc':
            activation = nn.PReLU(n_hidden)
            self.conv = SGConv(in_feats, n_hidden, k=10, cached=True)
        elif gnn_encoder == 'gcn_aggr':
            activation = nn.PReLU(in_feats)
            self.conv = GCNAGGR(g, in_feats, n_hidden, n_layers, activation, dropout)
        elif gnn_encoder == 'chebnet_aggr':
            activation = nn.PReLU(in_feats)
            self.conv = ChebNetAGGR(g, in_feats, n_hidden, n_layers, activation, dropout, k)

    def forward(self, blocks, corrupt=False):
        if corrupt:
            for block in blocks:
                block.ndata['feat']['_N'] = block.ndata['feat']['_N'][torch.randperm(block.num_src_nodes())]
        if self.gnn_encoder in ['gcn', 'gcn_aggr', 'chebnet_aggr']:
            features = self.conv(blocks)
        elif self.gnn_encoder == 'sgc':
            features = self.conv(self.g, blocks)
        return features


class GGD(nn.Module):
    def __init__(self, g, in_feats, n_hidden, n_layers, dropout, proj_layers, gnn_encoder, num_hop):
        super(GGD, self).__init__()
        self.encoder = Encoder(g, in_feats, n_hidden, n_layers, dropout, gnn_encoder, num_hop)
        self.mlp = torch.nn.ModuleList()
        self.mlp.append(nn.Linear(in_feats if gnn_encoder != 'gcn' else n_hidden, n_hidden))
        for i in range(proj_layers-1):
            self.mlp.append(nn.Linear(n_hidden, n_hidden))
        self.loss = nn.BCEWithLogitsLoss()
        self.graphconv = GraphConv(in_feats, n_hidden, weight=False, bias=False, activation=None)

    def forward(self, features, labels, loss_func):
        h_1 = self.encoder(features, corrupt=False)
        h_2 = self.encoder(features, corrupt=True)

        sc_1 = h_1.squeeze(0)
        sc_2 = h_2.squeeze(0)
        for i, lin in enumerate(self.mlp):
            sc_1 = lin(sc_1)
            sc_2 = lin(sc_2)

        sc_1 = sc_1.sum(1).unsqueeze(0)
        sc_2 = sc_2.sum(1).unsqueeze(0)

        lbl_1 = torch.ones(1, sc_1.shape[1])
        lbl_2 = torch.zeros(1, sc_1.shape[1])
        lbl = torch.cat((lbl_1, lbl_2), 1).cuda()

        logits = torch.cat((sc_1, sc_2), 1)
        loss = loss_func(logits, lbl)
        return loss

    def embed(self, blocks):
        h_1 = self.encoder(blocks, corrupt=False)
        return h_1.detach()


class Classifier(nn.Module):
    def __init__(self, n_hidden, n_classes):
        super(Classifier, self).__init__()
        self.fc = nn.Linear(n_hidden, n_classes)
        self.reset_parameters()

    def reset_parameters(self):
        self.fc.reset_parameters()

    def forward(self, features):
        features = self.fc(features)
        return torch.log_softmax(features, dim=-1)


class NodeSet(Dataset):
    def __init__(self, node_list: List[int], labels):
        super(NodeSet, self).__init__()
        self.node_list = node_list
        self.labels = labels
        assert len(self.node_list) == len(self.labels)

    def __len__(self):
        return len(self.node_list)

    def __getitem__(self, idx):
        return self.node_list[idx], self.labels[idx]


class NbrSampleCollater(object):
    def __init__(self, graph: dgl.DGLHeteroGraph,
                 block_sampler: dgl.dataloading.BlockSampler):
        self.graph = graph
        self.block_sampler = block_sampler

    def collate(self, batch):
        batch = torch.tensor(batch)
        nodes = batch[:, 0]
        labels = batch[:, 1]
        blocks = self.block_sampler.sample_blocks(self.graph, nodes)
        return blocks, labels


def aug_feature_dropout(input_feat, drop_percent=0.2):
    aug_input_feat = copy.deepcopy(input_feat)
    drop_feat_num = int(aug_input_feat.shape[1] * drop_percent)
    drop_idx = random.sample([i for i in range(aug_input_feat.shape[1])], drop_feat_num)
    aug_input_feat[:, drop_idx] = 0

    return aug_input_feat


def evaluate(model, features, labels, mask):
    model.eval()
    with torch.no_grad():
        logits = model(features)
        logits = logits[mask]
        labels = labels[mask]
        _, indices = torch.max(logits, dim=1)
        correct = torch.sum(indices == labels)
        return correct.item() * 1.0 / len(labels)


def load_data_ogb(dataset, args):
    global n_node_feats, n_classes

    if args.data_root_dir == 'default':
        data = DglNodePropPredDataset(name=dataset)
    else:
        data = DglNodePropPredDataset(name=dataset, root=args.data_root_dir)

    evaluator = Evaluator(name=dataset)

    splitted_idx = data.get_idx_split()
    train_idx, val_idx, test_idx = splitted_idx["train"], splitted_idx["valid"], splitted_idx["test"]
    graph, labels = data[0]

    # Replace node features here
    if args.pretrain_path != 'None':
        graph.ndata["feat"] = torch.tensor(np.load(args.pretrain_path)).float()
        print("Pretrained node feature loaded! Path: {}".format(args.pretrain_path))

    n_node_feats = graph.ndata["feat"].shape[1]
    n_classes = (labels.max() + 1).item()
    return graph, labels, train_idx, val_idx, test_idx, evaluator


def preprocess(graph):
    global n_node_feats

    # make bidirected
    feat = graph.ndata["feat"]
    graph.ndata["feat"] = feat

    # add self-loop
    print(f"Total edges before adding self-loop {graph.number_of_edges()}")
    graph = graph.remove_self_loop().add_self_loop()
    print(f"Total edges after adding self-loop {graph.number_of_edges()}")

    graph.create_formats_()

    return graph

def main(args):
    cuda = True
    free_gpu_id = args.gpu
    torch.cuda.set_device(free_gpu_id)
    # load and preprocess dataset
    if 'ogbn' not in args.dataset_name:
        data = load_data(args)
        features = torch.FloatTensor(data.features)
        labels = torch.LongTensor(data.labels)
        if hasattr(torch, 'BoolTensor'):
            train_mask = torch.BoolTensor(data.train_mask)
            val_mask = torch.BoolTensor(data.val_mask)
            test_mask = torch.BoolTensor(data.test_mask)
        else:
            train_mask = torch.ByteTensor(data.train_mask)
            val_mask = torch.ByteTensor(data.val_mask)
            test_mask = torch.ByteTensor(data.test_mask)
        in_feats = features.shape[1]
        n_classes = data.num_labels
        n_edges = data.graph.number_of_edges()
        g = data.graph
        train_mask = train_mask.cuda()
        val_mask = val_mask.cuda()
        test_mask = test_mask.cuda()
        if args.self_loop:
            g.remove_edges_from(nx.selfloop_edges(g))
            g.add_edges_from(zip(g.nodes(), g.nodes()))
        g = DGLGraph(g)
    else:
        g, all_labels, train_mask, val_mask, test_mask, evaluator = load_data_ogb(args.dataset_name, args)
        g = preprocess(g)

        features = g.ndata['feat']
        all_labels = all_labels.T.squeeze(0)

        all_labels, train_idx, val_idx, test_idx, features = map(
            lambda x: x.to(free_gpu_id), (all_labels, train_mask, val_mask, test_mask, features)
        )

        in_feats = g.ndata['feat'].shape[1]
        n_classes = all_labels.T.max().item() + 1
        n_edges = g.num_edges()

    fanouts_train = [12,12,12]
    fanouts_test = [12,12,12]

    train_collater = NbrSampleCollater(
        g, MultiLayerNeighborSampler(fanouts=fanouts_train, replace=False))
    train_node_set = NodeSet(torch.LongTensor(np.arange(g.num_nodes())).tolist(), all_labels.tolist())
    train_node_loader = DataLoader(dataset=train_node_set, batch_size=2048,
                                        shuffle=True, num_workers=4, pin_memory=True,
                                        collate_fn=train_collater.collate, drop_last=False)

    # create DGI model
    ggd = GGD(g,
              in_feats,
              args.n_hidden,
              args.n_layers,
              args.dropout,
              args.proj_layers,
              args.gnn_encoder,
              args.num_hop)

    if cuda:
        ggd.cuda()

    ggd_optimizer = torch.optim.AdamW(ggd.parameters(),
                                     lr=args.ggd_lr,
                                     weight_decay=args.weight_decay)

    b_xent = nn.BCEWithLogitsLoss()

    # train graph group discrimination
    cnt_wait = 0
    best = 1e9
    best_t = 0
    dur = []
    total_times = []
    tag = str(int(np.random.random() * 10000000000)) #generate a unique tag

    for epoch in range(args.n_ggd_epochs):
        start_time = time.time()
        t0 = time.time()
        ggd.train()
        if epoch >= 3:
            t0 = time.time()

        loss = 0
        for n_iter, (nodes, labels) in enumerate(tqdm(train_node_loader, desc=f'train epoch {epoch}')):
            input_nodes, output_nodes, blocks = nodes
            blocks = [block.to(free_gpu_id) for block in blocks]
            labels = labels.to(free_gpu_id)
            loss = ggd(blocks, labels, b_xent)

            ggd_optimizer.zero_grad()
            loss.backward()
            ggd_optimizer.step()

        epoch_time = time.time() - start_time
        total_times.append(epoch_time)

        if loss < best:
            best = loss
            best_t = epoch
            cnt_wait = 0
            torch.save(ggd.state_dict(), 'pkl/best_ggd' + tag + '.pkl')
        else:
            cnt_wait += 1

        if cnt_wait == args.patience:
            print('Early stopping!')
            break

        if epoch >= 3:
            dur.append(time.time() - t0)

        print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | "
              "ETputs(KTEPS) {:.2f}".format(epoch, np.mean(dur), loss.item(),
                                            n_edges / np.mean(dur) / 1000))

    print(f'Average epoch time = {np.mean(total_times):.6f} seconds.\n')
    # # train classifier
    print('Loading {}th epoch'.format(best_t))
    ggd.load_state_dict(torch.load('pkl/best_ggd' + tag + '.pkl'))

    # get all embeddings for evaluation
    ggd.eval()
    embeds = []

    test_collater = NbrSampleCollater(
        g, MultiLayerNeighborSampler(fanouts=fanouts_test, replace=False))
    test_node_set = NodeSet(torch.LongTensor(np.arange(g.num_nodes())).tolist(), all_labels.tolist())
    test_node_loader = DataLoader(dataset=test_node_set, batch_size=4196,
                                        shuffle=False, num_workers=0, pin_memory=True,
                                        collate_fn=test_collater.collate, drop_last=False)

    for n_iter, (nodes, labels) in enumerate(tqdm(test_node_loader, desc=f'loading embedding for evaluation')):
        input_nodes, output_nodes, blocks = nodes
        blocks = [block.to(free_gpu_id) for block in blocks]
        labels = labels.to(free_gpu_id)
        embed = ggd.embed(blocks)
        embeds.append(embed.cpu())

    l_embeds = torch.cat(embeds, dim = 0)

    torch.cuda.empty_cache()

    '''obtain embedding for downstream classifier training'''

    print('Start Testing. Please wait...')
    g_embeds = graph_power(l_embeds, g)
    embeds = l_embeds + g_embeds
    embeds = sk_prep.normalize(X=embeds.cpu().numpy(), norm="l2")
    embeds = torch.FloatTensor(embeds).cuda()

    # create classifier model
    classifier = Classifier(in_feats if args.gnn_encoder !='gcn' else args.n_hidden, n_classes)
    if cuda:
        classifier.cuda()

    classifier_optimizer = torch.optim.AdamW(classifier.parameters(),
                                            lr=args.classifier_lr,
                                            weight_decay=args.weight_decay)

    all_labels = all_labels.cuda()
    dur = []
    best_acc = 0
    patience = 100
    wait = 0
    for epoch in range(args.n_classifier_epochs):
        classifier.train()
        if epoch >= 3:
            t0 = time.time()

        classifier_optimizer.zero_grad()
        preds = classifier(embeds)
        loss = F.nll_loss(preds[train_mask], all_labels[train_mask])
        loss.backward()
        classifier_optimizer.step()

        if epoch >= 3:
            dur.append(time.time() - t0)
        acc = evaluate(classifier, embeds, all_labels, val_mask)
        # print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | "
        #       "ETputs(KTEPS) {:.2f}".format(epoch, np.mean(dur), loss.item(),
        #                                     acc, n_edges / np.mean(dur) / 1000))
        if acc > best_acc:
            best_acc = acc
            wait = 0
        wait += 1
        if wait > patience:
            break

    test_acc = evaluate(classifier, embeds, all_labels, test_mask)
    print(f"Test Accuracy {test_acc:.4f}")
    val_acc = evaluate(classifier, embeds, all_labels, val_mask)
    print(f"Val Accuracy {val_acc:.4f}")
    return test_acc, val_acc


def get_free_gpu():
    os.system('nvidia-smi -q -d Memory |grep -A4 GPU|grep Free >tmp')
    memory_available = [int(x.split()[2]) for x in open('tmp', 'r').readlines()]
    return np.argmax(memory_available)


def graph_power(embed, g):
    feat = embed.squeeze(0)

    degs = g.in_degrees().float().clamp(min=1)
    norm = torch.pow(degs, -0.5)
    norm = norm.to(feat.device).unsqueeze(1)
    for _ in range(10):
        feat = feat * norm
        g.ndata['h2'] = feat
        g.update_all(fn.copy_u('h2', 'm'),
                     fn.sum('m', 'h2'))
        feat = g.ndata.pop('h2')
        feat = feat * norm
    return feat


if __name__ == '__main__':
    import warnings

    warnings.filterwarnings("ignore")

    parser = argparse.ArgumentParser(description='DGI')
    register_data_args(parser)
    parser.add_argument("--dropout", type=float, default=0.,
                        help="dropout probability")
    parser.add_argument("--gpu", type=int, default=0,
                        help="gpu")
    parser.add_argument("--ggd-lr", type=float, default=0.001,
                        help="ggd learning rate")
    parser.add_argument("--drop_feat", type=float, default=0.2,
                        help="feature dropout rate")
    parser.add_argument("--classifier-lr", type=float, default=0.05,
                        help="classifier learning rate")
    parser.add_argument("--n-ggd-epochs", type=int, default=500,
                        help="number of training epochs")
    parser.add_argument("--n-classifier-epochs", type=int, default=100,
                        help="number of training epochs")
    parser.add_argument("--n-hidden", type=int, default=512,
                        help="number of hidden gcn units")
    parser.add_argument("--proj_layers", type=int, default=1,
                        help="number of project linear layers")
    parser.add_argument("--n-layers", type=int, default=1,
                        help="number of hidden gcn layers")
    parser.add_argument("--weight-decay", type=float, default=0.,
                        help="Weight for L2 loss")
    parser.add_argument("--patience", type=int, default=500,
                        help="early stop patience condition")
    parser.add_argument("--self-loop", action='store_true',
                        help="graph self-loop (default=False)")
    parser.add_argument("--n_trails", type=int, default=5,
                        help="number of trails")
    parser.add_argument("--gnn_encoder", type=str, default='gcn',
                        help="choice of gnn encoder")
    parser.add_argument("--num_hop", type=int, default=10,
                        help="number of k for sgc")
    parser.add_argument('--data_root_dir', type=str, default='default',
                           help="dir_path for saving graph data. Note that this model use DGL loader so do not mix up with the dir_path for the Pyg one. Use 'default' to save datasets at current folder.")
    parser.add_argument("--pretrain_path", type=str, default='None',
                        help="path for pretrained node features")
    parser.add_argument('--dataset_name', type=str, default='cora',
                        help='Dataset name: cora, citeseer, pubmed, cs, phy')
    parser.set_defaults(self_loop=False)
    args = parser.parse_args()
    print(args)

    test_accs = []
    val_accs = []
    for i in range(args.n_trails):
        test_acc, val_acc = main(args)
        test_accs.append(test_acc)
        val_accs.append(val_acc)
    print(f'test accuracy: {np.array(test_accs).mean()} +- {np.array(test_accs).std()}')
    print(f'mean val accuracy: {np.array(val_accs).mean()} +- {np.array(val_accs).std()}' )

    file_name = str(args.dataset_name)
    f = open('result/' + 'result_' + file_name + '.txt', 'a')
    f.write(str(args) + '\n')
    f.write(mean_acc + '\n')
