#!/opt/conda/bin/python3
from __future__ import division, print_function

import argparse
import re
import geotorch
import matplotlib.pyplot as plt
import numpy as np
import os
import time
import torch
import torch.nn.functional as F
import torch.optim as optim
import warnings

from mono_ignn.tools.data import load_chain
from mono_ignn import ChainsMonIGNN
from mono_ignn.ignn import ChainsIGNN
from mono_ignn.other import ChainsGCN
from mono_ignn.tools.normalization import cal_norm
from mono_ignn.tools.recorders import GradientStats
from mono_ignn.tools.utils import accuracy, clip_gradient

warnings.filterwarnings('ignore')

t_total = time.time()

""" ARGS """
parser = argparse.ArgumentParser(description='Chains dataset for IGNN and Monotone-IGNN.')
" COMPUTE ARGS "
compute_parser = parser.add_argument_group('Computational Parameters')
compute_parser.add_argument('--no-cuda', action='store_true',
                    help='Disables CUDA training.')
compute_parser.add_argument('--fastmode', action='store_true',
                    help='Validate during training pass.')
" DATA ARGS "
data_parser = parser.add_argument_group('Data Parameters')
data_parser.add_argument('--normalization', type=str, nargs='+', default='AugNorm',
                   choices=['AugNorm','DiagNorm','IdentNorm','LaplaceNorm','RWNorm','TransposeNorm'],
                   help='Normalization method for the adjacency matrix.')
data_parser.add_argument('--seed', type=int, default=42,
                    help='Random seed.')
data_parser.add_argument('--num_chains', type=int, default=20,
                    help='num of chains')
data_parser.add_argument('--chain_len', type=int, default=10,
                    help='the length of each chain')
data_parser.add_argument('--num_class', type=int, default=2,
                    help='num of class')

" TRAIN ARGS "
train_parser = parser.add_argument_group('Model Parameters')
train_parser.add_argument('--model', type=str, default='MIGNN',
                    choices=['IGNN','MIGNN', 'GCN'],
                    help='Model selection, fixed point only for MIGNN.')
train_parser.add_argument('--hidden', type=int, default=16,
                    help='Number of hidden units.')
train_parser.add_argument('--epochs', type=int, default=2000,
                    help='Number of epochs to train.')
train_parser.add_argument('--lr', type=float, default=0.01,
                    help='Initial learning rate.')
train_parser.add_argument('--weight_decay', type=float, default=5e-4,
                    help='Weight decay (L2 loss on parameters).')
train_parser.add_argument('--RGD',  action='store_true',
                    help='Train via SGD and RGD.')

" MONOTONE ARGS "
mon_parser = parser.add_argument_group('Monotone Operator Splitting Parameters')
mon_parser.add_argument('--lin_module', type=str, default='proj',
                    choices=['cayley','frob','proj','expm','symm','skew'],
                    help='Linear module selection, only for MIGNN.')
mon_parser.add_argument('--mu', type=float, default=None,
                    help='Linear module parameter.')
mon_parser.add_argument('--fp_method', type=str, default='pwr',
                   choices=['pwr','pwr+a','fb','fb+a','pr+a','pr','dr','dr+a','dr+h'],
                   help='Fixed point solving method.')
mon_parser.add_argument('--alpha', type=float, default=.9,
                    help='Fixed point convergence parameter alpha')
mon_parser.add_argument('--beta', type=float, default=.9,
                    help='Acceleration point convergence parameter alpha')
mon_parser.add_argument('--fp_tol', type=float, default=3e-6,
                    help='Fixed point tolerance parameter.')
mon_parser.add_argument('--max_iter', type=int, default=300,
                    help='Fixed point maximum iters.')
mon_parser.add_argument('--inv_method', type=str, default='direct',
                   help='Fixed point solving method [direct,eig,neumann-*].')
" REGULARIZATION ARGS "
reg_parser = parser.add_argument_group('Regularization Parameters')
reg_parser.add_argument('--jac_weight', type=float, default=0.0,
                    help='jacobian regularization loss weight (default to 0)')
reg_parser.add_argument('--jac_freq', type=float, default=0.0,
                    help='the frequency of applying the jacobian regularization (default to 0)')
" OTHER GNN ARGS "
oth_parser = parser.add_argument_group('Fixed Point Parameters')
oth_parser.add_argument('--kappa', type=float, default=0.9,
                    help='Projection parameter. ||W|| <= kappa/lpf(A)')
oth_parser.add_argument('--num_eigenvec', type=int, default=100,
                    help='Dropout rate (1 - keep probability).')
oth_parser.add_argument('--dropout', type=float, default=0.8,
                    help='Dropout rate (1 - keep probability).')
oth_parser.add_argument('--rho', type=float, default=0.2,
                    help='Percent of -1 in lin_module Cayley')
oth_parser.add_argument('--adj_pow', type=int, default=1,
                    help='Adjacency power.')
oth_parser.add_argument('--lr_orth', type=float, default=0.2,
                    help='Learning rate for RGD')
oth_parser.add_argument('--clip', type=float, default=0.5,
                    help='Gradient clipping.')
oth_parser.add_argument('--allow_tf32', action='store_true',
                    help='Allow fast float multiplication.')

args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()
for arg in vars(args):
    print(arg, getattr(args,arg))

if args.allow_tf32:
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
else:
    torch.backends.cuda.matmul.allow_tf32 = False
    torch.backends.cudnn.allow_tf32 = False

""" SEEDING """
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)

""" LOAD """
adj, sp_adj, features, labels, idx_train, idx_val, idx_test, edge_index = load_chain(args.normalization, args.cuda,
                                                                                          args.num_chains,
                                                                                          args.chain_len,
                                                                                          num_class=args.num_class,adj_pow=args.adj_pow)
Y = labels
m = features.shape[0]
m_y = torch.max(Y).int().item() + 1



if re.match(r'^neumann-*',args.inv_method):
    args.neum = int(args.inv_method.split('-')[1])
else:
    args.neum = 1


""" MODEL """
if args.model == 'MIGNN':
    model = ChainsMonIGNN(nfeat=features.shape[1],
            nhid=args.hidden,
            nclass=labels.max().item() + 1,
            num_node = adj.shape[1],
            dropout=args.dropout,
            adj=adj,
            sp_adj = sp_adj,
            linModule=args.lin_module,
            fpMethod=args.fp_method,
            invMethod=args.inv_method,
            kappa=args.kappa,
            max_iter=args.max_iter,
            alpha = args.alpha,
            update_alpha=False,
            alpha_factor=args.adj_pow * args.neum,
            beta = args.beta,
            tol = args.fp_tol,
            mu = args.mu,
            rho = args.rho,
            device='cuda' if args.cuda else 'cpu')
elif args.model == 'IGNN':
    model = ChainsIGNN(nfeat=features.shape[1],
            nhid=args.hidden,
            nclass=labels.max().item() + 1,
            num_node = adj.shape[1],
            dropout=args.dropout,
            adj = adj,
            rho = args.rho,
            tol = args.fp_tol,
            linModule=args.lin_module,
            device='cuda' if args.cuda else 'cpu',
            kappa=args.kappa)
elif args.model == 'GCN':
    model = ChainsGCN(nfeat=features.shape[1],
        nhid=args.hidden,
        nclass=2,
        num_node = adj.shape[1],
        dropout=args.dropout,
        adj = adj
    )
else:
    raise NotImplementedError('Model not found: {args.model}')

# Dynamic Trivialization
if args.lin_module == 'expm' and args.RGD:
    p_orth = model.ig1.lin_module.C
    orth_params = p_orth.parameters()
    non_orth_params = (
        p for p in model.parameters() if p not in set(p_orth.parameters())
    )
    optimizer = torch.optim.SGD(
    [{"params": non_orth_params}, {"params": orth_params, "lr": args.lr_orth}], lr=args.lr
    )

else:
    optimizer = optim.Adam(model.parameters(),
                       lr=args.lr, weight_decay=args.weight_decay)

features = features.T

if args.cuda:
    model.cuda()
    features = features.cuda()
    adj = adj.cuda()
    edge_index = edge_index.cuda()
    labels = labels.cuda()
    idx_train = idx_train.cuda()
    idx_val = idx_val.cuda()
    idx_test = idx_test.cuda()

norm_factor, edge_index = cal_norm(edge_index)
grad_stats = GradientStats(model)

""" TEST/TRAIN """
def train(epoch,W_old=None):
    t = time.time()
    model.train()
    optimizer.zero_grad()
    compute_jac_loss = np.random.uniform(0,1) < args.jac_freq
    output, jac_loss = model(features, norm_factor=norm_factor, edge_index=edge_index, compute_jac_loss=compute_jac_loss)
    output = F.log_softmax(output, dim=1)
    loss_train = F.nll_loss(output[idx_train], labels[idx_train])
    acc_train = accuracy(output[idx_train], labels[idx_train])
    if compute_jac_loss:
        (loss_train + jac_loss * args.jac_weight).backward()
    else:
        loss_train.backward()

    grad_stats.update()

    if args.clip!=0: # if clip!=0
        clip_gradient(model, clip_norm=args.clip)
    optimizer.step()

    if args.lin_module == 'expm' and args.RGD:
        geotorch.update_base(model.ig1.lin_module.C,'weight')

    if not args.fastmode:
        # Evaluate validation set performance separately,
        # deactivates dropout during validation run.
        model.eval()
        output, jac_loss = model(features, norm_factor=norm_factor, edge_index=edge_index)
        output = F.log_softmax(output, dim=1)
    

    model.eval()
    loss_val = F.nll_loss(output[idx_val], labels[idx_val])
    acc_val = accuracy(output[idx_val], labels[idx_val])
    loss_test = F.nll_loss(output[idx_test], labels[idx_test])
    acc_test = accuracy(output[idx_test], labels[idx_test])

    print('Epoch: {:04d}'.format(epoch+1),
          'loss_train: {:.4f}'.format(loss_train.item()),
          'acc_train: {:.4f}'.format(acc_train.item()),
          'loss_val: {:.4f}'.format(loss_val.item()),
          'acc_val: {:.4f}'.format(acc_val.item()),
          'loss_test: {:.4f}'.format(loss_test.item()),
          'acc_test: {:.4f}'.format(acc_test.item()),
          'time: {:.4f}s'.format(time.time() - t))
    # grad_stats.report()

def test():
    model.eval()
    output, jac_loss = model(features, norm_factor=norm_factor, edge_index=edge_index)
    loss_test = F.nll_loss(output[idx_test], labels[idx_test])
    acc_test = accuracy(output[idx_test], labels[idx_test])
    print("Test set results:",
          "loss= {:.4f}".format(loss_test.item()),
          "accuracy= {:.4f}".format(acc_test.item()))

""" MAIN """
for epoch in range(args.epochs):
    train(epoch)

print("Optimization Finished!")
print("Total time elapsed: {:.4f}s".format(time.time() - t_total))
test()