import os
import time
import random
import torch
import argparse
import numpy as np
import util
from model import SimST
from engine import trainer

torch.set_num_threads(3)

parser = argparse.ArgumentParser()
parser.add_argument('--device', type=str, default='cuda:2', help='device name')
parser.add_argument('--dataset', type=str, default='04', help='dataset name')

parser.add_argument('--in_neighbor', type=int, default=1, help='add neighbors in input or not')
parser.add_argument('--in_neighbor_num', type=int, default=0, help='highlighted neighbors number')
parser.add_argument('--in_dim', type=int, default=5, help='input dimension')
parser.add_argument('--seq_length', type=int, default=12, help='output sequence length')
parser.add_argument('--node_dim', type=int, default=20, help='node embedding dimension')
parser.add_argument('--init_dim', type=int, default=64, help='initial hidden dimension')
parser.add_argument('--skip_dim', type=int, default=0, help='skip hidden dimension')
parser.add_argument('--end_dim', type=int, default=512, help='end hidden dimension')
parser.add_argument('--layer', type=int, default=2, help='layer number')
parser.add_argument('--bs', type=int, default=1024, help='training batch size')
parser.add_argument('--infer_bs', type=int, default=64, help='inference batch size')

parser.add_argument('--epochs', type=int, default=150, help='epochs')
parser.add_argument('--patience', type=int, default=10, help='patience')
parser.add_argument('--stop', type=int, default=20, help='early stop')
parser.add_argument('--lrate', type=float, default=1e-3, help='learning rate')
parser.add_argument('--wdecay', type=float, default=1e-4, help='weight decay rate')
parser.add_argument('--dropout', type=float, default=0.1, help='dropout rate')
parser.add_argument('--clip', type=int, default=5, help='gradient clip')
parser.add_argument('--seed',type=int, default=0, help='random seed')
args = parser.parse_args()
print(args)

if args.dataset == '04':
    adj_data = 'adj_mx_04.pkl'
    input_data = 'PEMS-04'
    node_num = 307
save = 'save'
if not os.path.exists(save):
    os.makedirs(save)
save += '/'
    
def set_seed(seed):
    os.environ['PYTHONHASHSEED'] =str(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = False
set_seed(args.seed)

neighbor_record = None
if args.in_neighbor:
    _, _, pdf = util.load_pickle(adj_data)
    tpdf = np.transpose(pdf)
    pdfs = [pdf]
    tpdfs = [tpdf]
    
    k = args.in_neighbor_num
    neighbor_record = {}
    for n in range(node_num):
        nb1 = np.nonzero(pdfs[0][n])[0]
        tnb1 = np.nonzero(tpdfs[0][n])[0]
        nb1 = np.delete(nb1, np.argwhere(nb1 == n))
        tnb1 = np.delete(tnb1, np.argwhere(tnb1 == n))
        
        w = pdf[n, nb1]
        tw = tpdf[n, tnb1]
        w_idx = w.argsort()[-k:]
        tw_idx = tw.argsort()[-k:]
        n_id = list(nb1[w_idx])
        tn_id = list(tnb1[tw_idx])
        
        while len(n_id) < k:
            n_id.append(n)
        while len(tn_id) < k:
            tn_id.append(n)

        if len(nb1) == 0: nb1 = [n]
        if len(tnb1) == 0: tnb1 = [n]

        neighbor_record[n] = []
        for i in range(k-1, -1, -1):
            neighbor_record[n].append([n_id[i]])
            neighbor_record[n].append([tn_id[i]])
        neighbor_record[n].append(list(nb1))
        neighbor_record[n].append(list(tnb1))   

device = torch.device(args.device)
dataloader = util.load_dataset(input_data, neighbor_record, args.bs, args.infer_bs, args.infer_bs)
scaler = dataloader['scaler']

model = SimST(device, node_num, args.node_dim, args.in_dim, args.seq_length, args.init_dim, args.skip_dim, args.end_dim, args.layer, args.dropout)
para_num = sum([p.nelement() for p in model.parameters()])
print('Total parameters:', para_num)

engine = trainer(device, model, scaler, args.lrate, args.wdecay, args.clip)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(engine.optimizer, factor=np.sqrt(0.1), patience=args.patience, verbose=True)


print('Start training ...')
his_loss =[]
train_time = []
val_time = []
min_loss = float('inf')
wait = 0
for i in range(1, args.epochs + 1):
    train_loss = []
    train_mape = []
    train_rmse = []
    t1 = time.time()
    dataloader['train_loader'].shuffle()
    for iter, (x, y, idx) in enumerate(dataloader['train_loader'].get_iterator()):
        trainx = torch.Tensor(x).to(device)
        trainy = torch.Tensor(y).to(device)
        trainidx = torch.tensor(idx, dtype=torch.long).to(device)
        metrics = engine.train(trainx, trainy, trainidx)

        train_loss.append(metrics[0])
        train_mape.append(metrics[1])
        train_rmse.append(metrics[2])
    t2 = time.time()
    train_time.append(t2-t1)

    valid_loss = []
    valid_mape = []
    valid_rmse = []
    s1 = time.time()
    for iter, (x, y, idx) in enumerate(dataloader['val_loader'].get_iterator()):
        validx = torch.Tensor(x).to(device)
        validy = torch.Tensor(y).to(device)
        valididx = torch.tensor(idx, dtype=torch.long).to(device)
        metrics = engine.eval(validx, validy, valididx)

        valid_loss.append(metrics[0])
        valid_mape.append(metrics[1])
        valid_rmse.append(metrics[2])
    s2 = time.time()
    val_time.append(s2-s1)

    mtrain_loss = np.mean(train_loss)
    mtrain_mape = np.mean(train_mape)
    mtrain_rmse = np.mean(train_rmse)

    mvalid_loss = np.mean(valid_loss)
    mvalid_mape = np.mean(valid_mape)
    mvalid_rmse = np.mean(valid_rmse)
    his_loss.append(mvalid_loss)
    scheduler.step(mvalid_loss)

    log = 'Epoch: {:03d}, Train Loss: {:.4f}, Train RMSE: {:.4f}, Train MAPE: {:.4f}, Valid Loss: {:.4f}, Valid RMSE: {:.4f}, Valid MAPE: {:.4f}, Train Time: {:.4f}/epoch, Valid Time: {:.4f}/epoch'
    print(log.format(i, mtrain_loss, mtrain_rmse, mtrain_mape, mvalid_loss, mvalid_rmse, mvalid_mape, (t2 - t1), (s2 - s1)))

    if min_loss > mvalid_loss:
        torch.save(engine.model.state_dict(), save + 'epoch_' + str(i) + '_' + str(round(mvalid_loss, 2)) + '.pth')
        min_loss = mvalid_loss
        wait = 0
    else:
        wait += 1
        if wait == args.stop:
            print('Early stop')
            break


bestid = np.argmin(his_loss)
engine.model.load_state_dict(torch.load(save + 'epoch_' + str(bestid + 1) + '_' + str(round(his_loss[bestid], 2)) + '.pth'))
engine.model.eval()

log = 'Best Valid MAE: {:.4f}'
print(log.format(round(his_loss[bestid], 4)))
valid_loss = []
valid_mape = []
valid_rmse = []
for iter, (x, y, idx) in enumerate(dataloader['val_loader'].get_iterator()):
    validx = torch.Tensor(x).to(device)
    validy = torch.Tensor(y).to(device)
    valididx = torch.tensor(idx, dtype=torch.long).to(device)
    metrics = engine.eval(validx, validy, valididx)

    valid_loss.append(metrics[0])
    valid_mape.append(metrics[1])
    valid_rmse.append(metrics[2])
mvalid_loss = np.mean(valid_loss)
mvalid_mape = np.mean(valid_mape)
mvalid_rmse = np.mean(valid_rmse)
log = 'Recheck Valid MAE: {:.4f}, Valid RMSE: {:.4f}, Valid MAPE: {:.4f}'
print(log.format(np.mean(mvalid_loss), np.mean(mvalid_rmse), np.mean(mvalid_mape)))


preds = []
reals = []
for iter, (x, y, idx) in enumerate(dataloader['test_loader'].get_iterator()):
    testx = torch.Tensor(x).to(device)
    testy = torch.Tensor(y).to(device)
    testidx = torch.tensor(idx, dtype=torch.long).to(device)
    with torch.no_grad():
        output = engine.model(testx, testidx).transpose(1,3)
        preds.append(torch.squeeze(output, dim=1))
        reals.append(torch.squeeze(testy, dim=1))

sample_num = torch.Tensor(dataloader['y_test']).shape[0]
if args.infer_bs > 64:
    total = sample_num * node_num
else:
    total = sample_num
print(total)
preds = torch.cat(preds, dim=0)[:total]
reals = torch.cat(reals, dim=0)[:total]

test_loss = []
test_mape = []
test_rmse = []
res = []
for k in range(args.seq_length):
    pred = scaler.inverse_transform(preds[:,:,k])
    real = reals[:,:,k]
    metrics = util.metric(pred, real)
    log = 'Horizon {:d}, Test MAE: {:.4f}, Test RMSE: {:.4f}, Test MAPE: {:.4f}'
    print(log.format(k + 1, metrics[0], metrics[2], metrics[1]))
    test_loss.append(metrics[0])
    test_mape.append(metrics[1])
    test_rmse.append(metrics[2])
    if k in [2, 5, 11]:
        res += [metrics[0], metrics[2], metrics[1]]
mtest_loss = np.mean(test_loss)
mtest_mape = np.mean(test_mape)
mtest_rmse = np.mean(test_rmse)

log = 'Average Test MAE: {:.4f}, Test RMSE: {:.4f}, Test MAPE: {:.4f}'
print(log.format(mtest_loss, mtest_rmse, mtest_mape))
res += [mtest_loss, mtest_rmse, mtest_mape]
res = [round(r, 4) for r in res]
print(res)

print("Average Training Time: {:.4f} secs/epoch".format(np.mean(train_time)))
print("Average Validation Time: {:.4f} secs/epoch".format(np.mean(val_time)))
