import torch

from projop.utils import plus_fn

def affine_projection (x, M: torch.tensor, b: torch.tensor):
    Minv = torch.linalg.pinv(M)
    return x - Minv @ (M @ x - b)

def valency_projection (A: torch.tensor, X: torch.tensor, valencies: torch.tensor, hidden_Hs=True):
    # For Molecular data only with given atom valencies
    assert ((A.shape[0] == A.shape[1]) and (X.shape[0] == A.shape[0]) and 
            (X.shape[1] == valencies.shape[0]))
    Xnorm = torch.div (X.T, X.sum(dim=1)).T
    Xnorm = torch.nan_to_num(Xnorm, nan=0)
    Xsmpld = torch.zeros_like(X)
    Xsmpld[range(X.shape[0]), X.argmax(dim=1)] = 1
    atoms_exist = torch.cat([torch.any(Xnorm[i] > 0.5).ravel() for i in range(X.shape[0])])
    A_n = A[atoms_exist][:, atoms_exist] if hidden_Hs else A
    X_n = Xnorm[atoms_exist] if hidden_Hs else Xnorm
    wtd_vals = X_n @ valencies[:, None] # for each node we now have a weighted valency
    n = A_n.shape[0]
    a = A_n.reshape(-1, 1)
    M_val = torch.zeros(n, n**2, dtype=a.dtype, device=a.device)
    for i in range(M_val.shape[0]):
        M_val[i, i*n:(i+1)*n] = 1
    if hidden_Hs:
        a_proj = a - M_val.T @ plus_fn (M_val @ a - wtd_vals)/n # since Minv = I/n
        A_proj = A.clone()
        A_proj[torch.outer(atoms_exist, atoms_exist)] = a_proj.reshape(-1)
        # print (A_proj[atoms_exist][:, atoms_exist].reshape(-1), a_proj.reshape(-1))
        # print (A_proj[atoms_exist][:, atoms_exist].reshape(-1), a_proj.reshape(-1))
        # print (A_proj[atoms_exist][:, atoms_exist].sum(dim=1), wtd_vals + 0.01*wtd_vals)
        # print (torch.all(M_val @ a_proj <= wtd_vals + 0.01*wtd_vals))
        # if (not (torch.all(A_proj[atoms_exist][:, atoms_exist].sum(dim=1) <= (wtd_vals + 0.01*wtd_vals).reshape(-1)))):
        #     print (A_proj[atoms_exist][:, atoms_exist].sum(dim=1) <= (wtd_vals + 0.01*wtd_vals).reshape(-1))
        #     exit()
    else:
        a_proj = a - M_val.T @ (M_val @ a - wtd_vals)/n # since Minv = I/n
        A_proj = a_proj.reshape(n, n)
    return A_proj, Xnorm

def valency_projection_multiple (As: torch.tensor, Xs: torch.tensor, valencies: torch.tensor, hidden_Hs=True):
    # For Molecular data only with given atom valencies
    assert ((As.shape[0] == Xs.shape[0]) and (As.shape[1] == As.shape[2]) and 
            (Xs.shape[1] == As.shape[1]) and (Xs.shape[2] == valencies.shape[0]))
    # print (torch.any(Xs.isnan()), torch.any(As.isnan()))
    Xsnorm = Xs / Xs.sum(keepdim=True, dim=2)
    Xsnorm = torch.nan_to_num(Xsnorm, nan=0)
    # Xsmpld = torch.zeros_like(Xs).scatter(dim=2, index=Xsnorm.argmax(keepdim=True, dim=2), src=1)
    # wtd_vals = torch.matmul (Xs, valencies[:, None])
    wtd_vals = torch.matmul (Xsnorm, valencies[:, None])
    N, n = As.shape[0], As.shape[1]
    a = As.reshape(N, -1, 1)
    M_val = torch.zeros(N, n, n**2, dtype=a.dtype, device=a.device)
    for i in range(M_val.shape[1]):
        M_val[:, i, i*n:(i+1)*n] = 1
    if hidden_Hs:
        atoms_exist = torch.any(Xsnorm > 0.5, dim=2, keepdim=True).squeeze()
        # torch.save(atoms_exist, "atoms_exist_proj.pt")
        # X_atoms_exist = atoms_exist.repeat(1, 1, Xs.shape[2])
        a_atoms_exist = torch.einsum ('ij,ik->ijk', atoms_exist, atoms_exist).reshape(N, -1)
        a_atoms0, a_atoms2 = torch.where(~a_atoms_exist)
        M_val[~atoms_exist] = 0
        M_val[a_atoms0, :, a_atoms2] = 0
        # since Minv = I/n
        num_atoms = atoms_exist.sum(dim=1)[:, None, None]
        as_proj = a - torch.div(torch.matmul(torch.transpose(M_val, 1, 2), 
                                             plus_fn (M_val @ a - wtd_vals)),
                                torch.where(num_atoms == 0, torch.ones_like(num_atoms), num_atoms))
    else:
        as_proj = a - torch.transpose(M_val, 1, 2) @ (M_val @ a - wtd_vals) / n
    As_proj = as_proj.reshape(N, n, n)
    return As_proj, Xsnorm

def atomCount_projection (A: torch.tensor, X: torch.tensor, counts: torch.tensor):
    # For Molecular data only with given atom valencies
    assert ((A.shape[0] == A.shape[1]) and (X.shape[0] == A.shape[0]) and 
            (X.shape[1] == counts.shape[0]))
    n, f = A.shape[0], X.shape[1]
    atom_counts = counts[:, None] # for each node we now have a weighted count
    x = X.reshape(-1, 1)
    M_count = torch.zeros(f, n * f, dtype=x.dtype, device=x.device)
    for i in range(M_count.shape[0]):
        M_count[i, range(i, n*f, f)] = 1
    x_proj = x - M_count.T @ plus_fn (M_count @ x - atom_counts)/n # since Minv = I/n
    return A, x_proj.reshape(n, f)

def atomCount_projection_multiple (As: torch.tensor, Xs: torch.tensor, counts: torch.tensor):
    # For Molecular data only with given atom valencies
    assert ((As.shape[0] == Xs.shape[0]) and (As.shape[1] == As.shape[2]) and 
            (Xs.shape[1] == As.shape[1]) and (Xs.shape[2] == counts.shape[0]))
    N, n, f = As.shape[0], As.shape[1], Xs.shape[2]
    atom_counts = counts[:, None] # for each node we now have a weighted count
    x = Xs.reshape(N, -1, 1)
    M_count = torch.zeros(N, f, n * f, dtype=x.dtype, device=x.device)
    for i in range(M_count.shape[1]):
        M_count[:, i, range(i, n*f, f)] = 1
    x_proj = x - torch.transpose(M_count, 1, 2) @ plus_fn (M_count @ x - atom_counts)/n # since Minv = I/n
    return As, x_proj.reshape(N, n, f)