from typing import Callable, Union

from omegaconf import DictConfig


def make_learning_rate_schedule(init_lr: float, config: DictConfig) -> Callable:
    """Makes a very simple linear learning rate scheduler.

    Args:
    ----
        init_lr: initial learning rate.
        config: system configuration.

    Note:
    ----
        We use a simple linear learning rate scheduler based on the suggestions from a blog on PPO
        implementation details which can be viewed at http://tinyurl.com/mr3chs4p
        This function can be extended to have more complex learning rate schedules by adding any
        relevant arguments to the system config and then parsing them accordingly here.

    """

    def linear_scedule(count: int) -> float:
        frac: float = (
            1.0
            - (count // (config.system.ppo_epochs * config.system.num_minibatches))
            / config.system.num_updates
        )
        return init_lr * frac

    return linear_scedule


def make_learning_rate(init_lr: float, config: DictConfig) -> Union[float, Callable]:
    """Retuns a constant learning rate or a learning rate schedule.

    Args:
    ----
        init_lr: initial learning rate.
        config: system configuration.

    Returns:
    -------
        A learning rate schedule or fixed learning rate.

    """
    if config.system.decay_learning_rates:
        return make_learning_rate_schedule(init_lr, config)
    else:
        return init_lr


def adjust_config_for_gradient_accumulation(config: DictConfig) -> DictConfig:
    """Adjusts the number of updates, evaluations and parallel envirionments to account for
    gradient accumulation. For example, if we have 64 parallel environments, 1220 updates and 122
    evaluations, and we want to accumulate gradients over 4 steps, we would need to adjust the
    number of parallel environments to 64 / 4 = 16, the number of updates to 1220 * 4 = 4800 and the
    number of evaluations to 122 * 4 = 488.

    Args:
    ----
        config: system configuration.

    Returns:
    -------
        The adjusted system configuration.

    Raises:
    -------
        AssertionError: If the number of parallel environments is not divisible by the number of
        gradient accumulation steps.

    """
    if config.arch.grad_accumulation_steps == 1:
        return config

    if config.arch.num_envs % config.arch.grad_accumulation_steps != 0:
        raise AssertionError(
            "The number of parallel environments must be divisible by the number of gradient "
            f"accumulation steps. Got {config.arch.num_envs} parallel environments and "
            f"{config.arch.grad_accumulation_steps} gradient accumulation steps."
        )

    config.arch.num_envs //= config.arch.grad_accumulation_steps
    config.system.num_updates *= config.arch.grad_accumulation_steps
    config.arch.num_evaluation *= config.arch.grad_accumulation_steps
    return config
