from typing import NamedTuple, Tuple

import jax, jax.numpy as jnp
import optax
import haiku as hk

from relax.algorithm.base import Algorithm
from relax.network.ddsq import DDSQNet, DDSQParams, norm_fn
from relax.utils.experience import Experience
from relax.utils.typing import Metric


class DDSQOptStates(NamedTuple):
    q1: optax.OptState
    q2: optax.OptState
    log_alpha: optax.OptState
    log_beta: optax.OptState


class DDSQTrainState(NamedTuple):
    params: DDSQParams
    opt_state: DDSQOptStates
    step: int
    mean_q1_std: float
    mean_q2_std: float
    entropy_alpha: float
    entropy_beta: float
    entropy_std_alpha: float
    entropy_std_beta: float


class DDSQ(Algorithm):
    def __init__(
        self,
        agent: DDSQNet,
        params: DDSQParams,
        *,
        gamma: float = 0.99,
        lr: float = 1e-4,
        alpha_lr: float = 0.0,
        beta_lr: float = 0.0,
        delay_update: int = 1,
        delay_alpha_update: int = 100,
        delay_beta_update: int = 100,
        entropy_samples: int = 1000,
        reward_scale: float = 1.0,
        tau: float = 0.005,
    ):
        self.reward_scale = reward_scale
        self.tau = tau
        self.delay_update = delay_update
        self.delay_beta_update = delay_beta_update
        self.agent = agent
        self.gamma = gamma
        self.alpha_lr = alpha_lr
        self.lr = lr
        self.entropy_samples = entropy_samples
        self.beta_lr = beta_lr
        self.delay_alpha_update = delay_alpha_update
        self.optim = optax.adam(lr)
        self.alpha_optim = optax.adam(alpha_lr)
        self.beta_optim = optax.adam(beta_lr)

        self.state = DDSQTrainState(
            params=params,
            opt_state=DDSQOptStates(
                q1=self.optim.init(params.q1),
                q2=self.optim.init(params.q2),
                log_alpha=self.alpha_optim.init(params.log_alpha),
                log_beta=self.beta_optim.init(params.log_beta),
            ),
            step=jnp.int32(0),
            mean_q1_std=jnp.float32(-1.0),
            mean_q2_std=jnp.float32(-1.0),
            entropy_alpha=jnp.float32(self.agent.act_dim * jnp.log(2)),
            entropy_beta=jnp.float32(self.agent.act_dim * jnp.log(2)),
            entropy_std_alpha=jnp.float32(0.0),
            entropy_std_beta=jnp.float32(0.0),
        )

        @jax.jit
        def stateless_update(key: jax.Array, state: DDSQTrainState, data: Experience) -> Tuple[DDSQTrainState, Metric]:
            obs, action, reward, next_obs, done = data.obs, data.action, data.reward, data.next_obs, data.done
            reward *= self.reward_scale
            q1_params, q2_params, target_q1_params, target_q2_params, log_alpha, log_beta = state.params
            q1_opt_state, q2_opt_state, log_alpha_opt_state, log_beta_opt_state = state.opt_state
            step, mean_q1_std, mean_q2_std = state.step, state.mean_q1_std, state.mean_q2_std
            next_eval_key, entropy_key = jax.random.split(key, 2)

            log_space_volume = self.agent.act_dim * jnp.log(2.0)

            def q_for_all_actions_serial(q_param, obs, action):
                def get_q(carry, act):
                    return carry, self.agent.q(q_param, obs, act)  # None, [B, ]

                # [N, B], serial computation to alleviate GPU need
                _, qs = jax.lax.scan(get_q, None, action.transpose((1, 0, 2)))
                return qs.transpose((1, 0))

            def ratio_from_energy(energy):
                B, N = energy.shape
                log_partition = jax.scipy.special.logsumexp(energy, axis=1, keepdims=True)
                log_snis_ratio = energy - jnp.broadcast_to(log_partition, (B, N))
                snis_ratio = jnp.exp(log_snis_ratio)
                return snis_ratio

            def get_snis(lam, q_array):
                energy = q_array / lam
                snis_ratio = ratio_from_energy(energy)
                return snis_ratio, energy

            alpha = jnp.exp(log_alpha)
            beta = jnp.exp(log_beta)

            pi_next = self.agent.get_softmax(next_eval_key, (log_alpha, log_beta, q1_params, q2_params), next_obs)
            q1_next = self.agent.q(target_q1_params, next_obs, pi_next)
            q2_next = self.agent.q(target_q2_params, next_obs, pi_next)
            q_next = jnp.minimum(q1_next, q2_next)

            q_backup = jax.lax.stop_gradient(reward + (1 - done) * self.gamma * q_next)

            # update q

            def q_loss_fn(q_params: hk.Params) -> jax.Array:
                q = self.agent.q(q_params, obs, action)
                backup_loss = (q - q_backup) ** 2
                return jnp.mean(backup_loss), q

            (q1_loss, q1), q1_grads = jax.value_and_grad(q_loss_fn, has_aux=True)(q1_params)

            (q2_loss, q2), q2_grads = jax.value_and_grad(q_loss_fn, has_aux=True)(q2_params)

            def cal_entropy(log_lam):
                # action space volume
                entropy_actions = jax.random.uniform(
                    key=entropy_key, shape=(obs.shape[0], self.entropy_samples, self.agent.act_dim), minval=-1, maxval=1
                )
                q1_ent = q_for_all_actions_serial(q1_params, obs, entropy_actions)  # [B, N]
                q2_ent = q_for_all_actions_serial(q2_params, obs, entropy_actions)  # [B, N]
                q_ent = jnp.minimum(q1_ent, q2_ent)  # [B, N]
                normalized_q_ent = q_ent / (norm_fn(q_ent)[:, None] + 1e-6)

                # temperature
                lam = jnp.exp(log_lam)

                ent_snis_ratio, ent_energy = get_snis(lam, normalized_q_ent)
                logz = (
                    log_space_volume - jnp.log(self.entropy_samples) + jax.scipy.special.logsumexp(ent_energy, axis=1)
                )  # [B, ]

                eu = jnp.sum(ent_snis_ratio * ent_energy, axis=1)  # [B, ]

                # entropy over state samples
                batch_entropy = logz - eu
                entropy_std = jnp.std(batch_entropy)

                return jax.lax.stop_gradient(jnp.mean(batch_entropy)), jax.lax.stop_gradient(entropy_std)

            prev_entropy_alpha = state.entropy_alpha
            prev_entropy_beta = state.entropy_beta
            prev_entropy_std_alpha = state.entropy_std_alpha
            prev_entropy_std_beta = state.entropy_std_beta

            entropy_alpha, entropy_std_alpha = jax.lax.cond(
                step % self.delay_alpha_update == 0,
                lambda: cal_entropy(log_alpha),
                lambda: (prev_entropy_alpha, prev_entropy_std_alpha),
            )

            entropy_beta, entropy_std_beta = jax.lax.cond(
                step % self.delay_beta_update == 0,
                lambda: cal_entropy(log_beta),
                lambda: (prev_entropy_beta, prev_entropy_std_beta),
            )

            # update alpha
            def log_alpha_loss_fn(log_alpha: jax.Array) -> jax.Array:
                return (
                    -1.0 * log_alpha * (-entropy_alpha + self.agent.target_entropy) / jnp.maximum(entropy_std_alpha, 1)
                )

            # update beta
            def log_beta_loss_fn(log_beta: jax.Array) -> jax.Array:
                return -1.0 * log_beta * (-entropy_beta + self.agent.target_entropy) / jnp.maximum(entropy_std_beta, 1)

            # update networks
            def param_update(optim, params, grads, opt_state):
                update, new_opt_state = optim.update(grads, opt_state)
                new_params = optax.apply_updates(params, update)
                return new_params, new_opt_state

            def delay_target_update(params, target_params, tau):
                return jax.lax.cond(
                    step % self.delay_update == 0,
                    lambda target_params: optax.incremental_update(params, target_params, tau),
                    lambda target_params: target_params,
                    target_params,
                )

            def delay_alpha_param_update(optim, params, opt_state):
                return jax.lax.cond(
                    step % self.delay_alpha_update == 0,
                    lambda params, opt_state: param_update(
                        optim, params, jax.grad(log_alpha_loss_fn)(params), opt_state
                    ),
                    lambda params, opt_state: (params, opt_state),
                    params,
                    opt_state,
                )

            def delay_beta_param_update(optim, params, opt_state):
                return jax.lax.cond(
                    step % self.delay_beta_update == 0,
                    lambda params, opt_state: param_update(
                        optim, params, jax.grad(log_beta_loss_fn)(params), opt_state
                    ),
                    lambda params, opt_state: (params, opt_state),
                    params,
                    opt_state,
                )

            q1_params, q1_opt_state = param_update(self.optim, q1_params, q1_grads, q1_opt_state)

            q2_params, q2_opt_state = param_update(self.optim, q2_params, q2_grads, q2_opt_state)

            log_alpha, log_alpha_opt_state = delay_alpha_param_update(self.alpha_optim, log_alpha, log_alpha_opt_state)

            log_beta, log_beta_opt_state = delay_beta_param_update(self.beta_optim, log_beta, log_beta_opt_state)

            target_q1_params = delay_target_update(q1_params, target_q1_params, self.tau)
            target_q2_params = delay_target_update(q2_params, target_q2_params, self.tau)

            state = DDSQTrainState(
                params=DDSQParams(q1_params, q2_params, target_q1_params, target_q2_params, log_alpha, log_beta),
                opt_state=DDSQOptStates(
                    q1=q1_opt_state, q2=q2_opt_state, log_alpha=log_alpha_opt_state, log_beta=log_beta_opt_state
                ),
                step=step + 1,
                mean_q1_std=mean_q1_std,
                mean_q2_std=mean_q2_std,
                entropy_alpha=entropy_alpha,
                entropy_beta=entropy_beta,
                entropy_std_alpha=entropy_std_alpha,
                entropy_std_beta=entropy_std_beta,
            )

            info = {
                "reward_sq": jnp.mean(reward**2),
                "alpha": alpha,
                "beta": beta,
                "entropy_alpha_std": entropy_std_alpha,
                "entropy_beta_std": entropy_std_beta,
                "action": jnp.mean(jnp.fabs(action)),
                "q1_loss": q1_loss,
                "q1_mean": jnp.mean(q1),
                "q1_min": jnp.min(q1),
                "q1_max": jnp.max(q1),
                "q2_loss": q2_loss,
                "q2_mean": jnp.mean(q2),
                "q2_min": jnp.min(q2),
                "q2_max": jnp.max(q2),
                "mean_q1_std": mean_q1_std,
                "mean_q2_std": mean_q2_std,
                "entropy_alpha": entropy_alpha,
                "entropy_beta": entropy_beta,
            }
            return state, info

        self._implement_common_behavior(stateless_update, self.agent.get_action, self.agent.get_deterministic_action)

    def get_policy_params(self):
        return (
            self.state.params.log_alpha,
            self.state.params.log_beta,
            self.state.params.q1,
            self.state.params.q2,
        )
