"""Implementations of algorithms for continuous control."""

import functools
from typing import Sequence, Tuple

import jax
import jax.numpy as jnp
import numpy as np
import optax

from jaxrl.agents.rvs.critic import update as update_critic
from jaxrl.agents.rvs.actor import mse_update as update_actor

from jaxrl.datasets.rvs_d4rl_dataset import RvsBatch
from jaxrl.networks import rvs_policies, rvs_critics
from jaxrl.networks.common import InfoDict, Model, PRNGKey


@functools.partial(jax.jit, static_argnames=('update_target'))
def _update_jit(actor: Model, critic: Model, 
                batch: RvsBatch, rng: PRNGKey):

    new_critic, critic_info = update_critic(critic, batch)

    rng, new_actor, actor_info = update_actor(actor, batch, rng)

    return rng, new_actor, new_critic, {
        **critic_info,
        **actor_info,
    }


class GenRvsLearner(object):

    def __init__(self,
                 seed: int,
                 observations: jnp.ndarray,
                 actions: jnp.ndarray,
                 num_quantiles: int = 101,
                 actor_lr: float = 3e-4,
                 critic_lr: float = 3e-4,
                 hidden_dims: Sequence[int] = (256, 256),
                 distribution = 'det',
                 discount: float = 0.99,
                 tau: float = 0.005,
                 target_update_period: int = 1,
                 exploration_noise: float = 0.1,
                 **kwargs):
        """
        Generalized RVS. Also trains a critic with quantile regression
        """

        action_dim = actions.shape[-1]
        self.num_quantiles = num_quantiles
        self.distribution = distribution

        rng = jax.random.PRNGKey(seed)
        rng, actor_key, critic_key = jax.random.split(rng, 3)

        #critic_dims = tuple([d/4 for d in hidden_dims])
        critic_def = rvs_critics.QRValue(hidden_dims, num_quantiles)
        critic = Model.create(critic_def,
                              inputs=[critic_key, observations],
                              tx=optax.adam(learning_rate=critic_lr))
        outcomes = critic(observations)[:, 0:1]

        actor_def = rvs_policies.MSEPolicy(hidden_dims, action_dim)
        actor = Model.create(actor_def,
                             inputs=[actor_key, observations, outcomes],
                             tx=optax.adam(learning_rate=actor_lr))

        self.actor = actor
        self.critic = critic
        self.rng = rng

        self.step = 1

    def sample_actions(self,
                       observations: np.ndarray,
                       quantiles: np.ndarray,
                       temperature: float = 1.0):
        idx = ((1 - quantiles) * self.num_quantiles).astype(int).flatten()

        outcomes = rvs_critics.get_values(self.critic.apply_fn,
                                            self.critic.params,
                                            observations)
        outcomes = outcomes[:, idx]
        self.rng, actions = rvs_policies.sample_actions(self.rng,
                                                    self.actor.apply_fn,
                                                    self.actor.params,
                                                    observations,
                                                    outcomes,
                                                    temperature,
                                                    self.distribution)

        actions = np.asarray(actions)
        return np.clip(actions, -1, 1)

    def update(self, batch: RvsBatch):
        self.step += 1

        self.rng, self.actor, self.critic, info = _update_jit(
            self.actor, self.critic, batch, self.rng)

        return info

    def eval(self, batch: RvsBatch):
        return {}
