from typing import Tuple

import jax
import jax.numpy as jnp

from jaxrl.datasets.rvs_d4rl_dataset import RvsBatch
from jaxrl.networks.common import InfoDict, Model, Params, PRNGKey



def mse_update(actor: Model, batch: RvsBatch,
               rng: PRNGKey, support: jnp.ndarray):
    rng, key = jax.random.split(rng)
    
    batch_size = batch.observations.shape[0]
    n_bins = support.shape[0]
    batch_support = jnp.repeat(jnp.expand_dims(support, 0), 
                                batch_size, axis=0)

    def loss_fn(actor_params: Params):
        actions = actor.apply_fn({'params': actor_params},
                                 batch.observations,
                                 training=True,
                                 rngs={'dropout': key})

        # select actions
        actions = jnp.reshape(actions, (batch_size, n_bins, -1))
        idx = jnp.argmin(jnp.abs(batch_support - batch.outcomes), axis=1)
        selected_actions = actions[jnp.arange(actions.shape[0]), idx]
        
        actor_loss = ((selected_actions - batch.actions)**2).mean()
        return actor_loss, {'actor_loss': actor_loss}

    return (rng, *actor.apply_gradient(loss_fn))



def mse_eval(actor: Model, batch: RvsBatch,
               rng: PRNGKey, support: jnp.ndarray):
    rng, key = jax.random.split(rng)
    batch_size = batch.observations.shape[0]
    n_bins = support.shape[0]
    batch_support = jnp.repeat(jnp.expand_dims(support, 0), 
                                batch_size, axis=0)

    actions = actor.apply_fn({'params': actor.params},
                                 batch.observations,
                                 training=True,
                                 rngs={'dropout': key})

    # select actions
    actions = jnp.reshape(actions, (batch_size, n_bins, -1))
    idx = jnp.argmin(jnp.abs(batch_support - batch.outcomes), axis=1)
    selected_actions = actions[jnp.arange(actions.shape[0]), idx]
    
    actor_loss = ((selected_actions - batch.actions)**2).mean()
    actor_loss = jax.lax.stop_gradient(actor_loss)

    idx_zero_freq = jnp.mean(idx == 0)
    idx_mean  = jnp.mean(idx)
    idx_max_freq = jnp.mean(idx == n_bins-1)
    return (rng, {'actor_loss': actor_loss, 
                    'idx_zero_freq': idx_zero_freq,
                    'idx_max_freq': idx_max_freq,
                    'idx_mean': idx_mean})
