import os
import torch
import torch.nn as nn
import argparse
import numpy as np
import scipy.sparse as sp
from models import GMI_layer, LogReg, GMI
from utils import process
import random
import csv
from loads import load_split
import scipy.sparse as sp
import torch.nn.functional as F


from sklearn.cluster import KMeans
from clustering_metric import clustering_metrics
from tqdm import tqdm
import scipy.sparse as sp
from sklearn.decomposition import PCA



"""command-line interface"""
parser = argparse.ArgumentParser(description="PyTorch Implementation of GMI")
parser.add_argument('--dataset', default='cora',
                    help='name of dataset. if on citeseer and pubmed, the encoder is 1-layer GCN. you need to modify gmi.py')
parser.add_argument('--gpu', type=int, default=0,
                    help='set GPU')
"""training params"""
parser.add_argument('--hid_units', type=int, default=512,
                    help='dim of node embedding (default: 512)')
parser.add_argument('--nb_epochs', type=int, default=550,
                    help='number of epochs to train (default: 550)')
parser.add_argument('--epoch_flag', type=int, default=20,
                    help=' early stopping (default: 20)')
parser.add_argument('--lr', type=float, default=0.001,
                    help='learning rate (default: 0.001)')
parser.add_argument('--l2_coef', type=float, default=0.0,
                    help='weight decay (default: 0.0)')
parser.add_argument('--negative_num', type=int, default=5,
                    help='number of negative examples used in the discriminator (default: 5)')
parser.add_argument('--alpha', type=float, default=0.8,
                    help='parameter for I(h_i; x_i) (default: 0.8)')
parser.add_argument('--beta', type=float, default=1.0,
                    help='parameter for I(h_i; x_j), node j is a neighbor (default: 1.0)')
parser.add_argument('--gamma', type=float, default=1.0,
                    help='parameter for I(w_ij; a_ij) (default: 1.0)')
parser.add_argument('--activation', default='prelu',
                    help='activation function')
parser.add_argument('--num-splits', type=int, default=1,
                    help='number of splits use')
parser.add_argument('--num-seeds', type=int, default=10,
                    help='number of seeds use')
parser.add_argument("--pre-norm", type=int, default=1)
parser.add_argument("--final-norm", type=int, default=1)
parser.add_argument("--use-seed", type=int, default=1)
parser.add_argument("--n_clusters", type=int, default=7)

###############################################
# This section of code adapted from Petar Veličković/DGI #
###############################################

args = parser.parse_args()
torch.cuda.set_device(args.gpu)



csv_path = os.path.join(args.dataset.lower()+'_use_seed'+str(args.use_seed)+'_num_splits='+str(args.num_splits)+'_pre_norm'+str(args.pre_norm)+'_final_norm'+str(args.final_norm)+'_.csv')

def clustering(u, labels, tqdm, message):
    for e in range(5):
        kmeans = KMeans(n_clusters=args.n_clusters, random_state=e).fit(u)
        predict_labels = kmeans.predict(u)
        cm = clustering_metrics(labels, predict_labels)
        ac[e], nm[e], f1[e] =cm.evaluationClusterModelFromLabel(tqdm)
    acc_means = np.mean(ac)
    acc_stds = np.std(ac)
    nmi_means = np.mean(nm)
    nmi_stds = np.std(nm)
    f1_means = np.mean(f1)
    f1_stds = np.std(f1)

    print(message, args.dataset,
        'acc_mean: {}'.format(acc_means),
        'acc_std: {}'.format(acc_stds),
        'nmi_mean: {}'.format(nmi_means),
        'nmi_std: {}'.format(nmi_stds),
        'f1_mean: {}'.format(f1_means),
        'f1_std: {}'.format(f1_stds))



def create_csv():
    with open(csv_path,'w') as f:
        f.seek(0)
        f.truncate()
        csv_write = csv.writer(f)
        csv_head = ["split", "seed", "Epoch", "loss", "Test Acc", "Val Acc", "Train acc"]
        # csv_head = [1,2,3,4,5]
        csv_write.writerow(csv_head)

create_csv()



print('Loading ', args.dataset)
adj_ori, features, labels, idx_train, idx_val, idx_test = process.load_data(args.dataset)
features, _ = process.preprocess_features(features)

nb_nodes = features.shape[0]
ft_size = features.shape[1]
nb_classes = labels.shape[1]
adj = process.normalize_adj(adj_ori + sp.eye(adj_ori.shape[0]))

sp_adj = process.sparse_mx_to_torch_sparse_tensor(adj)
features = torch.FloatTensor(features[np.newaxis])
labels = torch.FloatTensor(labels[np.newaxis])
idx_train = torch.LongTensor(idx_train)
idx_val = torch.LongTensor(idx_val)
idx_test = torch.LongTensor(idx_test)

# model = GMI(ft_size, args.hid_units, args.activation)
# optimiser = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.l2_coef)

# if torch.cuda.is_available():
#     print('GPU available: Using CUDA')
#     model.cuda()
#     features = features.cuda()
#     sp_adj = sp_adj.cuda()
#     labels = labels.cuda()
#     idx_train = idx_train.cuda()
#     idx_val = idx_val.cuda()
#     idx_test = idx_test.cuda()

xent = nn.CrossEntropyLoss()
cnt_wait = 0
best = 1e9
best_t = 0

adj_dense = adj_ori.toarray()
adj_target = adj_dense+np.eye(adj_dense.shape[0])
adj_row_avg = 1.0/np.sum(adj_dense, axis=1)
adj_row_avg[np.isnan(adj_row_avg)] = 0.0
adj_row_avg[np.isinf(adj_row_avg)] = 0.0
adj_dense = adj_dense*1.0
for i in range(adj_ori.shape[0]):
    adj_dense[i] = adj_dense[i]*adj_row_avg[i]
adj_ori = sp.csr_matrix(adj_dense, dtype=np.float32)


##################################################
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
##################################################

##################################################
path_split = "/home/han/.datasets/splits"
##################################################

###########FOR CLUSTERING###################
clust_labels = labels.squeeze().cpu().data.numpy()
clust_labels = np.argmax(clust_labels, axis=1)
acc_list = []
nmi_list = []
f1_list = []
stdacc_list = []
stdnmi_list = []
stdf1_list = []
Loss_list = []
#print(clust_labels.shape)
#exit()



for split in range(args.num_splits):
    print(split)
    if args.num_splits == 1:
    # Using the standard split
        splits = idx_train, idx_val, idx_test
    else:
        splits = load_split(os.path.join(path_split, args.dataset.lower()+'_'+str(split)+'.mask'))

    for seed in range(args.num_seeds):
        if args.use_seed:
            setup_seed(seed)
            print(seed)

        if args.dataset == 'cora':
            model = GMI(ft_size, args.hid_units, args.activation)
        else:
            model = GMI_layer(ft_size, args.hid_units, args.activation)
        optimiser = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.l2_coef)

        if torch.cuda.is_available():
            print('GPU available: Using CUDA')
            model.cuda()
            features = features.cuda()
            sp_adj = sp_adj.cuda()
            labels = labels.cuda()
            idx_train = splits[0].cuda()
            idx_val = splits[1].cuda()
            idx_test = splits[2].cuda()


        for epoch in range(args.nb_epochs):
            ################CLUSTERING#################
            rep = 5
            ac = np.zeros(rep)
            nm = np.zeros(rep)
            f1 = np.zeros(rep)



            model.train()
            optimiser.zero_grad()

            res = model(features, adj_ori, args.negative_num, sp_adj, None, None)

            loss = args.alpha*process.mi_loss_jsd(res[0], res[1]) + args.beta*process.mi_loss_jsd(res[2], res[3]) + args.gamma*process.reconstruct_loss(res[4], adj_target)
            print('Epoch:', (epoch+1), '  Loss:', loss)

            if loss < best:
                best = loss
                best_t = epoch
                cnt_wait = 0
                torch.save(model.state_dict(), 'best_gmi'+str(args.dataset)+'.pkl')
            else:
                cnt_wait += 1

            if cnt_wait == args.epoch_flag:
                print('Early stopping!')
                break

            loss.backward()
            optimiser.step()

            ##############CLUSTERING##########################i
            #if (epoch + 1) % 10 == 0:
            #    z = model.embed(features, sp_adj)
            #    #z, _ = model.embed(features, sp_adj if sparse else adj, sparse, None)
            #    z = z.squeeze().cpu().data.numpy()
            #    #u, s, v = sp.linalg.svds(z, k=args.n_clusters, which='LM')
            #    #pca = PCA(n_components=args.n_clusters)
            #    #u = pca.fit_transform(z)

            #    u = z

            #    #print(type(u))
            #    #print(u.shape)
            #    #u = u[0, :-1]
            #    #print(type(u))
            #    #print(u.shape)
            #    #exit()
            #    tqdm.write("Evaluating intermediate clustering results...")
            #    for e in range(rep):
            #        kmeans = KMeans(n_clusters=args.n_clusters, random_state=e).fit(u)
            #        predict_labels = kmeans.predict(u)
            #        cm = clustering_metrics(clust_labels, predict_labels)
            #        ac[e], nm[e], f1[e] =cm.evaluationClusterModelFromLabel(tqdm)
            #    acc_means = np.mean(ac)
            #    acc_stds = np.std(ac)
            #    nmi_means = np.mean(nm)
            #    nmi_stds = np.std(nm)
            #    f1_means = np.mean(f1)
            #    f1_stds = np.std(f1)

            #    print(args.dataset,
            #        'acc_mean: {}'.format(acc_means),
            #        'acc_std: {}'.format(acc_stds),
            #        'nmi_mean: {}'.format(nmi_means),
            #        'nmi_std: {}'.format(nmi_stds),
            #        'f1_mean: {}'.format(f1_means),
            #        'f1_std: {}'.format(f1_stds))




        #############CLUSTERING FINAL##############################
        tqdm.write("Evaluating Optimized clustering results...")
        model.load_state_dict(torch.load('best_gmi'+str(args.dataset)+'.pkl'))

        z = model.embed(features, sp_adj)
        #z = z.squeeze().cpu().data.numpy()
        z = F.normalize(z, p=2., dim=-1)
        print('z.shape before: {}'.format(z.shape))
        z = z.squeeze().cpu().data.numpy()
        print('z.shape after: {}'.format(z.shape))
        clustering(z, clust_labels, tqdm, "no dim: ")

        u_svd, _, _ = sp.linalg.svds(z, k=args.n_clusters, which='LM')
        clustering(u_svd, clust_labels, tqdm, "svd: ")

        pca = PCA(n_components=args.n_clusters)
        u_pca = pca.fit_transform(z)
        clustering(u_pca, clust_labels, tqdm, "pca: ")


        print('Loading {}th epoch'.format(best_t+1))
        model.load_state_dict(torch.load('best_gmi'+str(args.dataset)+'.pkl'))

        embeds = model.embed(features, sp_adj)
        # if args.final_norm == True:
        #     embeds = F.normalize(embeds, p=2., dim=-1)

        train_embs = embeds[0, idx_train]
        # val_embs = embeds[0, idx_val]      # typically, you could use the validation set
        test_embs = embeds[0, idx_test]

        train_lbls = torch.argmax(labels[0, idx_train], dim=1)
        # val_lbls = torch.argmax(labels[0, idx_val], dim=1)
        test_lbls = torch.argmax(labels[0, idx_test], dim=1)

        accs = []

        iter_num = process.find_epoch(args.hid_units, nb_classes, train_embs, train_lbls, test_embs, test_lbls)
        for _ in range(50):
            log = LogReg(args.hid_units, nb_classes)
            opt = torch.optim.Adam(log.parameters(), lr=0.001, weight_decay=0.00001)
            log.cuda()

            pat_steps = 0
            best_acc = torch.zeros(1)
            best_acc = best_acc.cuda()
            for _ in range(iter_num):
                log.train()
                opt.zero_grad()

                logits = log(train_embs)
                loss = xent(logits, train_lbls)

                loss.backward()
                opt.step()

            logits = log(test_embs)
            preds = torch.argmax(logits, dim=1)
            acc = torch.sum(preds == test_lbls).float() / test_lbls.shape[0]
            print(acc * 100)
            accs.append(acc * 100)

        accs = torch.stack(accs)
        print('Average accuracy:', accs.mean())
        print('STD:', accs.std())

        with open(csv_path,'a+') as f:
            csv_write = csv.writer(f)
            data_row = [split, seed, accs.mean(), accs.std()]
            csv_write.writerow(data_row)
