from typing import Tuple, Callable
import dataclasses

import jax
import jax.numpy as jnp
import reverb
import rlax
from acme.agents.jax.dqn import learning_lib
from acme.jax import networks as networks_lib

from q_learners import LossFn
from rl_utils import softSPIBB_probs, BCQ_probs, to_qr
from utils import Transition

@dataclasses.dataclass
class SARSA(LossFn):
    """SARSA q learning with prioritization on TD error."""
    discount: float = 0.99
    importance_sampling_exponent: float = 0.0
    max_abs_reward: float = 1.
    huber_loss_parameter: float = 1.

    def __call__(self,
                network: networks_lib.FeedForwardNetwork,
                uncertainty_fn: Callable[[jnp.ndarray], jnp.ndarray],
                behavior_fn: Callable[[jnp.ndarray], jnp.ndarray],
                normalize_fn: Callable[[jnp.ndarray], jnp.ndarray],
                params: networks_lib.Params,
                target_params: networks_lib.Params,
                batch: reverb.ReplaySample,
                key: networks_lib.PRNGKey) -> Tuple[jnp.DeviceArray, learning_lib.LossExtra]:
        """Calculate a loss on a single batch of data."""
        del key
        del uncertainty_fn
        del behavior_fn
        transitions: Transition = batch.data
        keys, probs, *_ = batch.info

        norm_obs = normalize_fn(transitions.observation)
        norm_next_obs = normalize_fn(transitions.next_observation)

        # Forward pass.
        q_tm1 = network.apply(params, norm_obs)
        q_t_target = network.apply(target_params, norm_next_obs)

        # Cast and clip rewards.
        d_t = (transitions.discount * self.discount).astype(jnp.float32)
        r_t = jnp.clip(transitions.reward, -self.max_abs_reward,
                    self.max_abs_reward).astype(jnp.float32)

        # Compute SARSA TD-error.
        batch_error = jax.vmap(rlax.sarsa)
        td_error = batch_error(q_tm1, transitions.action, r_t, d_t, q_t_target,
                            transitions.next_action)
        batch_loss = rlax.huber_loss(td_error, self.huber_loss_parameter)

        # Importance weighting.
        importance_weights = (1. / probs).astype(jnp.float32)
        importance_weights **= self.importance_sampling_exponent
        importance_weights /= jnp.max(importance_weights)

        # Reweight.
        loss = jnp.mean(importance_weights * batch_loss)  # []
        reverb_update = learning_lib.ReverbUpdate(
            keys=keys, priorities=jnp.abs(td_error).astype(jnp.float64))
        mean_q = jnp.mean(q_tm1)
        extra = learning_lib.LossExtra(metrics={"mean_q": mean_q}, reverb_update=reverb_update)
        return loss, extra


@dataclasses.dataclass
class SoftSPIBB(LossFn):
    """Soft-SPIBB DQN."""
    discount: float = 0.99
    epsilon: float = 1.
    max_abs_reward: float = 1.
    huber_loss_parameter: float = 1.

    def __call__(self,
                network: networks_lib.FeedForwardNetwork,
                uncertainty_fn: Callable[[jnp.ndarray], jnp.ndarray],
                behavior_fn: Callable[[jnp.ndarray], jnp.ndarray],
                normalize_fn: Callable[[jnp.ndarray], jnp.ndarray],
                params: networks_lib.Params,
                target_params: networks_lib.Params,
                batch: reverb.ReplaySample,
                key: networks_lib.PRNGKey) -> Tuple[jnp.DeviceArray, learning_lib.LossExtra]:
        """Calculate a loss on a single batch of data."""
        del key
        transitions: Transition = batch.data
        keys, probs, *_ = batch.info

        u_t = jax.lax.stop_gradient(uncertainty_fn(transitions.next_observation))
        pi_b_t = behavior_fn(transitions.next_observation)

        norm_obs = normalize_fn(transitions.observation)
        norm_next_obs = normalize_fn(transitions.next_observation)

        # Forward pass.
        q_tm1 = network.apply(params, norm_obs)
        #q_t = network.apply(params, norm_next_obs)
        q_t_target = network.apply(target_params, norm_next_obs)

        batch_softSPIBB_probs = jax.vmap(softSPIBB_probs)
        eps = self.epsilon * jnp.ones((q_tm1.shape[0],))

        probs_a_t = batch_softSPIBB_probs(q_t_target, u_t, pi_b_t, eps)
        probs_a_t = jax.lax.stop_gradient(probs_a_t)

        # Cast and clip rewards.
        d_t = (transitions.discount * self.discount).astype(jnp.float32)
        r_t = jnp.clip(transitions.reward, -self.max_abs_reward,
                    self.max_abs_reward).astype(jnp.float32)

        # Compute expected SARSA td error using probs_a_t.
        batch_error = jax.vmap(rlax.expected_sarsa)
        td_error = batch_error(q_tm1, transitions.action, r_t, d_t, q_t_target,
                            probs_a_t)
        batch_loss = rlax.huber_loss(td_error, self.huber_loss_parameter)
        loss = jnp.mean(batch_loss)  
        
        reverb_update = learning_lib.ReverbUpdate(
            keys=keys, priorities=jnp.abs(td_error).astype(jnp.float64))
        mean_q = jnp.mean(q_tm1)
        extra = learning_lib.LossExtra(metrics={"mean_q": mean_q}, reverb_update=reverb_update)
        return loss, extra

@dataclasses.dataclass
class SoftSPIBBQuantile(LossFn):
    """Soft-SPIBB QR-DQN."""
    discount: float = 0.99
    epsilon: float = 1.
    max_abs_reward: float = 1.
    huber_loss_parameter: float = 1.
    num_actions: int = 1
    num_quantiles: int = 201
    quantiles = (jnp.arange(0, num_quantiles) + 0.5) / float(num_quantiles)

    def __call__(self,
                network: networks_lib.FeedForwardNetwork,
                uncertainty_fn: Callable[[jnp.ndarray], jnp.ndarray],
                behavior_fn: Callable[[jnp.ndarray], jnp.ndarray],
                normalize_fn: Callable[[jnp.ndarray], jnp.ndarray],
                params: networks_lib.Params,
                target_params: networks_lib.Params,
                batch: reverb.ReplaySample,
                key: networks_lib.PRNGKey) -> Tuple[jnp.DeviceArray, learning_lib.LossExtra]:
        """Calculate a loss on a single batch of data."""
        del key
        transitions: Transition = batch.data
        keys, probs, *_ = batch.info

        u_t = jax.lax.stop_gradient(uncertainty_fn(transitions.next_observation))
        pi_b_t = behavior_fn(transitions.next_observation)

        norm_obs = normalize_fn(transitions.observation)
        norm_next_obs = normalize_fn(transitions.next_observation)

        # Forward pass.
        q_tm1, dist_q_tm1 = to_qr(network.apply(params, norm_obs), 
                                    self.num_actions, self.num_quantiles)
        q_t_target, dist_q_t_target = to_qr(network.apply(target_params, norm_next_obs),
                                        self.num_actions, self.num_quantiles)

        batch_softSPIBB_probs = jax.vmap(softSPIBB_probs)
        eps = self.epsilon * jnp.ones((q_tm1.shape[0],))

        probs_a_t = batch_softSPIBB_probs(q_t_target, u_t, pi_b_t, eps)
        probs_a_t = jax.lax.stop_gradient(probs_a_t)

        # Cast and clip rewards.
        d_t = (transitions.discount * self.discount).astype(jnp.float32)
        r_t = jnp.clip(transitions.reward, -self.max_abs_reward,
                    self.max_abs_reward).astype(jnp.float32)

        # Compute expected SARSA td error using probs_a_t.
        batch_error = jax.vmap(rlax.quantile_expected_sarsa, 
                                    in_axes=(0, None, 0, 0, 0, 0, 0, None))
        losses = batch_error(dist_q_tm1, self.quantiles, 
                            transitions.action, r_t, d_t, 
                            dist_q_t_target, probs_a_t, self.huber_loss_parameter)
        loss = jnp.mean(losses)
        
        reverb_update = learning_lib.ReverbUpdate(
            keys=keys, priorities=jnp.abs(losses).astype(jnp.float64))
        mean_q = jnp.mean(q_tm1)
        extra = learning_lib.LossExtra(metrics={"mean_q": mean_q}, reverb_update=reverb_update)
        return loss, extra

@dataclasses.dataclass
class BCQ(LossFn):
    """Discrete BCQ."""
    discount: float = 0.99
    tau: float = 0.
    max_abs_reward: float = 1.
    huber_loss_parameter: float = 1.

    def __call__(self,
                network: networks_lib.FeedForwardNetwork,
                uncertainty_fn: Callable[[jnp.ndarray], jnp.ndarray],
                behavior_fn: Callable[[jnp.ndarray], jnp.ndarray],
                normalize_fn: Callable[[jnp.ndarray], jnp.ndarray],
                params: networks_lib.Params,
                target_params: networks_lib.Params,
                batch: reverb.ReplaySample,
                key: networks_lib.PRNGKey) -> Tuple[jnp.DeviceArray, learning_lib.LossExtra]:
        """Calculate a loss on a single batch of data."""
        del key
        del uncertainty_fn
        transitions: Transition = batch.data
        keys, probs, *_ = batch.info

        norm_obs = normalize_fn(transitions.observation)
        norm_next_obs = normalize_fn(transitions.next_observation)

        # Forward pass.
        q_tm1 = network.apply(params, norm_obs)
        #q_t = network.apply(params, norm_next_obs)
        q_t_target = network.apply(target_params, norm_next_obs)
        
        pi_b_t = behavior_fn(transitions.next_observation)

        batch_BCQ_probs = jax.vmap(BCQ_probs)
        tau = self.tau * jnp.ones((q_tm1.shape[0],))
        
        probs_a_t = batch_BCQ_probs(q_t_target, pi_b_t, tau)
        probs_a_t = jax.lax.stop_gradient(probs_a_t)

        # Cast and clip rewards.
        d_t = (transitions.discount * self.discount).astype(jnp.float32)
        r_t = jnp.clip(transitions.reward, -self.max_abs_reward,
                    self.max_abs_reward).astype(jnp.float32)

        # Compute expected SARSA td error using probs_a_t.
        batch_error = jax.vmap(rlax.expected_sarsa)
        td_error = batch_error(q_tm1, transitions.action, r_t, d_t, q_t_target,
                            probs_a_t)
        batch_loss = rlax.huber_loss(td_error, self.huber_loss_parameter)
        loss = jnp.mean(batch_loss)  
        
        reverb_update = learning_lib.ReverbUpdate(
            keys=keys, priorities=jnp.abs(td_error).astype(jnp.float64))
        mean_q = jnp.mean(q_tm1)
        extra = learning_lib.LossExtra(metrics={"mean_q": mean_q}, reverb_update=reverb_update)
        return loss, extra


@dataclasses.dataclass
class BCQQuantile(LossFn):
    """BCQ QR-DQN."""
    discount: float = 0.99
    tau: float = 0.
    max_abs_reward: float = 1.
    huber_loss_parameter: float = 1.
    num_actions: int = 1
    num_quantiles: int = 201
    quantiles = (jnp.arange(0, num_quantiles) + 0.5) / float(num_quantiles)

    def __call__(self,
                network: networks_lib.FeedForwardNetwork,
                uncertainty_fn: Callable[[jnp.ndarray], jnp.ndarray],
                behavior_fn: Callable[[jnp.ndarray], jnp.ndarray],
                normalize_fn: Callable[[jnp.ndarray], jnp.ndarray],
                params: networks_lib.Params,
                target_params: networks_lib.Params,
                batch: reverb.ReplaySample,
                key: networks_lib.PRNGKey) -> Tuple[jnp.DeviceArray, learning_lib.LossExtra]:
        """Calculate a loss on a single batch of data."""
        del key
        del uncertainty_fn
        transitions: Transition = batch.data
        keys, probs, *_ = batch.info

        norm_obs = normalize_fn(transitions.observation)
        norm_next_obs = normalize_fn(transitions.next_observation)

        # Forward pass.
        q_tm1, dist_q_tm1 = to_qr(network.apply(params, norm_obs), 
                                    self.num_actions, self.num_quantiles)
        q_t_target, dist_q_t_target = to_qr(network.apply(target_params, norm_next_obs),
                                        self.num_actions, self.num_quantiles)
        pi_b_t = behavior_fn(transitions.next_observation)

        batch_BCQ_probs = jax.vmap(BCQ_probs)
        tau = self.tau * jnp.ones((q_tm1.shape[0],))
        
        probs_a_t = batch_BCQ_probs(q_t_target, pi_b_t, tau)
        probs_a_t = jax.lax.stop_gradient(probs_a_t)

        # Cast and clip rewards.
        d_t = (transitions.discount * self.discount).astype(jnp.float32)
        r_t = jnp.clip(transitions.reward, -self.max_abs_reward,
                    self.max_abs_reward).astype(jnp.float32)

        # Compute expected SARSA td error using probs_a_t.
        batch_error = jax.vmap(rlax.quantile_expected_sarsa, 
                                    in_axes=(0, None, 0, 0, 0, 0, 0, None))
        losses = batch_error(dist_q_tm1, self.quantiles, 
                            transitions.action, r_t, d_t, 
                            dist_q_t_target, probs_a_t, self.huber_loss_parameter)
        loss = jnp.mean(losses)
        
        reverb_update = learning_lib.ReverbUpdate(
            keys=keys, priorities=jnp.abs(losses).astype(jnp.float64))
        mean_q = jnp.mean(q_tm1)
        extra = learning_lib.LossExtra(metrics={"mean_q": mean_q}, reverb_update=reverb_update)
        return loss, extra

@dataclasses.dataclass
class Pessimism(LossFn):
    """Pessimistic DQN."""
    discount: float = 0.99
    alpha: float = 1.
    max_abs_reward: float = 1.
    huber_loss_parameter: float = 1.

    def __call__(self,
                network: networks_lib.FeedForwardNetwork,
                uncertainty_fn: Callable[[jnp.ndarray], jnp.ndarray],
                behavior_fn: Callable[[jnp.ndarray], jnp.ndarray],
                normalize_fn: Callable[[jnp.ndarray], jnp.ndarray],
                params: networks_lib.Params,
                target_params: networks_lib.Params,
                batch: reverb.ReplaySample,
                key: networks_lib.PRNGKey) -> Tuple[jnp.DeviceArray, learning_lib.LossExtra]:
        """Calculate a loss on a single batch of data."""
        del key
        del behavior_fn
        transitions: Transition = batch.data
        keys, probs, *_ = batch.info

        norm_obs = normalize_fn(transitions.observation)
        norm_next_obs = normalize_fn(transitions.next_observation)

        # Forward pass.
        q_tm1 = network.apply(params, norm_obs)
        q_t = network.apply(params, norm_next_obs)
        q_t_target = network.apply(target_params, norm_next_obs)

        # Cast and clip rewards.
        d_t = (transitions.discount * self.discount).astype(jnp.float32)
        r_t = jnp.clip(transitions.reward, -self.max_abs_reward,
                    self.max_abs_reward).astype(jnp.float32)
        
        # subtract pessimism
        u_t = uncertainty_fn(transitions.next_observation)
        q_t = q_t - self.alpha * u_t
        q_t_target = q_t_target - self.alpha * u_t

        # Compute expected SARSA td error using probs_a_t.
        batch_error = jax.vmap(rlax.double_q_learning)
        td_error = batch_error(q_tm1, transitions.action, r_t, d_t, q_t_target, q_t)
        batch_loss = rlax.huber_loss(td_error, self.huber_loss_parameter)
        loss = jnp.mean(batch_loss)  
        
        reverb_update = learning_lib.ReverbUpdate(
            keys=keys, priorities=jnp.abs(td_error).astype(jnp.float64))
        mean_q = jnp.mean(q_tm1)
        extra = learning_lib.LossExtra(metrics={"mean_q": mean_q}, reverb_update=reverb_update)
        return loss, extra

@dataclasses.dataclass
class PessimismQuantile(LossFn):
    """Pessimistic QR-DQN."""
    discount: float = 0.99
    alpha: float = 1.0
    max_abs_reward: float = 1.
    huber_loss_parameter: float = 1.
    num_actions: int = 1
    num_quantiles: int = 201
    quantiles = (jnp.arange(0, num_quantiles) + 0.5) / float(num_quantiles)

    def __call__(self,
                network: networks_lib.FeedForwardNetwork,
                uncertainty_fn: Callable[[jnp.ndarray], jnp.ndarray],
                behavior_fn: Callable[[jnp.ndarray], jnp.ndarray],
                normalize_fn: Callable[[jnp.ndarray], jnp.ndarray],
                params: networks_lib.Params,
                target_params: networks_lib.Params,
                batch: reverb.ReplaySample,
                key: networks_lib.PRNGKey) -> Tuple[jnp.DeviceArray, learning_lib.LossExtra]:
        """Calculate a loss on a single batch of data."""
        del key
        del behavior_fn
        transitions: Transition = batch.data
        keys, probs, *_ = batch.info

        norm_obs = normalize_fn(transitions.observation)
        norm_next_obs = normalize_fn(transitions.next_observation)

        # Forward pass.
        q_tm1, dist_q_tm1 = to_qr(network.apply(params, norm_obs), 
                                    self.num_actions, self.num_quantiles)
        # q_t, dist_q_t = to_qr(network.apply(params, norm_next_obs), 
        #                             self.num_actions, self.num_quantiles)
        q_t_target, dist_q_t_target = to_qr(network.apply(target_params, norm_next_obs),
                                        self.num_actions, self.num_quantiles)

        # Cast and clip rewards.
        d_t = (transitions.discount * self.discount).astype(jnp.float32)
        r_t = jnp.clip(transitions.reward, -self.max_abs_reward,
                    self.max_abs_reward).astype(jnp.float32)

        # subtract pessimism
        u_t = uncertainty_fn(transitions.next_observation)
        dist_q_t_target = dist_q_t_target - self.alpha * jnp.expand_dims(u_t, 1)

        # Compute QR-DQN td error.
        batch_loss = jax.vmap(rlax.quantile_q_learning, in_axes=(0, None, 0, 0, 0, 0, 0, None))
        td_losses = batch_loss(dist_q_tm1, self.quantiles, transitions.action, r_t, d_t, 
                            dist_q_t_target, dist_q_t_target, self.huber_loss_parameter)
        loss = jnp.mean(td_losses) 
        
        reverb_update = learning_lib.ReverbUpdate(
            keys=keys, priorities=jnp.abs(td_losses).astype(jnp.float64))
        mean_q = jnp.mean(q_tm1)
        extra = learning_lib.LossExtra(metrics={"mean_q": mean_q}, reverb_update=reverb_update)
        return loss, extra

@dataclasses.dataclass
class CQL(LossFn):
    """Conservative DQN."""
    discount: float = 0.99
    alpha: float = 1.0
    max_abs_reward: float = 1.
    huber_loss_parameter: float = 1.

    def __call__(self,
                network: networks_lib.FeedForwardNetwork,
                uncertainty_fn: Callable[[jnp.ndarray], jnp.ndarray],
                behavior_fn: Callable[[jnp.ndarray], jnp.ndarray],
                normalize_fn: Callable[[jnp.ndarray], jnp.ndarray],
                params: networks_lib.Params,
                target_params: networks_lib.Params,
                batch: reverb.ReplaySample,
                key: networks_lib.PRNGKey) -> Tuple[jnp.DeviceArray, learning_lib.LossExtra]:
        """Calculate a loss on a single batch of data."""
        del key
        del behavior_fn
        del uncertainty_fn
        transitions: Transition = batch.data
        keys, probs, *_ = batch.info

        norm_obs = normalize_fn(transitions.observation)
        norm_next_obs = normalize_fn(transitions.next_observation)

        # Forward pass.
        q_tm1 = network.apply(params, norm_obs)
        q_t = network.apply(params, norm_next_obs)
        q_t_target = network.apply(target_params, norm_next_obs)

        # Cast and clip rewards.
        d_t = (transitions.discount * self.discount).astype(jnp.float32)
        r_t = jnp.clip(transitions.reward, -self.max_abs_reward,
                    self.max_abs_reward).astype(jnp.float32)

        # Compute double DQN td error.
        batch_error = jax.vmap(rlax.double_q_learning)
        td_error = batch_error(q_tm1, transitions.action, r_t, d_t, q_t_target, q_t)
        batch_loss = rlax.huber_loss(td_error, self.huber_loss_parameter)
        loss = jnp.mean(batch_loss) 

        # Compute CQL penalty
        log_sum_exp_term = jax.nn.logsumexp(q_tm1, axis=-1)
        behavior_term = q_tm1[jnp.arange(q_tm1.shape[0]), transitions.action]
        cql_penalty = jnp.mean(log_sum_exp_term - behavior_term)
        loss += self.alpha * cql_penalty
        
        reverb_update = learning_lib.ReverbUpdate(
            keys=keys, priorities=jnp.abs(td_error).astype(jnp.float64))
        mean_q = jnp.mean(q_tm1)
        extra = learning_lib.LossExtra(metrics={"mean_q": mean_q}, reverb_update=reverb_update)
        return loss, extra


@dataclasses.dataclass
class CQLQuantile(LossFn):
    """Conservative QR-DQN."""
    discount: float = 0.99
    alpha: float = 1.0
    max_abs_reward: float = 1.
    huber_loss_parameter: float = 1.
    num_actions: int = 1
    num_quantiles: int = 201
    quantiles = (jnp.arange(0, num_quantiles) + 0.5) / float(num_quantiles)

    def __call__(self,
                network: networks_lib.FeedForwardNetwork,
                uncertainty_fn: Callable[[jnp.ndarray], jnp.ndarray],
                behavior_fn: Callable[[jnp.ndarray], jnp.ndarray],
                normalize_fn: Callable[[jnp.ndarray], jnp.ndarray],
                params: networks_lib.Params,
                target_params: networks_lib.Params,
                batch: reverb.ReplaySample,
                key: networks_lib.PRNGKey) -> Tuple[jnp.DeviceArray, learning_lib.LossExtra]:
        """Calculate a loss on a single batch of data."""
        del key
        del behavior_fn
        del uncertainty_fn
        transitions: Transition = batch.data
        keys, probs, *_ = batch.info

        norm_obs = normalize_fn(transitions.observation)
        norm_next_obs = normalize_fn(transitions.next_observation)

        # Forward pass.
        q_tm1, dist_q_tm1 = to_qr(network.apply(params, norm_obs), 
                                    self.num_actions, self.num_quantiles)
        # q_t, dist_q_t = to_qr(network.apply(params, norm_next_obs), 
        #                             self.num_actions, self.num_quantiles)
        q_t_target, dist_q_t_target = to_qr(network.apply(target_params, norm_next_obs),
                                        self.num_actions, self.num_quantiles)

        # Cast and clip rewards.
        d_t = (transitions.discount * self.discount).astype(jnp.float32)
        r_t = jnp.clip(transitions.reward, -self.max_abs_reward,
                    self.max_abs_reward).astype(jnp.float32)

        # Compute double QR-DQN td error.
        batch_loss = jax.vmap(rlax.quantile_q_learning, in_axes=(0, None, 0, 0, 0, 0, 0, None))
        td_losses = batch_loss(dist_q_tm1, self.quantiles, transitions.action, r_t, d_t, 
                            dist_q_t_target, dist_q_t_target, self.huber_loss_parameter)
        loss = jnp.mean(td_losses) 

        # Compute CQL penalty
        log_sum_exp_term = jax.nn.logsumexp(q_tm1, axis=-1)
        behavior_term = q_tm1[jnp.arange(q_tm1.shape[0]), transitions.action]
        cql_penalty = jnp.mean(log_sum_exp_term - behavior_term)
        loss += self.alpha * cql_penalty
        
        reverb_update = learning_lib.ReverbUpdate(
            keys=keys, priorities=jnp.abs(td_losses).astype(jnp.float64))
        mean_q = jnp.mean(q_tm1)
        extra = learning_lib.LossExtra(metrics={"mean_q": mean_q}, reverb_update=reverb_update)
        return loss, extra

