from typing import Any, Callable, Optional, Sequence

import functools
import distrax
import jax
import optax
import jax.numpy as jnp
import numpy as np
from flax import linen as nn
from flax.training import train_state


class EnsembleDense(nn.Module):
    ensemble_num: int
    features: int
    use_bias: bool = True
    dtype: Any = jnp.float32
    precision: Any = None
    kernel_init: Callable = nn.initializers.lecun_normal()
    bias_init: Callable = nn.initializers.zeros

    @nn.compact
    def __call__(self, inputs: jnp.array) -> jnp.array:
        inputs = jnp.asarray(inputs, self.dtype)
        kernel = self.param(
            "kernel", self.kernel_init,
            (self.ensemble_num, inputs.shape[-1], self.features))
        kernel = jnp.asarray(kernel, self.dtype)
        y = jnp.einsum("ij,ijk->ik", inputs, kernel)
        if self.use_bias:
            bias = self.param("bias", self.bias_init,
                              (self.ensemble_num, self.features))
            bias = jnp.asarray(bias, self.dtype)
            y += bias
        return y


class EnsembleCritic(nn.Module):
    ensemble_num: int
    hid_dim: int = 256

    def setup(self):
        self.l1 = EnsembleDense(ensemble_num=self.ensemble_num,
                                features=self.hid_dim,
                                name="fc1")
        self.l2 = EnsembleDense(ensemble_num=self.ensemble_num,
                                features=self.hid_dim,
                                name="fc2")
        self.l3 = EnsembleDense(ensemble_num=self.ensemble_num,
                                features=1,
                                name="fc3")

    def __call__(self, observations, actions):
        x = jnp.concatenate([observations, actions], axis=-1)
        x = nn.relu(self.l1(x))
        x = nn.relu(self.l2(x))
        x = self.l3(x)
        return x.squeeze(-1)


class EnsembleDoubleCritic(nn.Module):
    ensemble_num: int
    hid_dim: int = 256

    def setup(self):
        self.q1 = EnsembleCritic(self.ensemble_num, self.hid_dim)
        self.q2 = EnsembleCritic(self.ensemble_num, self.hid_dim)

    def __call__(self, observations, actions):
        q1 = self.q1(observations, actions)
        q2 = self.q2(observations, actions)
        return q1, q2


class EnsembleActor(nn.Module):
    ensemble_num: int
    act_dim: int
    hid_dim: int = 256
    max_action: float = 1.0
    min_scale: float = 1e-3

    def setup(self):
        self.l1 = EnsembleDense(ensemble_num=self.ensemble_num,
                                features=self.hid_dim,
                                name="fc1")
        self.l2 = EnsembleDense(ensemble_num=self.ensemble_num,
                                features=self.hid_dim,
                                name="fc2")
        self.mu_layer = EnsembleDense(ensemble_num=self.ensemble_num,
                                      features=self.act_dim,
                                      name="mu")
        self.std_layer = EnsembleDense(ensemble_num=self.ensemble_num,
                                       features=self.act_dim,
                                       name="std")

    def __call__(self, observation: jnp.ndarray):
        x = nn.relu(self.l1(observation))
        x = nn.relu(self.l2(x))
        mu = self.mu_layer(x)
        mean_action = nn.tanh(mu)

        std = self.std_layer(x)
        std = jax.nn.softplus(std) + self.min_scale

        action_distribution = distrax.Transformed(
            distrax.MultivariateNormalDiag(mu, std),
            distrax.Block(distrax.Tanh(), ndims=1))
        return mean_action, action_distribution

    def get_logprob(self, observation, action):
        x = nn.relu(self.l1(observation))
        x = nn.relu(self.l2(x))
        mu = self.mu_layer(x)
        mean_action = nn.tanh(mu)

        std = self.std_layer(x)
        std = jax.nn.softplus(std) + self.min_scale

        action_distribution = distrax.Normal(mu, std)
        raw_action = atanh(action)
        log_prob = action_distribution.log_prob(raw_action).sum(-1)
        log_prob -= 2 * (jnp.log(2) - raw_action -
                         jax.nn.softplus(-2 * raw_action)).sum(-1)
        return log_prob


class EnsembleScalar(nn.Module):
    init_value: jnp.ndarray

    def setup(self):
        self.value = self.param("value", lambda x: jnp.array(self.init_value))

    def __call__(self):
        return self.value
