import os
import sys
sys.path.append("..")
import time
import datetime
import argparse
import numpy as np
from tqdm import tqdm
from random import SystemRandom

# Argument parser for command-line options
parser = argparse.ArgumentParser('ITS Forecasting with HyperIMTS')
parser.add_argument('--state', type=str, default='def')
parser.add_argument('-n',  type=int, default=int(1e8), help="Size of the dataset")
parser.add_argument('--epoch', type=int, default=1000, help="training epoches")
parser.add_argument('--patience', type=int, default=10, help="patience for early stop")
parser.add_argument('--history', type=int, default=24, help="number of hours (months for ushcn and ms for activity) as historical window")
parser.add_argument('--pred_window', type=int, default=1, help="number of hours (months for ushcn) as pred window")
parser.add_argument('--logmode', type=str, default="a", help='File mode of logging.')
parser.add_argument('--lr',  type=float, default=1e-3, help="Starting learning rate.")
parser.add_argument('-b', '--batch_size', type=int, default=32)
parser.add_argument('--load', type=str, default=None, help="ID of the experiment to load for evaluation. If None, run a new experiment.")
parser.add_argument('--seed', type=int, default=1, help="Random seed")
parser.add_argument('--dataset', type=str, default='physionet', help="Dataset to load. Available: physionet, mimic, ushcn, activity")
parser.add_argument('--quantization', type=float, default=0.0, help="Quantization on the physionet dataset.")
parser.add_argument('--model', type=str, default='HyperIMTS', help="Model name")
parser.add_argument('--nhead', type=int, default=4, help="heads in Transformer")
parser.add_argument('--nlayer', type=int, default=2, help="# of layers")
parser.add_argument('-hd', '--hid_dim', type=int, default=64, help="Hidden dim")
parser.add_argument('--gpu', type=str, default='0', help='which gpu to use.')

args = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'

import torch
import torch.optim as optim
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.use_deterministic_algorithms(True)
import utils as utils
from parse_datasets import parse_datasets
from evaluation import *
from ExpConfigs import ExpConfigs
from baselines.models.HyperIMTS import Model as HyperIMTS
import warnings
warnings.filterwarnings("ignore")

args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
args.PID = os.getpid()
print("PID, device:", args.PID, args.device)

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

if __name__ == '__main__':
    utils.setup_seed(args.seed)
    
    experimentID = args.load
    if experimentID is None:
        experimentID = int(SystemRandom().random()*100000)

    input_command = " ".join(sys.argv)

    # Parse dataset
    data_obj = parse_datasets(args, patch_ts=False)
    input_dim = data_obj["input_dim"]

    ### Model setting ###
    args.ndim = input_dim
    args.task = 'forecasting'
    args.maxlen = getattr(args, 'maxlen', 0)

    # Create ExpConfigs from args for HyperIMTS model
    configs = ExpConfigs(args)
    model = HyperIMTS(configs).to(args.device)

    if args.n < 12000:
        args.state = "debug"
        log_path = "logs/{}_{}_{}.log".format(args.dataset, args.model, args.state)
    else:
        log_path = "logs/{}_{}_{}_{}hdims_{}nlayers_{}history_{}pred_{}nhead_{}lr_{}seed.log".format(
            args.dataset, args.model, args.state, args.hid_dim, args.nlayer, args.history, args.pred_window, args.nhead, args.lr, args.seed)

    if not os.path.exists("logs/"):
        utils.makedirs("logs/")
    logger = utils.get_logger(logpath=log_path, filepath=os.path.abspath(__file__), mode=args.logmode)
    logger.info(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
    logger.info(input_command)
    logger.info(args)
    logger.info(model)
    logger.info(f'parameters:{count_parameters(model)}')

    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    num_batches = data_obj["n_train_batches"]

    best_val_mse = np.inf
    test_res = None
    
    for itr in range(args.epoch):
        train_var_mse = 0
        
        train_time = 0.
        st = time.time()
        ### Training ###
        model.train()
        for _ in tqdm(range(num_batches)):
            optimizer.zero_grad()
            batch_dict = utils.get_next_batch(data_obj["train_dataloader"])
            if batch_dict is None:
                continue
            train_res = compute_all_losses(model, batch_dict)
            train_res["loss"].backward()
            optimizer.step()
            train_var_mse += train_res["mse_var_avg"]
        et = time.time()
        train_time += (et - st)
        logger.info('training time per epoch :{:.3f}s'.format(train_time))
        torch.cuda.empty_cache()

        ### Validation ###
        model.eval()
        with torch.no_grad():
            val_res, val_mse_list = evaluation(model, data_obj["val_dataloader"], data_obj["n_val_batches"])
            torch.cuda.empty_cache()
            
            ### Testing ###
            if val_res["mse"] < best_val_mse:
                best_val_mse = val_res["mse"]
                best_iter = itr
                test_res, test_mse_list = evaluation(model, data_obj["test_dataloader"], data_obj["n_test_batches"])
                torch.cuda.empty_cache()
                
            logger.info('- Epoch {:03d}, ExpID {}'.format(itr, experimentID))
            logger.info("Train - Loss (one batch): {:.5f}".format(train_res["loss"].item()))
            logger.info("Val - Loss, MSE, RMSE, MAE, MAPE: {:.5f}, {:.5f}, {:.5f}, {:.5f}, {:.2f}%".format(
                val_res["loss"], val_res["mse"], val_res["rmse"], val_res["mae"], val_res["mape"]*100))
            if test_res is not None:
                logger.info("Test - Best epoch, Loss, MSE, RMSE, MAE, MAPE: {}, {:.5f}, {:.5f}, {:.5f}, {:.5f}, {:.2f}%".format(
                    best_iter, test_res["loss"], test_res["mse"], test_res["rmse"], test_res["mae"], test_res["mape"]*100))
            logger.info("Time spent: {:.2f}s".format(time.time()-st))
        
        if itr - best_iter >= args.patience:
            logger.info("Exp has been early stopped!")
            sys.exit(0)
