from collections.abc import Iterator

import torch
from torch import optim
from .landing import LandingStiefelSGD
from .landingadam import LandingStiefelAdam

OPTIMIZERS_NAMES = ["Adam", "Adagrad", "NAdam", "SGD", "LandingSGD", "LandingAdam"]


def setup_optimizer(
    parameters: Iterator[torch.nn.Parameter],
    opt: str,
    learning_rate: float,
    decay1: float = 0.9,
    decay2: float = 0.999,
    momentum: float = 0.9,
    weight_decay: float = 0,
    lambda_regul: float = 1.,
    nesterov: bool = False,
    dynamic_project: bool = False,
):
    parameters = list(parameters)  # LandingSGD needs a list, not an iterator
        
    if opt == "Adagrad":
        return optim.Adagrad(parameters, lr=learning_rate, weight_decay=weight_decay)
    if opt == "Adam":
        return optim.Adam(
            parameters,
            lr=learning_rate,
            betas=(decay1, decay2),
            weight_decay=weight_decay,
        )
    if opt == "NAdam":
        return optim.NAdam(
            parameters,
            lr=learning_rate,
            betas=(decay1, decay2),
            weight_decay=weight_decay,
        )
    if opt == "SGD":
        return optim.SGD(
            parameters, lr=learning_rate, momentum=momentum, weight_decay=weight_decay
        )
    if opt == "LandingAdam":
        return LandingStiefelAdam(
            parameters, learning_rate, betas=(decay1, decay2),
            weight_decay=weight_decay, check_type=False,
            lambda_regul=lambda_regul, safe_step=0.5,
            dproject=dynamic_project,
        )
    if opt == 'LandingSGD':
        return LandingStiefelSGD(
            parameters, learning_rate, safe_step=0.5,
            momentum=momentum, weight_decay=weight_decay, nesterov=nesterov, check_type=False,
            lambda_regul=lambda_regul, dproject=dynamic_project,
        )
  
    raise ValueError(f"Unknown optimiser called {opt}")
