import torch
import math

def get_scheduler(params):
    if params.scheduler == 'step':
        scheduler = torch.optim.lr_scheduler.StepLR
    elif params.scheduler == 'plateau':
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau
    else:
        raise ValueError(f'Scheduler {params.scheduler} not found')
    return scheduler

def get_optimizer(params):
    if params.optimizer == 'adam':
        optimizer = torch.optim.Adam
    elif params.optimizer == 'sgd':
        optimizer = torch.optim.SGD
    elif params.optimizer == 'rmsprop':
        optimizer = torch.optim.RMSprop
    else:
        raise ValueError(f'Optimizer {params.optimizer} not found')
    return optimizer

def get_loss(params):
    if params.loss == "cross_entropy":
        loss = torch.nn.CrossEntropyLoss(reduction="none")
    elif params.loss == "mse":
        loss = torch.nn.MSELoss(reduction="none")
    else:
        raise ValueError(f'Loss {params.loss} not found')
    return loss

def cosine_annealing(epoch, total_epochs, initial_factor, final_factor, rounds):
    """
    Cosine annealing function to adjust a factor based on epochs.
    
    Args:
    - epoch: Current epoch number.
    - total_epochs: Total number of epochs.
    - initial_factor: Initial value of the factor.
    - final_factor: Final value of the factor.
    
    Returns:
    - The adjusted factor for the current epoch.
    """
    return final_factor + 0.5 * (initial_factor - final_factor) * (1 + math.cos(math.pi * rounds * epoch / total_epochs))