# optim/optimizer_wrapper.py
"""
Small wrappers for optimizer, scheduler and gradient-clipping utilities.
Provides:
  - make_optimizer(model, lr)
  - make_lr_scheduler(optimizer, warmup_steps, total_steps)
  - clip_gradients(model, max_norm)
"""
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR
from typing import Iterable

def make_optimizer(model: Iterable, lr: float = 1e-3, weight_decay: float = 0.0):
    return optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

def make_linear_warmup_scheduler(optimizer, warmup_steps: int, total_steps: int):
    """
    Linear warmup followed by linear decay to zero.
    """
    def lr_lambda(step):
        if step < warmup_steps:
            return float(step) / float(max(1, warmup_steps))
        return max(0.0, float(total_steps - step) / float(max(1, total_steps - warmup_steps)))
    return LambdaLR(optimizer, lr_lambda)

def clip_gradients(model, max_norm: float = 1.0):
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
