"""
Utility functions.

Author:
Date: October 28, 2023
"""
import os
from typing import Optional

import hydra
import numpy as np
from omegaconf import OmegaConf
import torch
import torch.nn.functional as F


def load_model(path: str, checkpoint: Optional[int] = None, map_location='cpu'):
    """Load in a trained model."""
    config = OmegaConf.load(os.path.join(path, 'config.yaml'))
    model = hydra.utils.instantiate(config['model']['architecture'])
    if checkpoint is not None:
        state_dict = torch.load(os.path.join(path, 'checkpoints',
                                             f'ep_{checkpoint}.pt'),
                                map_location='cpu')
    else:
        best_files = [fn for fn in os.listdir(path) if 'best' in fn]
        if len(best_files):
            eps = [int(bn[len('best_ep_'):-3]) for bn in best_files]
            best_idx = np.argmax(eps)
            state_dict = torch.load(os.path.join(path, best_files[best_idx]),
                                    map_location=map_location)
        else:
            state_dict = torch.load(os.path.join(path, 'final.pt'),
                                    map_location=map_location)
    model.load_state_dict(state_dict)
    return model


def get_activation(name: str):
    """Get an activation function."""
    if name.lower() == 'relu':
        return F.relu
    if name.lower() == 'tanh':
        return F.tanh
    else:
        raise ValueError(f'Unkown activation {name}')


class IdentityLayer(torch.nn.Module):

    def forward(self, x_in: torch.Tensor) -> torch.Tensor:
        return x_in


class NoSchedule:

    def step(self):
        pass
