import time
import os
import random
import argparse
import numpy as np
import warnings
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.utils import remove_self_loops
import dgl
from dgl.nn import EdgeWeightNorm
from sklearn.neighbors import kneighbors_graph
from scipy import sparse

# from unsup_model import GREET, Edge_Discriminator
from utils import set_seed, random_splits, get_structural_encoding
from dataset_loader import DataLoader
from eval import unsupervised_test_linear
from utils import split_batch, get_adj_from_edges, normalize_adj
warnings.filterwarnings("ignore")


class SGC(nn.Module):
    def __init__(self, dataset, args):
        super(SGC, self).__init__()
        self.dropout = args.dropout
        self.linear = nn.Linear(dataset.num_node_features, args.hidden)
        self.k = args.K

    def forward(self, x, g):
        x = torch.relu(self.linear(x))
        x = F.dropout(x, p=self.dropout, training=self.training)
        for _ in range(self.k):
            x = torch.matmul(g, x)
        return x


class GREET(nn.Module):
    def __init__(self, encoder1, encoder2, nlayers_proj, emb_dim, proj_dim, batch_size, device):
        super(GREET, self).__init__()
        self.encoder1 = encoder1
        self.encoder2 = encoder2
        self.batch_size = batch_size
        self.device = device

        if nlayers_proj == 1:
            self.proj_head1 = nn.Sequential(nn.Linear(emb_dim, proj_dim))
            self.proj_head2 = nn.Sequential(nn.Linear(emb_dim, proj_dim))
        elif nlayers_proj == 2:
            self.proj_head1 = nn.Sequential(nn.Linear(emb_dim, proj_dim), nn.ReLU(inplace=True), nn.Linear(proj_dim, proj_dim))
            self.proj_head2 = nn.Sequential(nn.Linear(emb_dim, proj_dim), nn.ReLU(inplace=True), nn.Linear(proj_dim, proj_dim))

    def get_embedding(self, x, a1, a2):
        emb1 = self.encoder1(x, a1)
        emb2 = self.encoder2(x, a2)
        return torch.cat((emb1, emb2), dim=1)

    def get_projection(self, x, a1, a2):
        emb1 = self.encoder1(x, a1)
        emb2 = self.encoder2(x, a2)
        proj1 = self.proj_head1(emb1)
        proj2 = self.proj_head2(emb2)
        return torch.cat((proj1, proj2), dim=1)

    def forward(self, x1, a1, x2, a2):
        emb1 = self.encoder1(x1, a1)
        emb2 = self.encoder2(x2, a2)
        proj1 = self.proj_head1(emb1)
        proj2 = self.proj_head2(emb2)
        loss = self.batch_nce_loss(proj1, proj2)
        return loss

    def set_mask_knn(self, X, k, dataset, metric='cosine'):
        if k != 0:
            path = '../data/knn/{}'.format(dataset)
            if not os.path.exists(path):
                os.makedirs(path)
            file_name = path + '/{}_{}.npz'.format(dataset, k)
            if os.path.exists(file_name):
                knn = sparse.load_npz(file_name)
            else:
                knn = kneighbors_graph(X, k, metric=metric)
                sparse.save_npz(file_name, knn)
            knn = torch.tensor(knn.toarray(), device=self.device) + torch.eye(X.shape[0], device=self.device)
        else:
            knn = torch.eye(X.shape[0], device=self.device)
        self.pos_mask = knn
        self.neg_mask = 1 - self.pos_mask

    def batch_nce_loss(self, z1, z2, temperature=0.2, pos_mask=None, neg_mask=None):
        if pos_mask is None and neg_mask is None:
            pos_mask = self.pos_mask
            neg_mask = self.neg_mask

        nnodes = z1.shape[0]
        if (self.batch_size == 0) or (self.batch_size > nnodes):
            loss_0 = self.infonce(z1, z2, pos_mask, neg_mask, temperature)
            loss_1 = self.infonce(z2, z1, pos_mask, neg_mask, temperature)
            loss = (loss_0 + loss_1) / 2.0
        else:
            node_idxs = list(range(nnodes))
            random.shuffle(node_idxs)
            batches = split_batch(node_idxs, self.batch_size)
            loss = 0
            for b in batches:
                weight = len(b) / nnodes
                loss_0 = self.infonce(z1[b], z2[b], pos_mask[:,b][b,:], neg_mask[:,b][b,:], temperature)
                loss_1 = self.infonce(z2[b], z1[b], pos_mask[:,b][b,:], neg_mask[:,b][b,:], temperature)
                loss += (loss_0 + loss_1) / 2.0 * weight
        return loss

    def infonce(self, anchor, sample, pos_mask, neg_mask, tau):
        sim = self.similarity(anchor, sample) / tau
        exp_sim = torch.exp(sim) * neg_mask
        log_prob = sim - torch.log(exp_sim.sum(dim=1, keepdim=True))
        loss = log_prob * pos_mask
        loss = loss.sum(dim=1) / pos_mask.sum(dim=1)
        return -loss.mean()

    def similarity(self, h1: torch.Tensor, h2: torch.Tensor):
        h1 = F.normalize(h1)
        h2 = F.normalize(h2)
        return h1 @ h2.t()


class Edge_Discriminator(nn.Module):
    def __init__(self, nnodes, input_dim, alpha, sparse, hidden_dim=128, temperature=1.0, bias=0.0 + 0.0001, device=None):
        super(Edge_Discriminator, self).__init__()
        self.embedding_layers = nn.ModuleList()
        self.embedding_layers.append(nn.Linear(input_dim, hidden_dim))
        self.edge_mlp = nn.Linear(hidden_dim * 2, 1)

        self.temperature = temperature
        self.bias = bias
        self.nnodes = nnodes
        self.sparse = sparse
        self.alpha = alpha
        self.device = device

    def get_node_embedding(self, h):
        for layer in self.embedding_layers:
            h = layer(h)
            h = F.relu(h)
        return h

    def get_edge_weight(self, embeddings, edges):
        s1 = self.edge_mlp(torch.cat((embeddings[edges[0]], embeddings[edges[1]]), dim=1)).flatten()
        s2 = self.edge_mlp(torch.cat((embeddings[edges[1]], embeddings[edges[0]]), dim=1)).flatten()
        return (s1 + s2) / 2

    def gumbel_sampling(self, edges_weights_raw):
        eps = (self.bias - (1 - self.bias)) * torch.rand(edges_weights_raw.size()) + (1 - self.bias)
        gate_inputs = torch.log(eps) - torch.log(1 - eps)
        gate_inputs = gate_inputs.to(self.device)
        gate_inputs = (gate_inputs + edges_weights_raw) / self.temperature
        return torch.sigmoid(gate_inputs).squeeze()

    def weight_forward(self, features, edges):
        embeddings = self.get_node_embedding(features)
        edges_weights_raw = self.get_edge_weight(embeddings, edges)
        weights_lp = self.gumbel_sampling(edges_weights_raw)
        weights_hp = 1 - weights_lp
        return weights_lp, weights_hp

    def weight_to_adj(self, edges, weights_lp, weights_hp):
        EOS = 1e-10
        norm = EdgeWeightNorm(norm='both')
        if not self.sparse:
            adj_lp = get_adj_from_edges(edges, weights_lp, self.nnodes)
            adj_lp += torch.eye(self.nnodes, device=self.device)
            adj_lp = normalize_adj(adj_lp, 'sym', self.sparse)

            adj_hp = get_adj_from_edges(edges, weights_hp, self.nnodes)
            adj_hp += torch.eye(self.nnodes, device=self.device)
            adj_hp = normalize_adj(adj_hp, 'sym', self.sparse)

            mask = torch.zeros(adj_lp.shape, device=self.device)
            mask[edges[0], edges[1]] = 1.
            mask.requires_grad = False
            adj_hp = torch.eye(self.nnodes, device=self.device) - adj_hp * mask * self.alpha
        else:
            adj_lp = dgl.graph((edges[0], edges[1]), num_nodes=self.nnodes, device=self.device)
            adj_lp = dgl.add_self_loop(adj_lp)
            weights_lp = torch.cat((weights_lp, torch.ones(self.nnodes, device=self.device))) + EOS
            weights_lp = norm(adj_lp, weights_lp)
            adj_lp.edata['w'] = weights_lp

            adj_hp = dgl.graph((edges[0], edges[1]), num_nodes=self.nnodes, device=self.device)
            adj_hp = dgl.add_self_loop(adj_hp)
            weights_hp = torch.cat((weights_hp, torch.ones(self.nnodes, device=self.device))) + EOS
            weights_hp = norm(adj_hp, weights_hp)
            weights_hp *= - self.alpha
            weights_hp[edges.shape[1]:] = 1
            adj_hp.edata['w'] = weights_hp
        return adj_lp, adj_hp

    def forward(self, features, edges):
        weights_lp, weights_hp = self.weight_forward(features, edges)
        adj_lp, adj_hp = self.weight_to_adj(edges, weights_lp, weights_hp)
        return adj_lp, adj_hp, weights_lp, weights_hp
    


def get_feat_mask(features, mask_rate):
    feat_node = features.shape[1]
    mask = torch.zeros(features.shape, device=features.device)
    samples = np.random.choice(feat_node, size=int(feat_node * mask_rate), replace=False)
    mask[:, samples] = 1
    return mask, samples


def augmentation(features_1, adj_1, features_2, adj_2, args, training):
    # view 1
    mask_1, _ = get_feat_mask(features_1, args.maskfeat_rate_1)
    features_1 = features_1 * (1 - mask_1)
    if not args.sparse:
        adj_1 = F.dropout(adj_1, p=args.dropedge_rate_1, training=training)
    else:
        adj_1.edata['w'] = F.dropout(adj_1.edata['w'], p=args.dropedge_rate_1, training=training)

    # # view 2
    mask_2, _ = get_feat_mask(features_1, args.maskfeat_rate_2)
    features_2 = features_2 * (1 - mask_2)
    if not args.sparse:
        adj_2 = F.dropout(adj_2, p=args.dropedge_rate_2, training=training)
    else:
        adj_2.edata['w'] = F.dropout(adj_2.edata['w'], p=args.dropedge_rate_2, training=training)

    return features_1, adj_1, features_2, adj_2


def generate_random_node_pairs(nnodes, nedges, backup=300):
    rand_edges = np.random.choice(nnodes, size=(nedges + backup) * 2, replace=True)
    rand_edges = rand_edges.reshape((2, nedges + backup))
    rand_edges = torch.from_numpy(rand_edges)
    rand_edges = rand_edges[:, rand_edges[0,:] != rand_edges[1,:]]
    rand_edges = rand_edges[:, 0: nedges]
    return rand_edges


def train_cl(cl_model, discriminator, optimizer_cl, features, str_encodings, edges):
    cl_model.train()
    discriminator.eval()

    adj_1, adj_2, weights_lp, _ = discriminator(torch.cat((features, str_encodings), 1), edges)
    features_1, adj_1, features_2, adj_2 = augmentation(features, adj_1, features, adj_2, args, cl_model.training)
    cl_loss = cl_model(features_1, adj_1, features_2, adj_2)

    optimizer_cl.zero_grad()
    cl_loss.backward()
    optimizer_cl.step()
    return cl_loss.item()


def train_discriminator(cl_model, discriminator, optimizer_disc, features, str_encodings, edges, args):
    cl_model.eval()
    discriminator.train()

    adj_1, adj_2, weights_lp, weights_hp = discriminator(torch.cat((features, str_encodings), 1), edges)
    rand_np = generate_random_node_pairs(features.shape[0], edges.shape[1]).to(device)
    psu_label = torch.ones(edges.shape[1], device=edges.device)

    embedding = cl_model.get_embedding(features, adj_1, adj_2)
    edge_emb_sim = F.cosine_similarity(embedding[edges[0]], embedding[edges[1]])

    rnp_emb_sim_lp = F.cosine_similarity(embedding[rand_np[0]], embedding[rand_np[1]])
    loss_lp = F.margin_ranking_loss(edge_emb_sim, rnp_emb_sim_lp, psu_label, margin=args.margin_hom, reduction='none')
    loss_lp *= torch.relu(weights_lp - 0.5)

    rnp_emb_sim_hp = F.cosine_similarity(embedding[rand_np[0]], embedding[rand_np[1]])
    loss_hp = F.margin_ranking_loss(rnp_emb_sim_hp, edge_emb_sim, psu_label, margin=args.margin_het, reduction='none')
    loss_hp *= torch.relu(weights_hp - 0.5)

    rank_loss = (loss_lp.mean() + loss_hp.mean()) / 2

    optimizer_disc.zero_grad()
    rank_loss.backward()
    optimizer_disc.step()
    return rank_loss.item()


def unsupervised_learning(dataset, data, args, device):
    features = data.x
    edges = remove_self_loops(data.edge_index)[0]
    path = f'/data/se'
    if not os.path.exists(path):
        os.makedirs(path)
    file_name = path + f'/{args.dataset}_{16}.pt'
    if os.path.exists(file_name):
        str_encodings = torch.load(file_name)
    else:
        str_encodings = get_structural_encoding(edges.detach().cpu(), data.num_nodes)
        torch.save(str_encodings, file_name)
    str_encodings = str_encodings.to(device)

    encoder1 = SGC(dataset=dataset, args=args)
    encoder2 = SGC(dataset=dataset, args=args)
    cl_model = GREET(encoder1=encoder1, encoder2=encoder2, nlayers_proj=args.nlayers_proj, emb_dim=args.hidden,
                    proj_dim=args.proj_dim, batch_size=args.cl_batch_size, device=device).to(device)
    cl_model.set_mask_knn(features.detach().cpu(), k=args.knn_k, dataset=args.dataset)
    discriminator = Edge_Discriminator(data.num_nodes, dataset.num_node_features + str_encodings.shape[1], 
                                       args.alpha, args.sparse, device=device).to(device)

    optimizer_cl = torch.optim.Adam(cl_model.parameters(), lr=args.lr_gcl, weight_decay=args.w_decay)
    optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=args.lr_disc, weight_decay=args.w_decay)

    best = float("inf")
    cnt_wait = 0
    unsup_tag = str(int(time.time()))
    for epoch in range(1, args.unsup_epochs + 1):
        for _ in range(args.cl_rounds):
            cl_loss = train_cl(cl_model, discriminator, optimizer_cl, features, str_encodings, edges)
        rank_loss = train_discriminator(cl_model, discriminator, optimizer_discriminator, features, str_encodings, edges, args)
        # print("[TRAIN] Epoch:{:04d} | CL Loss {:.4f} | RANK loss:{:.4f} ".format(epoch, cl_loss, rank_loss))
        if cl_loss < best:
            best = cl_loss
            cnt_wait = 0
            torch.save(cl_model.state_dict(), 'unsup_pkl/' + 'greet_cl_' + args.net + '_best_cl_model_'+ args.dataset + unsup_tag + '.pkl')
            torch.save(discriminator.state_dict(), 'unsup_pkl/' + 'greet_cl_' + args.net + '_best_discriminator_'+ args.dataset + unsup_tag + '.pkl')
        else:
            cnt_wait += 1

        if cnt_wait == args.patience:
            print(f'Early stopping at epoch {epoch}.')
            break

    cl_model.eval()
    discriminator.eval()
    cl_model.load_state_dict(torch.load('unsup_pkl/' + 'greet_cl_' + args.net + '_best_cl_model_'+ args.dataset + unsup_tag + '.pkl'))
    discriminator.load_state_dict(torch.load('unsup_pkl/' + 'greet_cl_' + args.net + '_best_discriminator_'+ args.dataset + unsup_tag + '.pkl'))
    adj_1, adj_2, _, _ = discriminator(torch.cat((features, str_encodings), 1), edges)
    embeds = cl_model.get_embedding(features, adj_1, adj_2).detach()
    os.remove('unsup_pkl/' + 'greet_cl_' + args.net + '_best_cl_model_'+ args.dataset + unsup_tag + '.pkl')
    os.remove('unsup_pkl/' + 'greet_cl_' + args.net + '_best_discriminator_'+ args.dataset + unsup_tag + '.pkl')
    return embeds


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=42, help='seed.')
    parser.add_argument('--dataset', type=str,default='Cora')
    parser.add_argument('--device', type=int, default=0, help='GPU device.')
    parser.add_argument('--runs', type=int, default=10, help='number of runs.')
    parser.add_argument('--net', type=str, default='SGC')

    parser.add_argument('--K', type=int, default=10)
    parser.add_argument('--hidden', type=int, default=64, help='hidden units.')
    parser.add_argument('--dropout', type=float, default=0.5, help='dropout for neural networks.')
    
    parser.add_argument('--train_rate', type=float, default=0.6, help='train set rate.')
    parser.add_argument('--val_rate', type=float, default=0.2, help='val set rate.')

    # unsupervised learning
    parser.add_argument("--patience", type=int, default=50, help="Patient epochs to wait before early stopping.")
    parser.add_argument("--unsup_epochs", type=int, default=1000, help="Unupservised training epochs.")
    parser.add_argument("--lr2", type=float, default=0.01, help="Learning rate of linear evaluator.")
    parser.add_argument("--wd2", type=float, default=0.0, help="Weight decay of linear evaluator.")
    parser.add_argument('--lr_gcl', type=float, default=0.001)
    parser.add_argument('--lr_disc', type=float, default=0.001)
    parser.add_argument('--cl_rounds', type=int, default=2)
    parser.add_argument('--w_decay', type=float, default=0.0)

    # DISC Module - Hyper-param
    parser.add_argument('--sparse', type=int, default=0)
    parser.add_argument('--alpha', type=float, default=0.1)
    parser.add_argument('--margin_hom', type=float, default=0.5)
    parser.add_argument('--margin_het', type=float, default=0.5)

    # GRL Module - Hyper-param
    parser.add_argument('--nlayers_proj', type=int, default=1, choices=[1, 2])
    parser.add_argument('--proj_dim', type=int, default=128)
    parser.add_argument('--cl_batch_size', type=int, default=0)
    parser.add_argument('--knn_k', type=int, default=20)
    parser.add_argument('--maskfeat_rate_1', type=float, default=0.1)
    parser.add_argument('--maskfeat_rate_2', type=float, default=0.5)
    parser.add_argument('--dropedge_rate_1', type=float, default=0.5)
    parser.add_argument('--dropedge_rate_2', type=float, default=0.1)
    args = parser.parse_args()
    return args


if __name__ == '__main__':
    args = parse_args()
    print(args)
    print("---------------------------------------------")
    
    set_seed(args.seed)
    #10 fixed seeds for random splits from BernNet
    SEEDS=[1941488137,4198936517,983997847,4023022221,4019585660,2108550661,1648766618,629014539,3212139042,2424918363]
    device = torch.device('cuda:'+str(args.device) if torch.cuda.is_available() else 'cpu')

    dataset = DataLoader(args.dataset)
    data = dataset[0]
    data = data.to(device)

    percls_trn = int(round(args.train_rate * len(data.y) / dataset.num_classes))
    val_lb = int(round(args.val_rate * len(data.y)))

    embeds = unsupervised_learning(dataset=dataset, data=data, args=args, device=device)
    
    unsup_results = []
    for RP in range(args.runs):
        args.seed = SEEDS[RP]
        tag = str(args.seed)
        data = random_splits(data, dataset.num_classes, percls_trn, val_lb, args.seed).to(device)
        eval_acc = unsupervised_test_linear(data=data, embeds=embeds, n_classes=dataset.num_classes, device=device, args=args)
        unsup_results.append(eval_acc)

    test_acc_mean = np.mean(unsup_results) * 100
    values = np.asarray(unsup_results, dtype=object)
    uncertainty = np.max(np.abs(sns.utils.ci(sns.algorithms.bootstrap(values, func=np.mean, n_boot=1000), 95) - values.mean()))
    print(f'test acc mean = {test_acc_mean:.4f} ± {uncertainty * 100:.4f}')