from torch.utils.data import DataLoader
import os

from manage.files import FileHandler
from data.data import get_dataset, CurrentDatasetInfo, Modality, StateSpace
from manage.logger import Logger
from manage.generation import GenerationManager
from manage.training import TrainingManager
from evaluate.EvaluationManager import EvaluationManager
from manage.checkpoints import load_experiment, save_experiment
from manage.setup import _get_device, _optimize_gpu, _set_seed

from ddpm_init import init_method_ddpm, init_models_optmizers_ls, init_learning_schedule, init_optimizer

from script_utils import *


CONFIG_PATH = './configs/'


    
def run_exp(config_path):
    args = parse_args()
    
    # Specify directory to save and load checkpoints
    checkpoint_dir = 'checkpoints'
    save_dir = os.path.join(checkpoint_dir, args.name)
    
    
    # open and get parameters from file
    p = FileHandler.get_param_from_config(config_path, args.config + '.yml')
    update_parameters_before_loading(p, args)

    trainer, logger, file_handler, models, optimizers, learning_schedules, method, eval, gen_manager = initialize_experiment(p)
    
    # print the number of parameters in the models
    print('Number of parameters in the models:')
    for name, model in models.items():
        print(f'{name}: {sum(p.numel() for p in model.parameters())}')
    
    # load if necessary. Must be done here in case we have different hashes afterward
    if args.resume:
        # Look if thre is a model to load
        import pathlib
        save_dir_files = pathlib.Path(save_dir).glob('./*/model*.pt')
        
        # check that the checkpoint directory exists
        if not os.path.exists(save_dir):
            print('Checkpoint directory does not exist. Starting training from scratch.')
        elif len(list(save_dir_files)) == 0:
            print('No model to load. Starting training from scratch.')
        else:
            load_experiment(
                p=p,
                trainer=trainer,
                fh = file_handler,
                save_dir=save_dir,
                checkpoint_steps=args.resume_steps if args.resume_steps is not None else None,
                )
    
    if args.from_pretrained is not None:
        # load pretrained weights
        print('Loading pretrained weights from {}'.format(args.from_pretrained))
        trainer.load(args.from_pretrained)
        # reinitialize training steps
        trainer.total_steps = 0
        trainer.epochs = 0
        # assumes pre-training was done with a different optimizer
        print('Reintializing optimizer (assuming pre-training was done with a different optimizer)')
        for name in optimizers:
            optimizers[name] = init_optimizer(p, models[name])
        trainer.optimizers = optimizers
        # reinitialize learning schedule
        for name in learning_schedules:
            learning_schedules[name] = init_learning_schedule(p, optimizers[name])
        trainer.learning_schedules = learning_schedules
        
    
    # update parameters after loading, like new optim learning rate...
    update_experiment_after_loading(p, 
        optimizers,
        learning_schedules,
        init_learning_schedule,
        args,
    )
    
    # log some additional information
    additional_logging(p,
        logger,
        trainer,
        file_handler,
        args
    ) 
    
    # print parameters to stdout
    print_dict(p)
        
    # run training
    def checkpoint_callback(checkpoint_steps):
        print('saved files to', 
              save_experiment(p=p,
                              trainer=trainer,
                              fh = file_handler,
                              save_dir=save_dir,
                              checkpoint_steps=checkpoint_steps))

    # run the training loop wuth parameters from the configuration file
    # specifying arguments here will overwrite the arguments obtained from the configuration file, for this training run
    trainer.train(
        total_steps=p['run']['steps'], 
        checkpoint_callback=checkpoint_callback,
        no_ema_eval=args.no_ema_eval, # if True, will not run evaluation with EMA models
        progress= p['run']['progress'], # if True, will print progress bar
        max_batch_per_epoch= args.n_max_batch, 
        stop_lower_loss_threshold= args.stop_lower_loss_threshold, # if True, will stop training if loss is lower than threshold
    )
    
    # in any case, save the final model
    save_loc = save_experiment(p=p,
                    trainer = trainer,
                    fh = file_handler,
                    save_dir=save_dir,
                    checkpoint_steps=trainer.total_steps)
    print('Saved (model, eval, param) in ', save_loc)
    # terminates logger 
    if logger is not None:
        logger.stop()



if __name__ == '__main__':
    run_exp(CONFIG_PATH)