"""
Validation of the implementation of the W2 distance.
"""
import torch
import numpy as np
import ot
# Define the functions
def quantile(samples_sorted):
    def quantile_func(p):
        return samples_sorted[torch.floor(p * len(samples_sorted)).long()]
    return quantile_func

def W22(u_samples, v_samples):
    u_samples_sorted, _ = u_samples.sort()
    v_samples_sorted, _ = v_samples.sort()
    u_icdf_grids = torch.linspace(0, 1, steps=len(u_samples)+1)
    v_icdf_grids = torch.linspace(0, 1, steps=len(v_samples)+1)
    grids = torch.unique(torch.cat((u_icdf_grids, v_icdf_grids))).sort()[0]
    U_icdf = quantile(u_samples_sorted)(grids[:-1])
    V_icdf = quantile(v_samples_sorted)(grids[:-1])
    return torch.sum((U_icdf - V_icdf) ** 2 * torch.diff(grids))
            

def W2_distance(u_samples, v_samples):
    W2 = ot.emd2_1d(u_samples, v_samples)
    return W2 

def W2_distance_loss(ys, ys_truth):
    loss_cul = 0
    t_size = ys.shape[0]
    for t in range(0, t_size):
        loss_cul += W22(ys[t,:,0], ys_truth[t,:,0]) 
    return loss_cul


def W_2_distance(ys, ys_truth):
    criterion = torch.nn.MSELoss(reduction='sum')
    ys_truth_sort = torch.zeros(ys_truth.shape[0])
    bs = torch.tensor([float(ys[j]) for j in range(ys.shape[0])]).sort().indices
    ys_truth_re = torch.tensor([float(ys_truth[j]) for j in range(ys_truth.shape[0])])
    bs_truth = ys_truth_re.sort().indices
    for j in range(ys_truth.shape[0]):
            ys_truth_sort[int(bs[j])] = ys_truth[int(bs_truth[j])]
    loss = criterion(ys, ys_truth_sort) * 1.0 / ys.shape[0]
    return loss

def W_2_loss(ys, ys_truth):
    loss_cul = 0
    t_size = ys.shape[0]
    for t in range(0, t_size):
        loss_cul += W_2_distance(ys[t,:,0], ys_truth[t,:,0])
    return loss_cul

def W2_mingtao_loss(ys, ys_truth):
    criterion = torch.nn.MSELoss(reduction='sum')
    ys_truth_sort = torch.zeros(ys_truth.shape[0], ys_truth.shape[1], ys_truth.shape[2])
    for i in range(ys_truth.shape[0]):
        bs = torch.tensor([float(ys[i][j][0]) for j in range(ys[i].shape[0])]).sort().indices
        ys_truth_re = torch.tensor([float(ys_truth[i][j][0]) for j in range(ys_truth[i].shape[0])])
        bs_truth = ys_truth_re.sort().indices
        for j in range(ys_truth.shape[1]):
                ys_truth_sort[i,int(bs[j]),0] = ys_truth[i,int(bs_truth[j]),0]
    loss = criterion(ys, ys_truth_sort) / ys.shape[1]
    return loss

# Generate random data
torch.manual_seed(0)
np.random.seed(0)
u_samples_torch = torch.randn(1000)
v_samples_torch = torch.randn(1000)

ys = torch.randn(41,1000,1)
ys_truth = torch.randn(41,1000,1)
print(W2_mingtao_loss(ys, ys_truth))
print(W2_distance_loss(ys, ys_truth))
print(W_2_loss(ys, ys_truth))
# Compute the W2 distance using both functions
W22_value = W22(u_samples_torch, v_samples_torch)
W2_distance_value = W2_distance(u_samples_torch, v_samples_torch)
W_2_distance_value = W_2_distance(u_samples_torch, v_samples_torch)

W22_value, W2_distance_value, W_2_distance_value
# (tensor(0.0094), tensor(0.0094), tensor(0.0094))
import timeit
# Compare the speed of the two functions
print(timeit.timeit(lambda: W22(u_samples_torch, v_samples_torch), number=1000))
print(timeit.timeit(lambda: W2_distance(u_samples_torch, v_samples_torch), number=1000))
print(timeit.timeit(lambda: W_2_distance(u_samples_torch, v_samples_torch), number=1000))

u_samples_torch = torch.randn(1000)
v_samples_torch = torch.randn(1000)


print(W22(u_samples_torch, v_samples_torch))
print(W2_distance(u_samples_torch, v_samples_torch))
print(W_2_distance(u_samples_torch, v_samples_torch))
print(timeit.timeit(lambda: W22(u_samples_torch, v_samples_torch), number=100))
print(timeit.timeit(lambda: W2_distance(u_samples_torch, v_samples_torch), number=100))
print(timeit.timeit(lambda: W_2_distance(u_samples_torch, v_samples_torch), number=100))

# tensor(0.0088)
# tensor(0.0088)
# tensor(0.0088)
# 0.04397249998874031
# 0.04694950001430698
# 0.9402489000058267

u_samples_torch = torch.randn(500)
v_samples_torch = torch.randn(1000)


print(W22(u_samples_torch, v_samples_torch))
print(W2_distance(u_samples_torch, v_samples_torch))
print(timeit.timeit(lambda: W22(u_samples_torch, v_samples_torch), number=1000))
print(timeit.timeit(lambda: W2_distance(u_samples_torch, v_samples_torch), number=1000))

u_samples_torch = torch.randn(1001)
v_samples_torch = torch.randn(293)
print(W22(u_samples_torch, v_samples_torch))
print(W2_distance(u_samples_torch, v_samples_torch))
print(timeit.timeit(lambda: W22(u_samples_torch, v_samples_torch), number=1000))
print(timeit.timeit(lambda: W2_distance(u_samples_torch, v_samples_torch), number=1000))