from os import set_blocking
import torch.optim as optim

def sblock_optimizer(s_block, lr : float = 0.001, momentum: float  = 0.9, method: str = "SGD"):
    """Method to construct S-block optimizer. 

    Args:
        s_block : Network with S blocks. 
        lr (float, optional): Learning rate. Defaults to 0.001.
        momentum (float, optional): Momentum. Defaults to 0.9.
        method (str, optional): Optimization metho. Defaults to "SGD".

    Returns:
        torch.optim: Optimizer for torch.
    """
    if(method == "SGD"):
        s_optimizer = optim.SGD(s_block.parameters(), lr=lr, momentum=momentum)
    elif(method == "Adam"): 
        s_optimizer = optim.Adam(s_block.parameters(), lr=lr)
    else:
        return None
    return s_optimizer

def rmatrix_optimizer(rmatrix, width:int, height:int, lr: float = 0.001, momentum : float = 0.9, method : str = "SGD"):
    """Matrix of optimizers for the R-matrix.

    Args:
        rmatrix : An R-Matrix consisting of the different R-blocks.
        width (int): Width of the R-matrix. 
        height (int): Height of the R-matrix
        lr (float, optional): Learning rate. Defaults to 0.001.
        momentum (float, optional): Momentum. Defaults to 0.9.
        method (str): Method. Defaults to SGD. 

    Returns:
        list(list(torch.optim)): List of list of torch optimizers. 
    """

    r_optimizers = []
    for i in range(width):
        rrow = list()
        for j in range(height):
            rblock = rmatrix.get_rblock([i, j])
            if(method == "SGD"):
                r_optimizer = optim.SGD(rblock.parameters(), lr=lr, momentum=momentum)
            elif(method == "Adam"):
                r_optimizer = optim.Adam(rblock.parameters(), lr=lr)
            rrow.append(r_optimizer)
        r_optimizers.append(rrow)
    
    return r_optimizers