# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""Core code for Stochastic Weight Averaging."""

from __future__ import annotations

import logging
import warnings
from typing import Any, Optional

import torch
from torch.optim.lr_scheduler import LRScheduler
from torch.optim.swa_utils import SWALR, AveragedModel

from composer.core import Algorithm, Event, State, Time, TimeUnit
from composer.loggers import Logger

log = logging.getLogger(__name__)

__all__ = ['SWA']


def _assert_valid_duration(time: Time):
    if time.unit == TimeUnit.DURATION and (time < 0 or time > 1):
        raise ValueError(f'time in duration units must be [0, 1], got {time}')


class SWA(Algorithm):
    """Applies Stochastic Weight Averaging (`Izmailov et al, 2018 <https://arxiv.org/abs/1803.05407>`_).

    Stochastic Weight Averaging (SWA) averages model weights sampled at
    different times near the end of training. This leads to better
    generalization than just using the final trained weights.

    Because this algorithm needs to maintain both the current value of the
    weights and the average of all of the sampled weights, it doubles the
    model's memory consumption. Note that this does not mean that the total
    memory required doubles, however, since stored activations and the
    optimizer state are not doubled.

    .. note::

       The AveragedModel is currently stored on the CPU device, which may
       cause slow training if the model weights are large.

    Uses PyTorch's `torch.optim.swa_util
    <https://pytorch.org/docs/stable/optim.html#stochastic-weight-averaging>`_
    under the hood.

    See the :doc:`Method Card </method_cards/swa>` for more details.

    Example:
        .. testcode::

            from composer.algorithms import SWA
            from composer.trainer import Trainer

            swa_algorithm = SWA(
                swa_start="6ep",
                swa_end="8ep"
            )
            trainer = Trainer(
                model=model,
                train_dataloader=train_dataloader,
                eval_dataloader=eval_dataloader,
                max_duration="10ep",
                algorithms=[swa_algorithm],
                optimizers=[optimizer]
            )

    Args:
        swa_start (str, optional): The time string denoting the amount of training
            completed before stochastic weight averaging begins. Currently only units of
            duration ('dur') and epoch ('ep') are supported. Default: ``'0.7dur'``.
        swa_end (str, optional): The time string denoting the amount of training
            completed before the baseline (non-averaged) model is replaced with the
            stochastic weight averaged model. It's important to have at least one epoch
            of training after the baseline model is replaced by the SWA model so that the
            SWA model can have its buffers (most importantly its batch norm statistics)
            updated. If ``swa_end`` occurs during the final epoch of training (e.g.
            ``swa_end = 0.9dur`` and ``max_duration = "5ep"``, or ``swa_end = 1.0dur``),
            the SWA model will not have its buffers updated, which can negatively impact
            accuracy, so ensure ``swa_end`` < :math:`\\frac{N_{epochs}-1}{N_{epochs}}`.
            Currently only units of duration ('dur') and epoch ('ep') are supported.
            Default: ``'0.97dur'``.
        update_interval (str, optional): Time string denoting how often the averaged
            model is updated. For example, ``'1ep'`` means the averaged model will be
            updated once per epoch and ``'5ba'`` means the averaged model will be updated
            every 5 batches. Note that for single-epoch training runs (e.g. many NLP
            training runs), ``update_interval`` must be specified in units of ``'ba'``,
            otherwise SWA won't happen. Also note that very small update intervals (e.g.
            ``"1ba"``) can substantially slow down training. Default: ``'1ep'``.
        schedule_swa_lr (bool, optional): Flag to determine whether to apply an
            SWA-specific LR schedule during the period in which SWA is active. Default:
            ``False``.
        anneal_strategy (str, optional): SWA learning rate annealing schedule strategy.
            ``"linear"`` for linear annealing, ``"cos"`` for cosine annealing. Default:
            ``"linear"``.
        anneal_steps (int, optional): Number of SWA model updates over which to
            anneal SWA learning rate. Note that updates are determined by the
            ``update_interval`` argument. For example, if ``anneal_steps = 10`` and
            ``update_interval = '1ep'``, then the SWA LR will be annealed once per epoch
            for 10 epochs; if ``anneal_steps = 20`` and ``update_interval = '8ba'``, then
            the SWA LR will be annealed once every 8 batches over the course of 160
            batches (20 steps * 8 batches/step). Default: ``10``.
        swa_lr (float, optional): The final learning rate to anneal towards with the SWA
            LR scheduler. Set to ``None`` for no annealing. Default: ``None``.
    """

    def __init__(
        self,
        swa_start: str = '0.7dur',
        swa_end: str = '0.97dur',
        update_interval: str = '1ep',
        schedule_swa_lr: bool = False,
        anneal_strategy: str = 'linear',
        anneal_steps: int = 10,
        swa_lr: Optional[float] = None,
    ):

        warnings.warn(
            'SWA has known issues when resuming from a checkpoint on multiple GPUs, which will cause an error when resuming without `load_weights_only=True`.',
        )
        self.schedule_swa_lr = schedule_swa_lr
        self.anneal_strategy = anneal_strategy
        self.anneal_steps = anneal_steps
        self.swa_lr = swa_lr
        self.swa_model: Optional[torch.nn.Module] = None
        self.swa_completed = False
        self.swa_started = False

        # Check timestrings are parsable and convert into time objects
        self.swa_start = Time.from_timestring(swa_start)
        self.swa_end = Time.from_timestring(swa_end)
        self.update_interval = Time.from_timestring(update_interval)

        self._validate_time()

        if anneal_steps <= 0:
            raise ValueError('anneal_steps must be greater than 0')

        # Check annealing_strategy string
        if self.anneal_strategy.lower() in ['linear', 'lin']:
            self.anneal_strategy = 'linear'
        elif self.anneal_strategy.lower() in ['cos', 'cosine']:
            self.anneal_strategy = 'cos'
        else:
            raise ValueError("anneal_strategy must be one of {'linear', 'cos'}.")

        self.swa_scheduler = None
        self.swa_model = None

        # Keeps track of # steps so that we can know when to update averaged model
        self.step_counter = 0

        # Check units for update_interval and set match event accordingly
        if self.update_interval.unit == TimeUnit.BATCH:
            self.match_event = Event.BATCH_END
        elif self.update_interval.unit == TimeUnit.EPOCH:
            self.match_event = Event.EPOCH_END

    def _validate_time(self):
        # validate time units
        if self.swa_start.unit != self.swa_end.unit:
            raise ValueError(f'swa_start and swa_end must have same units, got {self.swa_start} and {self.swa_end}')
        if self.swa_start.unit not in [TimeUnit.DURATION, TimeUnit.EPOCH]:
            raise ValueError(f'swa_start must be DURATION or EPOCH, got {self.swa_start.unit}')
        if self.update_interval.unit not in [TimeUnit.BATCH, TimeUnit.EPOCH]:
            raise ValueError(f'update_iterval must be BATCH or EPOCH, got {self.update_interval.unit}')

        # validate time
        if self.swa_start >= self.swa_end:
            raise ValueError('swa_end must be > swa_start.')
        if self.swa_end.unit == TimeUnit.DURATION and self.swa_end == 1:
            log.warning(
                "'swa_end' = '1dur'. Batch norm statistics of averaged model "
                'will not be updated. This will negatively impact accuracy. '
                'See the documentation for the `swa_end` parameter for details.',
            )

        _assert_valid_duration(self.swa_start)
        _assert_valid_duration(self.swa_end)

    def _get_time(self, state: State):
        """helper function to retrieve either the epoch or the duration depending on the units"""
        unit = self.swa_start.unit

        if unit == TimeUnit.EPOCH:
            return state.timestamp.epoch
        elif unit == TimeUnit.DURATION:
            time_elapsed = state.get_elapsed_duration()
            assert time_elapsed is not None, 'Time should have been set on BATCH_END or EPOCH_END.'
            return time_elapsed
        else:
            raise ValueError('units must be in epoch or duration.')

    def _get_last_lr(self, schedulers: list[LRScheduler]):
        """ retrieves the last lr from current schedulers. """
        if len(schedulers) == 0:
            return 1.0
        if len(schedulers) != 1:
            raise RuntimeError(f'SWA supports only one scheduler, got {len(schedulers)}')
        scheduler = schedulers[0]
        last_lr = scheduler.get_last_lr()
        if len(last_lr) != 1:
            raise RuntimeError(f'SWA supports only one LR; instead found {len(last_lr)}')
        return last_lr[0]

    def match(self, event: Event, state: State) -> bool:
        if event == Event.INIT:
            return True

        # only match on BATCH_END or EPOCH_END, depending on the setting
        if event != self.match_event or self.swa_completed:
            return False

        return self._get_time(state) >= self.swa_start

    def _initialize_swa(self, state: State) -> None:
        if self.schedule_swa_lr:
            self.swa_lr = self._get_last_lr(state.schedulers)

            if len(state.optimizers) != 1:
                raise RuntimeError('SWA supports only one optimizer')

            self.swa_scheduler = SWALR(
                state.optimizers[0],
                swa_lr=self.swa_lr,
                anneal_epochs=self.anneal_steps,
                anneal_strategy=self.anneal_strategy,  # type: ignore
            )

        self.swa_model = AveragedModel(state.model, device=torch.device('cpu'))

    def apply(self, event: Event, state: State, logger: Logger) -> None:

        if event == event.INIT:
            # on trainer init, we create the schedulers and models
            # so that the checkpoints can be loaded
            self._initialize_swa(state)
            return

        if not self.swa_started:
            # re-initialize swa once time > swa_start
            self._initialize_swa(state)
            self.swa_started = True

        if self.step_counter % self.update_interval.value == 0:
            assert self.swa_model is not None

            self.swa_model.update_parameters(state.model)  # type: ignore

            if self.schedule_swa_lr:
                assert self.swa_scheduler is not None
                self.swa_scheduler.step()

        self.step_counter += 1

        # Determine whether it's time to end SWA
        if self._get_time(state) >= self.swa_end:
            self.swa_completed = True

            if state.get_elapsed_duration() == 1:
                log.warning((
                    'The baseline model was replaced with the SWA model after the end of '
                    'training. This means that SWA model will not have its batch norm '
                    'statistics updated. This will negatively impact accuracy. See the '
                    'documentation for the `swa_end` parameter for details.'
                ))

            state.model.load_state_dict(self.swa_model.module.state_dict())  # type: ignore
            log.info('Set model to the averaged model')

    def state_dict(self) -> dict[str, Any]:
        state_dict = super().state_dict()

        # we pop the anneal_func from the SWALR state
        # since it is set in the SWALR __init__
        swa_scheduler_state = None
        if self.swa_scheduler:
            swa_scheduler_state = self.swa_scheduler.state_dict()
            swa_scheduler_state.pop('anneal_func')

        state_dict = {
            'swa_model': self.swa_model.state_dict() if self.swa_model else None,
            'swa_completed': self.swa_completed,
            'swa_started': self.swa_started,
            'swa_scheduler': swa_scheduler_state,
            'step_counter': self.step_counter,
            **state_dict,
        }
        return state_dict

    def load_state_dict(self, state: dict[str, Any]) -> None:
        self.swa_completed = state['swa_completed']
        self.step_counter = state['step_counter']
        self.swa_started = state['swa_started']

        if self.swa_scheduler and state['swa_scheduler']:
            self.swa_scheduler.load_state_dict(state['swa_scheduler'])
        if self.swa_model and state['swa_model']:
            self.swa_model.load_state_dict(state['swa_model'])
