from flax import nnx
from basics.layers import FourierFeatureNetwork, create_mlp

import jax.numpy as jnp
import jax


@jax.jit
def distance_matrix(x, y):
    """
    objective (b, w, n), (b, w, m) -> (b, n, m)
    """
    # (b, w, n, 1) - (b, w, 1, m)

    pairwise_diff = x[..., None] - y[..., None, :]
    l2 = (pairwise_diff ** 2)
    l2 = l2.sum(axis=-3)
    # (b, n, n)
    return l2


@jax.jit
def imq_kernel(distance, c):
    return c / (distance + c)


@jax.jit
def flatten_and_remove_diagonal(matrix):
    """
    :param matrix: (n,n) matrix
    :return: n^2 - n array flatten, without diagonal
    """
    eyes = jnp.eye(matrix.shape[0])
    index = jnp.argsort(eyes.flatten())
    valid = index[:-matrix.shape[0]]
    return matrix.flatten()[valid]


@jax.jit
def imq_mmd(x, y):
    """
    :param x: array shape (Batch, n_dim, N)
    :param y: array shape (Batch, n_dim, K)
    :return: MMD between x, y
    Note that different batch means different condition, therefore, MMD should be calculated within the same batch array
    """
    c = jnp.asarray([0.1, 0.2, 0.5, 1, 2, 5, 10])
    xy_dist = distance_matrix(x, y)
    xy_dist = xy_dist.reshape(xy_dist.shape[0], -1)

    xx_dist = distance_matrix(x, x)
    xx_dist = jax.vmap(flatten_and_remove_diagonal, in_axes=0, out_axes=0)(xx_dist)
    yy_dist = distance_matrix(y, y)
    yy_dist = jax.vmap(flatten_and_remove_diagonal, in_axes=0, out_axes=0)(yy_dist)
    #  N parallel -> C parallel
    kernel_map = jax.vmap(jax.vmap(imq_kernel, in_axes=(None, -1), out_axes=-1), in_axes=(-1, None), out_axes=-1)
    k_xy = kernel_map(xy_dist, c).mean(axis=(-2)).mean(axis=-1)
    k_xx = kernel_map(xx_dist, c).mean(axis=(-2)).mean(axis=-1)
    k_yy = kernel_map(yy_dist, c).mean(axis=(-2)).mean(axis=-1)
    return (k_xx + k_yy - 2 * k_xy).mean()


class QuantileNet(nnx.Module):
    @nnx.split_rngs(splits=3)
    @nnx.vmap(in_axes=(0, None,  None, 0))
    def __init__(self,
                 features_dim: int,
                 reward_dim: int,
                 rngs: nnx.Rngs
                 ):
        self.reward_dim = reward_dim

        self.merge = nnx.Sequential(
            *create_mlp(features_dim, 256, net_arch=(256,), rngs=rngs),
        )

        self.taus_emb = nnx.Sequential(
            FourierFeatureNetwork(1, 256, rngs=rngs),
            nnx.Linear(256, 256, rngs=rngs),
            nnx.relu
        )
        self.conv = nnx.Sequential(
            nnx.Conv(self.reward_dim, 64, kernel_size=(3, ), rngs=rngs),
            nnx.LayerNorm(64, rngs=rngs),
            nnx.silu,
            nnx.Conv(64, 32, kernel_size=(3, ), rngs=rngs),
            nnx.LayerNorm(32, rngs=rngs),
            nnx.silu,
            nnx.Conv(32, self.reward_dim, kernel_size=(3, ), rngs=rngs)
        )
        self.linear = nnx.Linear(256, self.reward_dim + 2, rngs=rngs)

    @nnx.vmap(in_axes=(nnx.StateAxes({ nnx.Param: 0, }), None, None), out_axes=-1)
    def __call__(self, x, taus):
        # (b, w, n, 1) -> (b, w, n, f)
        taus = self.taus_emb(taus[..., None])
        # (b, f)
        feature = self.merge(x)
        feature = feature[..., None, None, :]
        # (b, w, n, f)
        mult = self.linear((taus * feature)).swapaxes(-3, -1)
        out = jax.vmap(self.conv, in_axes=-2, out_axes=-2)(mult)
        out = out.swapaxes(-3, -1)
        return out.mean(axis=-1)


class EWPCritic(nnx.Module):
    def __init__(self,
                 obs_dim,
                 action_dim,
                 reward_dim: int,
                 ff_action: bool = True,
                 *,
                 rngs
                 ):
        self.reward_dim = reward_dim
        self.obs_extractor = nnx.Sequential(
            FourierFeatureNetwork(obs_dim, 256, stddev=1e-2, rngs=rngs),
            nnx.Linear(256, 64,
                       rngs=rngs),
            nnx.LayerNorm(64, rngs=rngs),
            nnx.relu,
        )
        self.ff_action = ff_action
        if self.ff_action:
            self.action_extractor = nnx.Sequential(
                FourierFeatureNetwork(action_dim, 64, stddev=1e-3, rngs=rngs),
                nnx.Linear(64, 64,
                           rngs=rngs),
                nnx.LayerNorm(64, rngs=rngs),
                nnx.relu,
            )
        self.weight_extractor = nnx.Sequential(
            FourierFeatureNetwork(self.reward_dim, 64, rngs=rngs),
            nnx.Linear(64, 32, rngs=rngs)
        )
        self.merge = QuantileNet(64 * 2 + 32,  self.reward_dim,
                                 rngs=rngs)

    def __call__(self, obs, action, w, taus):
        f_obs = self.obs_extractor(obs)
        f_action = self.action_extractor(action)
        f_w = self.weight_extractor(w)
        return self.merge(jnp.concatenate([f_obs, f_action, f_w], axis=-1), taus)

    def loss_fn(self, obs, action, td_target, weight, taus):

        prediction = self(obs, action, weight, taus)

        return jax.vmap(imq_mmd, in_axes=(-1, None), out_axes=-1)(prediction, td_target).sum(axis=-1)


if __name__ == '__main__':

    import os
    import optax
    os.environ['CUDA_VISIBLE_DEVICES'] = '1'

    # out = iqn.loss_fn(obs_ph, action_ph, td_target, jax.random.PRNGKey(32))
    t = jax.random.uniform(key=jax.random.PRNGKey(42), shape=(10240,2))

    colors = jnp.stack([t[:, 0], jnp.zeros_like(t[:, 0]), 0.8 * t[:, 1]], axis=1)
    import matplotlib.pyplot as plt

    plt.scatter(t[:, 0], t[:, 1], c=colors)
    plt.show()
    iqn = EWPCritic(4, 2, 2, rngs=nnx.Rngs(42))
    keys = jax.random.split(jax.random.PRNGKey(42), 1000)
    obs_ph = jax.random.normal(keys[0], shape=(256, 4))
    action_ph = jax.random.normal(keys[1], shape=(256, 2))
    taus_ph = jax.random.uniform(keys[2], shape=(256, 2, 32))
    td_target = jnp.ones(shape=(256, 3, 32), )
    # out_call = iqn(obs_ph, action_ph, taus_ph)
    opt = nnx.Optimizer(iqn, optax.chain(
        optax.adabelief(3e-4),
    )
                        )
    graph, state = nnx.split((iqn, opt))
    from sklearn.datasets import make_swiss_roll
    import numpy as np
    import jax_dataloader as jdl
    import matplotlib.pyplot as plt
    from tqdm import trange

    data, _ = make_swiss_roll(n_samples=25600, random_state=42)
    data = np.stack([data[:, 0], data[:, -1]], axis=-1)
    dataset = jdl.ArrayDataset(data)
    loader = jdl.DataLoader(dataset, backend='jax', batch_size=256 * 8, drop_last=True)

    dummy_obs = jnp.zeros(shape=(256, 4))
    dummy_action = jnp.zeros(shape=(256, 2))
    dummy_taus = jax.random.uniform(jax.random.PRNGKey(42), shape=(256, 2, 4))

    out = iqn(dummy_obs, dummy_action, dummy_action, dummy_taus, )[..., 0]


    @jax.jit
    def update_fn(graph, state, x, key):
        x = x.reshape(-1, 8, 2).swapaxes(-2, -1)  # (256, 2, 8)

        model, opt = nnx.merge(graph, state)
        dummy_obs = jnp.zeros(shape=(x.shape[0], 4))
        dummy_action = jnp.zeros(shape=(x.shape[0], 2))

        def loss_fn(model: EWPCritic):
            # observations, actions, td_target, key
            taus = jax.random.uniform(key, (dummy_obs.shape[0], 2, 32))

            loss = model.loss_fn(dummy_obs, dummy_action, x, jnp.zeros(shape=dummy_action.shape),taus)
            return loss.mean()

        loss, grad = nnx.value_and_grad(loss_fn)(model)
        opt.update(grad)
        _, new_state = nnx.split((model, opt))
        new_key = jax.random.split(key, 2)[-1]
        return loss, new_state, new_key


    key = jax.random.PRNGKey(32)
    for _ in trange(1000):
        for x, in loader:
            loss, state, key = update_fn(graph, state, x, key)

    iqn, opt = nnx.merge(graph, state)
    import cloudpickle as pickle
    with open('mmd.pkl', 'wb') as f:
        pickle.dump(state, f)

    taus = dummy_taus  # dummy_taus.at[:, 0].set(dummy_taus[:, 0] * 0.5)
    out = iqn(dummy_obs, dummy_action, dummy_action, taus)[..., 0]
    # (256, 2, 4) -> (256, 4, 2) -> (1024, 2)
    t = dummy_taus.swapaxes(1, -1).reshape(-1, 2)

    colors = np.stack([t[:, 0], np.zeros_like(t[:, 0]), t[:, 1]], axis=1)
    out = out.swapaxes(1, -1).reshape(-1, 2)
    plt.rcParams['font.family'] = 'Times New Roman'

    plt.scatter(out[:, 0], out[:, 1], c=colors)
    plt.tight_layout()
    plt.show()
