import numpy as onp
import jax
import jax.numpy as jnp
from chex import Array
from sves.kernels import RBF


def cmapper(algorithm_name):
    """Colors from Paul Tol: https://personal.sron.nl/~pault/"""
    styles = {
        'OG_SVGD': {'color': '#4477AA', 'linestyle': '--', 'name': 'SVGD'},  # Blue
        'GF_SVGD': {'color': '#ff7f0e', 'linestyle': '-', 'name': 'GF-SVGD'},  # oran ge
        'BB_SVGD_ES': {'color': '#EE6677', 'linestyle': '-', 'name': 'SV-CMA-ES'},  # red
        'MC SVGD': {'color': '#228833', 'linestyle': '-', 'name': 'SV-OpenAI-ES'},  # green
        'Parallel CMA-ES': {'color': '#33BBEE', 'linestyle': '-.', 'name': 'Parallel CMA-ES'},  # cyan
        'Parallel MC SVGD': {'color': '#AA3377', 'linestyle': '-.', 'name': 'Parallel OpenAI-ES'},  # purple
        'CMA_ES': {'color': '#6B6B6B', 'linestyle': '-', 'name': 'CMA-ES'},  # grey
        'OpenES': {'color': 'black', 'linestyle': ':', 'name': 'OpenAI-ES'}  # brown
    }
    return styles.get(algorithm_name, {'color': 'black', 'linestyle': '-'})  # Default style


def get_mmd(x: Array, y: Array) -> Array:
    """Maximum Mean Discrepancy (MMD).

    Taken from: https://github.com/jindongwang/transferlearning/blob/master/code/distance/mmd_numpy_sklearn.py
    using tfp kernel for jax
    """
    del_diag = lambda matrix: matrix[~jnp.eye(matrix.shape[0], dtype=bool)].reshape(matrix.shape[0], -1)
    kernel = RBF()
    kernel_matrix = jax.vmap(jax.vmap(kernel, (None, 0, None)), (0, None, None))

    def mmd(a: Array, b: Array) -> float:
        """Two set MMD. b is assumed as GT."""
        # Choose bandwith based on distances of gt samples
        # This follows SVGD without gradient paper
        sq_dists = (b[None] - b[:, None]) ** 2
        median_sq_dist = jnp.median(sq_dists)
        median_ls = median_sq_dist / (jnp.log(b.shape[0] + 1))
        bandwidth = jnp.clip(median_ls, 1e-6, 1e6)

        # Compute MMD
        xx = kernel_matrix(a, a, bandwidth)
        xy = kernel_matrix(a, b, bandwidth)
        yy = kernel_matrix(b, b, bandwidth)
        xx = del_diag(xx)
        yy = del_diag(yy)

        return xx.mean() - 2 * xy.mean() + yy.mean()

    # Map metric over multiple runs if necessary
    if x.ndim > 2:
        return onp.array([
                jax.vmap(lambda xi: mmd(xi, y))(dslice)
            for dslice in x
        ])
    else:
        return onp.array(jax.vmap(lambda xi: mmd(xi, y))(x))
    # if x.ndim > 2:
    #     return onp.array(jax.vmap(
    #         lambda dslice: jax.vmap(lambda xi: mmd(xi, y))(dslice)
    #     )(x))
    # else:
    #     return onp.array(jax.vmap(lambda xi: mmd(xi, y))(x))


def get_log10mean_mse(x: Array, y: Array) -> [Array, Array]:
    """Compute MSE for approximation of E[X].

    Args:
        x: Input samples. Shape (batch size, n iterations, fct dim).
         batch size is the dimension over which the mean is taken.

        y: Ground truth samples. Shape (1, fct dim) or flat.

    Returns:
        Mean MSE and stdv of MSE for approximation of E[X].
    """
    means = jnp.mean(x, axis=(-2, -1))
    sqerr = (means - jnp.mean(y)) ** 2
    sqerr = jnp.log10(sqerr)
    return jnp.mean(sqerr, axis=0), jnp.std(sqerr, axis=0)


def get_log10var_mse(x: Array, y: Array) -> [Array, Array]:
    """Compute MSE for approximation of V[X].

    Args:
        x: Input samples. Shape (batch size, n iterations, fct dim).
         batch size is the dimension over which the mean is taken.

        y: Ground truth samples. Shape (1, fct dim) or flat.

    Returns:
        Mean MSE and stdv of MSE for approximation of V[X].
    """
    means = jnp.var(x, axis=(-2, -1))
    sqerr = (means - jnp.var(y)) ** 2
    sqerr = jnp.log10(sqerr)
    return jnp.mean(sqerr, axis=0), jnp.std(sqerr, axis=0)


def compute_metrics(data: Array, ground_truth: Array):
    # Analyze
    gt_samples = ground_truth

    # Compute metrics
    mmds = get_mmd(data, gt_samples)
    mean_sqerr = get_log10mean_mse(data, gt_samples)
    var_sqerr = get_log10var_mse(data, gt_samples)

    # Store
    out = {}
    mmds = jnp.log10(jnp.abs(mmds))  # compute log10; take the abs for negative values that can occur!
    out["mmds"] = (onp.mean(mmds, axis=0), onp.std(mmds, axis=0))
    out["means"] = onp.array(mean_sqerr)
    out["vars"] = onp.array(var_sqerr)

    return out
