from __future__ import annotations

import torch


def polyak_update(source: torch.nn.Module, target: torch.nn.Module, tau: float):
    with torch.no_grad():
        for p, p_t in zip(source.parameters(), target.parameters()):
            p_t.data.mul_(1 - tau).add_(tau * p.data)


def hard_update(source: torch.nn.Module, target: torch.nn.Module):
    target.load_state_dict(source.state_dict())

