import torch
import numpy as np
from scipy.stats import norm
def rand_projections(dim, num_projections=1000,device='cpu'):
    projections = torch.randn((num_projections, dim),device=device)
    projections = projections / torch.sqrt(torch.sum(projections ** 2, dim=1, keepdim=True))
    return projections


def one_dimensional_Wasserstein_prod(X,Y,theta,p):
    X_prod = torch.matmul(X, theta.transpose(0, 1))
    Y_prod = torch.matmul(Y, theta.transpose(0, 1))
    X_prod = X_prod.view(X_prod.shape[0], -1)
    Y_prod = Y_prod.view(Y_prod.shape[0], -1)
    wasserstein_distance = torch.abs(
        (
                torch.sort(X_prod, dim=0)[0]
                - torch.sort(Y_prod, dim=0)[0]
        )
    )
    wasserstein_distance = torch.mean(torch.pow(wasserstein_distance, p), dim=0,keepdim=True)
    return wasserstein_distance

def SW(X, Y, L=10, p=2, device="cpu"):
    dim = X.size(1)
    theta = rand_projections(dim, L,device)
    sw=one_dimensional_Wasserstein_prod(X,Y,theta,p=p).mean()
    return  sw

def QSW(X, Y, L=10, p=2, type='normal', randomized=True, device="cpu"):
    if(type=='nqsw'):
        soboleng = torch.quasirandom.SobolEngine(dimension=3, scramble=randomized)
        theta = soboleng.draw(L)
        theta = torch.clamp(theta, min=1e-6, max=1 - 1e-6)
        theta = torch.from_numpy(norm.ppf(theta) + 1e-6).float()
        theta = theta / torch.sqrt(torch.sum(theta ** 2, dim=1, keepdim=True)).to(device)
    elif(type=='qsw'):
        soboleng = torch.quasirandom.SobolEngine(dimension=2,scramble=randomized)
        net = soboleng.draw(L)
        alpha = net[:,[0]]
        tau = net[:,[1]]
        theta = torch.cat([2*torch.sqrt(tau-tau**2) *torch.cos(2*np.pi*alpha), 2*torch.sqrt(tau-tau**2) *torch.sin(2*np.pi*alpha),1-2*tau],dim=1).to(device)
    elif(type=='sqsw'):
        Z = (1 - (2 * torch.arange(1, L + 1) - 1) / L).view(-1, 1)
        theta1 = torch.arccos(Z)
        theta2 = torch.remainder(1.8 * np.sqrt(L) * theta1,2*np.pi)
        theta = torch.cat([torch.sin(theta1) * torch.cos(theta2), torch.sin(theta1) * torch.sin(theta2), torch.cos(theta1)], dim=1)
        # if(randomized):
        #     U = torch.linalg.qr(torch.randn(3,3))[0]
        #     theta = torch.matmul(theta,U)
        theta =theta.to(device)
    elif(type=='odqsw'):
        Z = (1 - (2 * np.arange(1, L + 1) - 1) / L).reshape(-1, 1)
        theta1 = np.arccos(Z)
        theta2 = np.mod(1.8 * np.sqrt(L) * theta1, 2 * np.pi)
        thetas = np.concatenate([np.sin(theta1) * np.cos(theta2), np.sin(theta1) * np.sin(theta2), np.cos(theta1)],
                                axis=1)
        theta0 = torch.from_numpy(thetas)
        thetas = torch.randn(L, 3, requires_grad=True,device='cuda')
        thetas.data = theta0
        optimizer = torch.optim.SGD([thetas], lr=1)
        for _ in range(10):
            loss = - torch.cdist(thetas, thetas, p=1).mean()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            thetas.data = thetas.data / torch.sqrt(torch.sum(thetas.data ** 2, dim=1, keepdim=True))
        theta = thetas.to(device).float()
    elif (type == 'ocqsw'):
        Z = (1 - (2 * np.arange(1, L + 1) - 1) / L).reshape(-1, 1)
        theta1 = np.arccos(Z)
        theta2 = np.mod(1.8 * np.sqrt(L) * theta1, 2 * np.pi)
        thetas = np.concatenate([np.sin(theta1) * np.cos(theta2), np.sin(theta1) * np.sin(theta2), np.cos(theta1)],
                                axis=1)
        theta0 = torch.from_numpy(thetas)
        thetas = torch.randn(L, 3, requires_grad=True,device='cuda')
        thetas.data = theta0
        optimizer = torch.optim.SGD([thetas], lr=1)
        for _ in range(10):
            loss = (1 / (torch.cdist(thetas, thetas, p=1) + 1e-6)).mean()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            thetas.data = thetas.data / torch.sqrt(torch.sum(thetas.data ** 2, dim=1, keepdim=True))
        theta = thetas.to(device).float()
    sw=one_dimensional_Wasserstein_prod(X,Y,theta,p=p).mean()
    return  sw




