import torch

def checkpoint(model_path, 
               s_block, s_optimizer, 
               rmatrix, rmatrix_optimizers, 
               width, height,
               save=True,load=False):
    """Method to checkpoint the models.

    Args:
        model_path (str): Path to save model. 
        save (bool, optional): Whether to save the model. Defaults to True.
        load (bool, optional): Whether to load the model. Defaults to False.
    """
    if(save):
        save_dict = {}
        save_dict['s_block_state_dict'] = s_block.state_dict()
        save_dict['s_optimizer_state_dict'] =  s_optimizer.state_dict()

        save_dict = {}
        for i in range(width):
            for j in range(height):
                rmatrix_optimizer_i_j = rmatrix_optimizers[i][j]
                rmatrix_i_j = rmatrix.get_rblock([i, j])
                save_dict[f'r_net_{i}_{j}'] = rmatrix_i_j.state_dict(),
                save_dict[f'r_optimizer_{i}_{j}'] = rmatrix_optimizer_i_j.state_dict()
        torch.save(save_dict, model_path)
          
    if(load):
        checkpoint = torch.load(model_path)
        s_block.load_state_dict(checkpoint['s_block_state_dict'])
        s_optimizer.load_state_dict(checkpoint['s_optimizer_state_dict'])

        save_dict = {}
        for i in range(width):
            for j in range(height):
                rmatrix_optimizer_i_j = rmatrix_optimizers[i][j]
                rmatrix_i_j = rmatrix.get_rblock([i, j])
                
                rmatrix_optimizer_i_j.load_state_dict(checkpoint[f'r_optimizer_{i}_{j}'])
                rmatrix_i_j.load_state_dict(checkpoint[f'r_net_{i}_{j}'])