import torch

@torch.no_grad()  
def first_order_EIF(
    e: torch.Tensor,
    A: torch.Tensor,
    u1: torch.Tensor,
    u2: torch.Tensor,
    P_yax: torch.Tensor,
    e_tol: float = 1e-8,
) -> torch.Tensor:
    """

    Parameters
    ----------
    e : (n,) torch.Tensor
    A : (n,) torch.Tensor, values in {0,1}
    u1 : (n,) torch.Tensor
    u2 : (n,) torch.Tensor
    P_yz : (n, 2, n) torch.Tensor
    eps : float

    Returns
    -------
    out : (n,) torch.Tensor
    """

    # Ensure shapes
    e = e.reshape(-1)
    A = A.reshape(-1).long()
    u1 = u1.reshape(-1)
    u2 = u2.reshape(-1)

    n = P_yax.shape[-1]

    # Positivity clipping (in-place safe alternative)
    e = torch.clamp(e, e_tol, 1.0 - e_tol)

    # Inverse propensity weights
    ipw1 = A / e
    ipw0 = (1.0 - A) / (1.0 - e)

    P_y_xa_u1 = torch.einsum('ijk,i->jk', P_yax, u1)
    P_y_xa_u2 = torch.einsum('ijk,i->jk', P_yax, u2)

    # Matrix–vector products
    yz_u1 = P_y_xa_u1[torch.abs(1-A), torch.arange(n, device=u1.device)]             
    yz_u2 = P_y_xa_u2[torch.abs(1-A), torch.arange(n, device=u1.device)]
    p1_u1 = P_y_xa_u1[0, :]      # (n,)
    p0_u2 = P_y_xa_u1[1, :]      # (n,)

    # Mean centering
    p1_u1_c = p1_u1 - p1_u1.mean()
    p0_u2_c = p0_u2 - p0_u2.mean()
    
    out = (
        (ipw1 * (u1 - yz_u1) + p1_u1_c)
        + (ipw0 * (u2 - yz_u2) + p0_u2_c)
    )

    return out


@torch.no_grad()  
def second_order_EIF_operator(
    omega: torch.Tensor,  # (2, n)
    K: torch.Tensor,      # (n, n)
    A: torch.Tensor,      # (n,) in {0,1}
    P_yax: torch.Tensor,  # (n, 2, n)
    P_ax: torch.Tensor,   # (n, 2)
) -> torch.Tensor:

    device = K.device
    omega = omega.to(device)
    A = A.to(device)
    P_yax = P_yax.to(device)
    P_ax = P_ax.to(device)

    n = A.shape[0]
    idx_n  = torch.arange(n, device=device)     
    useful_omega = omega[1-A.long(), idx_n]            

    # g[0,0,k,l] = useful_omega[k] * f[k,l]
    g = (useful_omega[:, None] * K)[None, None]   # (1, 1, n, n)

    # P_yax_f[a,k,l] = sum_i P_yax[i,a,k] * f[i,l]
    P_yax_f = torch.einsum('iak,il->akl', P_yax, K)    # (2, n, n)
    P_yax_f = P_yax_f.unsqueeze(0)                     # (1, 2, n, n)

    # h_full[0,a,k,l] = omega[a,k] * P_yax_f[0,a,k,l]
    h_full = omega[None, :, :, None] * P_yax_f         # (1, 2, n, n)

    # i[0,k,l] = sum_a P_ax[k,a] * h_full[0,a,k,l]
    i = torch.einsum('jk,ojkl->okl', P_ax, h_full).unsqueeze(1)  # (1, 1, n, n)

    # P_x_i[0,0,1,l] is mean over k; shape kept as (1,1,1,n) per your original
    P_x_i = i.mean(dim=2, keepdim=True)               # (1, 1, 1, n)

    # h_sel[0,0,k,l] = h_full[0, notA[k], k, l]
    gather_idx = (1-A.long())[None, None, :, None].expand(1, 1, n, n)  # (1, 1, n, n)
    h_sel = torch.gather(h_full, dim=1, index=gather_idx)      # (1, 1, n, n)

    out = (g - h_sel + i - P_x_i)[0, 0]               # (n, n)
    return out