import argparse
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import os

from tqdm import tqdm

from util import *
from model import *

parser = argparse.ArgumentParser(description='PyTorch graph neural net for whole-graph classification')
parser.add_argument('--dataset', type=str, default="QM9",
                    help='name of dataset (default: QM9)')
parser.add_argument('--datapath', type=str, default="/data/datasets/graphdata",
                    help='dataset path')
parser.add_argument('--device', type=int, default=0,
                    help='which gpu to use if any (default: 0)')
parser.add_argument('--batch_size', type=int, default=64,
                    help='input batch size for training (default: 32)')
parser.add_argument('--epochs', type=int, default=300,
                    help='number of epochs to train (default: 300)')
parser.add_argument('--lr', type=float, default=0.01,
                    help='learning rate (default: 0.01)')
parser.add_argument('--seed', type=int, default=0,
                    help='random seed for splitting the dataset into 10 (default: 0)')
parser.add_argument('--hidden_dim', type=int, default=64,
                    help='number of hidden units (default: 64)')
parser.add_argument('--layers', type=int, default=5,
                    help='number of GNN layers (default: 5)')
parser.add_argument('--agg', type=str, default="cat", choices=["cat", "sum"],
                    help='aggregate input and its neighbors, can be extended to other method like mean, max etc.')
parser.add_argument('--phi', type=str, default="power", choices=["power", "identical", "MLP","vdmd","powvdmd", 'GCN'],
                    help='transformation before aggregation')
parser.add_argument('--dropout', type=float, default=0,
                        help='final layer dropout (default: 0)')
parser.add_argument('--weight_decay', type=float, default=0.0,
                        help='weight decay in the optimizer (default: 0)')
parser.add_argument('--filename', type = str, default = "",
                    help='save result to file')
args = parser.parse_args()

device = args.device
dataset = args.dataset
datapath = args.datapath
agg = args.agg
hid_dim = args.hidden_dim
dropout = args.dropout
weight_decay = args.weight_decay
firstphi = True
nlayers = args.layers

if not args.filename == "":
    filename = args.filename 
else: 
    if not os.path.isdir("./results"):
        os.mkdir("./results")
    if not os.path.isdir("./results/{}".format(dataset)):
        os.mkdir("./results/{}".format(dataset))
    filename = args.filename if not args.filename == "" else "./results/{}/{}_hid{}_lr{}.csv"  \
        .format(dataset,args.phi,hid_dim,args.lr )
if os.path.isfile(filename):
    print('%s, file exists.'%(filename))
    os._exit(0)

torch.manual_seed(0)
np.random.seed(0)    
device = torch.device("cuda:" + str(device)) if torch.cuda.is_available() else torch.device("cpu")
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(0)
    
criterion = nn.L1Loss(reduction='sum')   
def train(args, model, device, train_graphs, optimizer, epoch):
    model.train()

    idxs = np.random.permutation(len(train_graphs))
    
    i=0
    loss_accum = 0
    while i<len(idxs):
        selected_idx = idxs[i:i+args.batch_size]
        i = i+args.batch_size

        batch_graph = [train_graphs[idx] for idx in selected_idx]
        _, output = model(batch_graph)

        labels = torch.Tensor(np.vstack([graph.label for graph in batch_graph])).to(device)

        #compute loss
        loss = criterion(output, labels)

        #backprop
        if optimizer is not None:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        loss = loss.detach().cpu().numpy()
        loss_accum += loss

    average_loss = loss_accum*args.batch_size/len(idxs)
    print("epoch:%d, loss training: %f" % (epoch, average_loss))
    
    return average_loss


def test(args, model, device, train_graphs, val_graphs, test_graphs, epoch):
    model.eval()

    with torch.no_grad():
        tr_batches = [train_graphs[i:i + args.batch_size] for i in range(0, len(train_graphs), args.batch_size)]  
        val_batches = [val_graphs[i:i + args.batch_size] for i in range(0, len(val_graphs), args.batch_size)] 
        te_batches = [test_graphs[i:i + args.batch_size] for i in range(0, len(test_graphs), args.batch_size)] 
        
        pred_tr = torch.cat( [model(graphs)[1] for graphs in tr_batches],dim=0 ).cpu().numpy()
        pred_val = torch.cat([model(graphs)[1] for graphs in val_batches], dim=0 ).cpu().numpy()
        pred_te = torch.cat([model(graphs)[1] for graphs in te_batches],dim=0).cpu().numpy()
         
        label_tr = np.vstack([graph.label for graph in train_graphs])
        label_val = np.vstack([graph.label for graph in val_graphs])
        label_te = np.vstack([graph.label for graph in test_graphs])
        
        mae_tr = np.absolute( (pred_tr-label_tr)).mean(axis=0)
        mae_val = np.absolute((pred_val-label_val)).mean(axis=0)
        mae_te = np.absolute( (pred_te-label_te)).mean(axis=0)

    print("MAE train: %f, val: %f,  test: %f" % (mae_tr.sum(),mae_val.sum(), mae_te.sum()))
    return mae_tr, mae_val, mae_te



train_graphs,val_graphs,test_graphs, train_labels_mean, train_labels_std = load_qm9( datapath='/data/datasets/graphdata')
num_var = train_graphs[0].label.shape[1]
train_graphs = [g.to(device) for g in train_graphs]
val_graphs = [g.to(device) for g in val_graphs]
test_graphs = [g.to(device) for g in test_graphs]

m = max([graph.max_neighbor for graph in train_graphs])
in_dim = train_graphs[0].node_features.shape[1]
print('max neigbor:{}, feature number:{}, number of variables to predict: {}'.format(m,in_dim,num_var))

out_features = ((hid_dim, hid_dim ), )*nlayers

if args.phi=="power":
    if firstphi:
        phi_features = (in_dim*m+1,)+( hid_dim*m+1,)*(nlayers-1)
        ph = [PHI(m) for i in range(nlayers)]
    else:
        phi_features = (in_dim,)+( hid_dim*m+1,)*(nlayers-1)
        ph = [lambda x:x]+[PHI(m) for i in range(nlayers-1)]
elif args.phi=="identical":
    phi_features = (in_dim, )*nlayers
    ph = [lambda x:x]*nlayers
elif args.phi=="MLP":    
    if firstphi:
        phi_features = (hid_dim,)*nlayers
        ph = [MLP(in_dim,(hid_dim,hid_dim), batch_norm=True)]+[MLP(hid_dim,(hid_dim,hid_dim), batch_norm=True) for i in range(nlayers-1)]
    else:
        phi_features = (in_dim, )+(hid_dim,)*(nlayers-1)
        ph = [lambda x:x]+[MLP(hid_dim,(hid_dim,hid_dim), batch_norm=True) for i in range(nlayers-1)]
elif args.phi == "vdmd":
    if firstphi:
        phi_features = (in_dim*m+1, )+(hid_dim*m+1,)*(nlayers-1)
        ph = [vdPHI(m) for i in range(nlayers)]
    else:
        phi_features = (in_dim,)+( hid_dim*m+1,)*(nlayers-1)
        ph = [lambda x:x]+[vdPHI(m) for i in range(nlayers-1)]        
elif args.phi == "powvdmd":
    if firstphi:
        phi_features = (2*in_dim*m+1-m,)+( 2*hid_dim*m+1-m,)*(nlayers-1)
        ph = [powvdPHI(m) for i in range(nlayers)]
    else:
        phi_features = (in_dim,)+( 2*hid_dim*m+1-m,)*(nlayers-1)
        ph = [lambda x:x]+[powvdPHI(m) for i in range(nlayers-1)]

if args.phi.upper()!="GCN":
    model = AttDGraphNN(in_dim,phi_features,out_features, n_class=num_var, dropout=dropout, phis=ph,
                      batch_norm=True, agg=agg).to(device)
else:
    model = GCN(in_dim,hid_dim,num_var,nlayers=nlayers, dropout=dropout).to(device)
    
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=weight_decay)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)


res = pd.DataFrame(columns =["avg_err_tr","avg_err_val","avg_err_te"]+['%s%s'%(s,i) for s in ('tr','val','te') for i in range(num_var) ], index=range(1, 300+ 1)) 
for epoch in range(1, args.epochs + 1):
    scheduler.step()

    avg_loss = train(args, model, device, train_graphs, optimizer, epoch)
    err_tr, err_val, err_te = test(args, model, device, train_graphs,val_graphs,test_graphs, epoch)
   
    res.loc[epoch] = [err_tr.sum(),err_val.sum(),err_te.sum()] + (err_tr*train_labels_std).tolist() +  \
                (err_val*train_labels_std).tolist() + (err_te*train_labels_std).tolist()

    res.to_csv(filename)

    


