from typing import List

import torch
from torch.optim.optimizer import Optimizer
from torch.optim.sgd import SGD
from torch.optim.adam import Adam
from torch.optim.adamw import AdamW
from torch.optim.adadelta import Adadelta

from .ranger import Ranger
from options import OptimizationConfig

def load_optimizer(parameters: List[torch.Tensor], cfg: OptimizationConfig) -> Optimizer:
    print('-' * 50)
    print('OPTIMIZER TYPE:', cfg.optimizer_type)
    print('LR:', cfg.lr)
    print('MOMENTUM:', cfg.momentum)
    print('WEIGHT DECAY:', cfg.weight_decay)
    print('-' * 50)
    if cfg.optimizer_type == "sgd":
        return SGD(parameters, lr=cfg.lr, momentum=cfg.momentum, 
                   weight_decay=cfg.weight_decay)
        
    if cfg.optimizer_type == "adam":
        return Adam(parameters, lr=cfg.lr, betas=cfg.betas, 
                    weight_decay=cfg.weight_decay)
    
    elif cfg.optimizer_type == "adamw":
        return AdamW(parameters, lr=cfg.lr, betas=cfg.betas, 
                     weight_decay=cfg.weight_decay)
    
    elif cfg.optimizer_type == "adadelta":
        return Adadelta(parameters, lr=cfg.lr, weight_decay=cfg.weight_decay)
    
    elif cfg.optimizer_type == "ranger":
        return Ranger(parameters, lr=cfg.lr, betas=cfg.betas, 
                      weight_decay=cfg.weight_decay, use_gc=cfg.ranger_use_gc, 
                      gc_conv_only=False)
    else:
        raise ValueError(f"Unknown Optimizer Type: {cfg.optimizer_type}")

