import numpy as np
import ot

def test_sk(source_supports, target_supports, source_masses, target_masses, lamda, epsilon):
    C = ot.dist(source_supports, target_supports, metric="sqeuclidean")
    value = ot.sinkhorn2(source_masses, target_masses, C, reg=lamda, stopThr=epsilon)
    return value

def test_sk_ROT(source_supports, target_supports, source_masses, target_masses, lamda, epsilon):
    source_mean = source_supports.T @ source_masses
    target_mean = target_supports.T @ target_masses
    C = ot.dist(source_supports - source_mean, target_supports - target_mean, metric="sqeuclidean")
    value = ot.sinkhorn2(source_masses, target_masses, C, reg=lamda, stopThr=epsilon) + np.linalg.norm(
        source_mean - target_mean) ** 2
    return value

def exact_emd(source_supports, target_supports, source_masses, target_masses):
    C = ot.dist(source_supports, target_supports, metric="sqeuclidean")
    value = ot.emd2(source_masses, target_masses, C)
    return value

# compute deviation
def compute_mean_and_std_following_col(mat):
    mean = [np.mean(mat[:,col])   for col in range(mat.shape[1])]
    std = [np.std(mat[:,col])   for col in range(mat.shape[1])]
    return np.array(mean), np.array(std)
def element_wise_add(arr1, arr2):
    return np.array([x + y for (x,y) in zip(arr1, arr2)])

def generate_dist(name, size):
    if name == "Gaussian":
        std_dev = 1
        return np.random.normal(0, std_dev, size)

    if name == "Uniform":
        left_r = 0
        right_r = 1
        return np.random.uniform(left_r, right_r, size)

    if name == "Poisson":
        # print("Only generate one dimensional data!")
        lamb = 1
        dist = [[x] for x in np.random.poisson(lam=lamb, size=size[0])]
        return np.array(dist)

    if name == "Geometric":
        # print("Only generate one dimensional data!")
        p = 0.5
        dist = [[x] for x in np.random.geometric(p = 0.5, size=size[0])]
        return np.array(dist)

    if name == "Gamma":
        # print("Only generate one dimensional data!")
        shape, scale = 2., 2.
        dist = [[x] for x in np.random.gamma(shape, scale, size=size[0])]
        return np.array(dist)

