import torch, sys
import random
import argparse
import numpy as np
import ipdb as pdb
import os, pwd, yaml
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
import lightning.pytorch as pl
from torch.utils.data import DataLoader, random_split
from IFactor.modules.stationary_IFactor_mine import StationaryIFactorProcess
from IFactor.tools.utils import load_yaml, setup_seed
from IFactor.datasets.sim_dataset import StationaryIFactorDataset

from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks import Callback
import shutil
import warnings
warnings.filterwarnings('ignore')

def main(args):

    assert args.exp is not None, "FATAL: "+__file__+": You must specify an exp config file (e.g., *.yaml)"

    current_user = pwd.getpwuid(os.getuid()).pw_name
    script_dir = os.path.dirname(__file__)
    rel_path = os.path.join('../IFactor/configs', 
                            '%s.yaml'%args.exp)
    abs_file_path = os.path.join(script_dir, rel_path)
    cfg = load_yaml(abs_file_path)
    cfg['VAE']['SEED'] = args.seed
    cfg['VAE']['BETA'] = args.beta
    cfg['VAE']['GAMMA'] = args.gamma
    cfg['VAE']['DELTA'] = args.delta
    cfg['VAE']['DELTA_EPOCH'] = args.delta_epoch
    cfg['VAE']['Z_DIM_LIST'] = [args.dim, args.dim, args.dim, args.dim]
    cfg['VAE']['TRANS_PRIOR'] = args.prior
    print("######### Configuration #########")
    print(yaml.dump(cfg, default_flow_style=False))
    print("#################################")
    pl.seed_everything(args.seed)

    data = StationaryIFactorDataset(directory=os.path.expanduser(cfg['ROOT']),
                             transition=os.path.expanduser(cfg['DATASET']))

    num_validation_samples = cfg['VAE']['N_VAL_SAMPLES']
    train_data, val_data = random_split(data, [len(data)-num_validation_samples, num_validation_samples])

    train_loader = DataLoader(train_data, 
                              batch_size=cfg['VAE']['TRAIN_BS'], 
                              pin_memory=cfg['VAE']['PIN'],
                              num_workers=cfg['VAE']['CPU'],
                              drop_last=False,
                              shuffle=True)

    val_loader = DataLoader(val_data, 
                            batch_size=cfg['VAE']['VAL_BS'], 
                            pin_memory=cfg['VAE']['PIN'],
                            num_workers=cfg['VAE']['CPU'],
                            shuffle=False)

    model = StationaryIFactorProcess(input_dim=cfg['VAE']['INPUT_DIM'],
                              length=cfg['VAE']['LENGTH'],
                              z_dim_list=cfg['VAE']['Z_DIM_LIST'], 
                              action_dim=cfg['VAE']['ACTION_DIM'],
                              lag=cfg['VAE']['LAG'],
                              hidden_dim=cfg['VAE']['ENC']['HIDDEN_DIM'],
                              trans_prior=cfg['VAE']['TRANS_PRIOR'],
                              config=cfg,
                              lr=cfg['VAE']['LR'],
                              aux_lr=cfg['VAE']['AUX_LR'],
                              beta=cfg['VAE']['BETA'],
                              gamma=cfg['VAE']['GAMMA'],
                              delta=cfg['VAE']['DELTA'],
                              delta_epoch=cfg['VAE']['DELTA_EPOCH'],
                              decoder_dist=cfg['VAE']['DEC']['DIST'],
                              correlation=cfg['MCC']['CORR'])

    # log_dir = os.path.join(cfg["LOG"], current_user, args.exp)
    log_dir = os.path.join(cfg["LOG"], current_user, args.exp, f"seed{args.seed}_dim{args.dim}_deltaepoch{args.delta_epoch}_{args.prior}")

    checkpoint_callback = ModelCheckpoint(monitor='ave_r2',
                                          save_top_k=1, 
                                          mode='max')

    early_stop_callback = EarlyStopping(monitor="ave_r2", 
                                        min_delta=0.00, 
                                        patience=50, 
                                        verbose=False, 
                                        mode="max")

    class YamlCopyCallback(Callback):
        def __init__(self, cfg):
            super().__init__()
            self.cfg = cfg
        def on_fit_start(self, trainer, pl_module):
            log_dir = trainer.logger.log_dir
            filename = os.path.join(log_dir, "config.yaml")
            os.makedirs(log_dir, exist_ok=True)
            with open(filename, "w") as f:
                yaml.dump(self.cfg, f)

    yaml_copy_callback = YamlCopyCallback(cfg)
    trainer = pl.Trainer(default_root_dir=log_dir,
                         accelerator='gpu', 
                         val_check_interval = cfg['MCC']['FREQ'],
                         max_epochs=cfg['VAE']['EPOCHS'],
                         callbacks=[checkpoint_callback, yaml_copy_callback],
                         strategy='ddp_find_unused_parameters_true',
                         )

    # Train the model
    trainer.fit(model, train_loader, val_loader)

if __name__ == "__main__":

    argparser = argparse.ArgumentParser(description=__doc__)
    argparser.add_argument(
        '-e',
        '--exp',
        type=str,
        default='stationary_2lag_IFactor'
    )

    argparser.add_argument(
        '-s',
        '--seed',
        type=int,
        default=0
    )

    argparser.add_argument(
        '-b',
        '--beta',
        type=float,
        default=3.0e-3
    )

    argparser.add_argument(
        '--gamma',
        type=float,
        default=1.0e-2
    )

    argparser.add_argument(
        '-d',
        '--delta',
        type=float,
        default=0.1
    )

    argparser.add_argument(
        '--dim',
        type=int,
        default=2
    )
    
    argparser.add_argument(
        '--delta_epoch',
        type=int,
        default=5
    )

    argparser.add_argument(
        '--prior',
        type=str,
        default='DNP'
    )
    args = argparser.parse_args()
    main(args)
