from __future__ import division
from __future__ import print_function

import time
import argparse
import numpy as np

import torch
import torch.nn.functional as F
import torch.optim as optim
import matplotlib
import scipy
import itertools
import sys
from utils import load_data, accuracy, full_load_data, data_split, normalize, train, normalize_adj,sparse_mx_to_torch_sparse_tensor, dataset_edge_balance, random_disassortative_splits,rand_train_test_idx
from models import GCN


# Training settings
parser = argparse.ArgumentParser()
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='Disables CUDA training.')
parser.add_argument('--fastmode', action='store_true', default=False,
                    help='Validate during training pass.')
parser.add_argument('--param_tunning', action='store_true', default=False,
                    help='Parameter fine-tunning mode')
parser.add_argument('--seed', type=int, default=42, help='Random seed.')
parser.add_argument('--epochs', type=int, default=5000,
                    help='Number of epochs to train.')
parser.add_argument('--num_splits', type=int, help='number of training/val/test splits ', default = 10)
parser.add_argument('--model', type=str, help='name of model (gcn, sgc, mfgcn, mfsgc, mlp)', default = 'mfgcn')
parser.add_argument('--early_stopping', type=float, default=200,
                    help='early stopping used in GPRGNN')
parser.add_argument('--lr', type=float, default=0.05,
                    help='Initial learning rate.')
parser.add_argument('--weight_decay', type=float, default= 1e-5,
                    help='Weight decay (L2 loss on parameters).')
parser.add_argument('--hidden', type=int, default=64,
                    help='Number of hidden units.')
parser.add_argument('--hops', type=int, default=1,
                    help='Number of hops we use, k= 1,2,3,5.')
parser.add_argument('--idx', type=int, default=0,
                    help='Split number.')
parser.add_argument('--dataset_name', type=str,
                    help='Dataset name.', default = 'cornell')
parser.add_argument('--dropout', type=float, default=0.8,
                    help='Dropout rate (1 - keep probability).')

parser.add_argument('--filter_type', type=int,help='0 for renormalized adjacency with self-loop, 1 for normalized adjacency without self-loop, 2 for lazy random walk matrix, 3 for generalized lazy random walk matrix',
                    default = 0)
parser.add_argument('--low_pass', type=int, help='1 for low pass filter, 0 for high pass filter', default = 1)
parser.add_argument('--sym_filter', type=int, help='0 for random walk normalized filter, 1 for symmetric normalized filter', default = 0)

args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()
torch.pi = torch.acos(torch.zeros(1)).item() * 2

np.random.seed(args.seed)
torch.manual_seed(args.seed)

best_result = 0
best_std = 0
best_dropout = None
best_weight_decay = None
best_lr = None
best_hop = None
best_runtime_average = None
best_epoch_average = None

if args.hops == 1:
    lr = [0.01,0.05]
else:
    lr = [args.lr] #
weight_decay = [0,5e-6,1e-5,5e-5,1e-4,5e-4,1e-3,5e-3,1e-2]
if args.model == 'sgc':
    dropout = [0.0]
else:
    dropout = [0.1, 0.2, 0.3, 0.4, 0.5 ,0.6, 0.7, 0.8, 0.9]
datasets= [ 'cornell','wisconsin','texas','film','chameleon','squirrel','cora','citeseer','pubmed']
adj_low, features, labels = full_load_data(args.dataset_name)
if args.dataset_name in {'CitationFull_dblp', 'Coauthor_CS', 'Coauthor_Physics', 'Amazon_Computers', 'Amazon_Photo'}:
    args.num_splits = 20
else:
    args.num_splits = 10

nnodes = adj_low.shape[0]

if (args.model =='sgc' or args.model =='mfsgc') and (args.hops>1):
        A_EXP = adj_low.to_dense()
        for _ in range(args.hops-1):
            A_EXP = torch.mm(A_EXP,adj_low.to_dense())
        adj_low = A_EXP.to_sparse()
        del A_EXP
adj_high =  torch.eye(nnodes) - adj_low
adj_high = adj_high.to_sparse()

if args.cuda:
    features = features.cuda()
    adj_low = adj_low.cuda()
    adj_high = adj_high.cuda()
    labels = labels.cuda()
for args.lr, args.weight_decay, args.dropout in itertools.product(lr, weight_decay, dropout):
        
    def test():
        model.eval()
        output = model(features, adj_low, adj_high)
        pred = torch.argmax(F.softmax(output,dim=1) , dim=1)
        pred = F.one_hot(pred).float()
        output = F.log_softmax(output, dim=1)
        loss_test = F.nll_loss(output[idx_test], labels[idx_test])
        acc_test = accuracy(output[idx_test], labels[idx_test] )
        return acc_test 
    
    # Train model
    t_total = time.time()
    epoch_total = 0
    patience = 50
    
    result = np.zeros(args.num_splits)
    
    for idx in range(args.num_splits):
        idx_train, idx_val, idx_test = random_disassortative_splits(labels, labels.max()+1)
 
        
        model = GCN(nfeat=features.shape[1], nhid=args.hidden, nclass=labels.max().item() + 1, dropout=args.dropout, hops = args.hops, features = features, model_type = args.model)
        if args.cuda:
            idx_train = idx_train.cuda()
            idx_val = idx_val.cuda()
            idx_test = idx_test.cuda()
            model.cuda()
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
        vlss_mn = np.inf
        vacc_mx = 0.0
        vacc_early_model = None
        vlss_early_model = None
        curr_step = 0
        best_test = 0
        best_training_loss = None
        best_val_acc  = 0
        best_val_loss = float('inf')
        val_loss_history = np.zeros(args.epochs)
        
        for epoch in range(args.epochs):
            t = time.time()
            acc_train, loss_train = train(model, optimizer, adj_low, adj_high, features, labels, idx_train, idx_val)
            
        
            if not args.fastmode:
                # Evaluate validation set performance separately,
                # deactivates dropout during validation run.
                model.eval()
                output = model(features, adj_low, adj_high)
                output = F.log_softmax(output, dim=1)
        
            val_loss = F.nll_loss(output[idx_val], labels[idx_val]).cpu()
            val_acc = accuracy(output[idx_val], labels[idx_val]).cpu()
            
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                best_val_loss = val_loss
                best_test  = test()
                best_training_loss = loss_train
                
            if epoch >= 0:
                val_loss_history[epoch] = val_loss.detach().numpy()
            if args.early_stopping > 0 and epoch > args.early_stopping:
                tmp = np.mean(val_loss_history[epoch-args.early_stopping:epoch])
                if val_loss > tmp:
                    break

        epoch_total = epoch_total + epoch
        
        if args.param_tunning:
            print("Optimization for %s, %s, weight decay %.5f, dropout %.4f, split %d, Best Test Result: %.4f, Training Loss: %.4f"%(args.model, args.dataset_name, args.weight_decay, args.dropout, idx, best_test, best_training_loss))
        # Testing
        
        result[idx] = best_test
        
        del model, optimizer
        if args.cuda: torch.cuda.empty_cache()
    total_time_elapsed = time.time() - t_total
    runtime_average = total_time_elapsed/args.num_splits
    epoch_average = total_time_elapsed/epoch_total * 1000
    if np.mean(result)>best_result:
            best_result = np.mean(result)
            best_std = np.std(result)
            best_dropout = args.dropout
            best_weight_decay = args.weight_decay
            best_lr = args.lr
            best_hop = args.hops
            best_runtime_average = runtime_average
            best_epoch_average = epoch_average

    print("Total time elapsed: {:.4f}s".format(time.time() - t_total))
    if args.param_tunning:
        print("%s, Hops: %s, Hidden: %s, Low Pass: %s, Symmetric Normalization: %s, Filtertype: %s, learning rate %.4f, weight decay %.6f, dropout %.4f, Test Mean: %.4f, Test Std: %.4f, runtime average time: %4f s, epoch average time: %4f ms"%
              (args.dataset_name, args.hops, args.hidden, args.low_pass, args.sym_filter, args.filter_type, args.lr, args.weight_decay, args.dropout, np.mean(result), np.std(result), runtime_average, epoch_average))
    else:
        print("%s, Model: %s, Hops: %s, Hidden: %s, Low Pass: %s, Symmetric Normalization: %s, Filtertype: %s, learning rate %.4f, weight decay %.6f, dropout %.4f, Test Mean: %.4f, Test Std: %.4f, runtime average time: %.2fs, epoch average time: %.2fms"%
              (args.dataset_name, args.model, args.hops, args.hidden, args.low_pass, args.sym_filter, args.filter_type, args.lr, args.weight_decay, args.dropout, np.mean(result), np.std(result), runtime_average, epoch_average))
 
print("Best Result of %s, on Dataset: %s, Hops: %s, Hidden: %s, Low Pass: %s, Symmetric Normalization: %s, Filtertype: %s, learning rate %.4f, weight decay %.6f, dropout %.4f, Test Mean: %.4f, Test Std: %.4f, epoch average/runtime average time: %.2fms/%.2fs"%
      (args.model, args.dataset_name, best_hop, args.hidden, args.low_pass, args.sym_filter, args.filter_type, best_lr, best_weight_decay, best_dropout, best_result, best_std, best_epoch_average, best_runtime_average))
 
