from dataclasses import dataclass
from functools import partial, reduce
from itertools import product as cart_product
from operator import mul
from typing import Sequence

import numpy as np
import jax
import jax.numpy as jnp
from jax.experimental.sparse import BCOO
from jax.image import resize, ResizeMethod
from ot.lp.emd_wrap import emd_c
from ott.geometry.pointcloud import PointCloud
from ott.geometry.grid import Grid
from ott.solvers.linear import solve as solve_sinkhorn, sinkhorn
from ott.geometry.costs import EuclideanP

from ot_jax.optimal_transport.jax_math import (
    sum_pool,
    scaled_coordinates,
    cost_matrix,
    coordiantes,
)


@jax.tree_util.register_dataclass
@dataclass
class OptimalTransport:
    coupling: BCOO
    potentials: tuple[jax.Array, jax.Array]
    value: jax.Array
    
def center_weights(shape: tuple[int], p: int) -> jax.Array:
    coords = coordiantes(shape)
    center = jnp.array([N // 2 for N in shape])
    return jnp.linalg.norm(coords - center, axis=1) ** p

def weighted_total_variation(
    x: jax.Array, y: jax.Array, weights: jax.Array, p: int) -> jax.Array:
    return 2 ** (1 - 1 / p) * jnp.sum(weights * jnp.abs(x - y)) ** (1 / p)

def weighted_average_pool(
    A: jax.Array, x_weight: jax.Array, y_weight: jax.Array, scale_factor: int
) -> jax.Array:
    weight_tensor = jnp.outer(x_weight.flatten(), y_weight.flatten()).reshape(
        x_weight.shape * 2
    )

    numerator = sum_pool(A.reshape(x_weight.shape * 2) * weight_tensor, scale_factor)
    denominator = sum_pool(weight_tensor, scale_factor)
    cost_tensor = jnp.where(denominator == 0, jnp.inf, numerator / denominator)
    return cost_tensor.reshape([N // scale_factor**x_weight.ndim for N in A.shape])


@jax.jit
def proportional_fitting(
    x: jax.Array, y: jax.Array, coupling: jax.Array, threshold: float
) -> jax.Array:

    def iteration(
        scale_vectors: tuple[jax.Array, jax.Array]
    ) -> tuple[jax.Array, jax.Array]:
        a, b = scale_vectors
        a = x / (coupling @ b)
        b = y / (coupling.T @ a)
        return a, b

    def condition(scale_vectors: tuple[jax.Array, jax.Array]) -> bool:
        a, b = scale_vectors
        x_hat = a * (coupling @ b)
        y_hat = b * (coupling.T @ a)
        error = jnp.linalg.norm(x - x_hat, ord=1) + jnp.linalg.norm(y - y_hat, ord=1)
        return error > threshold

    a, b = jnp.zeros(x.shape), jnp.ones(y.shape)
    a, b = jax.lax.while_loop(condition, iteration, (a, b))
    return a.reshape(-1, 1) * (coupling * b.reshape(1, -1))


@jax.jit
def upscale_potential(potential: jax.Array, x: jax.Array) -> jax.Array:
    ndims = x.squeeze().ndim
    n = round(potential.size ** (1 / ndims))
    shape = (n,) * ndims
    assert n**ndims == potential.size, f"n**ndims = {n**ndims}, potential.size = {potential.size}"
    upscaled_potential = resize(
        potential.reshape(shape), x.squeeze().shape, ResizeMethod.LINEAR
    ).reshape(-1, 1)

    return upscaled_potential


@jax.jit
def transport_cost_dual(
    x: jax.Array, y: jax.Array, f: jax.Array, g: jax.Array
) -> jax.Array:
    return (f.T @ x + g.T @ y).squeeze()


@jax.jit
def transport_cost_primal(coupling: jax.Array | BCOO, cost: jax.Array) -> jax.Array:
    return ((coupling * cost).sum()).squeeze()


@partial(jax.jit, static_argnames=["shape", "p"])
def transport_cost_primal_sparse(
    sparse_coupling: BCOO, shape: Sequence[int], p: float
) -> jax.Array:
    tensor = sparse_coupling.reshape(shape * 2)
    x_coords = tensor.indices[:, : len(shape)]
    y_coords = tensor.indices[:, len(shape) :]
    cost = jnp.linalg.norm(x_coords - y_coords, axis=1) ** p
    return tensor.data @ cost


@partial(jax.jit, static_argnames=["scale_factor"])
def upscale_sparse_indices(
    index: jax.Array, scale_factor: int, base_factor: jax.Array
) -> jax.Array:
    window = [i + jax.lax.iota(size=scale_factor, dtype=jnp.float32) for i in index]
    tensor_window_indices = jnp.asarray(list(cart_product(*window)))
    ndims = base_factor.size
    matrix_window_indices = jnp.concat(
        (
            tensor_window_indices[:, :ndims] @ base_factor,
            tensor_window_indices[:, ndims:] @ base_factor,
        ),
        axis=1,
    )
    return matrix_window_indices


@partial(jax.jax.jit, static_argnames=["scale_factor"])
def upscale_coupling(sparse_tensor: jax.Array, scale_factor: int) -> BCOO:
    """
    Upscale coupling by scale_factor.
    """
    ndims = sparse_tensor.ndim // 2
    base_factor = jnp.flip(
        (
            (scale_factor * sparse_tensor.shape[0])
            ** jax.lax.iota(size=ndims, dtype=jnp.float32)
        ).reshape(-1, 1)
    )
    sparse_matrix_indices = jnp.concatenate(
        jax.vmap(
            partial(
                upscale_sparse_indices,
                scale_factor=scale_factor,
                base_factor=base_factor,
            )
        )(sparse_tensor.indices * scale_factor)
    ).astype(jnp.int32)

    kernel_size = scale_factor**sparse_tensor.ndim  # k^(2d)
    kernel = jnp.ones(kernel_size) / kernel_size

    sparse_matrix_values = jnp.kron(sparse_tensor.data, kernel)
    coupling_dims = (
        reduce(mul, sparse_tensor.shape[:ndims]) * (scale_factor**ndims),
    ) * 2
    return BCOO((sparse_matrix_values, sparse_matrix_indices), shape=coupling_dims)


@jax.jit
def c_transform(potential, cost):
    return jnp.min(cost - potential, axis=0).reshape(-1, 1)


@jax.jit
def estimate_dual_null_weights(
    alpha0: jax.Array, beta0: jax.Array, a: jax.Array, b: jax.Array, M: jax.Array
) -> tuple[jax.Array, jax.Array]:
    # binary indexing of non-zeros weights
    asel = a != 0
    bsel = b != 0

    # compute dual constraints violation
    constraint_violation = alpha0[:, None] + beta0[None, :] - M

    # Compute largest violation per line and columns
    aviol = jnp.max(constraint_violation, 1)
    bviol = jnp.max(constraint_violation, 0)

    # update corrects violation of
    alpha_up = -1 * ~asel * jnp.maximum(aviol, 0)
    beta_up = -1 * ~bsel * jnp.maximum(bviol, 0)

    alpha = alpha0 + alpha_up
    beta = beta0 + beta_up
    return alpha, beta


def network_simplex_ot(
    x: jax.Array,
    y: jax.Array,
    C: jax.Array,
    *,
    max_iter: int = 1e6,
    n_threads: int = 1,
) -> OptimalTransport:
    y *= x.sum() / y.sum()
    G, cost, u, v, _ = emd_c(  # TODO: register as foreign function
        np.asarray(jax.device_get(x.flatten()).astype(jnp.float64)),
        np.asarray(jax.device_get(y.flatten()).astype(jnp.float64)),
        np.asarray(jax.device_get(C).astype(jnp.float64)),
        max_iter,
        n_threads,
    )
    u, v = jax.device_put(u), jax.device_put(v)
    if jnp.any(x == 0) or jnp.any(y == 0):
        u, v = estimate_dual_null_weights(u, v, x.flatten(), y.flatten(), C)

    return OptimalTransport(
        coupling=jax.device_put(BCOO.fromdense(G)),
        potentials=(u, v),
        value=jax.device_put(cost),
    )


def solve_scaled_ot(
    x: jax.Array, y: jax.Array, scale_factor: int, p: float, **kwargs
) -> OptimalTransport:
    coarse_coords = scaled_coordinates(x.squeeze().shape, scale_factor)
    coarse_cost = cost_matrix(coarse_coords, p)
    coarse_x = sum_pool(x, scale_factor)
    coarse_y = sum_pool(y, scale_factor)
    return network_simplex_ot(coarse_x, coarse_y, coarse_cost, **kwargs)


@partial(jax.jit, static_argnames=["p"])
def solve_regularized_ot(
    x: jax.Array, y: jax.Array, epsilon: float = 1.0, p: int = 2
) -> sinkhorn.SinkhornOutput:
    y *= x.sum() / y.sum()
    geom = PointCloud(
        x=coordiantes(x.shape),
        cost_fn=EuclideanP(p),
        epsilon=epsilon,
    )
    return solve_sinkhorn(geom, a=x.ravel(), b=y.ravel(), threshold=1e-6, max_iterations=1e4)