"""Implementations of algorithms for continuous control."""

from typing import Optional, Sequence, Tuple
import functools
import numpy as np
from dataclasses import dataclass

import jax
import jax.numpy as jnp
import optax
from flax.training.train_state import TrainState

from src.common import InfoDict, Batch
from src.models import DoubleCritic, NormalTanhPolicy, sample_actions
from src.agents.common import target_update
from src.agents.sac.actor import update as update_actor
from src.agents.sac.critic import update as update_critic
from src.agents.sac.temperature import Temperature, update as update_temp


@functools.partial(
    jax.vmap, in_axes=(0, None, 0, 0, 0, 0, 0, None, None, None)
)
def _update(
    rng: jax.random.PRNGKey,
    step: int,
    actor: TrainState,
    critic: TrainState,
    target_critic: TrainState,
    temp: TrainState,
    batch: Batch,
    discount: float,
    tau: float,
    args,
) -> Tuple[
    jax.random.PRNGKey,
    TrainState,
    TrainState,
    TrainState,
    TrainState,
    InfoDict,
]:

    rng, critic_key, actor_key = jax.random.split(rng, 3)

    new_critic, critic_info = update_critic(
        critic_key,
        actor,
        critic,
        target_critic,
        temp,
        batch,
        discount,
        soft_critic=True,
    )

    def _apply_actor(_):
        new_actor, info = update_actor(
            actor_key,
            actor,
            new_critic,
            temp,
            batch,
        )
        return new_actor, info

    def _skip_actor(_):
        dummy_info = {
            "actor_loss": jnp.array(0.0),
            "actor_entropy": jnp.array(0.0),
        }
        return actor, dummy_info

    new_actor, actor_info = jax.lax.cond(
        step % args.policy_update_freq == 0,
        _apply_actor,
        _skip_actor,
        operand=None,
    )
    new_target_critic = target_update(new_critic, target_critic, tau)
    new_temp, alpha_info = update_temp(
        temp, actor_info["actor_entropy"], args.target_entropy
    )
    info = {**critic_info, **actor_info, **alpha_info}

    return (
        rng,
        new_actor,
        new_critic,
        new_target_critic,
        new_temp,
        info,
    )


@functools.partial(
    jax.jit, static_argnames=("discount", "tau", "num_updates", "args")
)
def _do_multiple_updates(
    rng: jax.random.PRNGKey,
    step: int,
    actor: TrainState,
    critic: TrainState,
    target_critic: TrainState,
    temp: TrainState,
    batches: Batch,
    discount: float,
    tau: float,
    num_updates: int,
    args,
) -> Tuple[
    jax.random.PRNGKey,
    TrainState,
    TrainState,
    TrainState,
    TrainState,
    InfoDict,
]:

    def one_step(i, state):

        rng, step, actor, critic, target_critic, temp, info = state
        (
            new_rng,
            new_actor,
            new_critic,
            new_target_critic,
            new_temp,
            info,
        ) = _update(
            rng,
            step,
            actor,
            critic,
            target_critic,
            temp,
            jax.tree.map(lambda x: jnp.take(x, i, axis=1), batches),
            discount,
            tau,
            args,
        )
        step = step + 1
        return (
            new_rng,
            step,
            new_actor,
            new_critic,
            new_target_critic,
            new_temp,
            info,
        )

    rng, step, actor, critic, target_critic, temp, info = one_step(
        0, (rng, step, actor, critic, target_critic, temp, {})
    )

    return jax.lax.fori_loop(
        1,
        num_updates,
        one_step,
        (rng, step, actor, critic, target_critic, temp, info),
    )


@dataclass(frozen=True)
class ConfigArgs:
    policy_update_freq: int
    target_entropy: float


class SACLearner(object):
    def __init__(
        self,
        seed: int,
        # env settings
        state_dim: int,
        action_dim: int,
        # common RL settings
        actor_lr: float = 3e-4,
        critic_lr: float = 3e-4,
        hidden_dims: Sequence[int] = (256, 256),
        discount: float = 0.99,
        tau: float = 0.005,
        num_parallel_seeds: int = 1,
        max_steps: int = 3e6,
        # sac
        policy_update_freq: int = 1,
        target_entropy: Optional[float] = None,
        init_temperature: float = 1.0,
        temp_lr: float = 3e-4,
    ):

        self.seeds = jnp.arange(seed, seed + num_parallel_seeds)
        self.tau = tau
        self.discount = discount
        self.max_steps = max_steps
        target_entropy = (
            -action_dim / 2 if target_entropy is None else target_entropy
        )

        self.args = ConfigArgs(
            policy_update_freq,
            target_entropy,
        )

        dummy_observations = jnp.ones([1, state_dim], dtype=jnp.float32)
        dummy_actions = jnp.ones([1, action_dim], dtype=jnp.float32)

        def _init_models(seed):
            rng = jax.random.PRNGKey(seed)
            rng, actor_key, critic_key, temp_key = jax.random.split(rng, 4)

            # optimisers
            critic_optimiser = optax.adam(learning_rate=critic_lr)
            actor_optimiser = optax.adam(learning_rate=actor_lr)
            temp_optimiser = optax.adam(learning_rate=temp_lr)

            # actors
            actor_def = NormalTanhPolicy(hidden_dims, action_dim)
            actor = TrainState.create(
                apply_fn=actor_def.apply,
                params=actor_def.init(actor_key, dummy_observations),
                tx=actor_optimiser,
            )

            # critics
            critic_def = DoubleCritic(hidden_dims)
            critic = TrainState.create(
                apply_fn=critic_def.apply,
                params=critic_def.init(
                    critic_key, dummy_observations, dummy_actions
                ),
                tx=critic_optimiser,
            )
            target_critic = TrainState.create(
                apply_fn=critic_def.apply,
                params=critic_def.init(
                    critic_key, dummy_observations, dummy_actions
                ),
                tx=critic_optimiser,
            )

            # temperature
            temp_def = Temperature(init_temperature)
            temp = TrainState.create(
                apply_fn=temp_def.apply,
                params=temp_def.init(temp_key),
                tx=temp_optimiser,
            )
            return actor, critic, target_critic, temp, rng

        self.init_models = jax.jit(jax.vmap(_init_models))

        (
            self.actor,
            self.critic,
            self.target_critic,
            self.temp,
            self.rng,
        ) = self.init_models(self.seeds)
        self.trainable_models = ["actor", "critic"]
        self.step = 0

    def sample_actions(
        self, observations: np.ndarray, temperature: float = 1.0
    ) -> jax.Array:

        rng, actions = sample_actions(
            self.rng,
            self.actor.apply_fn,
            self.actor.params,
            observations,
            temperature,
            distribution="log_prob",
        )
        self.rng = rng
        actions = np.asarray(actions)
        return np.clip(actions, -1, 1)

    def update(self, batch: Batch, num_updates: int = 1) -> InfoDict:
        (
            self.rng,
            self.step,
            self.actor,
            self.critic,
            self.target_critic,
            self.temp,
            info,
        ) = _do_multiple_updates(
            self.rng,
            self.step,
            self.actor,
            self.critic,
            self.target_critic,
            self.temp,
            batch,
            self.discount,
            self.tau,
            num_updates,
            self.args,
        )
        return info

    def reset(self):
        (
            self.actor,
            self.critic,
            self.target_critic,
            self.temp,
            self.rng,
        ) = self.init_models(self.seeds)

    def apply_once(self, batch: Batch):
        raise NotImplementedError
