import jax
from ot_jax.optimal_transport.jax_math import (
    coordiantes,
    cost_matrix,
    min_pool,
    sum_pool,
)
from ot_jax.optimal_transport.jax_transport import (
    solve_scaled_ot,
    c_transform,
    transport_cost_dual,
    transport_cost_primal_sparse,
    upscale_potential,
    upscale_coupling,
    proportional_fitting,
    solve_regularized_ot,
    weighted_average_pool,
    network_simplex_ot,
    center_weights,
    weighted_total_variation,
)
from functools import reduce
from operator import mul



def bilevel_lower_bound(
    x: jax.Array,
    y: jax.Array,
    p: float = 2,
    scale_factor: int = 4,
) -> float:
    """
    Compute the Lower Bound to Wasserstein distance between two measures using the Bilevel approximation.
    """
    coarse_ot = solve_scaled_ot(x, y, scale_factor, p)
    cost = cost_matrix(coordiantes(x.squeeze().shape), p)
    upscaled_potential = upscale_potential(coarse_ot.potentials[0], x)
    g = c_transform(
        upscaled_potential, cost
    )  # For large arrays, this is extremely slow on CPU
    f = c_transform(g, cost)  # For large arrays, this is extremely slow on CPU
    return transport_cost_dual(x.flatten(), y.flatten(), f, g) ** (1 / p)


def min_cost_lower_bound(
    x: jax.Array,
    y: jax.Array,
    p: float = 2,
    scale_factor: int = 4,
) -> float:
    """
    Compute the Lower Bound to Wasserstein distance between two measures using the Min-Cost approximation.
    """
    shape = x.squeeze().shape
    coarse_shape = tuple(dim // scale_factor for dim in shape)
    coarse_size = reduce(mul, coarse_shape)
    min_cost = min_pool(
        cost_matrix(coordiantes(shape), p).reshape(shape * 2), scale_factor
    ).reshape(coarse_size, coarse_size)
    coarse_x = sum_pool(x, scale_factor)
    coarse_y = sum_pool(y, scale_factor)
    min_ot = network_simplex_ot(coarse_x, coarse_y, min_cost)
    return min_ot.value ** (1 / p)


def bilevel_upper_bound(
    x: jax.Array,
    y: jax.Array,
    p: float = 2,
    scale_factor: int = 4,
    threshold: float = 1e-6,
) -> float:
    """
    Compute the Upper Bound to Wasserstein distance between two measures using the Bilevel approximation.
    """
    coarse_ot = solve_scaled_ot(x, y, scale_factor, p)
    shape = x.squeeze().shape
    coarse_shape = tuple(dim // scale_factor for dim in shape)
    upscaled_coupling = upscale_coupling(
        coarse_ot.coupling.reshape(coarse_shape * 2), scale_factor
    )
    fitted_coupling = proportional_fitting(
        x.flatten(), y.flatten(), upscaled_coupling, threshold
    )
    w = transport_cost_primal_sparse(fitted_coupling, shape=shape, p=p) ** (1 / p)
    tv_weights = center_weights(shape, p)
    x_wtv = weighted_total_variation(x.flatten(), fitted_coupling.sum(axis=1).todense(), tv_weights,  p)
    y_wtv = weighted_total_variation(y.flatten(), fitted_coupling.sum(axis=0).todense(), tv_weights, p)
    return w + x_wtv + y_wtv


def weighted_cost_upper_bound(
    x: jax.Array,
    y: jax.Array,
    p: float = 2,
    scale_factor: int = 4,
) -> float:
    """
    Compute the Upper Bound to Wasserstein distance between two measures using Weighted Average Pooling.
    """
    shape = x.squeeze().shape
    costs = cost_matrix(coordiantes(shape), p)
    weighted_cost = weighted_average_pool(costs, x, y, scale_factor)
    coarse_x = sum_pool(x, scale_factor)
    coarse_y = sum_pool(y, scale_factor)
    wighted_ot = network_simplex_ot(coarse_x, coarse_y, weighted_cost)
    return wighted_ot.value ** (1 / p)


def entropy_upper_bound(
    x: jax.Array,
    y: jax.Array,
    p: float = 2,
    epsilon_factor: float = 1e-2,
) -> float:
    if epsilon_factor:
        N: int = x.shape[-1]
        epsilon_factor *= N ** p
    regularized_ot = solve_regularized_ot(x.squeeze(), y.squeeze(), epsilon=epsilon_factor, p=p)
    w = regularized_ot.primal_cost ** (1 / p)
    tv_weights = center_weights(x.squeeze().shape, p)
    x_wtv = weighted_total_variation(x.flatten(), regularized_ot.marginal(axis=1), tv_weights, p)
    y_wtv = weighted_total_variation(y.flatten(), regularized_ot.marginal(axis=0), tv_weights, p)
    return w + x_wtv + y_wtv

def entropy_lower_bound(
    x: jax.Array,
    y: jax.Array,
    p: float = 2,
    epsilon_factor: float = 1e-2,
) -> float:
    if epsilon_factor:
        N: int = x.shape[-1]
        epsilon_factor *= N ** p
    regularized_ot = solve_regularized_ot(x.squeeze(), y.squeeze(), epsilon=epsilon_factor, p=p)
    return max(regularized_ot.dual_cost, 0) ** (1 / p)

def exact_wasserstein(x: jax.Array, y: jax.Array, p: float) -> float:
    cost = cost_matrix(coordiantes(x.squeeze().shape), p=p)
    return network_simplex_ot(x, y, cost).value ** (1 / p)


if __name__ == "__main__":
    import sys
    sys.set_int_max_str_digits(0)
    
    import jax
    import time
    from typing import Callable

    from ot_jax.data.datasets.DOTmark import (
        DOTmarkLoader,
        DOTmarkResolution,
        DOTmarkClass,
    )

    def benchmark_wass(dist_fun: Callable, x: jax.Array, y: jax.Array, **kwargs):
        start = time.perf_counter_ns()
        w = jax.block_until_ready(dist_fun(x, y, **kwargs))
        duration = time.perf_counter_ns() - start
        print(f"{dist_fun.__name__}: {w:.5f}")
        print(f"Time: {duration * 1e-9:.3f} s")

    res = DOTmarkResolution.MEDIUM
    dot_class = DOTmarkClass.Microscopy_Images
    dot = DOTmarkLoader(
        dot_class=dot_class, resolution=res, normalize=True, array_backend=jax.numpy).as_list()
    x = dot[dot_class][res][0]
    y = dot[dot_class][res][1]

    test_epsilon = 1e-3
    test_scale = 2
    test_p = 2
    print(f"{res=},{test_epsilon=},{test_scale=},{test_p=}")
    # with jax.checking_leaks():
    #     _ = jax.block_until_ready(
    #         weighted_cost_upper_bound(x, y, scale_factor=test_scale, p=test_p))
    # ...

    # print("Network simplex OT:")
    # w_exact = jax.block_until_ready(exact_wasserstein(x, y, p=test_p))
    # print(f"{w_exact=:.5f}")

    # print("Upper Bounds:")
    # benchmark_wass(weighted_cost_upper_bound, x, y, scale_factor=test_scale, p=test_p)
    # benchmark_wass(bilevel_upper_bound, x, y, scale_factor=test_scale, p=test_p)
    # benchmark_wass(entropy_upper_bound, x, y, p=test_p, epsilon_factor=test_epsilon)

    print("Lower Bounds:")
    benchmark_wass(bilevel_lower_bound, x, y, scale_factor=test_scale, p=test_p)
    benchmark_wass(entropy_lower_bound, x, y, p=test_p, epsilon_factor=test_epsilon)
    benchmark_wass(min_cost_lower_bound, x, y, scale_factor=test_scale, p=test_p)
