"running the baseline file: main.py"
import warnings
warnings.filterwarnings('ignore')

import argparse
import os, pwd, yaml, sys
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
import lightning.pytorch as pl
from torch.utils.data import DataLoader, random_split

"utils file (SAME)"
from IFactor.tools.utils import load_yaml
# Stationary: 
from IFactor.datasets.sim_dataset import SimulationDatasetTSTwoSample 
from lightning.pytorch.callbacks import Callback

"baseline list"
from IFactor.baselines.TCL.model import TCL
from IFactor.baselines.PCL.model import PCL 
from IFactor.baselines.iVAE.model import iVAE
from IFactor.baselines.BetaVAE.model import BetaVAE
from IFactor.baselines.SlowVAE.model import SlowVAE
from IFactor.baselines.FactorVAE.model import FactorVAE
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks import Callback
from IFactor.datasets.sim_dataset import SimulationDatasetPCL


def main(args):

    assert args.exp is not None, "FATAL: "+__file__+": You must specify an exp config file (e.g., *.yaml)"
    
    current_user = pwd.getpwuid(os.getuid()).pw_name
    script_dir = os.path.dirname(__file__)
    rel_path = os.path.join('../IFactor/configs', '%s.yaml'%args.exp)
    abs_file_path = os.path.join(script_dir, rel_path)
    cfg = load_yaml(abs_file_path)
    print("######### Configuration #########")
    print(yaml.dump(cfg, default_flow_style=False))
    print("#################################")
    pl.seed_everything(args.seed)

    if cfg['MODEL'] == "PCL":
        max_epochs = cfg['PCL']['EPOCHS']
        data = SimulationDatasetPCL(directory=os.path.expanduser(cfg['ROOT']), transition=os.path.expanduser(cfg['DATASET']), lags=cfg['PCL']['LAG'])
        num_validation_samples = cfg['PCL']['N_VAL_SAMPLES']
        train_data, val_data = random_split(data, [len(data)-num_validation_samples, num_validation_samples])
        train_loader = DataLoader(train_data, 
                                batch_size=cfg['PCL']['TRAIN_BS'], 
                                pin_memory=cfg['PCL']['PIN'],
                                num_workers=cfg['PCL']['CPU'],
                                drop_last=False,
                                shuffle=True)

        val_loader = DataLoader(val_data, 
                                batch_size=cfg['PCL']['VAL_BS'], 
                                pin_memory=cfg['PCL']['PIN'],
                                num_workers=cfg['PCL']['CPU'],
                                shuffle=False)

        model = PCL(input_dim=cfg['PCL']['INPUT_DIM'],
                    z_dim=cfg['PCL']['LATENT_DIM'], 
                    lags=cfg['PCL']['LAG'], 
                    hidden_dims=cfg['PCL']['HIDDEN_DIM'], 
                    encoder_layers=cfg['PCL']['ENCODER_LAYER'], 
                    scoring_layers=cfg['PCL']['SCORE_LAYER'],
                    correlation=cfg['MCC']['CORR'],
                    lr=cfg['PCL']['LR'])
    else:
        max_epochs = cfg['VAE']['EPOCHS']
        data = SimulationDatasetTSTwoSample(directory=os.path.expanduser(cfg['ROOT']), transition=os.path.expanduser(cfg['DATASET']))

        num_validation_samples = cfg['VAE']['N_VAL_SAMPLES']
        train_data, val_data = random_split(data, [len(data)-num_validation_samples, num_validation_samples])

        train_loader = DataLoader(train_data, 
                                batch_size=cfg['VAE']['TRAIN_BS'], 
                                pin_memory=cfg['VAE']['PIN'],
                                num_workers=cfg['VAE']['CPU'],
                                drop_last=False,
                                shuffle=True)

        val_loader = DataLoader(val_data, 
                                batch_size=cfg['VAE']['VAL_BS'], 
                                pin_memory=cfg['VAE']['PIN'],
                                num_workers=cfg['VAE']['CPU'],
                                shuffle=False)

        if cfg['MODEL'] == "TCL":
            model = TCL(input_dim=cfg['VAE']['INPUT_DIM'],
                    z_dim=cfg['VAE']['LATENT_DIM'], 
                    nclass=cfg['TCL']['NCLASS'], 
                    hidden_dim=cfg['VAE']['ENC']['HIDDEN_DIM'],
                    lr=cfg['TCL']['LR'],
                    correlation=cfg['MCC']['CORR'])

        elif cfg['MODEL'] == "iVAE":
            model = iVAE(input_dim=cfg['VAE']['INPUT_DIM'],
                    z_dim=cfg['VAE']['LATENT_DIM'], 
                    hidden_dim=cfg['VAE']['ENC']['HIDDEN_DIM'],
                    lr=cfg['iVAE']['LR'],
                    correlation=cfg['MCC']['CORR'])

        elif cfg['MODEL'] == "BetaVAE":
            model = BetaVAE(input_dim=cfg['VAE']['INPUT_DIM'],
                        z_dim=cfg['VAE']['LATENT_DIM'], 
                        hidden_dim=cfg['VAE']['ENC']['HIDDEN_DIM'],
                        beta=cfg['BetaVAE']['BETA'], 
                        beta1=cfg['SlowVAE']['beta1_VAE'],
                        beta2=cfg['SlowVAE']['beta2_VAE'],
                        lr=cfg['BetaVAE']['LR'],
                        correlation=cfg['MCC']['CORR'])

        elif cfg['MODEL'] == "SlowVAE":
            model = SlowVAE(input_dim=cfg['VAE']['INPUT_DIM'],
                        z_dim=cfg['VAE']['LATENT_DIM'], 
                        hidden_dim=cfg['VAE']['ENC']['HIDDEN_DIM'],
                        beta=cfg['SlowVAE']['BETA'], 
                        gamma=cfg['SlowVAE']['GAMMA'], 
                        beta1=cfg['SlowVAE']['beta1_VAE'],
                        beta2=cfg['SlowVAE']['beta2_VAE'],
                        lr=cfg['SlowVAE']['LR'],
                        rate_prior=cfg['SlowVAE']['RATE_PRIOR'], 
                        correlation=cfg['MCC']['CORR'])

        elif cfg['MODEL'] == "FactorVAE":
            model = FactorVAE(input_dim=cfg['VAE']['INPUT_DIM'],
                            z_dim=cfg['VAE']['LATENT_DIM'], 
                            hidden_dim=cfg['VAE']['ENC']['HIDDEN_DIM'],
                            gamma=cfg['FactorVAE']['GAMMA'],
                            lr_VAE=cfg['FactorVAE']['LR_VAE'],
                            beta1_VAE=cfg['FactorVAE']['beta1_VAE'],
                            beta2_VAE=cfg['FactorVAE']['beta2_VAE'],
                            lr_D=cfg['FactorVAE']['LR_D'],
                            beta1_D=cfg['FactorVAE']['beta1_D'],
                            beta2_D=cfg['FactorVAE']['beta2_D'],
                            correlation=cfg['MCC']['CORR'])

    # log_dir = os.path.join(cfg["LOG"], current_user, args.exp)
    log_dir = os.path.join(cfg["LOG"], current_user, args.exp, f"seed{args.seed}")
    
    checkpoint_callback = ModelCheckpoint(monitor='ave_r2', 
                                          save_top_k=1, 
                                          mode='max')

    early_stop_callback = EarlyStopping(monitor="ave_r2", 
                                        min_delta=0.00, 
                                        patience=50, 
                                        verbose=False, 
                                        mode="max")

    class YamlCopyCallback(Callback):
        def __init__(self, cfg):
            super().__init__()
            self.cfg = cfg
        def on_fit_start(self, trainer, pl_module):
            log_dir = trainer.logger.log_dir
            filename = os.path.join(log_dir, "config.yaml")
            os.makedirs(log_dir, exist_ok=True)
            with open(filename, "w") as f:
                yaml.dump(self.cfg, f)

    yaml_copy_callback = YamlCopyCallback(cfg)
    trainer = pl.Trainer(default_root_dir=log_dir,
                         accelerator='gpu', 
                         val_check_interval = cfg['MCC']['FREQ'],
                         max_epochs=max_epochs,
                         deterministic=True,
                         callbacks=[checkpoint_callback, yaml_copy_callback],
                         strategy='ddp_find_unused_parameters_true')

    # Train the model
    trainer.fit(model, train_loader, val_loader)

if __name__ == "__main__":

    argparser = argparse.ArgumentParser(description=__doc__)
    argparser.add_argument(
        '-e',
        '--exp',
        type=str
    )
    argparser.add_argument(
        '-s',
        '--seed',
        type=int,
        default=0
    )
    args = argparser.parse_args()
    main(args)