from typing import NamedTuple, Tuple

import jax, jax.numpy as jnp
import optax
import haiku as hk

from relax.algorithm.base import Algorithm
from relax.network.dsace import DSACENet, DSACEParams
from relax.utils.experience import Experience
from relax.utils.typing import Metric


def compute_grad_norm(grads):
    """Compute the global L2 norm of a PyTree of gradients."""
    return jnp.sqrt(sum([jnp.sum(jnp.square(g)) for g in jax.tree_util.tree_leaves(grads)]))

class DSACEOptStates(NamedTuple):
    q1: optax.OptState
    q2: optax.OptState
    qe1: optax.OptState
    qe2: optax.OptState
    policy: optax.OptState
    log_alpha: optax.OptState


class DSACETrainState(NamedTuple):
    params: DSACEParams
    opt_state: DSACEOptStates
    step: int
    mean_q1_std: float
    mean_q2_std: float
    mean_qe1_std: float
    mean_qe2_std: float

class DSACE(Algorithm):

    def __init__(
        self,
        agent: DSACENet,
        params: DSACEParams,
        *,
        gamma: float = 0.99,
        lr: float = 1e-4,
        alpha_lr: float = 3e-4,
        tau: float = 0.005,
        delay_update: int = 2,
        reward_scale: float = 0.1,
    ):
        self.agent = agent
        self.gamma = gamma
        self.tau = tau
        self.delay_update = delay_update
        self.reward_scale = reward_scale
        # self.optim = optax.adam(lr)
        self.optim = optax.chain(
            optax.clip_by_global_norm(1.0),
            optax.adam(lr)
        )
        self.alpha_optim = optax.adam(alpha_lr)

        self.state = DSACETrainState(
            params=params,
            opt_state=DSACEOptStates(
                q1=self.optim.init(params.q1),
                q2=self.optim.init(params.q2),
                qe1=self.optim.init(params.qe1),
                qe2=self.optim.init(params.qe2),
                policy=self.optim.init(params.policy),
                log_alpha=self.alpha_optim.init(params.log_alpha),
            ),
            step=jnp.int32(0),
            mean_q1_std=jnp.float32(-1.0),
            mean_q2_std=jnp.float32(-1.0),
            mean_qe1_std=jnp.float32(-1.0),
            mean_qe2_std=jnp.float32(-1.0),
        )

        @jax.jit
        def stateless_update(
            key: jax.Array, state: DSACETrainState, data: Experience
        ) -> Tuple[DSACETrainState, Metric]:
            obs, action, reward, next_obs, done = data.obs, data.action, data.reward, data.next_obs, data.done
            q1_params, q2_params, target_q1_params, target_q2_params, \
            qe1_params, qe2_params, target_qe1_params, target_qe2_params, \
                policy_params, target_policy_params, log_alpha = state.params
            q1_opt_state, q2_opt_state, qe1_opt_state, qe2_opt_state, policy_opt_state, log_alpha_opt_state = state.opt_state
            step, mean_q1_std, mean_q2_std = state.step, state.mean_q1_std, state.mean_q2_std
            mean_qe1_std, mean_qe2_std = state.mean_qe1_std, state.mean_qe2_std
            next_eval_key, new_eval_key, new_q1_eval_key, new_q2_eval_key, \
                new_qe1_eval_key, new_qe2_eval_key  = jax.random.split(key, 6)

            reward *= self.reward_scale

            # compute target q
            next_action, next_logp = self.agent.evaluate(next_eval_key, target_policy_params, next_obs)
            next_q1_mean, q1_std, next_q1_sample = self.agent.q_evaluate(new_q1_eval_key, target_q1_params, next_obs, next_action)
            next_q2_mean, q2_std, next_q2_sample = self.agent.q_evaluate(new_q2_eval_key, target_q2_params, next_obs, next_action)
            next_qe1_mean, _, next_qe1_sample = self.agent.q_evaluate(new_qe1_eval_key, target_qe1_params, next_obs, next_action)
            next_qe2_mean, _, next_qe2_sample = self.agent.q_evaluate(new_qe2_eval_key, target_qe2_params, next_obs, next_action)
            next_q_mean = jnp.minimum(next_q1_mean, next_q2_mean)
            next_qe_mean = jnp.minimum(next_qe1_mean, next_qe2_mean)
            next_q_sample = jnp.where(next_q1_sample < next_q2_sample, next_q1_sample, next_q2_sample)
            next_qe_sample = jnp.where(next_qe1_sample < next_qe2_sample, next_qe1_sample, next_qe2_sample)
            
            q_target = next_q_mean
            q_target_sample = next_q_sample
            q_backup = reward + (1 - done) * self.gamma * q_target
            q_backup_sample = reward + (1 - done) * self.gamma * q_target_sample

            qe_target = next_qe_mean 
            qe_target_sample = next_qe_sample
            qe_backup = - self.gamma * next_logp + (1 - done) * self.gamma * qe_target
            qe_backup_sample = - self.gamma * next_logp + (1 - done) * self.gamma * qe_target_sample

            # update q
            def q_loss_fn(q_params: hk.Params, mean_q_std: float) -> jax.Array:
                q_mean, q_std = self.agent.q(q_params, obs, action)
                new_mean_q_std = jnp.mean(q_std)
                mean_q_std = jax.lax.stop_gradient(
                    (mean_q_std == -1.0) * new_mean_q_std +
                    (mean_q_std != -1.0) * (self.tau * new_mean_q_std + (1 - self.tau) * mean_q_std)
                )
                q_backup_bounded = jax.lax.stop_gradient(q_mean + jnp.clip(q_backup_sample - q_mean, -3 * mean_q_std, 3 * mean_q_std))
                q_std_detach = jax.lax.stop_gradient(jnp.maximum(q_std, 0))
                epsilon = 0.1
                q_loss = -(mean_q_std ** 2 + epsilon) * jnp.mean(
                    q_mean * jax.lax.stop_gradient(q_backup - q_mean) / (q_std_detach ** 2 + epsilon) +
                    q_std * ((jax.lax.stop_gradient(q_mean) - q_backup_bounded) ** 2 - q_std_detach ** 2) / (q_std_detach ** 3 + epsilon)
                )
                return q_loss, (q_mean, q_std, mean_q_std)

            def qe_loss_fn(q_params: hk.Params, mean_q_std: float) -> jax.Array:
                q_mean, q_std = self.agent.q(q_params, obs, action)
                new_mean_q_std = jnp.mean(q_std)
                mean_q_std = jax.lax.stop_gradient(
                    (mean_q_std == -1.0) * new_mean_q_std +
                    (mean_q_std != -1.0) * (self.tau * new_mean_q_std + (1 - self.tau) * mean_q_std)
                )
                q_backup_bounded = jax.lax.stop_gradient(q_mean + jnp.clip(qe_backup_sample - q_mean, -3 * mean_q_std, 3 * mean_q_std))
                q_std_detach = jax.lax.stop_gradient(jnp.maximum(q_std, 0))
                epsilon = 0.1
                q_loss = -(mean_q_std ** 2 + epsilon) * jnp.mean(
                    q_mean * jax.lax.stop_gradient(qe_backup - q_mean) / (q_std_detach ** 2 + epsilon) +
                    q_std * ((jax.lax.stop_gradient(q_mean) - q_backup_bounded) ** 2 - q_std_detach ** 2) / (q_std_detach ** 3 + epsilon)
                )
                return q_loss, (q_mean, q_std, mean_q_std)

            (q1_loss, (q1_mean, q1_std, mean_q1_std)), q1_grads = jax.value_and_grad(q_loss_fn, has_aux=True)(q1_params, mean_q1_std)
            (q2_loss, (q2_mean, q2_std, mean_q2_std)), q2_grads = jax.value_and_grad(q_loss_fn, has_aux=True)(q2_params, mean_q2_std)
            (qe1_loss, (qe1_mean, qe1_std, mean_qe1_std)), qe1_grads = jax.value_and_grad(qe_loss_fn, has_aux=True)(qe1_params, mean_qe1_std)
            (qe2_loss, (qe2_mean, qe2_std, mean_qe2_std)), qe2_grads = jax.value_and_grad(qe_loss_fn, has_aux=True)(qe2_params, mean_qe2_std)

            # update policy
            def policy_loss_fn(policy_params: hk.Params) -> jax.Array:
                new_action, new_logp = self.agent.evaluate(new_eval_key, policy_params, obs)
                q1_mean, _ = self.agent.q(q1_params, obs, new_action)
                q2_mean, _ = self.agent.q(q2_params, obs, new_action)
                q_mean = jnp.minimum(q1_mean, q2_mean)

                qe1_mean, _ = self.agent.q(qe1_params, obs, new_action)
                qe2_mean, _ = self.agent.q(qe2_params, obs, new_action)
                qe_mean = jnp.minimum(qe1_mean, qe2_mean)
                policy_loss = jnp.mean(jnp.exp(log_alpha) * new_logp - q_mean - jnp.exp(log_alpha) * qe_mean)
                return policy_loss, new_logp

            (policy_loss, new_logp), policy_grads = jax.value_and_grad(policy_loss_fn, has_aux=True)(policy_params)

            # update alpha
            def log_alpha_loss_fn(log_alpha: jax.Array) -> jax.Array:
                log_alpha_loss_current = -jnp.mean(log_alpha * (new_logp + self.agent.target_entropy)) # alpha_scale
                qe_mean = jnp.minimum(qe1_mean, qe2_mean)
                log_alpha_loss_Qe = jnp.mean(log_alpha * (qe_mean - self.agent.target_entropy * self.gamma / (1 - self.gamma))) # alpha_scale
                log_alpha_loss = log_alpha_loss_current + log_alpha_loss_Qe
                return log_alpha_loss

            log_alpha_grads = jax.grad(log_alpha_loss_fn)(log_alpha)

            # update networks
            def param_update(optim, params, grads, opt_state):
                update, new_opt_state = optim.update(grads, opt_state)
                new_params = optax.apply_updates(params, update)
                return new_params, new_opt_state

            def delay_param_update(optim, params, grads, opt_state):
                return jax.lax.cond(
                    step % self.delay_update == 0,
                    lambda params, opt_state: param_update(optim, params, grads, opt_state),
                    lambda params, opt_state: (params, opt_state),
                    params, opt_state
                )

            def delay_target_update(params, target_params, tau):
                return jax.lax.cond(
                    step % self.delay_update == 0,
                    lambda target_params: optax.incremental_update(params, target_params, tau),
                    lambda target_params: target_params,
                    target_params
                )

            q1_params, q1_opt_state = param_update(self.optim, q1_params, q1_grads, q1_opt_state)
            q2_params, q2_opt_state = param_update(self.optim, q2_params, q2_grads, q2_opt_state)
            qe1_params, qe1_opt_state = param_update(self.optim, qe1_params, qe1_grads, qe1_opt_state)
            qe2_params, qe2_opt_state = param_update(self.optim, qe2_params, qe2_grads, qe2_opt_state)
            policy_params, policy_opt_state = delay_param_update(self.optim, policy_params, policy_grads, policy_opt_state)
            log_alpha, log_alpha_opt_state = delay_param_update(self.alpha_optim, log_alpha, log_alpha_grads, log_alpha_opt_state)
            target_q1_params = delay_target_update(q1_params, target_q1_params, self.tau)
            target_q2_params = delay_target_update(q2_params, target_q2_params, self.tau)
            target_qe1_params = delay_target_update(qe1_params, target_qe1_params, self.tau)
            target_qe2_params = delay_target_update(qe2_params, target_qe2_params, self.tau)
            target_policy_params = delay_target_update(policy_params, target_policy_params, self.tau)

            state = DSACETrainState(
                params=DSACEParams(q1_params, q2_params, target_q1_params, target_q2_params,
                                   qe1_params, qe2_params, target_qe1_params, target_qe2_params,
                                   policy_params, target_policy_params, log_alpha),
                opt_state=DSACEOptStates(q1=q1_opt_state, q2=q2_opt_state, 
                                        qe1=qe1_opt_state, qe2=qe2_opt_state,
                                         policy=policy_opt_state, log_alpha=log_alpha_opt_state),
                step=step + 1,
                mean_q1_std=mean_q1_std,
                mean_q2_std=mean_q2_std,
                mean_qe1_std=mean_qe1_std,
                mean_qe2_std=mean_qe2_std,
            )
            info = {
                "q1_loss": q1_loss,
                "q1_mean": jnp.mean(q1_mean),
                "q1_std": jnp.mean(q1_std),
                "q2_loss": q2_loss,
                "q2_mean": jnp.mean(q2_mean),
                "q2_std": jnp.mean(q2_std),
                "qe1_loss": qe1_loss,
                "qe1_mean": jnp.mean(qe1_mean),
                "qe1_std": jnp.mean(qe1_std),
                "qe2_loss": qe2_loss,
                "qe2_mean": jnp.mean(qe2_mean),
                "qe2_std": jnp.mean(qe2_std),
                "qe_grad_norm": compute_grad_norm(qe1_grads) + compute_grad_norm(qe2_grads),
                "policy_loss": policy_loss,
                "entropy": -jnp.mean(new_logp),
                "next_logp": jnp.mean(next_logp),
                "alpha": jnp.exp(log_alpha),
                "mean_q1_std": mean_q1_std,
                "mean_q2_std": mean_q2_std,
                "mean_qe1_std": mean_qe1_std,
                "mean_qe2_std": mean_qe2_std,
            }
            return state, info
        self._implement_common_behavior(stateless_update, self.agent.get_action, self.agent.get_deterministic_action, self.agent.get_entropy)

    # def get_policy_params(self):
    #     return (self.state.params.policy, self.state.step)

