import torch

def estimate_dW2(X, Y, W1, B0, W2, ridge_lambda, k ):
    with torch.no_grad():
        T = Y - B0
        T_bar = torch.mean(T, dim=0)
        T_minus_T_bar = T - T_bar
        a = torch.nn.functional.relu(W1 @ X.t()).t()
        a_bar = torch.mean(a, dim=0)
        a_minus_a_bar = a - a_bar
        target = T_minus_T_bar - a_minus_a_bar @ W2.t()
        predictor = a_minus_a_bar
        Z2 = torch.linalg.inv(predictor.t() @ predictor + ridge_lambda) @ predictor.t() @ target

        Y_hat = predictor @ Z2  # (n,l3)
        M = Y_hat.t() @ Y_hat
        _, _, VT_hat = torch.linalg.svd(M)
        VT_hat = VT_hat[:k, :]
        P = VT_hat.t() @ VT_hat

        Z2_k = Z2 @ P
        updated_W2 = W2 + Z2_k.t()
        Z0 = (torch.mean(Y - B0 - a @ updated_W2.t(), dim=0)).view(-1)
        return updated_W2, Z0

def estimate_dW1(X, Y, W1, k,  device):
    n, l1 = X.size()
    n, l3 = Y.size()
    with torch.no_grad():
        # Standardize
        x_bar = torch.mean(X, dim=0)
        X_dm = X - x_bar

        cov = (X_dm.t() @ X_dm) / n
        prec = torch.linalg.inv(cov)
        B = prec @ X_dm.t()

        # Stack-SVD
        stack_list = list()
        for i in range(l3):
            Yi = Y[:, i].view(-1, 1)
            EYSx = (B @ (Yi * B.t()) - torch.sum(Yi) * prec) / n
            stack_list.append(EYSx)
        EYSx_stacked = torch.concat(stack_list, dim=1)
        V, S, _ = torch.linalg.svd(EYSx_stacked)

        Var = (S ** 2).cpu()
        Var_per = Var / torch.sum(Var)
        cum_per = torch.cumsum(Var_per, dim=0)
        best_rank = torch.argmax((cum_per >= 0.75).float()).item() + 1


        if best_rank == l1:
            W1_copy = W1.detach()
            updated_W1 = W1_copy
            Z1_k = updated_W1 - W1_copy
        else:
            Vk = V.detach()[:, :best_rank]
            W1_copy = W1.detach()
            ImP = torch.eye(l1, device=device) - Vk @ Vk.t()
            Z1 = -W1_copy @ ImP
            A, B, CT = torch.linalg.svd(Z1)
            Z1_k = A[:, :k] @ torch.diag(B[:k]) @ CT[:k, :]
            updated_W1 = W1_copy + Z1_k
    return updated_W1
