from __future__ import annotations

from typing import Iterable
import torch


def freeze(module: torch.nn.Module):
    for p in module.parameters():
        p.requires_grad_(False)


def unfreeze(module: torch.nn.Module):
    for p in module.parameters():
        p.requires_grad_(True)


def count_parameters(module: torch.nn.Module) -> int:
    return sum(p.numel() for p in module.parameters())


def zero_grad(modules: Iterable[torch.nn.Module]):
    for m in modules:
        if hasattr(m, "zero_grad"):
            m.zero_grad(set_to_none=True)

