import torch
from projop.utils import *

def halfspace_projection (x, c: torch.tensor, b: torch.float):
    x, c = x[:, None], c[:, None]
    cxb = c.T @ x - b
    return x - plus_fn (cxb) @ c/(torch.norm(c)**2)

def halfspace_projection_multiple (xs, c: torch.tensor, b: torch.float):
    N = xs.shape[0]
    rep_c = c.repeat (N, 1)
    cxb = torch.einsum('ij,ij->i', xs, rep_c) - b
    return xs - (plus_fn (cxb) / (rep_c.norm(dim=1)**2))[:, None] * rep_c

def hyperplane_projection (x, c: torch.tensor, b: torch.float):
    cxb = c.dot(x) - b
    return x - cxb* c/(torch.norm(c)**2)

def hyperplane_projection_multiple (xs, c: torch.tensor, b: torch.float):
    N = xs.shape[0]
    rep_c = c.repeat (N, 1)
    cxb = torch.einsum('ij,ij->i', xs, rep_c) - b
    return xs - (cxb / (rep_c.norm(dim=1)**2))[:, None] * rep_c

def molwt_projection (A, X, weights: torch.tensor, max_weight: torch.float):
    n, f = X.shape[0], X.shape[1]
    x = X.reshape(-1, 1)
    rep_c = weights.repeat(1, n)
    cxb = rep_c @ x - max_weight
    x_proj = x - plus_fn (cxb) @ rep_c/(torch.norm(rep_c)**2)
    return A, x_proj.reshape(n, f)

def molwt_projection_multiple (As, Xs, weights: torch.tensor, max_weight: torch.float):
    N, n, f = Xs.shape
    x = Xs.reshape(N, -1)
    rep_c = weights.repeat(N, n) # [[w_1, w_2, w_3, w_1, w_2, w_3],
    cxb = torch.einsum('ij,ij->i', x, rep_c) - max_weight # dot product
    x_proj = x - (plus_fn (cxb) / (rep_c.norm(dim=1)**2))[:, None] * rep_c
    return As, x_proj.reshape(N, n, f)


def reg_projection (A, X, c_theta: torch.tensor, b_theta: torch.float):
    assert ((A.shape[0] == A.shape[1]) and (X.shape[0] == A.shape[0]) and 
            (c_theta.shape[0] == X.shape[0]*X.shape[1]+A.shape[0]*A.shape[1]))
    n, f = X.shape[0], X.shape[1]
    xa = torch.cat((X.reshape(-1, 1), A.reshape(-1, 1)))
    cxb = c_theta @ xa - b_theta
    xa_proj = xa - plus_fn (cxb) @ c_theta/(torch.norm(c_theta)**2)
    return xa_proj[n*f:].reshape(n, n), xa_proj[:n*f].reshape(n, f)

def reg_projection_multiple (As, Xs, c_theta: torch.tensor, b_theta: torch.float):
    assert ((As.shape[1] == As.shape[2]) and (Xs.shape[1] == As.shape[1]) and 
            (c_theta.shape[0] == Xs.shape[1]*Xs.shape[2]+As.shape[1]*As.shape[2]))
    N, n, f = Xs.shape
    xa = torch.cat((Xs.reshape(N,-1), As.reshape(N,-1)), dim=1)
    rep_c = c_theta.repeat(N, 1)
    cxb = torch.einsum('ij,ij->i', xa, rep_c) - b_theta # dot product
    xa_proj = xa - (plus_fn (cxb) / (rep_c.norm(dim=1)**2))[:, None] * rep_c
    return xa_proj[n*f:].reshape(n, n), xa_proj[:n*f].reshape(n, f)

