import flax.struct
import jax.nn
import distrax
import numpy as np

from networks.base import MLP, PlainMLP, SplineCurve, CosineQuantileHead, LFF
import flax.linen as nn
import jax.numpy as jnp
from typing import Callable
from functools import partial
from networks.base import GaussianFourierWithGrad


@jax.jit
def x_exp_xsqure(x):
    return x * jnp.sinc(jnp.exp(-jnp.abs(x)))


@jax.jit
def mish(x):
    return x * jnp.tanh(jax.nn.softplus(x))


class QCritic(nn.Module):
    activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu

    @nn.compact
    def __call__(self, observations: jnp.ndarray, action: jnp.ndarray) -> jnp.ndarray:
        feature = jnp.concatenate([observations, action], axis=-1)
        y = MLP(1, hidden_sizes=(256, 256),
                activation_fn=self.activation_fn, layer_norm=True, d2rl=False)(feature)
        return y


class VectorCritic(nn.Module):
    n_critics: int = 2
    activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu
    cmv: bool = False

    @nn.compact
    def __call__(self, obs: jnp.ndarray, action: jnp.ndarray):
        # Idea taken from https://github.com/perrin-isir/xpag
        # Similar to https://github.com/tinkoff-ai/CORL for PyTorch

        vmap_critic = nn.vmap(
            QCritic,
            variable_axes={"params": 0},  # parameters not shared between the critics
            split_rngs={"params": True},
            in_axes=None,
            out_axes=-1,
            axis_size=self.n_critics,
        )
        q_values = vmap_critic(
            activation_fn=self.activation_fn,
        )(obs, action)
        return q_values


class FFNet(nn.Module):
    n_dim: int
    scale: float = 3e-3

    def setup(self) -> None:
        self.kernel_initializer = nn.initializers.variance_scaling(self.scale,
                                                                   mode='fan_in', distribution='normal')

        self.dense = nn.Dense(self.n_dim,
                              kernel_init=self.kernel_initializer,
                              use_bias=False)

    def __call__(self, x):
        ff_x = 2 * jnp.pi * self.dense(x)
        return jnp.concatenate([jnp.sin(ff_x), jnp.cos(ff_x), x], axis=-1)


class CosineQuantileCritic(nn.Module):
    activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu
    embedding_size: int = 64
    n_cosine: int = 64
    smooth: bool = True
    ff_feature: bool = False

    @nn.compact
    def __call__(self, obs: jnp.ndarray, action: jnp.ndarray, taus: jnp.ndarray) -> (
            jnp.ndarray):
        feature = (jnp.concatenate([obs, action], axis=-1))
        if self.ff_feature:
            print("FOURIER FEATURE")
            ff_proj = nn.Dense(250, kernel_init=nn.initializers.variance_scaling(1e-3, mode='fan_in', distribution='normal'),
            use_bias=False)(feature)
            ff = 2 * jnp.pi * ff_proj
            feature = jnp.concatenate([jnp.sin(ff), jnp.cos(ff), feature], axis=-1)

        feature = MLP(self.embedding_size, activation_fn=self.activation_fn, layer_norm=True,
                      )(feature)

        qfs = (CosineQuantileHead(smooth=self.smooth,
                                  n_cosine=self.n_cosine,
                                  embedding_size=self.embedding_size)(feature, taus))
        return qfs


class SplineQunatileCritic(nn.Module):
    features_dim: int = 64
    activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.PReLU()
    n_splines: int = 2
    n_bins: int = 8

    def setup(self) -> None:
        self.feature_extractor = MLP(self.features_dim)
        self.spline = SplineCurve(8)

    def __call__(self, obs: jnp.ndarray, action: jnp.ndarray, taus: jnp.ndarray) -> jnp.ndarray:
        feature = self.feature_extractor(jnp.concatenate([obs, action], axis=-1))
        quantiles = self.spline(feature, taus)

        # quantiles = a * quantiles + b
        return quantiles


class VectorQuantileCritic(nn.Module):
    n_critics: int = 2
    activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.PReLU()
    monotone: bool = False
    smooth: bool = True
    ff_feature: bool = False

    @nn.compact
    def __call__(self, obs: jnp.ndarray, action: jnp.ndarray, taus: jnp.ndarray):
        # Idea taken from https://github.com/perrin-isir/xpag
        # Similar to https://github.com/tinkoff-ai/CORL for PyTorch
        if self.monotone:
            cls = partial(SplineQunatileCritic, activation_fn=self.activation_fn)
        else:
            cls = partial(CosineQuantileCritic, activation_fn=self.activation_fn, smooth=self.smooth,
                          ff_feature=self.ff_feature)
        vmap_critic = nn.vmap(
            cls,
            variable_axes={"params": 0},  # parameters not shared between the critics
            split_rngs={"params": True},
            in_axes=None,
            out_axes=-1,
            axis_size=self.n_critics,
        )
        q_values = vmap_critic(
            activation_fn=self.activation_fn,
        )(obs, action, taus)
        return q_values


class VectorSpline(nn.Module):
    @nn.compact
    def __call__(self, features, taus):
        vmap_spline = nn.vmap(
            partial(SplineCurve, num_range=8),
            variable_axes={"params": 0},  # parameters not shared between the critics
            split_rngs={"params": True},
            in_axes=None,
            out_axes=-1,
            axis_size=self.n_critics,
        )
        q_values = vmap_spline(
            activation_fn=self.activation_fn,
        )(features, taus)
        return q_values



