import torch, numpy as np

def get_eqsw_projections(L, device):
    sob = torch.quasirandom.SobolEngine(dimension=2, scramble=False)
    net = sob.draw(L).to(device)
    alpha, tau = net[:, [0]], net[:, [1]]
    theta = torch.cat([
        2 * torch.sqrt((tau - tau**2).clamp_min(0)) * torch.cos(2*np.pi*alpha),
        2 * torch.sqrt((tau - tau**2).clamp_min(0)) * torch.sin(2*np.pi*alpha),
        1 - 2 * tau
    ], dim=1)
    return theta

def get_gqsw_projections(L, device):
    sob = torch.quasirandom.SobolEngine(dimension=3, scramble=False)
    u = sob.draw(L).clamp(1e-6, 1-1e-6).to(device)
    # Gaussian PPF via erfinv: Φ^{-1}(u) = sqrt(2)*erfinv(2u-1)
    z = torch.sqrt(torch.tensor(2.0, device=device)) * torch.erfinv(2*u - 1)
    theta = z / z.norm(dim=1, keepdim=True).clamp_min(1e-12)
    return theta

def get_sqsw_projections(L, device):
    i = torch.arange(1, L+1, device=device, dtype=torch.float32).view(-1,1)
    Z = 1 - (2*i - 1)/L
    th1 = torch.acos(Z)
    th2 = torch.remainder(1.8*torch.sqrt(torch.tensor(float(L), device=device))*th1, 2*np.pi)
    return torch.cat([torch.sin(th1)*torch.cos(th2),
                      torch.sin(th1)*torch.sin(th2),
                      torch.cos(th1)], dim=1)

def get_dqsw_projections(L, device, iters=100, lr=1.0):
    theta = get_sqsw_projections(L, device).clone().requires_grad_(True)
    opt = torch.optim.SGD([theta], lr=lr)
    for _ in range(iters):
        opt.zero_grad()
        loss = -torch.cdist(theta, theta, p=1).mean()
        loss.backward(); opt.step()
        theta.data = theta.data / theta.data.norm(dim=1, keepdim=True).clamp_min(1e-12)
    return theta.detach()

def get_cqsw_projections(L, device, iters=100, lr=1.0):
    theta = get_sqsw_projections(L, device).clone().requires_grad_(True)
    opt = torch.optim.SGD([theta], lr=lr)
    for _ in range(iters):
        opt.zero_grad()
        dist = torch.cdist(theta, theta, p=1) + 1e-6
        loss = (1.0 / dist).mean()
        loss.backward(); opt.step()
        theta.data = theta.data / theta.data.norm(dim=1, keepdim=True).clamp_min(1e-12)
    return theta.detach()
