### EMA weights helper class
### Inspired by OpenAI

from contextlib import contextmanager

import torch as t
import torch.nn as nn


@contextmanager
def temporary_weight_swap(model: nn.Module, new_weights: list[t.Tensor]):
    for _p, new_p in zip(model.parameters(), new_weights, strict=True):
        assert _p.shape == new_p.shape
        _p.data, new_p.data = new_p.data, _p.data

    yield

    for _p, new_p in zip(model.parameters(), new_weights, strict=True):
        assert _p.shape == new_p.shape
        _p.data, new_p.data = new_p.data, _p.data


class EmaModel:
    def __init__(self, model: nn.Module, ema_multiplier, update_after_step=10):
        self.model = model
        self.ema_multiplier = ema_multiplier
        self.ema_weights = [t.zeros_like(x, requires_grad=False) for x in model.parameters()]
        self.ema_steps = 0
        self.update_after_step = update_after_step

    def step(self):
        if self.ema_steps < self.update_after_step:
            pass
        else:
            t._foreach_lerp_(
                self.ema_weights,
                list(self.model.parameters()),
                1 - self.ema_multiplier,
            )

        self.ema_steps += 1

    # context manager for setting the autoencoder weights to the EMA weights
    @contextmanager
    def use_ema_weights(self):
        assert self.ema_steps > 0

        # apply bias correction
        bias_correction = 1 - self.ema_multiplier**self.ema_steps
        ema_weights_bias_corrected = t._foreach_div(self.ema_weights, bias_correction)

        with t.no_grad():
            with temporary_weight_swap(self.model, ema_weights_bias_corrected):
                yield

    def update_model_weights(self):
        for p, ema_p in zip(self.model.parameters(), self.ema_weights):
            p.data.copy_(ema_p.data)
