import argparse
import time
import warnings
import seaborn as sns
import numpy as np
import scipy.sparse as sp

import torch
import torch.nn.modules.loss
import torch.nn.functional as F
from torch import optim
from torch_geometric.utils import to_scipy_sparse_matrix

from models import *
from utils import set_seed, random_splits
from unsup_model import GCNModelAE, GCNModelVAE
from dataset_loader import DataLoader
from eval import unsupervised_test_linear
warnings.filterwarnings("ignore")


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=42, help='seed.')
    parser.add_argument('--dataset', type=str,default='Cora')
    parser.add_argument('--device', type=int, default=0, help='GPU device.')
    parser.add_argument('--runs', type=int, default=10, help='number of runs.')
    parser.add_argument('--net', type=str, default='GCN')

    parser.add_argument('--num_layers', type=int, default=2)
    parser.add_argument('--hidden1', type=int, default=32, help='Number of units in hidden layer 1.')
    parser.add_argument('--hidden2', type=int, default=16, help='Number of units in hidden layer 2.')
    parser.add_argument('--dropout', type=float, default=0.5, help='dropout for neural networks.')
    
    parser.add_argument('--fix_split', action='store_true')
    parser.add_argument('--train_rate', type=float, default=0.6, help='train set rate.')
    parser.add_argument('--val_rate', type=float, default=0.2, help='val set rate.')

    # unsupervised learning
    parser.add_argument('--model', type=str, default='gcn_vae', help="models used")

    parser.add_argument("--patience", type=int, default=20, help="Patient epochs to wait before early stopping.")
    parser.add_argument("--unsup_epochs", type=int, default=500, help="Unupservised training epochs.")
    parser.add_argument("--lr1", type=float, default=0.001, help="Learning rate of the unsupervised model.")
    parser.add_argument("--lr2", type=float, default=0.01, help="Learning rate of linear evaluator.")
    parser.add_argument("--wd1", type=float, default=0.0, help="Weight decay of the unsupervised model.")
    parser.add_argument("--wd2", type=float, default=0.0, help="Weight decay of linear evaluator.")
    args = parser.parse_args()
    return args


def sparse_mx_to_torch_sparse_tensor(sparse_mx):
    """Convert a scipy sparse matrix to a torch sparse tensor."""
    sparse_mx = sparse_mx.tocoo().astype(np.float32)
    indices = torch.from_numpy(
        np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
    values = torch.from_numpy(sparse_mx.data)
    shape = torch.Size(sparse_mx.shape)
    return torch.sparse.FloatTensor(indices, values, shape)


def preprocess_graph(adj):
    adj = sp.coo_matrix(adj)
    adj_ = adj + sp.eye(adj.shape[0])
    rowsum = np.array(adj_.sum(1))
    degree_mat_inv_sqrt = sp.diags(np.power(rowsum, -0.5).flatten())
    adj_normalized = adj_.dot(degree_mat_inv_sqrt).transpose().dot(degree_mat_inv_sqrt).tocoo()
    return sparse_mx_to_torch_sparse_tensor(adj_normalized)


def vgae_loss_function(preds, labels, mu, logvar, n_nodes, norm, pos_weight):
    cost = norm * F.binary_cross_entropy_with_logits(preds, labels, pos_weight=pos_weight)
    KLD = -0.5 / n_nodes * torch.mean(torch.sum(1 + 2 * logvar - mu.pow(2) - logvar.exp().pow(2), 1))
    return cost + KLD


def gae_loss_function(preds, labels, norm, pos_weight):
    cost = norm * F.binary_cross_entropy_with_logits(preds, labels, pos_weight=pos_weight)
    return cost


def sparse_to_tuple(sparse_mx):
    if not sp.isspmatrix_coo(sparse_mx):
        sparse_mx = sparse_mx.tocoo()
    coords = np.vstack((sparse_mx.row, sparse_mx.col)).transpose()
    values = sparse_mx.data
    shape = sparse_mx.shape
    return coords, values, shape


def unsupervised_learning(data, args, device):
    adj = to_scipy_sparse_matrix(data.edge_index)

    adj_norm = preprocess_graph(adj).to(device)
    adj_label = adj + sp.eye(adj.shape[0])
    adj_label = torch.FloatTensor(adj_label.toarray()).to(device)

    pos_weight = float(adj.shape[0] * adj.shape[0] - adj.sum()) / adj.sum()
    pos_weight = torch.tensor(pos_weight, dtype=torch.float32)
    norm = adj.shape[0] * adj.shape[0] / float((adj.shape[0] * adj.shape[0] - adj.sum()) * 2)

    if args.model == 'gcn_vae':
        model = GCNModelVAE(dataset.num_node_features, args.hidden1, args.hidden2, args.dropout).to(device)
    elif args.model == 'gcn_ae':
        model = GCNModelAE(dataset.num_node_features, args.hidden1, args.hidden2, args.dropout).to(device)
    optimizer = optim.Adam(model.parameters(), lr=args.lr1, weight_decay=args.wd1)

    best = float("inf")
    unsup_tag = str(int(time.time()))
    for epoch in range(args.unsup_epochs):
        model.train()
        optimizer.zero_grad()
        recovered, mu, logvar = model(data.x, adj_norm)
        if args.model == 'gcn_vae':
            loss = vgae_loss_function(preds=recovered, labels=adj_label, mu=mu, logvar=logvar, 
                                      n_nodes=data.num_nodes, norm=norm, pos_weight=pos_weight)
        elif args.model == 'gcn_ae':
            loss = gae_loss_function(preds=recovered, labels=adj_label, norm=norm, pos_weight=pos_weight)
    
        loss.backward()
        optimizer.step()
        
        if loss < best:
            best = loss
            torch.save(model.state_dict(), 'unsup_pkl/' + f'{args.model}_cl_' + args.net + '_best_model_'+ args.dataset + unsup_tag + '.pkl')

    model.load_state_dict(torch.load('unsup_pkl/' + f'{args.model}_cl_' + args.net + '_best_model_'+ args.dataset + unsup_tag + '.pkl'))
    model.eval()
    embeds =  mu.data.detach()
    return embeds



if __name__ == '__main__':
    args = parse_args()
    print(args)
    print("---------------------------------------------")
    
    set_seed(args.seed)
    #10 fixed seeds for random splits from BernNet
    SEEDS=[1941488137,4198936517,983997847,4023022221,4019585660,2108550661,1648766618,629014539,3212139042,2424918363]
    device = torch.device('cuda:'+str(args.device) if torch.cuda.is_available() else 'cpu')

    dataset = DataLoader(args.dataset)
    data = dataset[0]
    data = data.to(device)

    percls_trn = int(round(args.train_rate * len(data.y) / dataset.num_classes))
    val_lb = int(round(args.val_rate * len(data.y)))
    
    embeds = unsupervised_learning(data=data, args=args, device=device)
    
    if args.dataset not in ['Computers', 'Photo']:
        full_train_mask, full_val_mask, full_test_mask = data.train_mask, data.val_mask, data.test_mask
    
    unsup_results = []
    for RP in range(args.runs):
        args.seed = SEEDS[RP]
        
        if args.fix_split:
            if args.dataset in ['Computers', 'Photo']:  # no public splitting, train/val/test=1/1/8
                percls_trn = int(round(0.1 * len(data.y) / dataset.num_classes))
                val_lb = int(round(0.1 * len(data.y)))
                data = random_splits(data, dataset.num_classes, percls_trn, val_lb, args.seed).to(device)
            else:
                data.train_mask, data.val_mask, data.test_mask = full_train_mask[:, RP], full_val_mask[:, RP], full_test_mask[:, RP]
        else:       
            data = random_splits(data, dataset.num_classes, percls_trn, val_lb, args.seed).to(device)

        eval_acc = unsupervised_test_linear(data=data, embeds=embeds, n_classes=dataset.num_classes, device=device, args=args)
        unsup_results.append(eval_acc)

    test_acc_mean = np.mean(unsup_results) * 100
    values = np.asarray(unsup_results, dtype=object)
    uncertainty = np.max(np.abs(sns.utils.ci(sns.algorithms.bootstrap(values, func=np.mean, n_boot=1000), 95) - values.mean()))
    print(f'test acc mean = {test_acc_mean:.4f} ± {uncertainty * 100:.4f}')

