import torch

from torch.optim import AdamW, Optimizer
from torch.distributed.optim import ZeroRedundancyOptimizer
from transformers import get_scheduler
from typing import Iterable

from LLMProxy.option import TrainArg


def build_optimizer(
    params: Iterable[torch.nn.parameter.Parameter],
    args: TrainArg
) -> Optimizer:
    if args.use_zero:
        return ZeroRedundancyOptimizer(
            params,
            optimizer_class=torch.optim.Adam,
            lr=args.learning_rate
        )
    return AdamW(params=params, lr=args.learning_rate)


def build_scheduler(args: TrainArg, optimizer: Optimizer):
    return get_scheduler(
        args.lr_scheduler,
        optimizer=optimizer,
        num_warmup_steps=args.num_warmup_steps,
        num_training_steps=args.max_update,
    )