import jax
import jax.numpy as jnp  
import haiku as hk
import optax   
from acme.utils import loggers
import functools
from typing import Callable, Tuple, Dict, Iterator, NamedTuple
from uncertainty import UncertaintyLearner

class TrainingState(NamedTuple):
  """Holds the agent's training state."""
  params: hk.Params
  prior_params: hk.Params
  opt_state: optax.OptState
  steps: int
  rng_keys: jnp.ndarray

class EnsembleLearner(UncertaintyLearner):
    """ An object for learning an ensemble and predicting uncertainty """
    def __init__(self, net: hk.Module, 
                    opt: optax.GradientTransformation, 
                    prior_net: hk.Module,
                    dataset: Iterator,
                    normalize_fn: Callable[[jnp.ndarray],jnp.ndarray],
                    seed: int,
                    num_actions: int,
                    n_comp: int,
                    feature_dim: int,
                    noise_scale: jnp.float32,
                    prior_scale: jnp.float32,
                    beta: jnp.float32,
                    logger: loggers.Logger):

        self.net = net 
        self.opt = opt
        self.prior_net = prior_net
        self.dataset = dataset
        self.num_actions = num_actions
        self.n_comp = n_comp 
        self.feature_dim = feature_dim
        self.noise_scale = noise_scale
        self.prior_scale = prior_scale
        self.beta = beta
        self.normalize_fn = normalize_fn

        key = jax.random.PRNGKey(seed)
        rng_keys = jax.random.split(key, 2 * self.n_comp)

        ensemble_init = jax.vmap(self.net.init, in_axes=(0,))
        params = ensemble_init(rng_keys[:self.n_comp])
        prior_init = jax.vmap(self.prior_net.init, in_axes=(0,))
        prior_params = prior_init(rng_keys[self.n_comp:])
        
        ensemble_opt_init = jax.vmap(self.opt.init, in_axes=(0,))
        opt_state = ensemble_opt_init(params)

        self._state = TrainingState(params=params, prior_params=prior_params,
                                    opt_state=opt_state, 
                                    steps=0, rng_keys=rng_keys)
        self._logger = logger

    def _prior_fns(self, x: jnp.ndarray):
        preds = jax.vmap(self.prior_net.apply, in_axes=(0, None))(self._state.prior_params, x)
        preds = jnp.reshape(preds, (preds.shape[0], preds.shape[1],
                                    self.feature_dim, self.num_actions))
        return self.prior_scale * preds

    def _loss_single(self, params: hk.Params, x: jnp.ndarray,
                    a: jnp.ndarray, y: jnp.ndarray,
                    priors: jnp.ndarray) -> jnp.ndarray:
        """
            Function to compute loss for a single network
            Inputs:
                preds: predictions, shape: (batch_size, feature_dim, num_actions)
                a: action indices, shape: (batch_size,)
                y: labels, shape: (batch_size, feature_dim)
                priors: prior preds, shape: (batch_size, feature_dim, num_actions)
            Outputs:
                mse: the mean squared error at action a to predict y, shape: (1,)
        """
        preds = self.net.apply(params, x)
        preds = jnp.reshape(preds, (x.shape[0], self.feature_dim, self.num_actions))
        preds = preds + priors
        if self.num_actions > 1:
            selected_preds = preds[jnp.arange(y.shape[0]), :, a.astype(jnp.int32)]
        else:
            selected_preds = preds[:,:,0]
        mse = jnp.mean(jnp.square(selected_preds - y))
        return mse

    def _step_single(self, params: hk.Params, opt_state: optax.OptState,
                    x: jnp.ndarray, a: jnp.ndarray, noise: jnp.ndarray,
                    priors: jnp.ndarray) -> Tuple[hk.Params, optax.OptState]:
        loss, grads = jax.value_and_grad(self._loss_single)(params, x, a, noise, priors)
        updates, opt_state = self.opt.update(grads, opt_state)
        new_params = optax.apply_updates(params, updates)
        return new_params, opt_state, jnp.mean(loss)

    @functools.partial(jax.jit, static_argnums=0)
    def _step(self, state: TrainingState, 
                x: jnp.ndarray, a:jnp.ndarray, noise: jnp.ndarray) -> Tuple[TrainingState, Dict]:
        x = self.normalize_fn(x)
        priors = self._prior_fns(x)

        ensemble_update = jax.vmap(self._step_single, in_axes=(0, 0, None, None, 1, 0))
        params, opt_state, loss = ensemble_update(state.params, 
                                                    state.opt_state, 
                                                    x, a, noise, priors)
        state = TrainingState(params, state.prior_params,
                                    opt_state, 
                                    state.steps + 1, 
                                    state.rng_keys)
        metrics = {'step': state.steps,
                    'loss': jnp.mean(loss), 
                    'loss_std': jnp.std(loss),
                    'noise_std': jnp.std(noise)}
        return state, metrics

    def step(self):
        index, transition = next(self.dataset)
        x = transition.data.observation
        a = transition.data.action

        keys = jax.vmap(jax.random.PRNGKey)((index % 1e8).astype(jnp.int32))
        noise = self.noise_scale * jax.vmap(jax.random.normal , 
                                        in_axes=(0, None))(keys, (self.n_comp, self.feature_dim))

        self._state, metrics = self._step(self._state, x, a, noise)
        self._logger.write(metrics)
        
    @functools.partial(jax.jit, static_argnums=0)
    def _predict(self, state: TrainingState, x: jnp.ndarray) -> jnp.ndarray:
        x = self.normalize_fn(x)
        priors = self._prior_fns(x)
        ensemble_predict = jax.vmap(self.net.apply, in_axes=(0, None))
        preds = ensemble_predict(state.params, x)
        preds = jnp.reshape(preds, (preds.shape[0], preds.shape[1],
                                    self.feature_dim, self.num_actions))
        return preds + priors

    def predict(self, x: jnp.ndarray) -> jnp.ndarray:
        return self._predict(self._state, x)
    
    def uncertainty(self, x: jnp.ndarray) -> jnp.ndarray:
        preds = self.predict(x) # (n_comp, batch_size, feature_dim, num_actions)
        mse = jnp.mean(jnp.square(preds), axis=2) # (n_comp, batch_size, num_actions)
        mean_mse = jnp.mean(mse, axis=0) # (batch_size, num_actions)
        var_mse = jnp.var(mse, axis=0) # (batch_size, num_actions)
        unc = mean_mse + self.beta * var_mse 
        #- jnp.square(self.noise_scale)
        return jnp.sqrt(unc) 
        #jnp.sqrt(jnp.maximum(0, unc))

    def eval(self, x: jnp.ndarray, a: jnp.ndarray, info):
        pred = self.predict(x)
        unc_pred = self.uncertainty(x)
        if self.num_actions > 1:
            unc_pred_a = unc_pred[jnp.arange(a.shape[0]), 
                                a.astype(jnp.int32)]
        else:
            unc_pred_a = unc_pred
        metrics = info
        metrics.update({'unc_mean': jnp.mean(unc_pred),
                        'unc_a': jnp.mean(unc_pred_a),
                        'pred_mean': jnp.mean(pred),
                        'pred_std': jnp.std(pred)})
        return metrics

    def restore(self, state: TrainingState):
        self._state = state
