from dataclasses import dataclass
from typing import Callable, NamedTuple, Sequence, Tuple

import jax, jax.numpy as jnp
import haiku as hk

from relax.network.blocks import Activation, DistributionalQNet2, PolicyNet
from relax.network.common import WithSquashedGaussianPolicy
from numpyro.distributions import Normal
import math

class DSACEParams(NamedTuple):
    q1: hk.Params
    q2: hk.Params
    target_q1: hk.Params
    target_q2: hk.Params
    qe1: hk.Params
    qe2: hk.Params
    target_qe1: hk.Params
    target_qe2: hk.Params
    policy: hk.Params
    target_policy: hk.Params
    log_alpha: jax.Array


@dataclass
class DSACENet(WithSquashedGaussianPolicy):
    q: Callable[[hk.Params, jax.Array, jax.Array], Tuple[jax.Array, jax.Array]]
    target_entropy: float

    def get_action(self, key: jax.Array, policy_params: hk.Params, obs: jax.Array) -> jax.Array:
        """for data collection"""
        policy_params = policy_params
        mean, std = self.policy(policy_params, obs)
        z = jax.random.normal(key, mean.shape)
        act = mean + std * z

        key, noise_key = jax.random.split(key)
        act = jnp.tanh(act) + jax.random.normal(noise_key, act.shape) * 0.0
        return act.clip(-1, 1)

    # def get_deterministic_action(self, policy_params: hk.Params, obs: jax.Array) -> jax.Array:
    #     """for evaluation"""
    #     policy_params, step = policy_params

    #     mean, _ = self.policy(policy_params, obs)
    #     return jnp.tanh(mean)
    


    def q_evaluate(
        self, key: jax.Array, q_params: hk.Params, obs: jax.Array, act: jax.Array
    ) -> Tuple[jax.Array, jax.Array, jax.Array]:
        q_mean, q_std = self.q(q_params, obs, act)
        z = jax.random.normal(key, q_mean.shape)
        z = jnp.clip(z, -3.0, 3.0)  # NOTE: Why not truncated normal?
        q_value = q_mean + q_std * z
        return q_mean, q_std, q_value

def create_dsace_net(
    key: jax.Array,
    obs_dim: int,
    act_dim: int,
    hidden_sizes: Sequence[int],
    activation: Activation = jax.nn.relu,
    entropy_ratio: float = 1.0,
) -> Tuple[DSACENet, DSACEParams]:
    q = hk.without_apply_rng(hk.transform(lambda obs, act: DistributionalQNet2(hidden_sizes, activation)(obs, act)))
    policy = hk.without_apply_rng(hk.transform(lambda obs: PolicyNet(act_dim, hidden_sizes, activation)(obs)))

    @jax.jit
    def init(key, obs, act):
        q1_key, q2_key, qe1_key, qe2_key, policy_key = jax.random.split(key, 5)
        q1_params = q.init(q1_key, obs, act)
        q2_params = q.init(q2_key, obs, act)
        qe1_params = q.init(qe1_key, obs, act)
        qe2_params = q.init(qe2_key, obs, act)
        target_q1_params = q1_params
        target_q2_params = q2_params
        target_qe1_params = qe1_params
        target_qe2_params = qe2_params
        # Initialize policy parameters
        policy_params = policy.init(policy_key, obs)
        target_policy_params = policy_params
        log_alpha = jnp.array(1.0, dtype=jnp.float32)
        return DSACEParams(q1_params, q2_params, target_q1_params, target_q2_params, 
                        qe1_params, qe2_params, target_qe1_params, target_qe2_params,
                        policy_params, target_policy_params, log_alpha)

    sample_obs = jnp.zeros((1, obs_dim))
    sample_act = jnp.zeros((1, act_dim))
    params = init(key, sample_obs, sample_act)

    net = DSACENet(policy=policy.apply, q=q.apply, target_entropy=-act_dim*entropy_ratio)
    return net, params
