# encoding:utf-8
import argparse
import numpy as np
import os
import sys
import yaml
import torch
import time
from einops import rearrange
sys.path.append(os.getcwd())
from lib.utils import *
from engine import trainer
from datahandler import load_st_data
def main(args,supervisor_config):
    device = torch.device(supervisor_config['device'])
    print("Config: ",supervisor_config)
    #load train, val and test dataset
    dataset_dir = os.path.join(supervisor_config['data']['dataset_dir'],
                            args.data_name,str(args.seq_len))
    sensor_ids, sensor_id_to_ind, adj_mx = \
        load_adj(supervisor_config['data']['graph_pkl_filename'],
                 supervisor_config['special']['adjtype'])
    dataloader = load_dataset(dataset_dir=dataset_dir,
                              batch_size=supervisor_config['data']['batch_size'],
                              valid_batch_size=supervisor_config['data']['val_batch_size'],
                              test_batch_size=supervisor_config['data']['test_batch_size'])
    scaler = dataloader['scaler']
    supports = [torch.tensor(i).to(device) for i in adj_mx]
    #Build the model engine
    engine = trainer(scaler=scaler, lrate=supervisor_config['train']['base_lr'],
                     wdecay=supervisor_config['train']['weight_decay'],device=device,
                     in_dim=supervisor_config['model']['input_dim'],
                     seq_len=args.seq_len,
                     horizon=args.horizon,
                     num_nodes=supervisor_config['model']['num_nodes'],
                     supports = supports,
                     hdim=supervisor_config['model']['hdim'],
                     dropout=supervisor_config['train']['dropout'],
                     gcn_bool=supervisor_config['model']['gcn_bool'],
                     addaptadj=supervisor_config['model']['addaptadj'],
                     aptinit=None
                     )
    ###训练
    print("start training...", flush=True)
    his_loss = []
    val_time = []
    train_time = []
    epochs = supervisor_config['train']['epochs']
    for i in range(1, epochs + 1):
        train_loss = []
        train_mape = []
        train_rmse = []
        t1 = time.time()
        dataloader['train_loader'].shuffle()
        for iter, (x, y) in enumerate(dataloader['train_loader'].get_iterator()):
            trainx = torch.Tensor(x).to(device)
            trainx = trainx.transpose(1, 3)
            trainy = torch.Tensor(y).to(device)
            trainy = trainy.transpose(1, 3)
            metrics = engine.train(trainx, trainy[:,0,:,:])
            train_loss.append(metrics[0])
            train_mape.append(metrics[1])
            train_rmse.append(metrics[2])
            if iter % supervisor_config['train']['print_every'] == 0:
                log = 'Iter: {:03d}, Train Loss: {:.4f}, Train MAPE: {:.4f}, Train RMSE: {:.4f}'
                print(log.format(iter, train_loss[-1], train_mape[-1], train_rmse[-1]), flush=True)
        t2 = time.time()
        train_time.append(t2 - t1)

        # validation
        valid_loss = []
        valid_mape = []
        valid_rmse = []
        s1 = time.time()
        for iter, (x, y) in enumerate(dataloader['val_loader'].get_iterator()):
            testx = torch.Tensor(x).to(device)
            testx = testx.transpose(1, 3)
            testy = torch.Tensor(y).to(device)
            testy = testy.transpose(1, 3)
            metrics = engine.eval(testx, testy[:, 0, :, :])
            valid_loss.append(metrics[0])
            valid_mape.append(metrics[1])
            valid_rmse.append(metrics[2])
        s2 = time.time()
        log = 'Epoch: {:03d}, Inference Time: {:.4f} secs'
        print(log.format(i, (s2 - s1)))
        val_time.append(s2 - s1)
        mtrain_loss = np.mean(train_loss)
        mtrain_mape = np.mean(train_mape)
        mtrain_rmse = np.mean(train_rmse)

        mvalid_loss = np.mean(valid_loss)
        mvalid_mape = np.mean(valid_mape)
        mvalid_rmse = np.mean(valid_rmse)
        his_loss.append(mvalid_loss)

        log = 'Epoch: {:03d}, Train Loss: {:.4f}, Train MAPE: {:.4f}, Train RMSE: {:.4f}, Valid Loss: {:.4f}, Valid MAPE: {:.4f}, Valid RMSE: {:.4f}, Training Time: {:.4f}/epoch'
        print(log.format(i, mtrain_loss, mtrain_mape, mtrain_rmse, mvalid_loss, mvalid_mape, mvalid_rmse, (t2 - t1)),
              flush=True)
        save_path = os.path.join(supervisor_config['model']['checkpoint_path'],
                                  "epoch_" + str(i) + "_" + str(round(mvalid_loss, 2)) + ".pth")
        engine.save(save_path)
    print("Average Training Time: {:.4f} secs/epoch".format(np.mean(train_time)))
    print("Average Inference Time: {:.4f} secs".format(np.mean(val_time)))
    # testing
    bestid = np.argmin(his_loss)
    best_path = os.path.join(supervisor_config['model']['checkpoint_path'],
                                  "epoch_" + str(bestid + 1) + "_" + str(round(his_loss[bestid], 2)) + ".pth")
    engine.load(best_path)
    outputs = []
    realy = torch.Tensor(dataloader['y_test']).to(device)
    realy = realy.transpose(1, 3)[:, 0, :, :]
    for iter, (x, y) in enumerate(dataloader['test_loader'].get_iterator()):
        testx = torch.Tensor(x).to(device)
        testx = testx.transpose(1, 3)
        with torch.no_grad():
            preds = engine.model(testx)
        outputs.append(preds.squeeze())
    yhat = torch.cat(outputs, dim=0)
    yhat = yhat[:realy.size(0), ...]
    print("Training finished")
    print("The valid loss on best model is", str(round(his_loss[bestid], 4)))
    amae = []
    amape = []
    armse = []
    amse = []
    aacc = []
    for i in range(args.horizon):
        pred = scaler.inverse_transform(yhat[:, :, i])
        real = realy[:, :, i]
        metrics = metric(pred, real)
        log = 'Evaluate best model on test data for horizon {:d}, ' \
              'Test MSE: {:.4f}, Test MAE: {:.4f}, Test MAPE: {:.4f}, Test RMSE: {:.4f}'
        print(log.format(i + 1, metrics[3], metrics[0], metrics[1], metrics[2]))
        amae.append(metrics[0])
        amape.append(metrics[1])
        armse.append(metrics[2])
        amse.append(metrics[3])
        aacc.append(metrics[4])
    log = 'On average over {} horizons, Test MSE: {:.4f}, Test MAE: {:.4f}, ' \
          'Test MAPE: {:.4f}, Test RMSE: {:.4f}'
    print(log.format(args.horizon,np.mean(amse), np.mean(amae), np.mean(amape), np.mean(armse)))
    torch.save(engine.model.state_dict(),
               os.path.join(supervisor_config['model']['checkpoint_path'],
                            "best_" + str(bestid + 1) +
                            "_" + "data{}_horizon{}_".format(args.data_name, args.horizon)
                            + str(round(his_loss[bestid], 2)) + ".pth"))
    del engine
    return [args.horizon,np.mean(amse), np.mean(amae), np.mean(amape), np.mean(armse),np.mean(aacc)]

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    # parser.add_argument('--config_filename', default=r'methods/STHSL/config/STHSL_PEMS.yaml', type=str,
    #                     help='Config file for  model.')
    # parser.add_argument('--data_name', default='PEMS-BAY', type=str,
    #                     help='data name.')
    parser.add_argument('--config_filename', default=r'methods/STHSL/config/STHSL_la.yaml', type=str,
                        help='Config file for  model.')
    parser.add_argument('--data_name', default='METR-LA', type=str,
                        help='data name.')
    # parser.add_argument('--config_filename', default=r'methods/STHSL/config/STHSL_solar.yaml', type=str,
    #                     help='Config file for  model.')
    # parser.add_argument('--data_name', default='solar', type=str,
    #                     help='data name.')
    parser.add_argument('--model',default='STHSL',type=str,help='model name.')

    parser.add_argument('--horizon',default=12,type=int,help="horizon")
    parser.add_argument('--seq_len', default=12, type=int, help="seq_len")
    args = parser.parse_args()
    #for data_str in ['METR-LA']:
    for data_str in ['METR-LA', 'PEMS-BAY', 'solar']:
        args.data_name = data_str
        if args.data_name == 'METR-LA':
            args.config_filename = 'methods/STHSL/config/STHSL_la.yaml'
        elif args.data_name == 'PEMS-BAY':
            args.config_filename = 'methods/STHSL/config/STHSL_PEMS.yaml'
        elif args.data_name == 'solar':
            args.config_filename = 'methods/STHSL/config/STHSL_solar.yaml'
        elif args.data_name == 'pems4':
            args.config_filename = 'methods/STHSL/config/STHSL_pems4.yaml'
        elif args.data_name == 'pems3':
            args.config_filename = 'methods/STHSL/config/STHSL_pems3.yaml'
        elif args.data_name == 'pems7':
            args.config_filename = 'methods/STHSL/config/STHSL_pems7.yaml'
        elif args.data_name == 'pems8':
            args.config_filename = 'methods/STHSL/config/STHSL_pems8.yaml'
        else:
            assert "None data!!"
        with open(args.config_filename) as f:
            supervisor_config = yaml.safe_load(f)
        args.seg_len = 3
        result_save_path = os.path.join(supervisor_config['save_result']['save_path'],
                                        "{}-{}.csv".format(args.data_name, args.model))
        create_csv(result_save_path)
        # len_list = [5,6,7,8,9,10,12]#循环测试
        len_list = sorted([3, 6, 12]*5)  # 循环测试
        for l in len_list:
            args.seq_len = l
            args.horizon = l
            rst = main(args,supervisor_config)
            append_csv(result_save_path, rst)


