"""Implementations of algorithms for continuous control."""
from pickletools import optimize
from typing import Sequence

import jax
import jax.numpy as jnp
import numpy as np
import optax

from jaxrl.agents.rvs import joint_actor
from jaxrl.datasets.rvs_d4rl_dataset import RvsBatch
from jaxrl.networks import policies
from jaxrl.networks.common import InfoDict, Model


_mse_update_jit = jax.jit(joint_actor.mse_update)
_mse_eval_jit = jax.jit(joint_actor.mse_eval)


class JointRvsLearner(object):

    def __init__(self,
                 seed: int,
                 observations: jnp.ndarray,
                 actions: jnp.ndarray,
                 actor_lr: float = 1e-3,
                 num_steps: int = int(1e6),
                 hidden_dims: Sequence[int] = (256, 256),
                 distribution: str = 'det',
                 n_bins:  int = 11,
                 v_min: int = 0,
                 v_max: int = 1000,
                 **kwargs):

        self.distribution = distribution

        rng = jax.random.PRNGKey(seed)
        rng, actor_key = jax.random.split(rng)

        self.n_bins = n_bins
        self.support = jnp.linspace(v_min, v_max, n_bins)

        action_dim = actions.shape[-1]
        if distribution == 'det':
            actor_def = policies.MSEPolicy(hidden_dims,
                                           action_dim * n_bins,
                                           dropout_rate=0.0)
        else:
            raise NotImplemented

        optimizer = optax.adam(actor_lr)

        self.actor = Model.create(actor_def,
                                  inputs=[actor_key, observations],
                                  tx=optimizer)
        self.rng = rng

    def sample_actions(self,
                       observations: np.ndarray,
                       outcomes: np.ndarray,
                       temperature: float = 1.0):
        self.rng, actions = policies.sample_actions(self.rng,
                                                    self.actor.apply_fn,
                                                    self.actor.params,
                                                    observations,
                                                    temperature,
                                                    self.distribution)
        # select actions
        actions = jnp.reshape(actions, (actions.shape[0], self.n_bins, -1))
        batch_support = jnp.repeat(jnp.expand_dims(self.support, 0), 
                                    outcomes.shape[0], axis=0)
        idx = jnp.argmin(jnp.abs(batch_support - outcomes), axis=1)
        selected_actions = actions[jnp.arange(actions.shape[0]), idx]

        # cast to  numpy
        selected_actions = np.asarray(selected_actions)
        return np.clip(selected_actions, -1, 1)

    def update(self, batch: RvsBatch):
        if self.distribution == 'det':
            self.rng, self.actor, info = _mse_update_jit(
                self.actor, batch, self.rng, self.support)
        return info


    def eval(self, batch: RvsBatch):
        if self.distribution == 'det':
            self.rng, info = _mse_eval_jit(
                self.actor, batch, self.rng, self.support)
        return info