import torch
import time, sys
import argparse
import logging
import os, pwd
import pytorch_lightning as pl
from torch.utils.data import Subset
from torch._utils import _accumulate
from torch.utils.data import DataLoader
from fssm.modules.nonstationarypred import NonStationaryPredProcess
from fssm.modules.stationary_scalarnoise import StationaryProcess
from fssm.tools.utils import load_yaml
from fssm.datasets.sim_dataset import StationaryDataset
from pytorch_lightning.loggers import TensorBoardLogger
import warnings
warnings.filterwarnings('ignore')

def main(args):
    assert args.exp is not None, "FATAL: "+__file__+": You must specify an exp configs file (e.g., *.yaml)"
    current_user = pwd.getpwuid(os.getuid()).pw_name
    script_dir = os.path.dirname(__file__)
    rel_path = os.path.join('../fssm/configs',
                            '%s.yaml'%args.exp)
    abs_file_path = os.path.join(script_dir, rel_path)
    cfg = load_yaml(abs_file_path)

    # update configs dic from cmd line
    if args.epoch > 0:
        cfg['VAE']['EPOCHS'] = args.epoch

    pl.seed_everything(args.seed)
    data = StationaryDataset(directory=cfg['ROOT'],
                             transition=cfg['DATASET'])
    num_validation_samples = cfg['VAE']['N_VAL_SAMPLES']
    indices = list(range(len(data)))
    lengths = [len(data)-num_validation_samples, num_validation_samples]
    train_data, val_data = [Subset(data, indices[offset - length: offset]) for offset, length in zip(_accumulate(lengths), lengths)]
    train_loader = DataLoader(train_data, 
                              batch_size=cfg['VAE']['TRAIN_BS'], 
                              pin_memory=cfg['VAE']['PIN'],
                              num_workers=cfg['VAE']['CPU'],
                              drop_last=False,
                              shuffle=True)
    val_loader = DataLoader(val_data, 
                            batch_size=cfg['VAE']['VAL_BS'], 
                            pin_memory=cfg['VAE']['PIN'],
                            num_workers=cfg['VAE']['CPU'],
                            shuffle=False)

    if args.model == 'stationary':
        model = StationaryProcess(input_dim=cfg['VAE']['INPUT_DIM'],
                                      length=cfg['VAE']['LENGTH'],
                                      z_dim=cfg['VAE']['LATENT_DIM'],
                                      lag=cfg['VAE']['LAG'],
                                      hidden_dim=cfg['VAE']['ENC']['HIDDEN_DIM'],
                                      trans_prior=cfg['VAE']['TRANS_PRIOR'],
                                      lr=cfg['VAE']['LR'],
                                      beta=cfg['VAE']['BETA'],
                                      gamma=cfg['VAE']['GAMMA'],
                                      decoder_dist=cfg['VAE']['DEC']['DIST'],
                                      correlation=cfg['MCC']['CORR'],
                                      fixed_noise=not args.learnablenoise)
    elif args.model == 'nonstationary':
        from fssm.modules.nonstationary2_scalarnoise import NonStationaryProcess
        model = NonStationaryProcess(
                                  input_dim=cfg['VAE']['INPUT_DIM'],
                                  length=cfg['VAE']['LENGTH'],
                                  z_dim=cfg['VAE']['LATENT_DIM'],
                                  lag=cfg['VAE']['LAG'],
                                  hidden_dim=cfg['VAE']['ENC']['HIDDEN_DIM'],
                                  trans_prior=cfg['VAE']['TRANS_PRIOR'],
                                  lr=cfg['VAE']['LR'],
                                  beta=cfg['VAE']['BETA'],
                                  gamma=cfg['VAE']['GAMMA'],
                                  delta=cfg['VAE']['DELTA'],
                                  epsilon=cfg['VAE']['EPSILON'],
                                  a_distribution=cfg['VAE']['DEC']['A_DIST'],
                                  decoder_dist=cfg['VAE']['DEC']['DIST'],
                                  fixed_noise=not args.learnablenoise,
                                  correlation=cfg['MCC']['CORR'])
    elif args.model == 'nonstationarypred':
        model = NonStationaryPredProcess(input_dim=cfg['VAE']['INPUT_DIM'],
                                  length=cfg['VAE']['LENGTH'],
                                  z_dim=cfg['VAE']['LATENT_DIM'],
                                  lag=cfg['VAE']['LAG'],
                                  hidden_dim=cfg['VAE']['ENC']['HIDDEN_DIM'],
                                  trans_prior=cfg['VAE']['TRANS_PRIOR'],
                                  predict_mode=cfg['PREDICTOR']['PREDICT_MODE'],
                                  predict_witha=cfg['PREDICTOR']['PRE_WITHA'],
                                  lstm_layer=cfg['PREDICTOR']['LSTM_LAYER'],
                                  lr=cfg['VAE']['LR'],
                                  beta=cfg['VAE']['BETA'],
                                  gamma=cfg['VAE']['GAMMA'],
                                  delta=cfg['VAE']['DELTA'],
                                  epsilon=cfg['VAE']['EPSILON'],
                                  alpha=cfg['VAE']['ALPHA'],
                                  decoder_dist=cfg['VAE']['DEC']['DIST'],
                                  a_distribution=cfg['VAE']['DEC']['A_DIST'],
                                  prediction_sample_times=cfg['VAE']['SAMPLE_TIMES'],
                                  correlation=cfg['MCC']['CORR'])
    if not torch.cuda.is_available():
        cfg['VAE']['GPU'] = None
        train_loader.num_workers = 0
        val_loader.num_workers = 0

    # logfile
    log_dir = os.path.join(cfg["LOG"], current_user, args.exp.split('_')[-1])
    if args.logfolder_suffixmodel is None:
        model_name = args.model
    else:
        model_name = args.model + "_" + args.logfolder_suffixmodel

    # version name
    if args.model == 'stationary':
        version_name = f"lr{cfg['VAE']['LR']}_beta{cfg['VAE']['BETA']}_gamma{cfg['VAE']['GAMMA']}"
    if args.model == 'nonstationary':
        version_name = f"lr{cfg['VAE']['LR']}_beta{cfg['VAE']['BETA']}_gamma{cfg['VAE']['GAMMA']}_delta{cfg['VAE']['DELTA']}_epsilon{cfg['VAE']['EPSILON']}"
    elif args.model == 'nonstationarypred':
        version_name = f"lr{cfg['VAE']['LR']}_alpha{cfg['VAE']['ALPHA']}_beta{cfg['VAE']['BETA']}_gamma{cfg['VAE']['GAMMA']}_delta{cfg['VAE']['DELTA']}_epsilon{cfg['VAE']['EPSILON']}_predlr{cfg['PREDICTOR']['LR']}" \
                       f"_predmode{cfg['PREDICTOR']['PREDICT_MODE']}_predwitha{cfg['PREDICTOR']['PRE_WITHA']}_predlstmlayer{cfg['PREDICTOR']['LSTM_LAYER']}"
    if args.logfolder_suffix:
        version_name = version_name+"_"+args.logfolder_suffix

    logger_TB = TensorBoardLogger(log_dir, name=model_name, version=version_name)
    log_dir = os.path.join(log_dir, model_name, version_name)
    logfile_name = f"{version_name}_{int(time.time())}.log"
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    else:
        print(f"{model_name}_{version_name} exist!")
        sys.exit()

    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    output_file_handler = logging.FileHandler(os.path.join(log_dir, logfile_name), mode='w')
    logger.addHandler(output_file_handler)
    logging.info("######### Configuration #########")
    logging.info("#################################")
    trainer = pl.Trainer(limit_train_batches=0.01,
                         logger=logger_TB,
                         default_root_dir=log_dir,
                         gpus=cfg['VAE']['GPU'],
                         auto_select_gpus=True,
                         max_epochs=cfg['VAE']['EPOCHS'],
                         progress_bar_refresh_rate=0
                        )

    # Train the model
    trainer.fit(model, train_loader, val_loader)

if __name__ == "__main__":
    argparser = argparse.ArgumentParser(description=__doc__)
    argparser.add_argument('-e', '--exp', type=str)
    argparser.add_argument('-s', '--seed', type=int, default=770)
    argparser.add_argument('-m', '--model', type=str, default='stationary')
    argparser.add_argument('--epoch', type=int, default=-1)
    argparser.add_argument('--logfolder_suffixmodel', type=str, default=None)
    argparser.add_argument('--logfolder_suffix', type=str, default=None)
    argparser.add_argument('--learnablenoise', action='store_true', default=False)
    args = argparser.parse_args()
    main(args)
