import jax  
import jax.numpy as jnp  
from typing import Callable, NamedTuple, Dict, Iterator, Tuple, List
import haiku as hk
import optax
from acme.utils import loggers
import functools
import rlax
from acme.jax import networks as networks_lib


class TrainingState(NamedTuple):
  """Holds the agent's training state."""
  params: hk.Params
  opt_state: optax.OptState
  steps: int
  rng_keys: jnp.ndarray

class BCLearner():
    """ An object for learning an ensemble and predicting uncertainty """
    def __init__(self, net: networks_lib.FeedForwardNetwork, 
                    opt: optax.GradientTransformation, 
                    dataset: Iterator,
                    normalize_fn: Callable[[jnp.ndarray], jnp.ndarray],
                    seed: int,
                    logger: loggers.Logger):

        self.net = net 
        self.opt = opt
        self.dataset = dataset
        self.normalize_fn = normalize_fn

        key = jax.random.PRNGKey(seed)
        params = self.net.init(key)
        opt_state = self.opt.init(params)

        self._state = TrainingState(params=params, opt_state=opt_state, 
                                    steps=0, rng_keys=key)
        self._logger = logger

    def _loss(self, params: hk.Params, x: jnp.ndarray,
                    a: jnp.ndarray) -> jnp.ndarray:
        logits = self.net.apply(params, x)
        log_probs = jax.nn.log_softmax(logits)
        loss = - log_probs[jnp.arange(a.shape[0]), a.astype(jnp.int32)]
        return jnp.mean(loss)

    @functools.partial(jax.jit, static_argnums=0)
    def _step(self, state: TrainingState, 
                x: jnp.ndarray, a:jnp.ndarray
                ) -> Tuple[TrainingState, Dict]:
        x = self.normalize_fn(x)
        loss, grads = jax.value_and_grad(self._loss)(state.params, x, a)
        updates, opt_state = self.opt.update(grads, state.opt_state)
        new_params = optax.apply_updates(state.params, updates)
        state = TrainingState(new_params, opt_state, 
                                    state.steps + 1, 
                                    state.rng_keys)
        metrics = {'step': state.steps,
                    'loss': jnp.mean(loss)}
        return state, metrics

    def step(self):
        _, transition = next(self.dataset)
        x = transition.data.observation
        a = transition.data.action
        self._state, metrics = self._step(self._state, x, a)
        self._logger.write(metrics)

    def get_probs(self, x: jnp.ndarray) -> jnp.ndarray:
        x = self.normalize_fn(x)
        logits = self.net.apply(self._state.params, x)
        return jax.nn.softmax(logits)

    def sample_action(self, x: jnp.ndarray, key: jnp.ndarray) -> jnp.ndarray:
        probs = self.get_probs(x)
        raise rlax.categorical_sample(key, probs)

    def eval(self, x: jnp.ndarray, a: jnp.ndarray, info):
        x = self.normalize_fn(x)
        logits = self.net.apply(self._state.params, x)
        log_probs = jax.nn.log_softmax(logits)
        loss = - log_probs[jnp.arange(a.shape[0]), a.astype(jnp.int32)]

        probs = jnp.exp(log_probs)
        p_a = probs[jnp.arange(a.shape[0]), a.astype(jnp.int32)]
        entropy = jnp.sum(- log_probs * probs, axis=-1)
        metrics = info
        metrics.update({'loss': jnp.mean(loss),
                        'prob_a': jnp.mean(p_a),
                        'entropy': jnp.mean(entropy)})
        return metrics

    def get_variables(self, names: List[str]) -> List[hk.Params]:
        return [self._state.params]

    def save(self) -> TrainingState:
        return self._state

    def restore(self, state: TrainingState):
        self._state = state