"""
CODE ADAPTED FROM: https://github.com/microsoft/cf-ode/tree/main

"""
#!/usr/bin/env python
import argparse
import sys
sys.path.insert(0,"../")

import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger, TensorBoardLogger, MLFlowLogger
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.callbacks import LearningRateMonitor, Callback
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
#from lightning.pytorch import seed_everything
from pytorch_lightning import seed_everything
from pytorch_lightning.callbacks import LearningRateMonitor
from predict_multistep import multistep_predict


import os

from data_loaders import data_utils, tumor_data_utils  #, cv_data_utils, covid_data_utils  moabb_data_utils, robert_data_utils,
import buildFlow
from buildFlow import cnfModule
import models
#from azureml.core.run import Run
import torch

#python train_model.py --logger_type=wandb --seed=44 --max_epochs=2

def get_logger(model_type, dataset_name, loss_func, run_note, args, entity = "usrname"):
    init_kwargs = {'group': loss_func, "config": args}
    logger = WandbLogger(
        name=f"{dataset_name}_{model_type}_{run_note}",
        project="crlode",
        #entity=entity,
        log_model=False,
        **init_kwargs) #False


    return logger

def get_logdir(logger):
    log_dir = getattr(logger, 'log_dir', False) or logger.experiment.dir 

def main(model_cls, dataset_cls, args):
    
    # Instantiate objects according to parameters
    cpus = int(os.environ.get('SLURM_CPUS_PER_TASK',1))
    print("number of CPUS (using only 1 for dataloader)")
    print(cpus)
    dataset = dataset_cls(**vars(args),num_workers= 1 )#min(cpus,8)) #)min(cpus/2,4))
    dataset.prepare_data()

    action_dim = dataset.action_dim #1
    obs_dim = dataset.input_dim
    output_dim = dataset.output_dim
    print(torch.cuda.is_available())
    devices = args.gpu if torch.cuda.is_available() else 1#cpus
    print('Running with hyperparameters:')
    print(args)

    
    model_type = "ODE"


    logger = get_logger(model_type, args.dataset_name, args.loss_func, args.run_note, args, entity = args.entity)

    print('logging')
    print(args)
    #logger.log_hyperparams(args)


    log_dir = get_logdir(logger)
    
    checkpoint_cb = ModelCheckpoint(
        dirpath=log_dir,
        monitor='val_rmse' ,
        mode='min',
        verbose=True,
        save_last = True,
    )

    early_stop_callback =   EarlyStopping(monitor='val_loss',
            min_delta=0.00,
            patience=50,#200,
            verbose=False,
        mode='min'
        )
    
    lr_monitor = LearningRateMonitor(logging_interval='step')

    checkpoint_propensity_model_path = None

    seed_everything(dataset.seed)
    model = model_cls( obs_dim = obs_dim, action_dim = action_dim, checkpoint_propensity = checkpoint_propensity_model_path,
        **vars(args),
        )
    trainer = pl.Trainer(
        #gpus=gpu,
        logger=logger,
        log_every_n_steps=10, #5
        max_epochs=args.max_epochs,
        accelerator='gpu' if torch.cuda.is_available() else 'cpu',
        devices = devices,
        #strategy='ddp',
        callbacks=[checkpoint_cb, early_stop_callback, lr_monitor],
        gradient_clip_val=0.5,
        gradient_clip_algorithm="value",
        deterministic = True,
        #accumulate_grad_batches=10,
        #profiler="simple" #"pytorch"
    )

    trainer.fit(model, datamodule=dataset)

    if dataset_cls == tumor_data_utils.TumorDataModule:
        print('Testing multistep prediction for tumor data')
        multistep_predict( model, args.data_path, 5, trainer, args)
    else:
        print('Testing the model')
        test_results = trainer.test(ckpt_path="best", dataloaders=dataset.test_dataloader()) #ckpt_path="last", 

if __name__ == '__main__':
    parser = argparse.ArgumentParser(add_help=False)
    parser.add_argument('--max_epochs', type=int, default=300) 
    parser.add_argument('--gpu', default=1, type=int)
    parser.add_argument('--model', default = "cnf", type = str)
    parser.add_argument('--dataset_name', default = "pendulum", type = str, help = "dataset to train on")
    parser.add_argument('--entity', default = "usrname", type = str, help = "name of the wandb logger entity")
    parser.add_argument("--run_note",type=str,default = 'test', help = "Note for naming the run")

    partial_args, _ = parser.parse_known_args()
    
    model_cls = cnfModule 
    
    if partial_args.dataset_name=="pendulum":
        dataset_cls = data_utils.PendulumDataModule
    elif partial_args.dataset_name=="tumor":
        dataset_cls = tumor_data_utils.TumorDataModule
    else:
        raise("Invalid dataset name")

    parser = model_cls.add_model_specific_args(parser)
    parser = dataset_cls.add_dataset_specific_args(parser)
    args = parser.parse_args()

    main(model_cls, dataset_cls, args)


