"""
Optimizer and learning rate scheduler configurations.
"""

import torch
from torch.optim import Adam, AdamW, SGD
from torch.optim.lr_scheduler import (
    CosineAnnealingLR,
    ReduceLROnPlateau,
    StepLR,
    MultiStepLR
)
from typing import Optional


def build_optimizer(
    model: torch.nn.Module,
    optimizer_type: str = 'adam',
    lr: float = 1e-3,
    weight_decay: float = 0.0,
    **kwargs
) -> torch.optim.Optimizer:
    """
    Build optimizer.
    
    Args:
        model: Model to optimize
        optimizer_type: Type of optimizer ('adam', 'adamw', 'sgd')
        lr: Learning rate
        weight_decay: Weight decay (L2 regularization)
        **kwargs: Additional optimizer-specific arguments
    
    Returns:
        Optimizer instance
    """
    params = model.parameters()
    
    if optimizer_type.lower() == 'adam':
        return Adam(
            params,
            lr=lr,
            weight_decay=weight_decay,
            betas=kwargs.get('betas', (0.9, 0.999)),
            eps=kwargs.get('eps', 1e-8)
        )
    
    elif optimizer_type.lower() == 'adamw':
        return AdamW(
            params,
            lr=lr,
            weight_decay=weight_decay,
            betas=kwargs.get('betas', (0.9, 0.999)),
            eps=kwargs.get('eps', 1e-8)
        )
    
    elif optimizer_type.lower() == 'sgd':
        return SGD(
            params,
            lr=lr,
            weight_decay=weight_decay,
            momentum=kwargs.get('momentum', 0.9),
            nesterov=kwargs.get('nesterov', False)
        )
    
    else:
        raise ValueError(f"Unknown optimizer type: {optimizer_type}")


def build_scheduler(
    optimizer: torch.optim.Optimizer,
    scheduler_type: str = 'cosine',
    **kwargs
) -> Optional[torch.optim.lr_scheduler._LRScheduler]:
    """
    Build learning rate scheduler.
    
    Args:
        optimizer: Optimizer to schedule
        scheduler_type: Type of scheduler ('cosine', 'plateau', 'step', 'multistep', 'none')
        **kwargs: Scheduler-specific arguments
    
    Returns:
        Scheduler instance or None
    """
    if scheduler_type.lower() == 'cosine':
        return CosineAnnealingLR(
            optimizer,
            T_max=kwargs.get('t_max', 100),
            eta_min=kwargs.get('eta_min', 0.0)
        )
    
    elif scheduler_type.lower() == 'plateau':
        return ReduceLROnPlateau(
            optimizer,
            mode=kwargs.get('mode', 'min'),
            factor=kwargs.get('factor', 0.1),
            patience=kwargs.get('patience', 10),
            verbose=kwargs.get('verbose', True)
        )
    
    elif scheduler_type.lower() == 'step':
        return StepLR(
            optimizer,
            step_size=kwargs.get('step_size', 30),
            gamma=kwargs.get('gamma', 0.1)
        )
    
    elif scheduler_type.lower() == 'multistep':
        return MultiStepLR(
            optimizer,
            milestones=kwargs.get('milestones', [30, 60, 90]),
            gamma=kwargs.get('gamma', 0.1)
        )
    
    elif scheduler_type.lower() == 'none':
        return None
    
    else:
        raise ValueError(f"Unknown scheduler type: {scheduler_type}")