def forward(transformed_list, blind_list, blind_mat, unblind_mat, loss_fn, is_blind, image, target):
    prev_val = image
    if is_blind:
            prev_val = blind(image, blind_mat).detach()

    node_list = []
    loss_list = []
    grad_list = []

    for idx, layer in enumerate(transformed_list):
        # fetch prev_val
            
        if idx != 0:
            prev_val = node_list[idx-1].detach()

        if blind_list[idx] == 1 and is_blind:
            # unblind
            prev_val = unblind(prev_val, unblind_mat).detach()

        prev_val.requires_grad = True
        grad_list.append(prev_val)                    
    
        res = layer(prev_val)
        loss_list.append(res)
    
        if blind_list[idx] == 1 and is_blind:
            res = blind(res, blind_mat).detach()
        node_list.append(res)
    
    # loss calculation
    if is_blind:
        res = unblind(res, unblind_mat).detach()
    else:
        res = res.detach()
    res.requires_grad = True
    loss = loss_func(res, target)

    return loss, res, node_list, loss_list, grad_list

def backward(loss, res, node_list, loss_list, grad_list):
    loss.backward()
    prev_grad = res.grad
        
    for idx in range(len(loss_list)):
        newidx = len(node_list) - idx - 1
        
        loss = loss_list[newidx]
        loss.backward(gradient=prev_grad)
        # cleaning
        if idx == 0:
            del res.grad
        else:
            del grad_list[newidx+1].grad
        
        prev_grad = grad_list[newidx].grad

def forward_val(transformed_list, blind_list, blind_mat, unblind_mat, is_blind, image):
    prev_val = image
    if is_blind:
            prev_val = blind(image, blind_mat).detach()

    node_list = []

    for idx, layer in enumerate(transformed_list):
        if idx != 0:
            prev_val = node_list[idx-1].detach()

        if blind_list[idx] == 1 and is_blind:
            # unblind                                                                                                  
            prev_val = unblind(prev_val, unblind_mat).detach()


        res = layer(prev_val)

        if blind_list[idx] == 1 and is_blind:
            res = blind(res, blind_mat).detach()
        node_list.append(res)

    if is_blind:
        res = unblind(res, unblind_mat)
    return res