#!/opt/conda/bin/python3
from __future__ import division, print_function

import argparse
import numpy as np
import os.path as osp
from scipy.sparse import csr_matrix
from sklearn.metrics import f1_score
from sklearn.model_selection import StratifiedKFold
import time
import torch
from torch.optim.lr_scheduler import LambdaLR
import torch.nn.functional as F
import torch.optim as optim
from torch_geometric.datasets import TUDataset
from torch_geometric.data import DataLoader

from mono_ignn import GraphClassificationMonIGNN
from mono_ignn.tools.normalization import fetch_normalization

# Training settings
""" ARGS """
parser = argparse.ArgumentParser(description='Chains dataset for IGNN and Monotone-IGNN.')
parser.add_argument('--dout', type=str, default="./out/chains/",
                    help='out directory')
" COMPUTE ARGS "
compute_parser = parser.add_argument_group('Computational Parameters')
compute_parser.add_argument('--no-cuda', action='store_true',
                    help='Disables CUDA training.')
" DATA ARGS "
data_parser = parser.add_argument_group('Data Parameters')
data_parser.add_argument('--normalization', type=str, default='LaplaceNorm',
                   choices=['AugNorm','DiagNorm','IdentNorm','LaplaceNorm','RWNorm','TransposeNorm'],
                   help='Normalization method for the adjacency matrix.')
data_parser.add_argument('--seed', type=int, default=42,
                    help='Random seed.')
data_parser.add_argument('--dataset', type=str, default="MUTAG",
                        help='Dataset to use.')
data_parser.add_argument('--fold_idx', type=int, default=0,
                    help='Which fold is chosen for test (0-9).')
" TRAIN ARGS "
train_parser = parser.add_argument_group('Model Parameters')
train_parser.add_argument('--model', type=str, default='MIGNN',
                    choices=['IGNN','MIGNN','EIGNN'],
                    help='Model selection, fixed point only for MIGNN.')
train_parser.add_argument('--RGD',  action='store_true',
                    help='Train via SGD and RGD.')
train_parser.add_argument('--epochs', type=int, default=200,
                    help='Number of epochs to train.')
train_parser.add_argument('--lr', type=float, default=0.01,
                    help='Initial learning rate.')
train_parser.add_argument('--weight_decay', type=float, default=0,
                    help='Weight decay (L2 loss on parameters).')
train_parser.add_argument('--hidden', type=int, default=64,
                    help='Number of hidden units.')
train_parser.add_argument('--feature', type=str, default="mul",
                    choices=['mul', 'cat', 'adj'],
                    help='feature-type')
" MONOTONE ARGS "
mon_parser = parser.add_argument_group('Monotone Operator Splitting Parameters')
mon_parser.add_argument('--lin_module', type=str, default='cayley',
                    choices=['cayley','diagd','frob','proj','expm','symm','skew'],
                    help='Linear module selection, only for MIGNN.')
mon_parser.add_argument('--mu', type=float, default=None,
                    help='Linear module parameter.')
mon_parser.add_argument('--mu0', type=float, default=1,
                    help='Linear module parameter.')
mon_parser.add_argument('--fp_method', type=str, default='pr+a',
                   choices=['pwr','pwr+a','fb','fb+a','pr+a','pr','dr','dr+a','dr+h'],
                   help='Fixed point solving method.')
mon_parser.add_argument('--alpha', type=float, default=.5,
                    help='Fixed point convergence parameter alpha')
mon_parser.add_argument('--beta', type=float, default=.5,
                    help='Fixed point convergence parameter alpha')
mon_parser.add_argument('--fp_tol', type=float, default=3e-6,
                    help='Fixed point tolerance parameter.')
mon_parser.add_argument('--max_iter', type=int, default=300,
                    help='Fixed point maximum iters.')
mon_parser.add_argument('--inv_method', type=str, default='direct',
                   help='Fixed point solving method [direct,eig,neumann-*].')
mon_parser.add_argument('--disable_norm', action='store_true',
                    help='Add additional normalization layer.')
" REGULARIZATION ARGS "
reg_parser = parser.add_argument_group('Regularization Parameters')
reg_parser.add_argument('--jac_weight', type=float, default=0.0,
                    help='jacobian regularization loss weight (default to 0)')
reg_parser.add_argument('--jac_freq', type=float, default=0.0,
                    help='the frequency of applying the jacobian regularization (default to 0)')
" OTHER GNN ARGS "
oth_parser = parser.add_argument_group('Fixed Point Parameters')
oth_parser.add_argument('--ignn_default', action='store_true',
                    help='Run as default ignn')
oth_parser.add_argument('--kappa', type=float, default=0.99,
                    help='Projection parameter. ||W|| <= kappa/lpf(A)')
oth_parser.add_argument('--dropout', type=float, default=0.5,
                    help='Dropout rate (1 - keep probability).')
oth_parser.add_argument('--rho', type=float, default=0.2,
                    help='Percent of -1 in lin_module Cayley')
oth_parser.add_argument('--adj_pow', type=int, default=1,
                    help='Adjacency power.')
oth_parser.add_argument('--lr_orth', type=float, default=0.2,
                    help='Learning rate for RGD')
oth_parser.add_argument('--allow_tf32', action='store_true',
                    help='Allow fast float multiplication.')

args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()
for arg in vars(args):
    print(arg, getattr(args,arg))

""" SEEDING """
np.random.seed(args.seed)
torch.manual_seed(args.seed)

if args.cuda:
    torch.backends.cuda.matmul.allow_tf32 = False
    torch.backends.cudnn.allow_tf32 = False
    torch.cuda.manual_seed(args.seed)

if args.allow_tf32:
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
else:
    torch.backends.cuda.matmul.allow_tf32 = False
    torch.backends.cudnn.allow_tf32 = False


class LinearScheduler(LambdaLR):
    """ Linear warmup and then linear decay.
        Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps.
        Linearly decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` steps.
    """
    def __init__(self, optimizer, t_total, warmup_steps=0, last_epoch=-1):
        self.warmup_steps = warmup_steps
        self.t_total = t_total
        super(LinearScheduler, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)

    def lr_lambda(self, step):
        if step < self.warmup_steps:
            return float(step) / float(max(1, self.warmup_steps))
        return max(0.0, float(self.t_total - step) / float(max(1.0, self.t_total - self.warmup_steps)))


""" LOAD """
dataset = TUDataset('/root/workspace/data/', name=args.dataset).shuffle()
skf = StratifiedKFold(n_splits=10, shuffle = True, random_state = args.seed)
idx_list = []
for idx in skf.split(np.zeros(len(dataset.data.y)), dataset.data.y):
    idx_list.append(idx)
assert 0 <= args.fold_idx and args.fold_idx < 10, "fold_idx must be from 0 to 9."

# Model and optimizer
device = torch.device('cuda' if args.cuda else 'cpu')

results = [[] for i in range(10)]

for fold_idx in range(10):
    print('*'*8,f'FOLD INDX {fold_idx}','*'*8)
    if dataset.num_features ==0:
        dataset.num_features = 1

    if args.model == 'MIGNN':
        model = GraphClassificationMonIGNN(nfeat=dataset.num_features,
            nhid=args.hidden,
            nclass=dataset.num_classes,
            num_node = None,
            dropout=args.dropout,
            adj=None,
            sp_adj = None,
            linModule=args.lin_module,
            fpMethod=args.fp_method,
            invMethod=args.inv_method,
            record=True,
            kappa=args.kappa,
            max_iter=args.max_iter,
            alpha = args.alpha,
            beta = args.beta,
            tol = args.fp_tol,
            mu = args.mu,
            mu0 = args.mu0,
            rho = args.rho,
            device='cuda' if args.cuda else 'cpu').to(device)
    elif args.model == 'IGNN':
        model = GraphClassificationIGNN(nfeat=dataset.num_features,
                    nhid=args.hidden,
                    nclass=dataset.num_classes,
                    num_node = None,
                    record=True,
                    dropout=args.dropout,
                    kappa=args.kappa).to(device)
    else:
        raise NotImplementedError

    train_idx, test_idx = idx_list[fold_idx]
    test_dataset = dataset[test_idx.tolist()]
    train_dataset = dataset[train_idx.tolist()]

    test_loader = DataLoader(test_dataset, batch_size=128)
    train_loader = DataLoader(train_dataset, batch_size=128)


    def train(epoch):

        model.train()

        # if epoch == 51:
        #     for param_group in optimizer.param_groups:
        #         param_group['lr'] = 0.5 * param_group['lr']

        loss_all = 0
        for data in train_loader:
            data = data.to(device)
            optimizer.zero_grad()
            if data.edge_attr is None:
                edge_weight = torch.ones((data.edge_index.size(1), ), dtype=torch.float32, device=data.edge_index.device)
            else:
                if args.dataset == 'MUTAG' or 'PTC_MR':
                    edge_weight = data.edge_attr.argmax(1)
                else:
                    edge_weight = data.edge_attr
            
            adj_sp = csr_matrix((edge_weight.cpu().numpy(), (data.edge_index[0,:].cpu().numpy(), data.edge_index[1,:].cpu().numpy() )), shape=(data.num_nodes, data.num_nodes))

            if args.adj_pow > 1 :
                d = np.max(adj_sp.sum(axis=1))
                adj_sp = (1/d) * adj_sp
                new_adj = [adj_sp]
                for _ in range(1,args.adj_pow):
                    new_adj += [new_adj[-1] @ adj_sp]
                adj_sp = (1/args.adj_pow)*sum(new_adj)

            adj_normalizer = fetch_normalization(args.normalization)
            adj_sp_nz = adj_normalizer(adj_sp)
            adj = torch.sparse.FloatTensor(torch.LongTensor(np.array([adj_sp_nz.row,adj_sp_nz.col])).to(device), torch.Tensor(adj_sp_nz.data).to(device), torch.Size([data.num_nodes, data.num_nodes])) #normalized adj


            adj_ori = torch.sparse.FloatTensor(data.edge_index, edge_weight, torch.Size([data.num_nodes, data.num_nodes])) #original adj
            if data.x is None:
                data.x = torch.sparse.sum(adj_ori, [0]).to_dense().unsqueeze(1).to(device)
            output, _ = model(data.x.T, adj, data.batch,sp_adj = adj_sp_nz)
            loss = F.nll_loss(output, data.y)
            loss.backward()
            loss_all += loss.item() * data.num_graphs

            optimizer.step()

        return loss_all / len(train_dataset)


    def test(loader):
        model.eval()

        correct = 0
        for data in loader:
            data = data.to(device)
            if data.edge_attr is None:
                edge_weight = torch.ones((data.edge_index.size(1), ), dtype=torch.float32, device=data.edge_index.device)
            else:
                edge_weight = torch.ones((data.edge_index.size(1), ), dtype=torch.float32, device=data.edge_index.device)
            adj_sp = csr_matrix((edge_weight.cpu().numpy(), (data.edge_index[0,:].cpu().numpy(), data.edge_index[1,:].cpu().numpy() )), shape=(data.num_nodes, data.num_nodes))
            if args.ignn_default: #see ignn github
                adj_sp = adj_sp + adj_sp.T

            if args.adj_pow > 1 :
                d = np.max(adj_sp.sum(axis=1))
                adj_sp = (1/d) * adj_sp
                new_adj = [adj_sp]
                for _ in range(1,args.adj_pow):
                    new_adj += [new_adj[-1] @ adj_sp]
                adj_sp = (1/args.adj_pow)*sum(new_adj)

            adj_normalizer = fetch_normalization(args.normalization)
            adj_sp_nz = adj_normalizer(adj_sp)
            adj = torch.sparse.FloatTensor(torch.LongTensor(np.array([adj_sp_nz.row,adj_sp_nz.col])).to(device), torch.Tensor(adj_sp_nz.data).to(device), torch.Size([data.num_nodes, data.num_nodes])) #normalized adj
            adj_ori = torch.sparse.FloatTensor(data.edge_index, edge_weight, torch.Size([data.num_nodes, data.num_nodes])) #original adj

            if data.x is None:
                data.x = torch.sparse.sum(adj_ori, [0]).to_dense().unsqueeze(1).to(device)
            output,_ = model(data.x.T, adj, data.batch)
            pred = output.max(dim=1)[1]
            correct += pred.eq(data.y).sum().item()
        return correct / len(loader.dataset)

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    # if not args.ignn_default:
    #     scheduler = LinearScheduler(optimizer, args.epochs)

    for epoch in range(1, args.epochs+1):
        start = time.time()
        train_loss = train(epoch)
        # train_acc = test(train_loader)
        test_acc = test(test_loader)
        results[fold_idx].append(test_acc)
        print('Epoch: {:03d}, Train Loss: {:.7f}, '
              'Test Acc: {:.7f}, Time: {:.4f}'.format(epoch, train_loss, test_acc, time.time()-start))
                                                        #    train_acc, test_acc))
        # if not args.ignn_default:
        #     scheduler.step()
        
re_np = np.array(results)
re_all = [re_np.max(1).mean(), re_np.max(1).std()]
print('Graph classification mean accuracy and std are {}'.format(re_all))