
import copy
import torch.optim as optim
from timm.scheduler.cosine_lr import CosineLRScheduler
import torch.distributed as dist


def get_param_list(config, model):
    param_list = [{
        'params': list(model.parameters()),
        'weight_decay': config.train.weight_decay,
    }]  # can add different part as diff weight_decay
    return param_list

def build_optimizer(config, model):
    model = model.module if hasattr(model, 'module') else model
    params = get_param_list(config, model)

    # for name, param in model.named_parameters():
    #     print(f'{name}')

    optimizer = optim.AdamW(params,
                            lr= config.train.base_lr,
                            betas=(0.9, 0.98), eps=1e-8)

    return optimizer


def build_scheduler(config, optimizer, n_iter_per_epoch):
    num_steps = int(config.train.epochs * n_iter_per_epoch) ## total step
    warmup_steps = int(config.train.warmup_epochs * n_iter_per_epoch)

    lr_scheduler = CosineLRScheduler(
        optimizer,
        t_initial=num_steps,
        lr_min=config.train.base_lr / 100,
        warmup_lr_init=0,
        warmup_t=warmup_steps,
        cycle_limit=1,
        t_in_epochs=False,
    )

    # lr_scheduler = CosineLRScheduler(
    #     optimizer,
    #     t_initial=config.train.epochs,
    #     lr_min=config.train.base_lr / 100,
    #     warmup_lr_init=config.train.base_lr / 1000,
    #     warmup_t= config.train.warmup_epochs,
    #     cycle_limit=1,
    #     t_in_epochs=False,
    # )


    return lr_scheduler

# def build_scheduler(config, optimizer, steps_per_epoch):
#     #------------ simple
#     sched_config = config.scheduler
#     scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=sched_config.milestones, gamma=sched_config.lr_decay)
#
#
#     # warmup_epochs = config.scheduler.warmup.epochs
#     # main_epochs = config.scheduler.epochs - warmup_epochs
#     #
#     # warmup_scheduler = _create_warmup(config, warmup_epochs)
#     # main_scheduler = _create_main_scheduler(config, main_epochs)
#     #
#     # scheduler_func = CombinedScheduler([warmup_scheduler, main_scheduler])
#     # scheduler_func.multiply_steps(steps_per_epoch)
#     #
#     # scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, scheduler_func)
#     return scheduler