import os
import sys
sys.path.append("..")
import time
import datetime
import argparse
import numpy as np
import math
from tqdm import tqdm
from random import SystemRandom

# Argument parser for command-line options
parser = argparse.ArgumentParser('ITS Forecasting')
parser.add_argument('--state', type=str, default='def')
parser.add_argument('-n',  type=int, default=int(1e8), help="Size of the dataset")
parser.add_argument('--epoch', type=int, default=1000, help="training epoches")
parser.add_argument('--patience', type=int, default=10, help="patience for early stop")
parser.add_argument('--history', type=int, default=24, help="number of hours (months for ushcn and ms for activity) as historical window")
parser.add_argument('--pred_window', type=int, default=1, help="number of hours (months for ushcn) as pred window")
parser.add_argument('--logmode', type=str, default="a", help='File mode of logging.')
parser.add_argument('--lr',  type=float, default=1e-3, help="Starting learning rate.")
parser.add_argument('--w_decay', type=float, default=0.0, help="weight decay.")
parser.add_argument('-b', '--batch_size', type=int, default=32)
parser.add_argument('--load', type=str, default=None, help="ID of the experiment to load for evaluation. If None, run a new experiment.")
parser.add_argument('--seed', type=int, default=1, help="Random seed")
parser.add_argument('--dataset', type=str, default='physionet', help="Dataset to load. Available: physionet, mimic, ushcn, activity")
parser.add_argument('--quantization', type=float, default=0.0, help="Quantization on the physionet dataset.")
parser.add_argument('--model', type=str, default='Hi-Patch', help="Model name")
parser.add_argument('--nhead', type=int, default=1, help="heads in Transformer")
parser.add_argument('--nlayer', type=int, default=1, help="# of layer in TSmodel")
parser.add_argument('-ps', '--patch_size', type=float, default=24, help="window size for a patch")
parser.add_argument('--stride', type=float, default=24, help="period stride for patch sliding")
parser.add_argument('-hd', '--hid_dim', type=int, default=64, help="Hidden dim of node embeddings")
parser.add_argument('--alpha', type=float, default=1, help="Proportion of Time decay")
parser.add_argument('--res', type=float, default=1, help="Res")
parser.add_argument('--gpu', type=str, default='0', help='which gpu to use.')
parser.add_argument('--mode', type=str)
parser.add_argument('--type', type=str, default='global')


args = parser.parse_args()
args.npatch = int(np.ceil((args.history - args.patch_size) / args.stride)) + 1 # (window size for a patch)
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'

import torch
import torch.optim as optim
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic=True
torch.use_deterministic_algorithms(True)
import utils as utils
from parse_datasets import parse_datasets
from evaluation import *
from model.ours_ab import *
import warnings
warnings.filterwarnings("ignore")

file_name = os.path.basename(__file__)[:-3]
args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
args.PID = os.getpid()
print("PID, device:", args.PID, args.device)
#####################################################################################################
# Function to count model parameters
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# Recursive function to determine patch layers
def layer_of_patches(n_patch):
    if n_patch == 1:
        return 1
    if n_patch % 2 == 0:
        return 1 + layer_of_patches(n_patch / 2)
    else:
        return layer_of_patches(n_patch + 1)

if __name__ == '__main__':
    utils.setup_seed(args.seed)
    
    experimentID = args.load
    if experimentID is None:
        # Make a new experiment ID
        experimentID = int(SystemRandom().random()*100000)

    input_command = sys.argv
    ind = [i for i in range(len(input_command)) if input_command[i] == "--load"]
    if len(ind) == 1:
        ind = ind[0]
        input_command = input_command[:ind] + input_command[(ind+2):]
    input_command = " ".join(input_command)

    # Parse dataset and initialize model
    if args.mode == 'self' or args.mode == 'cross':
        data_obj = parse_datasets(args, patch_ts=True)
    else:
        data_obj = parse_datasets(args, patch_ts=True, max_ts=True)
    input_dim = data_obj["input_dim"]

    ### Model setting ###
    args.ndim = input_dim
    args.npatch = int(math.ceil((args.history - args.patch_size) / args.stride)) + 1
    args.patch_layer = layer_of_patches(args.npatch)
    args.scale_patch_size = args.patch_size / (args.history + args.pred_window)
    args.task = 'forecasting'
    
    model = Model(args).to(args.device)
    ##################################################################
    if(args.n < 12000):
        args.state = "debug"
        log_path = "logs/{}_{}_{}.log".format(args.dataset, args.model, args.state)
    else:
        log_path = "logs/{}_{}_{}_{}hdims_{}nlayers_{}npatch_{}patchsize_{}history_{}pred_{}nhead_{}lr_{}seed.log". \
            format(args.dataset, args.model, args.mode, args.hid_dim, args.nlayer, args.npatch, args.patch_size, args.history, args.pred_window, args.nhead, args.lr, args.seed)

    if not os.path.exists("logs/"):
        utils.makedirs("logs/")
    logger = utils.get_logger(logpath=log_path, filepath=os.path.abspath(__file__), mode=args.logmode)
    logger.info(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
    logger.info(input_command)
    logger.info(args)
    logger.info(model)
    logger.info(f'parameters:{count_parameters(model)}')

    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    num_batches = data_obj["n_train_batches"]

    best_val_mse = np.inf
    test_res = None
    
    for itr in range(args.epoch):
        train_var_mse = 0
        
        train_time = 0.
        st = time.time()
        ### Training ###
        model.train()
        for _ in tqdm(range(num_batches)):
            optimizer.zero_grad()
            batch_dict = utils.get_next_batch(data_obj["train_dataloader"])
            if batch_dict is None:
                continue
            train_res = compute_all_losses(model, batch_dict)
            train_res["loss"].backward()
            optimizer.step()
            train_var_mse += train_res["mse_var_avg"]
        et = time.time()
        train_time += (et - st)
        logger.info('training time per epoch :{:.3f}s'.format(train_time))
        
        ### Validation ###
        model.eval()
        with torch.no_grad():
            val_time = 0.
            st = time.time()
            val_res, val_mse_list = evaluation(model, data_obj["val_dataloader"], data_obj["n_val_batches"])
            et = time.time()
            val_time += (et - st)
            logger.info("val time per epoch :{:.2f}s".format(val_time))
            
            ### Testing ###
            if (val_res["mse"] < best_val_mse):
                best_val_mse = val_res["mse"]
                best_iter = itr
                
                test_time = 0.
                st = time.time()
                test_res, test_mse_list = evaluation(model, data_obj["test_dataloader"], data_obj["n_test_batches"])
                et = time.time()
                test_time += (et - st)
                logger.info("test time per epoch :{:.2f}s".format(test_time))
                
            logger.info('- Epoch {:03d}, ExpID {}'.format(itr, experimentID))
            logger.info("Train - Loss (one batch): {:.5f}".format(train_res["loss"].item()))
            logger.info("Val - Loss, MSE, RMSE, MAE, MAPE: {:.5f}, {:.5f}, {:.5f}, {:.5f}, {:.2f}%".format(val_res["loss"], val_res["mse"], val_res["rmse"], val_res["mae"], val_res["mape"]*100))
            if (test_res != None):
                logger.info("Test - Best epoch, Loss, MSE, RMSE, MAE, MAPE: {}, {:.5f}, {:.5f}, {:.5f}, {:.5f}, {:.2f}%".format(best_iter, test_res["loss"], test_res["mse"], test_res["rmse"], test_res["mae"], test_res["mape"]*100))
        
        if (itr - best_iter >= args.patience):
            logger.info("Exp has been early stopped!")
            sys.exit(0)