#!/opt/conda/bin/python3
from __future__ import division, print_function

import argparse
import os
import numpy as np
import scipy.sparse as sp
from sklearn import metrics
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import warnings

#from utils import accuracy, clip_gradient, load_txt_data, Evaluation, AdditionalLayer
from mono_ignn.tools.data import load_txt_data
from mono_ignn import AmazonMonIGNN

warnings.filterwarnings('ignore')

""" ARGS """
parser = argparse.ArgumentParser(description='Chains dataset for IGNN and Monotone-IGNN.')
" COMPUTE ARGS "
compute_parser = parser.add_argument_group('Computational Parameters')
compute_parser.add_argument('--no-cuda', action='store_true',default=False,
                    help='Disables CUDA training.')
compute_parser.add_argument('--fastmode', action='store_true',
                    help='Validate during training pass.')
" DATA ARGS "
data_parser = parser.add_argument_group('Data Parameters')
data_parser.add_argument('--normalization', type=str, default='AugNorm',
                   choices=['AugNorm','DiagNorm','IdentNorm','LaplaceNorm','TransposeNorm'],
                   help='Normalization method for the adjacency matrix.')
data_parser.add_argument('--seed', type=int, default=42,
                    help='Random seed.')
data_parser.add_argument('--degree', type=int, default=2,
                    help='degree of the approximation.')
data_parser.add_argument('--per', type=int, default=-1,
                    help='Number of each nodes so as to balance.')
data_parser.add_argument('--experiment', type=str, default="base-experiment",
                    help='feature-type')
data_parser.add_argument('--dataset', type=str, default="amazon-all",
                        help='Dataset to use.')
data_parser.add_argument('--portion', type=float, default=0.06,
                    help='training set fraction for amazon dataset.')
" TRAIN ARGS "
train_parser = parser.add_argument_group('Model Parameters')
train_parser.add_argument('--model', type=str, default='IGNN',
                    choices=['MIGNN'],
                    help='Model selection, fixed point only for MIGNN.')
train_parser.add_argument('--epochs', type=int, default=5000,
                    help='Number of epochs to train.')
train_parser.add_argument('--lr', type=float, default=0.01,
                    help='Initial learning rate.')
train_parser.add_argument('--weight_decay', type=float, default=0,
                    help='Weight decay (L2 loss on parameters).')
train_parser.add_argument('--hidden', type=int, default=128,
                    help='Number of hidden units.')
train_parser.add_argument('--patience', type=int, default=500,
                    help='Number of bad iterations to exit.')
train_parser.add_argument('--feature', type=str, default="mul",
                    choices=['mul', 'cat', 'adj'],
                    help='feature-type')
" MONOTONE ARGS "
mon_parser = parser.add_argument_group('Monotone Operator Splitting Parameters')
mon_parser.add_argument('--lin_module', type=str, default='proj',
                    choices=['cayley','frob','proj','expm','symm','skew'],
                    help='Linear module selection, only for MIGNN.')
mon_parser.add_argument('--mu', type=float, default=None,
                    help='Linear module parameter.')
mon_parser.add_argument('--fp_method', type=str, default='pwr',
                   choices=['pwr','pwr+a','fb','fb+a','pr+a','pr','dr','dr+a','dr+h'],
                   help='Fixed point solving method.')
mon_parser.add_argument('--alpha', type=float, default=.9,
                    help='Fixed point convergence parameter alpha')
mon_parser.add_argument('--beta', type=float, default=.9,
                    help='Fixed point convergence parameter beta')
mon_parser.add_argument('--fp_tol', type=float, default=1e-6,
                    help='Fixed point tolerance parameter.')
mon_parser.add_argument('--inv_method', type=str, default='direct',
                   help='Fixed point solving method [direct,eig,neumann-*].')
mon_parser.add_argument('--disable_norm', action='store_true',
                    help='Add additional normalization layer.')
" REGULARIZATION ARGS "
reg_parser = parser.add_argument_group('Regularization Parameters')
reg_parser.add_argument('--jac_weight', type=float, default=0.0,
                    help='jacobian regularization loss weight (default to 0)')
reg_parser.add_argument('--jac_freq', type=float, default=0.0,
                    help='the frequency of applying the jacobian regularization (default to 0)')
reg_parser.add_argument('--jac_incremental', type=int, default=0,
                    help='if positive, increase jac_weight by 0.1 after this many steps')
reg_parser.add_argument('--spectral_radius_mode', action='store_true',
                    help='compute spectral radius at validation time')
" OTHER GNN ARGS "
oth_parser = parser.add_argument_group('Fixed Point Parameters')
oth_parser.add_argument('--kappa', type=float, default=0.95,
                    help='Projection parameter. ||W|| <= kappa/lpf(A)')
oth_parser.add_argument('--dropout', type=float, default=0.5,
                    help='Dropout rate (1 - keep probability).')
oth_parser.add_argument('--rho', type=float, default=0.2,
                    help='Percent of -1 in lin_module Cayley')
oth_parser.add_argument('--adj_pow', type=int, default=1,
                    help='Adjacency power.')
oth_parser.add_argument('--lr_orth', type=float, default=0.2,
                    help='Learning rate for RGD')

args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()
for arg in vars(args):
    print(arg, getattr(args,arg))

np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)

# Load data
if args.dataset == "amazon-all" or args.dataset == "amazon-top5000":
    portion = args.portion
    adj, sp_adj, features, labels, idx_train, idx_val, idx_test, num_nodes, num_class = load_txt_data(args.adj_pow, args.dataset, portion, normalization=args.normalization)
else:
    print("dataset provided is not supported")
adj = adj.to('cuda' if args.cuda else 'cpu')

Y = labels
m = features.shape[0]
m_y = torch.max(Y).int().item() + 1

""" MODEL """
if args.model == 'MIGNN':
    model = AmazonMonIGNN(nfeat=features.shape[1],
            nhid=args.hidden,
            nclass=num_class,
            num_node=num_nodes,
            dropout=args.dropout,
            adj=adj,
            sp_adj = sp_adj,
            linModule=args.lin_module,
            fpMethod=args.fp_method,
            invMethod=args.inv_method,
            kappa=args.kappa,
            alpha = args.alpha,
            tol = args.fp_tol,
            mu = args.mu,
            beta = args.beta,
            rho = args.rho,
            device='cuda' if args.cuda else 'cpu')
elif args.model == 'IGNN':
    model = AmazonIGNN(nfeat=features.shape[1],
            nhid=args.hidden,
            nclass=num_class,
            num_node=num_nodes,
            dropout=args.dropout,
            adj = adj,
            rho = args.rho,
            tol = args.fp_tol,
            linModule=args.lin_module,
            device='cuda' if args.cuda else 'cpu',
            kappa=args.kappa)
else:
    raise NotImplementedError

optimizer = optim.Adam(model.parameters(),
                       lr=args.lr, weight_decay=args.weight_decay)

if args.cuda:
    model.cuda()
    features = features.cuda()
    adj = adj.cuda()
    labels = labels.cuda()
    idx_train = idx_train.cuda()
    idx_val = idx_val.cuda()
    idx_test = idx_test.cuda()

criterion = nn.BCEWithLogitsLoss()

def Evaluation(output, labels):
    preds = output.cpu().detach().numpy()
    labels = labels.cpu().detach().numpy()

    num_correct = 0
    binary_pred = np.zeros(preds.shape).astype('int')
    for i in range(preds.shape[0]):
        k = labels[i].sum().astype('int')
        topk_idx = preds[i].argsort()[-k:]
        binary_pred[i][topk_idx] = 1
        for pos in list(labels[i].nonzero()[0]):
            if labels[i][pos] and labels[i][pos] == binary_pred[i][pos]:
                num_correct += 1

    print('total number of correct is: {}'.format(num_correct))
    #print('preds max is: {0} and min is: {1}'.format(preds.max(),preds.min()))
    #'''
    return metrics.f1_score(labels, binary_pred, average="micro"), metrics.f1_score(labels, binary_pred, average="macro")

def train(epoch):
    t = time.time()
    model.train()
    optimizer.zero_grad()
    compute_jac_loss = np.random.uniform(0,1) < args.jac_freq
    output , jac_loss = model(features, compute_jac_loss=compute_jac_loss)
    loss_train = criterion(output[idx_train], labels[idx_train])
    f1_train_micro, f1_train_macro = Evaluation(output[idx_train], labels[idx_train])
    if compute_jac_loss:
        (loss_train + jac_loss * args.jac_weight).backward()
    else:
        loss_train.backward()
    optimizer.step()

    if not args.fastmode:
        # Evaluate validation set performance separately, deactivates dropout during validation run.
        model.eval()
        output, jac_loss = model(features)

    #loss_val = criterion(output[idx_val], labels[idx_val])
    #f1_val_micro, f1_val_macro = Evaluation(output[idx_test], labels[idx_test])
    loss_test = criterion(output[idx_test], labels[idx_test])
    f1_test_micro, f1_test_macro = Evaluation(output[idx_test], labels[idx_test])

    print('Epoch: {:04d}'.format(epoch+1),
          'loss_train: {:.4f}'.format(loss_train.item()),
          "f1_train_micro= {:.4f}".format(f1_train_micro),
          "f1_train_macro= {:.4f}".format(f1_train_macro),
          #'loss_val: {:.4f}'.format(loss_val.item()),
          #"f1_val_micro= {:.4f}".format(f1_val_micro),
          #"f1_val_micro= {:.4f}".format(f1_val_macro),
          'loss_test: {:.4f}'.format(loss_test.item()),
          "f1_test_micro= {:.4f}".format(f1_test_micro),
          "f1_test_macro= {:.4f}".format(f1_test_macro),
          'time: {:.4f}s'.format(time.time() - t))

def test():
    model.eval()
    output, jac_loss = model(features)
    loss_test = criterion(output[idx_test], labels[idx_test])
    f1_test_micro, f1_test_macro = Evaluation(output[idx_test], labels[idx_test])
    print("Dataset: " + args.dataset)
    print("Test set results:",
          "loss= {:.4f}".format(loss_test.item()),
          "f1_test_micro= {:.4f}".format(f1_test_micro),
          "f1_test_macro= {:.4f}".format(f1_test_macro))


# Train model
t_total = time.time()
for epoch in range(args.epochs):
    train(epoch)

print("Optimization Finished!")
print("Total time elapsed: {:.4f}s".format(time.time() - t_total))

# Testing
test()