import time
import flax.struct as fstruct
from typing import Any, Dict, Tuple, Literal
import flax
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
import optax
from flax.training.train_state import TrainState
from functools import partial
from flax.linen.initializers import constant, orthogonal
import jax.random as jrnd
import distrax
import flax.core as fcore
import chex
from dataclasses import dataclass
from utils.type import EnvStep
from omegaconf import DictConfig

Dense = partial(
    nn.Dense,
    kernel_init=orthogonal(np.sqrt(2)),
    bias_init=constant(0.0),
)


class QNetwork(nn.Module):
    @nn.compact
    def __call__(self, s: jnp.ndarray, a: jnp.ndarray):
        s = jnp.concatenate([s, a], -1)
        s = Dense(256)(s)
        s = nn.relu(s)
        s = Dense(256)(s)
        s = nn.relu(s)
        s = Dense(1)(s)
        return s


LOG_STD_MIN = -5
LOG_STD_MAX = 2


class Actor(nn.Module):
    action_dim: int

    @nn.compact
    def __call__(self, s):
        s = Dense(256)(s)
        s = nn.relu(s)
        s = Dense(256)(s)
        s = nn.relu(s)
        s_mean = Dense(self.action_dim)(s)
        s_std = Dense(self.action_dim)(s)
        s_std = jnp.clip(s_std, LOG_STD_MIN, LOG_STD_MAX)
        # s = nn.tanh(s)
        # x = x * self.action_scale + self.action_bias
        return s_mean, s_std


@fstruct.dataclass
class AlgoState:
    actor: TrainState
    q1: TrainState
    q2: TrainState
    target_q1: TrainState
    target_q2: TrainState

    log_alpha: TrainState
    update_times: jnp.ndarray


@fstruct.dataclass
class Transitions:
    states: jnp.ndarray
    actions: jnp.ndarray
    rewards: jnp.ndarray
    terms: jnp.ndarray
    next_states: jnp.ndarray


@dataclass
class Cfg:
    state_dim: int
    act_dim: int
    lr: float
    gamma: float
    tau: float


class SAC:
    def __init__(
        self,
        cfg: Cfg,
    ) -> None:
        self.state_dim = cfg.state_dim
        self.act_dim = cfg.act_dim
        self.lr = cfg.lr
        self.gamma = cfg.gamma
        self.tau = cfg.tau

        self.target_entropy = -self.act_dim

    def gen_tree(self):
        return jnp.array(0)

    def make(self, key: jnp.ndarray):
        _actor = Actor(action_dim=self.act_dim)
        _q1 = QNetwork()
        _q2 = QNetwork()
        _target_q1 = QNetwork()
        _target_q2 = QNetwork()
        act_key, q1_key, q2_key = jrnd.split(key, 3)

        _q1_params = _q1.init(
            q1_key,
            jnp.zeros((1, self.state_dim)),
            jnp.zeros((1, self.act_dim)),
        )
        _q2_params = _q2.init(
            q2_key,
            jnp.zeros((1, self.state_dim)),
            jnp.zeros((1, self.act_dim)),
        )

        _actor_state = TrainState.create(
            apply_fn=_actor.apply,
            tx=optax.adamw(self.lr),
            params=_actor.init(act_key, jnp.zeros((1, self.state_dim))),
        )
        _q1_state = TrainState.create(
            apply_fn=_q1.apply, tx=optax.adamw(self.lr), params=_q1_params
        )
        _q2_state = TrainState.create(
            apply_fn=_q2.apply, tx=optax.adamw(self.lr), params=_q2_params
        )
        _targetq1_state = TrainState.create(
            apply_fn=_target_q1.apply, tx=optax.adamw(1e-7), params=_q1_params
        )
        _targetq2_state = TrainState.create(
            apply_fn=_target_q2.apply, tx=optax.adamw(1e-7), params=_q2_params
        )
        _alpha_state = TrainState.create(
            apply_fn=None,
            tx=optax.adamw(self.lr),
            params=fcore.freeze(dict(log_alpha=jnp.zeros(()))),
        )

        return AlgoState(
            actor=_actor_state,
            q1=_q1_state,
            target_q1=_targetq1_state,
            q2=_q2_state,
            target_q2=_targetq2_state,
            log_alpha=_alpha_state,
            update_times=jnp.array(0),
        )

    def make_action(self, key, algo_state: AlgoState, obs: jnp.ndarray):
        acts, _, _ = self._sample_policy(
            algo_state.actor, algo_state.actor.params, obs, key
        )

        return acts

    def _update_critic(
        self,
        algo_state: AlgoState,
        transitions: EnvStep,
        key: jnp.ndarray,
    ):
        states, acts, rwds, dones, next_states = (
            transitions.obs,
            transitions.acts,
            transitions.rwds,
            transitions.terms,
            transitions.next_obs,
        )
        q1, q1_target, q2, q2_target, log_alpha = (
            algo_state.q1,
            algo_state.target_q1,
            algo_state.q2,
            algo_state.target_q2,
            algo_state.log_alpha,
        )

        key, sample_key = jax.random.split(key)
        next_state_actions, next_state_logprobs, _ = self._sample_policy(
            algo_state.actor, algo_state.actor.params, next_states, sample_key
        )
        # next_state_actions = jnp.clip(actor.apply_fn(actor.params, next_states), -1, 1)
        # chex.assert_shape(next_state_actions, (states.shape[0], self.act_dim))

        qf1_next_target = q1_target.apply_fn(
            q1_target.params, next_states, next_state_actions
        ).reshape(-1)
        qf2_next_target = q2_target.apply_fn(
            q2_target.params, next_states, next_state_actions
        ).reshape(-1)
        min_qf_next_target = jnp.minimum(qf1_next_target, qf2_next_target) - jnp.exp(
            log_alpha.params["log_alpha"]
        ) * jnp.squeeze(next_state_logprobs, -1)
        next_q_value = (rwds + (1 - dones) * self.gamma * (min_qf_next_target)).reshape(
            -1
        )

        @partial(jax.value_and_grad, has_aux=True)
        def mse_loss(params, qf: TrainState):
            qf_a_values = qf.apply_fn(params, states, acts).squeeze()
            return ((qf_a_values - next_q_value) ** 2).mean(), qf_a_values.mean()

        (qf1_loss_value, qf1_a_values), grads1 = mse_loss(q1.params, q1)
        (qf2_loss_value, qf2_a_values), grads2 = mse_loss(q2.params, q2)

        q1 = q1.apply_gradients(grads=grads1)
        q2 = q2.apply_gradients(grads=grads2)

        return algo_state.replace(q1=q1, q2=q2)

    def _update_actor(
        self,
        algo_state: AlgoState,
        states: jnp.ndarray,
        key,
    ):

        @jax.value_and_grad
        def actor_loss(params, algo_state: AlgoState, states, key):
            actor, q1, q2, log_alpha = (
                algo_state.actor,
                algo_state.q1,
                algo_state.q2,
                algo_state.log_alpha,
            )
            alpha = jnp.exp(log_alpha.params["log_alpha"])
            pi, log_pi, _ = self._sample_policy(algo_state.actor, params, states, key)
            q1v = q1.apply_fn(q1.params, states, pi)
            q2v = q2.apply_fn(q2.params, states, pi)
            min_q = jnp.minimum(q1v, q2v)
            _loss = ((alpha * log_pi) - min_q).mean()
            return _loss

        _, grads = actor_loss(algo_state.actor.params, algo_state, states, key)
        actor_state = algo_state.actor.apply_gradients(grads=grads)
        return algo_state.replace(actor=actor_state)

    def _update_alpha(self, algo_state: AlgoState, states, key):
        @jax.value_and_grad
        def alpha_loss(log_alpha_params, algo_state: AlgoState, key):
            _, log_pi, _ = self._sample_policy(
                algo_state.actor, algo_state.actor.params, states, key
            )
            _loss = -(
                log_alpha_params["log_alpha"] * (log_pi + self.target_entropy)
            ).mean()
            return _loss

        _, grads = alpha_loss(algo_state.log_alpha.params, algo_state, key)
        log_alpha_state = algo_state.log_alpha.apply_gradients(grads=grads)

        return algo_state.replace(log_alpha=log_alpha_state)

    def _update_target(self, algo_state: AlgoState):
        q1, q2, q1_target, q2_target = (
            algo_state.q1,
            algo_state.q2,
            algo_state.target_q1,
            algo_state.target_q2,
        )

        algo_state = algo_state.replace(
            target_q1=algo_state.target_q1.replace(
                params=optax.incremental_update(q1.params, q1_target.params, self.tau)
            )
        )
        algo_state = algo_state.replace(
            target_q2=algo_state.target_q2.replace(
                params=optax.incremental_update(q2.params, q2_target.params, self.tau)
            )
        )
        return algo_state

    def _sample_policy(self, actor, params, obs: jnp.ndarray, key):
        act_mean, act_log_std = actor.apply_fn(params, obs)
        act_std = jnp.exp(act_log_std)

        dist = distrax.Normal(act_mean, act_std)

        x = dist.sample(seed=key)
        acts = jnp.tanh(x)
        log_prob = dist.log_prob(x)

        log_prob -= jnp.log((1 - jnp.square(acts)) + 1e-6)
        log_prob = log_prob.sum(-1, keepdims=True)

        mean = jnp.tanh(act_mean)

        return acts, log_prob, mean

    def update(self, key, algo_state: AlgoState, transition: EnvStep):

        key, critic_key = jrnd.split(key)
        algo_state = self._update_critic(algo_state, transition, critic_key)

        key, actor_key = jrnd.split(key)
        algo_state = self._update_actor(algo_state, transition.obs, actor_key)
        key, alpha_key = jrnd.split(key)
        algo_state = self._update_alpha(algo_state, transition.obs, alpha_key)
        algo_state = self._update_target(algo_state)

        return algo_state.replace(update_times=algo_state.update_times + 1)


def refill_cfg(cfg: DictConfig) -> Tuple[DictConfig, Dict[str, Any]]:
    from omegaconf.omegaconf import OmegaConf

    dcfg = {}

    if OmegaConf.is_missing(cfg, "episodes_to_train"):
        cfg.episodes_to_train = 20 if cfg.env.name == "cartpole" else 22

    dcfg["total_training_times"] = cfg.episodes_to_train * cfg.env.episode_length

    assert dcfg["total_training_times"] % cfg.split_nums == 0
    dcfg["frames_per_split"] = int(dcfg["total_training_times"] / cfg.split_nums)

    if OmegaConf.is_missing(cfg.algo, "training_start_timesteps"):
        cfg.algo.training_start_timesteps = 500

    return cfg, dcfg
