import torch
from .training import TrainingManager
from .files import FileHandler
from data.data import available_datasets

'''
The save_experiment function should save the model, the evaluation metrics and the parameters of the run.

Arguments:
- checkpoint_epoch: the epoch of the checkpoint. If None, will save with no suffix
- files: files = 'all', it saves all files. if files included in ['model', 'eval', 'param'], it saves only the specified files
- save_dir: the directory where to save the files
- new_eval_subdir: if True, it saves eval and param in a subfolder of the save_dir; this is helpful when you want to save multiple evaluations of the same model
- checkpoint_epoch: if None, it saves the files with no suffix, otherwise it saves them with the suffix '_{}'.format(checkpoint_epoch)

Returns:
- List of paths the saved files
'''
def save_experiment(
    p : dict,
    trainer : TrainingManager,
    fh : FileHandler,
    save_dir : str,
    checkpoint_steps = None,
    files = 'all',
    new_eval_subdir=False,
    topological_losses=None,
    ):
    if isinstance(files, str):
        files = [files]
    
    for f in files:
        assert f in ['all', 'model', 'eval', 'param'], 'specified file to save must be either all, model, eval, param'
    
    model_path, param_path, eval_path = fh.get_paths_from_param(p,
                                                            folder_path=save_dir, 
                                                            curr_step=checkpoint_steps,
                                                            make_new_dir = True, 
                                                            new_eval_subdir=new_eval_subdir,
                                                            )
    
    available = available_datasets()
    if p['data']['dataset'].lower() in available['custom']:
        # if we are using a custom dataset, we need to save the dataset state dict
        dataset_path = param_path.replace('parameters', 'train_dataset')
        print('saving dataset to {}'.format(dataset_path))
        torch.save(trainer.data.dataset._data, dataset_path)
    
    if 'all' in files:
        trainer.save(model_path)
        trainer.save_eval_metrics(eval_path, topological_losses=topological_losses)
        torch.save(p, param_path)
        return model_path, param_path, eval_path
    
    # else, slightly more complicated logic
    objects_to_save = {name: {'path': path, 'saved':False} for name, path in zip(['model', 'eval', 'param'],
                                                                    [model_path, eval_path, param_path])}
    for name, obj in objects_to_save.items():
        if name in files:
            obj['saved'] = True
            if name == 'model':
                trainer.save(obj['path'])
            if name == 'eval':
                trainer.save_eval_metrics(obj['path'], topological_losses=topological_losses)
            if name == 'param':
                torch.save(p, obj['path'])
    
    # return values in the right order
    return tuple(objects_to_save[name]['path'] if objects_to_save[name]['saved'] else None for name in ['model', 'eval', 'param'])


def load_experiment(
    p : dict,
    trainer : TrainingManager,
    fh : FileHandler,
    save_dir,
    checkpoint_steps=None,
    get_topological_losses=False,
    ):    
    model_path, param_path, eval_path = fh.get_paths_from_param(
                                p, 
                                folder_path=save_dir,
                                curr_step=checkpoint_steps, # if None, will load the latest checkpoint (checks filename suffix)
                                make_new_dir=False
                                )
    
    available = available_datasets()
    if p['data']['dataset'].lower() in available['custom']:
        # if we are using a custom dataset, we need to load the dataset state dict
        dataset_path = param_path.replace('parameters', 'train_dataset')
        print('loading dataset from {}'.format(dataset_path))
        trainer.data.dataset._data = torch.load(dataset_path)
    
    print('loading from model file {}'.format(model_path))
    trainer.load(model_path)
    print('loading from eval file {}'.format(eval_path))
    return trainer.load_eval_metrics(eval_path, get_topological_losses=get_topological_losses)
    