import numpy as np
import torch
import math
import torch.optim as optim
from models import *
from torch_geometric.data import DataLoader
from prepare_data import get_eff_res_data
from pathlib import Path

import argparse

parser = argparse.ArgumentParser(description='training parameters')

parser.add_argument('--nhid', type=int, default=225,
                    help='number of hidden node features')
parser.add_argument('--model_name', type=str, default='GCN',
                    help='which model: GCN, GIN, GraphSAGE')
parser.add_argument('--nlayers', type=int, default=11,
                    help='number of layers')
parser.add_argument('--epochs', type=int, default=3000,
                    help='max epochs')
parser.add_argument('--device', type=str, default=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
                    help='computing device')
parser.add_argument('--batch', type=int, default=128,
                    help='batch size')
parser.add_argument('--mix_type', type=int, default=0,
                    help='type of node mixing')
parser.add_argument('--alpha', type=float, default=0.5,
                    help='learning rate')
parser.add_argument('--lr', type=float, default=0.001,
                    help='learning rate')
parser.add_argument('--reduce_point', type=int, default=20,
                    help='length of patience')
parser.add_argument('--start_reduce', type=int, default=250,
                    help='epoch when to start reducing')
parser.add_argument('--seed', type=int, default=1111,
                    help='random seed')

args = parser.parse_args()


torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.manual_seed(args.seed)
np.random.seed(args.seed)

if args.model_name == 'GCN':
    model = GCN(nfeat=1,nhid=225,nclass=1,nlayers=args.nlayers).to(args.device)
elif args.model_name == 'GIN':
    model = GIN(nfeat=1,nhid=225,nclass=1,nlayers=args.nlayers).to(args.device)
elif args.model_name == 'GraphSAGE':
    model = GraphSAGE(nfeat=1,nhid=185,nclass=1,nlayers=args.nlayers).to(args.device)
elif args.model_name == 'GatedGCN':
    model = GatedGCN(nfeat=1, nhid=145, nclass=1, nlayers=args.nlayers).to(args.device)

print(args)
nparams = 0
for p in model.parameters():
    nparams += p.numel()
print('number of parameters: ',nparams)

alpha = args.alpha
train_dataset, test_dataset, val_dataset = get_eff_res_data(alpha,mix_type=args.mix_type)
train_loader = DataLoader(train_dataset, batch_size=args.batch,shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=args.batch,shuffle=False)
val_loader = DataLoader(val_dataset, batch_size=args.batch,shuffle=False)

optimizer = optim.Adam(model.parameters(),lr=args.lr)
lf = torch.nn.L1Loss()

patience = 0
best_eval = 1000000

def test(loader):
    model.eval()
    error = 0
    with torch.no_grad():
        for data in loader:
            data = data.to(args.device)
            output = model(data)
            error += (output - data.y).abs().sum().item()
    return error / len(loader.dataset)

for epoch in range(args.epochs):
    model.train()
    for i, data in enumerate(train_loader):
        data = data.to(args.device)
        optimizer.zero_grad()
        out = model(data)
        loss = lf(out,data.y)
        loss.backward()
        optimizer.step()

    val_loss = test(val_loader)

    if (val_loss < best_eval):
        best_eval = val_loss
        best_test_loss = test(test_loader)

    elif (val_loss >= best_eval and (epoch + 1) >= args.start_reduce):
        patience += 1

    if (epoch + 1) >= args.start_reduce and patience == args.reduce_point:
        patience = 0
        args.lr /= 2.
        for param_group in optimizer.param_groups:
            param_group['lr'] = args.lr

    if (args.lr < 1e-5):
        break

Path('results/eff_res/'+args.model_name+'/').mkdir(parents=True, exist_ok=True)
f = open('results/eff_res/'+args.model_name+'/' + str(alpha)+ '.txt', 'a')
f.write(str(best_test_loss) + '\n')
f.close()