from data_handling import get_zinc_data
import numpy as np
import torch
import math
import torch.optim as optim
from models import *
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch_geometric.data import DataLoader
from torch_geometric.nn import global_mean_pool, global_add_pool
from prepare_data import get_eff_res_data
from pathlib import Path
from scipy.stats import loguniform

import argparse

parser = argparse.ArgumentParser(description='training parameters')

parser.add_argument('--nhid', type=int, default=185,
                    help='number of hidden node features')
parser.add_argument('--model_name', type=str, default='GatedGCN',
                    help='which model: GCN, GIN, GraphSAGE')
parser.add_argument('--nlayers', type=int, default=15,
                    help='number of layers')
parser.add_argument('--epochs', type=int, default=3000,
                    help='max epochs')
parser.add_argument('--device', type=str, default=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
                    help='computing device')
parser.add_argument('--batch', type=int, default=128,
                    help='batch size')
parser.add_argument('--mix_type', type=int, default=1,
                    help='type of node mixing')
parser.add_argument('--alpha', type=float, default=0.8,
                    help='learning rate')
parser.add_argument('--lr', type=float, default=0.001,
                    help='learning rate')
parser.add_argument('--reduce_point', type=int, default=20,
                    help='length of patience')
parser.add_argument('--start_reduce', type=int, default=250,
                    help='epoch when to start reducing')
parser.add_argument('--seed', type=int, default=1111,
                    help='random seed')
args = parser.parse_args()

args.lr = float(loguniform.rvs(0.0001,0.01))
args.nlayers = int(np.random.randint(1,33))

print(args)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.manual_seed(args.seed)
np.random.seed(args.seed)

if args.model_name == 'GCN':
    model = GCN(nfeat=1,nhid=225,nclass=1,nlayers=args.nlayers).to(args.device)
elif args.model_name == 'GIN':
    model = GIN(nfeat=1,nhid=225,nclass=1,nlayers=args.nlayers).to(args.device)
elif args.model_name == 'GraphSAGE':
    model = GraphSAGE(nfeat=1,nhid=185,nclass=1,nlayers=args.nlayers).to(args.device)
elif args.model_name == 'GatedGCN':
    model = GatedGCN(nfeat=1, nhid=145, nclass=1, nlayers=args.nlayers).to(args.device)

nparams = 0
for p in model.parameters():
    nparams += p.numel()
print('number of parameters: ',nparams)

alpha = args.alpha
train_dataset, test_dataset, val_dataset = get_eff_res_data(alpha,mix_type=args.mix_type)
train_loader = DataLoader(train_dataset, batch_size=args.batch,shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000,shuffle=False)
val_loader = DataLoader(val_dataset, batch_size=1000,shuffle=False)

optimizer = optim.Adam(model.parameters(),lr=args.lr)
lf = torch.nn.L1Loss()

patience = 0
best_eval = 1000000

def test(loader):
    model.eval()
    error = 0
    with torch.no_grad():
        for data in loader:
            data = data.to(args.device)
            output = model(data)
            error = lf(output, data.y)/torch.mean(torch.abs(data.y))
    return error.item()

for epoch in range(args.epochs):
    model.train()
    for i, data in enumerate(train_loader):
        data = data.to(args.device)
        optimizer.zero_grad()
        out = model(data)
        loss = lf(out,data.y)
        loss.backward()
        optimizer.step()

    val_loss = test(val_loader)

    if (val_loss < best_eval):
        best_eval = val_loss
        best_test_loss = test(test_loader)

    elif (val_loss >= best_eval and (epoch + 1) >= args.start_reduce):
        patience += 1

    if (epoch + 1) >= args.start_reduce and patience == args.reduce_point:
        patience = 0
        args.lr /= 2.
        for param_group in optimizer.param_groups:
            param_group['lr'] = args.lr

    if (args.lr < 1e-5):
        break

print('final test loss: ',best_test_loss)

Path('results/mixing/mix_type'+str(args.mix_type)+'/'+args.model_name+'/').mkdir(parents=True, exist_ok=True)
f = open('results/mixing/mix_type'+str(args.mix_type)+'/'+args.model_name+'/' + str(args.nlayers)+ '.txt', 'a')
f.write(str(best_test_loss) + '\n')
f.close()
