import numpy as np
import numba
from numba import jit
import scipy as sp  # for sparse matrices
import scipy.sparse.linalg as sln
from scipy.sparse import csr_matrix
from tqdm import tqdm


@jit(nopython=True)
def solve_ot(a, b, x, y, p):
    """Computes the 1D Optimal Transport between two histograms.

    _Important: one should have np.sum(a)=np.sum(b)._

    _Important:_ x and y needs to be sorted.

    Parameters
    ----------
    a: vector of length n with positive entries

    b: vector of length m with positive entries

    x: vector of real of length n

    y: vector of real of length m

    p: real, should >= 1


    Returns
    ----------
    I: vector of length q=n+m-1 of increasing integer in {0,...,n-1}

    J: vector of length q of increasing integer in {0,...,m-1}

    P: vector of length q of positive values of length q

    f: dual vector of length n

    g: dual vector of length m

    cost: (dual) OT cost
        sum a_i f_i + sum_j b_j f_j
        It should be equal to the primal cost
        = sum_k |x(i)-y(j)|^p where i=I(k), j=J(k)
    """
    n = len(a)
    m = len(b)
    q = m + n - 1
    a1 = a.copy()
    b1 = b.copy()
    I = np.zeros(q).astype(numba.int64)
    J = np.zeros(q).astype(numba.int64)
    P = np.zeros(q)
    f = np.zeros(n)
    g = np.zeros(m)
    g[0] = np.abs(x[0] - y[0]) ** p
    for k in range(q - 1):
        i = I[k]
        j = J[k]
        if (a1[i] < b1[j]) and (i < n - 1):
            I[k + 1] = i + 1
            J[k + 1] = j
            f[i + 1] = np.abs(x[i + 1] - y[j]) ** p - g[j]
        elif (a1[i] > b1[j]) and (j < m - 1):
            I[k + 1] = i
            J[k + 1] = j + 1
            g[j + 1] = np.abs(x[i] - y[j + 1]) ** p - f[i]
        elif i == n - 1:
            I[k + 1] = i
            J[k + 1] = j + 1
            g[j + 1] = np.abs(x[i] - y[j + 1]) ** p - f[i]
        elif j == m - 1:
            I[k + 1] = i + 1
            J[k + 1] = j
            f[i + 1] = np.abs(x[i + 1] - y[j]) ** p - g[j]
        t = min(a1[i], b1[j])
        P[k] = t
        a1[i] = a1[i] - t
        b1[j] = b1[j] - t
    P[k + 1] = max(a1[-1], b1[-1])  # remaining mass
    cost = np.sum(f * a) + np.sum(g * b)
    return I, J, P, f, g, cost


def logsumexp(f, a, stable_lse=True):
    if not stable_lse:
        return np.log(np.sum(a * np.exp(f)))
    else:
        xm = np.amax(f + np.log(a))
        return xm + np.log(np.sum(np.exp(f + np.log(a) - xm)))


def rescale_potentials(f, g, a, b, rho1, rho2=None, stable_lse=True):
    if rho2 is None:
        rho2 = rho1
    tau = (rho1 * rho2) / (rho1 + rho2)
    transl = tau * (logsumexp(-f / rho1, a, stable_lse=stable_lse) -
                    logsumexp(-g / rho2, b, stable_lse=stable_lse))
    return transl


def primal_dual_gap(a, b, x, y, p, f, g, P, I, J, rho1, rho2=None):
    if rho2 is None:
        rho2 = rho1
    prim = np.sum(P * np.abs(x[I] - y[J]) ** p)
    dual = np.sum(f * np.exp(-f / rho1) * a) \
           + np.sum(g * np.exp(-g / rho2) * b)
    return prim - dual


def solve_uot(a, b, x, y, p, rho1, rho2=None, niter=1000, tol=1e-10,
              stable_lse=True):
    if rho2 is None:
        rho2 = rho1

    # Initialize potentials
    f, g = np.zeros_like(a), np.zeros_like(b)

    for k in (range(niter)):
        # Output FW descent direction
        transl = rescale_potentials(f, g, a, b, rho1, rho2,
                                    stable_lse=stable_lse)
        f, g = f + transl, g - transl
        A, B = a * np.exp(-f / rho1), b * np.exp(-g / rho2)
        I, J, P, fd, gd, cost = solve_ot(A, B, x, y, p)

        # Line search - convex update
        t = 2. / (2. + k)
        f = f + t * (fd - f)
        g = g + t * (gd - g)

        pdg = primal_dual_gap(a, b, x, y, p, f, g, P, I, J, rho1, rho2=None)
        if pdg < tol:
            break

    # Last iter before output
    transl = rescale_potentials(f, g, a, b, rho1, rho2,
                                stable_lse=stable_lse)
    f, g = f + transl, g - transl
    A, B = a * np.exp(-f / rho1), b * np.exp(-g / rho2)
    I, J, P, _, _, cost = solve_ot(A, B, x, y, p)
    return I, J, P, f, g, cost


def normalize(x):
    return x / np.sum(x)


def generate_random_measure(n):
    a = normalize(np.random.uniform(size=n))
    x = np.sort(np.random.uniform(size=n))
    return a, x


if __name__ == '__main__':
    size_meas = 15
    num_meas = 6
    p = 2.
    rho = 1e-1
    constant_mass = False
    # np.random.seed(43)

    # build projection
    z = np.zeros(shape=num_meas)
    z[0] = -1
    C = sp.linalg.circulant(z)
    C[0,:] = 1.
    C = C[:, 1:]

    # generate and store measures
    list_meas = []
    for k in range(num_meas):
        a, x = generate_random_measure(size_meas)
        x = x - np.mean(x)
        shift = np.random.normal()
        if constant_mass:
            mass = 1.
        else:
            mass = np.random.uniform(low=0.2, high=20.)
        list_meas.append([mass * a, x + 1 * shift])
    
    kernel = np.zeros(shape=(num_meas,num_meas))
    for i in range(num_meas):
        for j in range(i, num_meas):
            [a, x] = list_meas[i]
            [b, y] = list_meas[j]
            _, _, _, _, _, cost = solve_uot(a, b, x, y, p, rho)
            kernel[i,j] = cost
            kernel[j,i] = cost
    
    # kernel = np.sqrt(kernel)
    kernel = C.T @ kernel @ C
    print("Kernel matrix", kernel)
    print('\n')
    eigval = np.linalg.eigvalsh(kernel)
    print("Eigenvalues of UOT kernel matrix", eigval)

    exp_ker = np.exp(-kernel / (1e3))
    eigval = np.linalg.eigvalsh(exp_ker)
    print("Eigenvalues of UOT kernel matrix", eigval)
    print('\n\n')

    print("Summing kernel against random vectors:")
    Ntry = 10000
    s = np.zeros(shape=Ntry)
    for i in range(Ntry):
        z = np.random.normal(size=num_meas-1)
        z = z - np.mean(z)
        s[i] = np.sum(z[:,None] * z[None,:] * kernel)
    print("Ratio of strictly positive values = ", np.mean((s>0)))