import numpy as np
import scipy.sparse as sp
import torch
import torch.nn as nn
import torch_geometric
import torch_geometric.transforms as T

from models import GMI_layer, LogReg, GMI
from utils import process
from loads import load_split
from data import load_dataset
import random
import os
import argparse
import csv
# from loads import load_split

parser = argparse.ArgumentParser(description="PyTorch Implementation of DGI")
parser.add_argument('--dataset', default='wiki',
                    help='name of dataset.')
parser.add_argument('--save_path', default='results',
                    help='name of dataset.')
parser.add_argument('--split_type', default='standard',
                    help='name of dataset.')          
parser.add_argument('--num_splits', default=1, type=int,
                    help='name of dataset.')         
parser.add_argument('--num_seeds', default=1, type=int,
                    help='name of dataset.')
parser.add_argument('--batch_size', default=1, type=int, help='what batch size')
parser.add_argument('--nb_epochs', default=550, type=int, help='how many epochs')
parser.add_argument('--patience', default=20, type=int,
                    help='how many epochs to tolerate if loss criteria not met before breaking')
parser.add_argument('--drop_prob', default=0, type=float, help='drop probability')
parser.add_argument('--lr', default=0.001, type=float, help='learning rate')
parser.add_argument('--l2_coef', default=0.0, type=float, help='weight decay (default: 0.0)')
parser.add_argument('--hid_units', default=256, type=int, help='embedding size')
parser.add_argument('--sparse', default=True, type=bool, help='sparse')
parser.add_argument('--nonlinearity', default='prelu', type=str, help='embedding size')
parser.add_argument('--negative_num', type=int, default=5,
                    help='number of negative examples used in the discriminator (default: 5)')
parser.add_argument('--alpha', type=float, default=0.8,
                    help='parameter for I(h_i; x_i) (default: 0.8)')
parser.add_argument('--beta', type=float, default=1.0,
                    help='parameter for I(h_i; x_j), node j is a neighbor (default: 1.0)')
parser.add_argument('--gamma', type=float, default=1.0,
                    help='parameter for I(w_ij; a_ij) (default: 1.0)')

args = parser.parse_args()

dataset = args.dataset
save_path = args.save_path
csv_path = os.path.join(save_path, dataset.lower()+'.csv')

def make_dir(dirName):
    # Create a target directory & all intermediate
    # directories if they don't exists
    if not os.path.exists(dirName):
        os.makedirs(dirName, exist_ok = True)
        print("[INFO] Directory " , dirName ,  " created")
    else:
        print("[INFO] Directory " , dirName ,  " already exists")

make_dir(save_path)

def create_csv():
    with open(csv_path,'w') as f:
        f.seek(0)
        f.truncate()
        csv_write = csv.writer(f)
        # csv_head = ["split", "seed", "Epoch", "loss", "Test Acc", "Val Acc", "Train acc"]
        csv_head = ["split", "seed", "Best_Epoch", "loss", "Test Acc", "Test Acc std"]
        csv_write.writerow(csv_head)

create_csv()

def generate_split(num_classes, labels, seed=0, train_num_per_c=20, val_num_per_c=30):
    train_mask = torch.zeros(labels.shape[0], dtype=torch.bool)
    val_mask = torch.zeros(labels.shape[0], dtype=torch.bool)
    test_mask = torch.zeros(labels.shape[0], dtype=torch.bool)
    for c in range(num_classes):
        all_c_idx = (labels == c).nonzero()
        if all_c_idx.shape[0] <= train_num_per_c + val_num_per_c:
            test_mask[all_c_idx] = True
            continue
        perm = torch.randperm(all_c_idx.size(0))
        c_train_idx = all_c_idx[perm[:train_num_per_c]]
        train_mask[c_train_idx] = True
        test_mask[c_train_idx] = True
        c_val_idx = all_c_idx[perm[train_num_per_c : train_num_per_c + val_num_per_c]]
        val_mask[c_val_idx] = True
        test_mask[c_val_idx] = True
    test_mask = ~test_mask
    return train_mask, val_mask, test_mask

# training params
batch_size = args.batch_size
nb_epochs = args.nb_epochs
patience = args.patience
lr = args.lr
l2_coef = args.l2_coef
drop_prob = args.drop_prob
hid_units = args.hid_units
sparse = args.sparse
nonlinearity = args.nonlinearity # special name to separate parameters
num_splits = args.num_splits

# if args.dataset in ['pubmed', 'computers', 'photo', 'physics', 'cs']: # fixing params
#     hid_units = 256
# else:
#     hid_units = 512

if dataset in ['computers', 'photo', 'physics', 'cs']:
    print(f'----------------{args.dataset}---------------------')
    print('using custom loader')

    dataset = load_dataset(args.dataset, transform = T.NormalizeFeatures())
    data = dataset[0]
    adj = torch_geometric.utils.to_dense_adj(data.edge_index).squeeze().numpy()
    adj_target = adj + np.eye(adj.shape[0])
    # sp_adj = torch_geometric.utils.dense_to_sparse(adj)
    adj = process.normalize_adj(adj + sp.eye(adj.shape[0]))
    sp_adj = process.sparse_mx_to_torch_sparse_tensor(adj)
    features = data.x
    # features, _ = process.preprocess_features(features)
    features = features.unsqueeze(0)
    nb_classes = dataset.num_classes
    ones = torch.sparse.torch.eye(nb_classes)
    labels = ones.index_select(0, data.y.long())
    labels = labels.unsqueeze(0)
    nb_nodes = data.x.size(0)
    ft_size = features.size(-1)

else:
    if args.dataset == 'wiki':
        print(f'----------------{args.dataset}---------------------')
        print('using custom loader')
        adj, features, gnd, idx_train, idx_val, idx_test, labels = process.load_other_data(dataset)
    else:
        print(f'----------------{args.dataset}---------------------')
        print('using standard loader')
        adj, features, labels, idx_train, idx_val, idx_test = process.load_data(dataset)

    adj_target = adj + np.eye(adj.shape[0])
    adj = process.normalize_adj(adj + sp.eye(adj.shape[0]))
    sp_adj = process.sparse_mx_to_torch_sparse_tensor(adj)

    features, _ = process.preprocess_features(features)
    features = torch.FloatTensor(features[np.newaxis])

    if not sparse:
        adj = torch.FloatTensor(adj[np.newaxis])

    labels = torch.FloatTensor(labels[np.newaxis])

    idx_train = torch.LongTensor(idx_train)
    idx_val = torch.LongTensor(idx_val)
    idx_test = torch.LongTensor(idx_test)

# print(features)
# print(labels)
# print(idx_train)
# print(idx_val)
# print(idx_test)
# print(features.shape[0])
# print(features.shape[1])

nb_nodes = features.shape[1]
ft_size = features.shape[-1]
nb_classes = labels.shape[-1]

# print(sp_adj.coalesce().values().size())
# print(sp_adj.coalesce().indices())

# print(labels.size())
# print(features.size())

# print(idx_test)
# print(idx_train)
# print(idx_val)

print(f"feature.shape: {features.shape}, adj.shape: {adj.shape}, "
      f"labels.shape: {labels.shape}， nb_classes: {nb_classes},"
      f"nb_nodes: {nb_nodes}, ft_size: {ft_size}")

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

path_split = "/home/han/.datasets/splits"
##################################################

for split in range(num_splits):
    print(f"split {split}")
    if args.split_type == 'standard':
        print("Using standard split")
        # Using the standard split
        if args.dataset in ['wiki', 'computers', 'photo', 'physics', 'cs']:
            splits = generate_split(nb_classes, labels[0], split)
            torch.save(splits, os.path.join('./splits', args.dataset + str(split) + '.pt'))
            splits = torch.load(os.path.join('./splits', args.dataset + str(split) + '.pt'))
        else:
            splits = idx_train, idx_val, idx_test
    else:
        print("Using preloaded split")
        splits = load_split(os.path.join(path_split, args.dataset.lower()+'_'+str(split)+'.mask'))
    
    for seed in range(args.num_seeds):
        setup_seed(seed)
        print(f"seed: {seed}")

        if args.dataset == 'cora':
            model = GMI(ft_size, hid_units, nonlinearity)
        else:
            model = GMI_layer(ft_size, hid_units, nonlinearity)
        
        optimiser = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=l2_coef)

        if torch.cuda.is_available():
            print('Using CUDA')
            model.cuda()
            features = features.cuda()
            # print(features.size())
            # exit()
            if sparse:
                sp_adj = sp_adj.cuda()
            else:
                adj = adj.cuda()
            # print(adj)
            # print(features)
            # print(labels)
            # exit()
            labels = labels.cuda()
            idx_train = splits[0].cuda()
            idx_val = splits[1].cuda()
            idx_test = splits[2].cuda()

        b_xent = nn.BCEWithLogitsLoss()
        xent = nn.CrossEntropyLoss()
        cnt_wait = 0
        best = 1e9
        best_t = 0

        for epoch in range(nb_epochs):
            model.train()
            optimiser.zero_grad()

            res = model(features, adj, args.negative_num, sp_adj, None, None)

            loss = args.alpha*process.mi_loss_jsd(res[0], res[1]) + args.beta*process.mi_loss_jsd(res[2], res[3])\
                   + args.gamma*process.reconstruct_loss(res[4], adj_target)
            print('Epoch:', (epoch+1), '  Loss:', loss)

            if loss < best:
                best = loss
                best_t = epoch
                cnt_wait = 0
                torch.save(model.state_dict(), f'{save_path}/best_gmi'+str(args.dataset)+'_'+str(split)+'.pkl')
            else:
                cnt_wait += 1

            if cnt_wait == patience:
                print('Early stopping!')
                break

            loss.backward()
            optimiser.step()

        print('Loading {}th epoch'.format(best_t))
        model.load_state_dict(torch.load(f'{save_path}/best_gmi'+str(args.dataset)+'_'+str(split)+'.pkl'))

        embeds = model.embed(features, sp_adj)
        np.savez(f'{save_path}/best_gmi_embed_' + args.dataset + str(split), x=embeds.cpu().numpy())

        train_embs = embeds[0, idx_train]
        val_embs = embeds[0, idx_val]
        test_embs = embeds[0, idx_test]

        train_lbls = torch.argmax(labels[0, idx_train], dim=1)
        val_lbls = torch.argmax(labels[0, idx_val], dim=1)
        test_lbls = torch.argmax(labels[0, idx_test], dim=1)

        tot = torch.zeros(1)
        tot = tot.cuda()

        accs = []

        for _ in range(50):
            log = LogReg(hid_units, nb_classes)
            opt = torch.optim.Adam(log.parameters(), lr=0.01, weight_decay=0.0)
            log.cuda()

            pat_steps = 0
            best_acc = torch.zeros(1)
            best_acc = best_acc.cuda()
            for _ in range(100):
                log.train()
                opt.zero_grad()

                logits = log(train_embs)
                loss = xent(logits, train_lbls)

                loss.backward()
                opt.step()

            logits = log(test_embs)
            preds = torch.argmax(logits, dim=1)
            acc = torch.sum(preds == test_lbls).float() / test_lbls.shape[0]
            accs.append(acc * 100)
            print(acc)
            tot += acc

        print('Average accuracy:', tot / 50)

        accs = torch.stack(accs)
        print(accs.mean())
        print(accs.std())

        with open(csv_path,'a+') as f:
            csv_write = csv.writer(f)
            # data_row = [split, seed, accs.mean().item(), accs.std().item()]
            data_row = [split, seed, best_t, loss.mean().item(), accs.mean().item(), accs.std().item()]
            csv_write.writerow(data_row)

            #hearder =  ["split", "seed", "Best_Epoch", "loss", "Test Acc", "Test Acc std"]
