import numpy as np
from scipy.linalg import sqrtm, det, inv, circulant
import matplotlib.pyplot as plt


def _get_matrices_unbalanced(cov_a, cov_b, gamma):
    Id = np.eye(len(cov_a))
    lb = gamma / 2
    tau = 1
    cov_at = 0.5 * gamma * (Id - lb * inv(cov_a + lb * Id))
    cov_bt = 0.5 * gamma * (Id - lb * inv(cov_b + lb * Id))
    C = sqrtm(cov_at.dot(cov_bt) / tau )
    Cinv = inv(C)
    F = cov_bt.dot(Cinv)
    G = Cinv.dot(cov_at)

    return C, F, G, cov_at, cov_bt


def closed_form_unbalanced(cov_a, cov_b, mean_a=None, mean_b=None,
                           return_params=False,
                           gamma=None, mass_a=1, mass_b=1):
    d = len(cov_a)
    if mean_a is None:
        mean_a = np.zeros(d)
    if mean_b is None:
        mean_b = np.zeros(d)
    Id = np.eye(d)
    C, F, G, cov_at, cov_bt = _get_matrices_unbalanced(cov_a, cov_b, gamma)
    lb = gamma / 2
    X = cov_a + cov_b + lb * Id
    Xinv = inv(X)
    det_ab = det(cov_a.dot(cov_b))
    det_atbt = det(cov_at.dot(cov_bt))

    diff = mean_a - mean_b
    exp_mass = np.exp(- 0.5 * diff.dot(Xinv.dot(diff)) / (2))
    num = mass_a * mass_b * det(C) * (det_atbt / det_ab) ** 0.5
    num **= 1 / (2)
    num *= exp_mass
    den = det(C - 2 / gamma * cov_at.dot(cov_bt)) ** 0.5
    plan_mass = num / den

    # UOT at optimality depends on the mass of the plan and those of the inputs
    loss = gamma * (mass_a + mass_b - 2 * plan_mass)
    # loss += 2 * s * (mass_a * mass_b - plan_mass)

    if return_params:
        H1 = (Id + C / lb).dot(cov_a - cov_a.dot(Xinv).dot(cov_a))
        H4 = (Id + C.T / lb).dot(cov_b - cov_b.dot(Xinv).dot(cov_b))
        H2 = C + (Id + C / lb).dot(cov_a.dot(Xinv).dot(cov_b))
        H3 = C.T + (Id + C.T / lb).dot(cov_b).dot(Xinv).dot(cov_a)
        plan_cov = np.block([[H1, H2], [H3, H4]])
        plan_mean_a = mean_a + cov_a.dot(Xinv).dot(mean_b - mean_a)
        plan_mean_b = mean_b + cov_b.dot(Xinv).dot(mean_a - mean_b)
        plan_mean = np.concatenate([plan_mean_a, plan_mean_b]).flatten()

        return loss, plan_cov, plan_mean, plan_mass
    return loss


def ghk_loss(mean_a, mean_b, cov_a, cov_b, mass_a, mass_b, gamma):
    loss = closed_form_unbalanced(cov_a=cov_a, cov_b=cov_b, mean_a=mean_a, mean_b=mean_b,
                           return_params=False,
                           gamma=gamma, mass_a=mass_a, mass_b=mass_b)
    return loss



if __name__ == '__main__':
    num_meas = 5
    deterministic = False
    constant_mass = True

    if deterministic:
        list_mean = [np.array([0.0]), np.array([1.0]), np.array([5.0]), np.array([2.0]), np.array([0.5])]
        list_cov = [k * np.ones(shape=(1,1)) for k in [0.5, 0.2, 5.0, 1., 2.]]
    else:
        list_mean = [np.random.normal(size=1) for k in range(num_meas)]
        list_cov = [5 * np.random.uniform(size=(1,1)) for k in range(num_meas)]

    if constant_mass:
        list_mass = [1.0 for i in range(num_meas)]
    else: 
       list_mass = [2 * np.random.uniform(size=(1)) for i in range(num_meas)]


    num = 20
    gamma_min, gamma_max = -4, 4
    gamma_list = np.logspace(gamma_min, gamma_max, num=num)
    eigval_list = np.zeros(shape=(num, num_meas-1))

    # Projection matrix to subspace of vectors summing to zeros
    # build projection
    z = np.zeros(shape=num_meas)
    z[0] = -1
    C = circulant(z)
    C[0,:] = 1.
    proj = C[:, 1:]
    print(proj)
    proj, R = np.linalg.qr(proj)

    for i, gamma in enumerate(gamma_list):
        kernel = np.zeros(shape=(num_meas, num_meas))
        for j in range(num_meas):
            for l in range(j, num_meas):
                loss = ghk_loss(list_mean[j], list_mean[l], list_cov[j], list_cov[l], list_mass[j], list_mass[l], gamma)
                kernel[j,l] = loss
                kernel[l,j] = loss


        # print("KERNEL = ", kernel)
        proj_kernel = proj.T @ kernel @ proj
        # print(proj_kernel)
        eigval_list[i] = np.linalg.eigvalsh(proj_kernel)


    # Make the plot
    t = np.linspace(gamma_min, gamma_max, num=num)
    for j in range(num_meas-1):
        plt.plot(t, eigval_list[:, j], label=f'{j}th')
    plt.plot(t, np.zeros_like(t), c='k', linestyle='dotted')
    plt.legend()
    # plt.show()

    print(np.greater(np.zeros_like(eigval_list), eigval_list))
    # print(eigval_list)
