import argparse
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import os

from util import *
from model import *

parser = argparse.ArgumentParser(description='PyTorch graph neural net for whole-graph classification')
parser.add_argument('--dataset', type=str, default="cora",
                    help='name of dataset (default: cora)')
parser.add_argument('--datapath', type=str, default="/data/datasets/graphdata",
                    help='dataset path')
parser.add_argument('--device', type=int, default=0,
                    help='which gpu to use if any (default: 0)')
parser.add_argument('--batch_size', type=int, default=32,
                    help='input batch size for training (default: 32)')
parser.add_argument('--epochs', type=int, default=300,
                    help='number of epochs to train (default: 300)')
parser.add_argument('--lr', type=float, default=0.01,
                    help='learning rate (default: 0.01)')
parser.add_argument('--layers', type=int, default=2,
                    help='number of GNN layers (default: 2)')
parser.add_argument('--MLP_layers', type=int, default=2,
                    help='number of MLP layers in each GNN layer (default: 2)')
parser.add_argument('--seed', type=int, default=0,
                    help='random seed for splitting the dataset into 10 (default: 0)')
parser.add_argument('--hidden_dim', type=int, default=64,
                    help='number of hidden units (default: 64)')
parser.add_argument('--agg', type=str, default="cat", choices=["cat", "sum"],
                    help='aggregate input and its neighbors, can be extended to other method like mean, max etc.')
parser.add_argument('--phi', type=str, default="MLP", choices=["power", "identical", "MLP","vdmd","powvdmd","GCN"],
                    help='aggregate input and its neighbors with sum or concat')
parser.add_argument('--dropout', type=float, default=0,
                        help='final layer dropout (default: 0)')
parser.add_argument('--weight_decay', type=float, default=0,
                        help='weight decay in the optimizer (default: 0)')
parser.add_argument('--train_ratio', type=float, default=0.8,
                        help='The training size ratio (default: 0.8)')
parser.add_argument('--filename', type = str, default = "auto",
                    help='save result to file')
args = parser.parse_args()

device = args.device
dataset = args.dataset
datapath = args.datapath
agg = args.agg
hid_dim = args.hidden_dim
dropout = args.dropout
weight_decay = args.weight_decay
seed=args.seed
nlayers = args.layers
tr_size = args.train_ratio
nlayers_mlp = args.MLP_layers

if not args.filename == "auto":
    filename = args.filename 
else: 
    if not os.path.isdir("./results"):
        os.mkdir("./results")
    if not os.path.isdir("./results/{}".format(dataset)):
        os.mkdir("./results/{}".format(dataset))
        os.mkdir("./results/{}/embeddings".format(dataset))
    filename = "./results/{}/{}_hid{}_mlp{}_lr{}_wd{}_{}_tr{}_dropout{}_s{}.csv"  \
        .format(dataset,args.phi,hid_dim,nlayers_mlp,args.lr, weight_decay, agg, tr_size, dropout, seed )
if os.path.isfile(filename):
    print('%s, file exists.'%(filename))
    os._exit(0)
# if os.path.isfile('./results/{}/embeddings/{}_hid{}_lr{}_tr{}_ep300_s{}.pkl'.format(dataset, args.phi,hid_dim,args.lr,tr_size,   seed)):
#     print('%s, file exists.'%('./results/{}/embeddings/{}_hid{}_lr{}_tr{}_ep300_s{}.pkl'.format(dataset, args.phi,hid_dim,args.lr,tr_size,  seed)))
#     os._exit(0)

torch.manual_seed(0)
np.random.seed(0)    
device = torch.device("cuda:" + str(device)) if torch.cuda.is_available() else torch.device("cpu")
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(0)
    
criterion = nn.CrossEntropyLoss()
  
def train( model, graph, idx_train, optimizer, epoch):
    model.train()
    tr_label = graph.node_tags[idx_train]
    optimizer.zero_grad()    
    output, _ = model([graph])    
    loss = criterion(output[graph.node_index][idx_train], tr_label)
    acc_train = accuracy(output[graph.node_index][idx_train], tr_label)
    loss.backward()
    optimizer.step()
    
    print('Epoch: {:04d}'.format(epoch),
          'loss_train: {:.4f}'.format(loss.item()),
          'acc_train: {:.4f}'.format(acc_train.item()))
    
    return loss.item(), acc_train.item()


def test( model, graph, idx_val, idx_test, epoch):
    model.eval()

    val_label, te_label = graph.node_tags[idx_val], graph.node_tags[idx_test]
    with torch.no_grad():
        output, _ = model([graph]) 
        loss_val = criterion(output[graph.node_index][idx_val], val_label)
        acc_val = accuracy(output[graph.node_index][idx_val], val_label)
        loss_te = criterion(output[graph.node_index][idx_test], te_label)
        acc_te = accuracy(output[graph.node_index][idx_test], te_label)
        
#     if epoch%100==0:
#         save_obj({'node_embedding':output[graph.node_index].cpu().numpy(), 'label':graph.node_tags, 'idx_val':idx_val, 'idx_test':idx_test},
#                  './results/{}/embeddings/{}_hid{}_lr{}_tr{}_ep{}_s{}.pkl'.format(dataset, args.phi,hid_dim,args.lr,tr_size,  epoch, seed))
        
    
    print('loss_val: {:.4f}'.format(loss_val.item()),
          'acc_val: {:.4f}'.format(acc_val.item()))
    print("loss_te {:.4f}".format(loss_te.item()),
          "acc_te {:.4f}".format(acc_te.item()))
    
    return loss_val.item(), acc_val.item(), loss_te.item(), acc_te.item()

graph, idx_train, idx_val, idx_test = load_graph(dataset,datapath)
sss = StratifiedShuffleSplit(n_splits=1, train_size=tr_size, random_state = seed)
idx_train, idx_test = next(sss.split(np.zeros(len(graph.node_tags)), graph.node_tags))
sss = StratifiedShuffleSplit(n_splits=1, train_size=0.5, random_state = seed)
idx_1, idx_2 = next(sss.split(np.zeros(len(idx_test)), graph.node_tags[idx_test]))
idx_val, idx_test = idx_test[idx_1], idx_test[idx_2]

# idx_train=range(len(graph.node_tags))
num_classes = len(graph.unique_node)
# graph.to(device)

m = graph.max_neighbor
in_dim = graph.node_features.shape[1]
print('max neigbor:{}, feature number:{}, number of classes: {}'.format(m,in_dim,num_classes))
out_features = ((hid_dim,  )*nlayers_mlp, )*nlayers

firstphi=False
if args.phi=="power":
    if firstphi:
        phi_features = (in_dim*m+1,)+( hid_dim*m+1,)*(nlayers-1)
        ph = [PHI(m) for i in range(nlayers)]
    else:
        phi_features = (in_dim,)+( hid_dim*m+1,)*(nlayers-1)
        ph = [lambda x:x]+[PHI(m) for i in range(nlayers-1)]
elif args.phi=="identical":
    phi_features = (in_dim, )*nlayers
    ph = [lambda x:x]*nlayers
elif args.phi=="MLP":    
    if firstphi:
        phi_features = (hid_dim,)*nlayers
        ph = [MLP(in_dim,(hid_dim,hid_dim), batch_norm=True)]+[MLP(hid_dim,(hid_dim,hid_dim), batch_norm=True) for i in range(nlayers-1)]
    else:
        phi_features = (in_dim, )+(hid_dim,)*(nlayers-1)
        ph = [lambda x:x]+[MLP(hid_dim,(hid_dim,hid_dim), batch_norm=True) for i in range(nlayers-1)]
elif args.phi == "vdmd":
    if firstphi:
        phi_features = (in_dim*m+1, )+(hid_dim*m+1,)*(nlayers-1)
        ph = [vdPHI(m) for i in range(nlayers)]
    else:
        phi_features = (in_dim,)+( hid_dim*m+1,)*(nlayers-1)
        ph = [lambda x:x]+[vdPHI(m) for i in range(nlayers-1)]        
elif args.phi == "powvdmd":
    if firstphi:
        phi_features = (2*in_dim*m+1-m,)+( 2*hid_dim*m+1-m,)*(nlayers-1)
        ph = [powvdPHI(m) for i in range(nlayers)]
    else:
        phi_features = (in_dim,)+( 2*hid_dim*m+1-m,)*(nlayers-1)
        ph = [lambda x:x]+[powvdPHI(m) for i in range(nlayers-1)]

if args.phi.upper()!="GCN":
    model = ExpGraphNN_ND(in_dim,phi_features,out_features, n_class=num_classes, dropout=dropout, phis=ph,
                      batch_norm=False, agg=agg).to(device)
else:
    graph.edge_mat = normalize(graph.edge_mat + sp.eye(graph.edge_mat.shape[0])).tocoo()
    model = GCN(in_dim,hid_dim,num_classes,nlayers=nlayers,dropout=dropout).to(device)
    
print('model size: %d'%( sum([p.numel() for p in model.parameters()]) ) )

graph.to(device)    
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=weight_decay)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)
# scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)

acc_tr=[]
acc_val=[]
acc_te=[]
loss_tr=[]
loss_val=[]
bestacc=0
bestloss=np.inf
best_epoc = 0
for epoch in range(1, args.epochs + 1):
#     scheduler.step()

    loss_train, acc_train = train( model, graph, idx_train, optimizer, epoch)
    loss_validation, acc_validation, loss_test, acc_test = test( model, graph, idx_val, idx_test, epoch)
    # scheduler.step(avg_loss)
   
    acc_tr.append(acc_train)
    acc_val.append(acc_validation)
    acc_te.append(acc_test)
    loss_tr.append(loss_train)
    loss_val.append(loss_validation)
    
#     if acc_train>bestacc or avg_loss<bestloss:
#         bestacc=max(acc_train, bestacc)
#         bestloss=min(avg_loss, bestloss)
#         best_epoc=epoch
        
#     if epoch-best_epoc>=50:
#         break

res = pd.DataFrame({"acc_tr":acc_tr,"acc_val":acc_val,"acc_te":acc_te,"loss_tr":loss_tr, "loss_val":loss_val})    

if filename!='no':
    res.to_csv(filename)

    


