import torch

from accelerate.utils import DummyOptim, DummyScheduler
from torch.optim import AdamW, Optimizer
from transformers import get_scheduler
from typing import Iterable

from CoLM.option import TrainArg


def build_optimizer(
    params: Iterable[torch.nn.parameter.Parameter],
    args: TrainArg,
    is_deepspeed: bool = False,
) -> Optimizer:
    if is_deepspeed:
        return DummyOptim(params=params, lr=args.learning_rate)
    return AdamW(params=params, lr=args.learning_rate)


def build_scheduler(
    args: TrainArg,
    optimizer: Optimizer,
    is_deepspeed: bool = False,
):
    if is_deepspeed:
        return DummyScheduler(
            optimizer, 
            total_num_steps=args.max_update * args.distributed_world_size,
            warmup_num_steps=args.num_warmup_steps,
        )
    return get_scheduler(
        args.lr_scheduler,
        optimizer=optimizer,
        num_warmup_steps=args.num_warmup_steps,
        num_training_steps=args.max_update,
    )