from pathlib import Path
import os
import numpy as np
import torch


def set_bn_eval(module):
    for submodule in module.modules():
        if 'batchnorm' in submodule.__class__.__name__.lower():
            submodule.train(False)


def set_train(model):
    """Disable batch normalization when training."""
    model.train()
    set_bn_eval(model)


def set_eval(model):
    model.eval()


def get_task_signature(args):
    signature = f'{args.dataset}' \
                f'-stepsize{args.step_size}-lr{args.lr}' \
                f'-optimizer{args.optimizer}-{args.comment}'
    return signature


def save_model(model, args):
    model_dir = os.path.join(args.logdir, 'models/{}/steps{:05d}'.format(
        get_task_signature(args), args.num_steps))
    Path(model_dir).mkdir(parents=True, exist_ok=True)
    torch.save(model.state_dict(), '{}/model.pth'.format(model_dir))
    print('saved ' + '{}/model.pth'.format(model_dir))


def cache_model(model, args):
    cache_dir = os.path.join(args.logdir, 'models/{}/cache'.format(get_task_signature(args)))
    Path(cache_dir).mkdir(parents=True, exist_ok=True)
    cached_model = os.path.join(cache_dir, 'model.pth')
    torch.save(model.state_dict(), cached_model)
    print('saved ' + cached_model)


def cached_model_path(args):
    cache_dir = os.path.join(args.logdir, 'models/{}/cache'.format(get_task_signature(args)))
    Path(cache_dir).mkdir(parents=True, exist_ok=True)
    return os.path.join(cache_dir, 'model.pth')
