from torch.utils.data import DataLoader
import os

from manage.files import FileHandler
from data.data import get_dataset, CurrentDatasetInfo, Modality, StateSpace
from manage.logger import Logger
from manage.generation import GenerationManager
from manage.training import TrainingManager
from evaluate.EvaluationManager import EvaluationManager
from manage.checkpoints import load_experiment, save_experiment
from manage.setup import _get_device, _optimize_gpu, _set_seed

from ddpm_init import init_method_ddpm, init_models_optmizers_ls, init_learning_schedule

from script_utils import *


CONFIG_PATH = './configs/'

def eval_exp(config_path):
    args = parse_args()
    
    # Specify directory to save and load checkpoints
    checkpoint_dir = 'checkpoints'
    save_dir = os.path.join(checkpoint_dir, args.name)
    
    
    # open and get parameters from file
    p = FileHandler.get_param_from_config(config_path, args.config + '.yml')
    update_parameters_before_loading(p, args)
    

    trainer, logger, file_handler, models, optimizers, learning_schedules, method, eval, gen_manager = initialize_experiment(p)

    if args.reset_eval:
        print('Resetting eval dictionnary')
        load_experiment(
                p=p,
                trainer=trainer,
                fh = file_handler,
                save_dir=save_dir,
                checkpoint_steps=args.resume_steps if args.resume_steps is not None else None,
                )
        trainer.eval.reset(keep_losses=True, keep_evals=False)
        # in any case, save the final model
        save_experiment(p=p,
                        trainer = trainer,
                        fh = file_handler,
                        save_dir=save_dir,
                        checkpoint_steps=p['run']['steps'],
                        files = 'eval')
        print('Eval dictionnary reset and saved.')
        
    # log some additional information
    additional_logging(p,
        logger,
        trainer,
        file_handler,
        args
    ) 
    
    # print parameters to stdout
    print_dict(p)
    
    if (p['run']['eval_freq'] is None) or (p['run']['eval_freq'] == 0):
        print('No evaluation frequency specified. Evaluating on latest checkpoint.')
        args.latest_checkpoint = True
    
    if args.latest_checkpoint:
        checkpoint_steps = [None]
    else:
        checkpoint_steps = range(int(p['run']['eval_freq']), int(p['run']['steps'] + 1), int(p['run']['eval_freq']))
    
    for step in checkpoint_steps:
        print('Evaluating model at step {}'.format(step))
        load_experiment(
                p=p,
                trainer=trainer,
                fh = file_handler,
                save_dir=save_dir,
                checkpoint_steps=step
                )
        if args.force_ema_eval:
            trainer.evaluate(evaluate_emas=True)    
        else:
            # evalute models
            print('Evaluating EMA and non-EMA models')
            trainer.evaluate(evaluate_emas=True)
            trainer.evaluate(evaluate_emas=False)
        
        paths = save_experiment(
            p=p,
            trainer = trainer,
            fh = file_handler,
            save_dir = save_dir,
            files=['eval', 'param'], 
            new_eval_subdir=True, 
            checkpoint_steps=trainer.total_steps
            )
        print('Saved (model, eval, param) in ', paths)
    
    # terminates logger 
    if logger is not None:
        logger.stop()


if __name__ == '__main__':
    eval_exp(CONFIG_PATH)