"""Implementations of algorithms for continuous control."""

from typing import Optional, Sequence, Tuple
import functools
import numpy as np
from dataclasses import dataclass

import jax
import jax.numpy as jnp
import optax
from flax.training.train_state import TrainState

from src.common import InfoDict, Batch
from src.models import DoubleCritic, NormalTanhPolicy, sample_actions, RND
from src.agents.common import target_update
from src.agents.sac_derl.actor import update as update_actor
from src.agents.sac_derl.critic import update as update_critic
from src.agents.sac_derl.temperature import (
    Temperature,
    update as update_temp,
)
from src.agents.sac_derl.intrinsicrewards import update_rnd
from src.agents.sac_derl.rnd_states import RNDTrainState, RunningMeanStd


@functools.partial(
    jax.vmap, in_axes=(0, None, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, None, None, None)
)
def _update(
    rng: jax.random.PRNGKey,
    step: int,
    actor_explore: TrainState,
    actor: TrainState,
    critic_explore: TrainState,
    target_critic_explore: TrainState,
    critic: TrainState,
    target_critic: TrainState,
    temp_explore: TrainState,
    temp: TrainState,
    rnd: RNDTrainState,
    batch: Batch,
    discount: float,
    tau: float,
    args,
) -> Tuple[
    jax.random.PRNGKey,
    TrainState,
    TrainState,
    TrainState,
    TrainState,
    TrainState,
    TrainState,
    TrainState,
    TrainState,
    TrainState,
    InfoDict,
]:

    rng, critic_key, actor_key, rnd_key = jax.random.split(rng, 4)

    new_critic, critic_info = update_critic(
        critic_key,
        actor,
        critic,
        target_critic,
        temp,
        rnd,
        batch,
        discount,
        soft_critic=True,
        rnd_coeff=0.0,
    )
    new_critic_explore, critic_info_explore = update_critic(
        critic_key,
        actor_explore,
        critic_explore,
        target_critic_explore,
        temp_explore,
        rnd,
        batch,
        discount,
        soft_critic=True,
        rnd_coeff=args.rnd_coeff,
    )

    critic_info_explore = {
        f"{k}_explore": v for k, v in critic_info_explore.items()
    }
    critic_info = {**critic_info, **critic_info_explore}

    def _apply_actor(_):
        new_actor, info = update_actor(
            actor_key,
            actor,
            new_critic,
            temp,
            rnd,
            batch,
            rnd_coeff=0.0,
        )
        new_actor_explore, info_explore = update_actor(
            actor_key,
            actor_explore,
            new_critic,
            temp_explore,
            rnd,
            batch,
            rnd_coeff=args.rnd_coeff,
        )
        info_explore = {f"{k}_explore": v for k, v in info_explore.items()}
        return new_actor_explore, new_actor, {**info, **info_explore}

    def _skip_actor(_):
        dummy_info = {
            "actor_loss": jnp.array(0.0),
            "actor_entropy": jnp.array(0.0),
            # "actor_rnd_bonus": jnp.array(0.0),
            "actor_loss_explore": jnp.array(0.0),
            "actor_entropy_explore": jnp.array(0.0),
            # "actor_rnd_bonus_explore": jnp.array(0.0),
        }
        return actor_explore, actor, dummy_info

    new_actor_explore, new_actor, actor_info = jax.lax.cond(
        step % args.policy_update_freq == 0,
        _apply_actor,
        _skip_actor,
        operand=None,
    )

    new_rnd, rnd_info = update_rnd(rnd_key, rnd, batch)

    new_target_critic_explore = target_update(
        new_critic_explore, target_critic_explore, tau
    )
    new_target_critic = target_update(new_critic, target_critic, tau)
    new_temp, alpha_info = update_temp(
        temp, actor_info["actor_entropy"], args.target_entropy
    )
    new_temp_explore, alpha_info_explore = update_temp(
        temp_explore, actor_info["actor_entropy_explore"], args.target_entropy
    )
    alpha_info_explore = {
        f"{k}_explore": v for k, v in alpha_info_explore.items()
    }
    info = {**critic_info, **actor_info, **alpha_info, **alpha_info_explore}

    return (
        rng,
        new_actor_explore,
        new_actor,
        new_critic_explore,
        new_target_critic_explore,
        new_critic,
        new_target_critic,
        new_temp_explore,
        new_temp,
        new_rnd,
        info,
    )


@functools.partial(
    jax.jit, static_argnames=("discount", "tau", "num_updates", "args")
)
def _do_multiple_updates(
    rng: jax.random.PRNGKey,
    step: int,
    actor_explore: TrainState,
    actor: TrainState,
    critic_explore: TrainState,
    target_critic_explore: TrainState,
    critic: TrainState,
    target_critic: TrainState,
    temp_explore: TrainState,
    temp: TrainState,
    rnd: RNDTrainState,
    batches: Batch,
    discount: float,
    tau: float,
    num_updates: int,
    args,
) -> Tuple[
    jax.random.PRNGKey,
    TrainState,
    TrainState,
    TrainState,
    TrainState,
    TrainState,
    TrainState,
    TrainState,
    TrainState,
    InfoDict,
]:

    def one_step(i, state):

        (
            rng,
            step,
            actor_explore,
            actor,
            critic_explore,
            target_critic_explore,
            critic,
            target_critic,
            temp_explore,
            temp,
            rnd,
            info,
        ) = state
        (
            new_rng,
            new_actor_explore,
            new_actor,
            new_critic_explore,
            new_target_critic_explore,
            new_critic,
            new_target_critic,
            new_temp_explore,
            new_temp,
            new_rnd,
            info,
        ) = _update(
            rng,
            step,
            actor_explore,
            actor,
            critic_explore,
            target_critic_explore,
            critic,
            target_critic,
            temp_explore,
            temp,
            rnd,
            jax.tree.map(lambda x: jnp.take(x, i, axis=1), batches),
            discount,
            tau,
            args,
        )
        step = step + 1
        return (
            new_rng,
            step,
            new_actor_explore,
            new_actor,
            new_critic_explore,
            new_target_critic_explore,
            new_critic,
            new_target_critic,
            new_temp_explore,
            new_temp,
            new_rnd,
            info,
        )

    (
        rng,
        step,
        actor_explore,
        actor,
        critic_explore,
        target_critic_explore,
        critic,
        target_critic,
        temp_explore,
        temp,
        rnd,
        info,
    ) = one_step(
        0,
        (
            rng,
            step,
            actor_explore,
            actor,
            critic_explore,
            target_critic_explore,
            critic,
            target_critic,
            temp_explore,
            temp,
            rnd,
            {},
        ),
    )

    return jax.lax.fori_loop(
        1,
        num_updates,
        one_step,
        (
            rng,
            step,
            actor_explore,
            actor,
            critic_explore,
            target_critic_explore,
            critic,
            target_critic,
            temp_explore,
            temp,
            rnd,
            info,
        ),
    )


@dataclass(frozen=True)
class ConfigArgs:
    policy_update_freq: int
    target_entropy: float
    rnd_coeff: float


class SACDERLLearner(object):
    def __init__(
        self,
        seed: int,
        # env settings
        state_dim: int,
        action_dim: int,
        # common RL settings
        actor_lr: float = 3e-4,
        critic_lr: float = 3e-4,
        hidden_dims: Sequence[int] = (256, 256),
        discount: float = 0.99,
        tau: float = 0.005,
        num_parallel_seeds: int = 1,
        max_steps: int = 3e6,
        # sac
        policy_update_freq: int = 1,
        target_entropy: Optional[float] = None,
        init_temperature: float = 1.0,
        temp_lr: float = 3e-4,
        # intrinsic rewards
        rnd_lr: float = 3e-4,
        rnd_coeff: float = 1.0,
    ):

        self.seeds = jnp.arange(seed, seed + num_parallel_seeds)
        self.tau = tau
        self.discount = discount
        self.max_steps = max_steps
        target_entropy = (
            -action_dim / 2 if target_entropy is None else target_entropy
        )

        self.args = ConfigArgs(
            policy_update_freq,
            target_entropy,
            rnd_coeff,
        )

        dummy_observations = jnp.ones([1, state_dim], dtype=jnp.float32)
        dummy_actions = jnp.ones([1, action_dim], dtype=jnp.float32)

        def _init_models(seed):
            rng = jax.random.PRNGKey(seed)
            rng, actor_key, critic_key, temp_key, rnd_key = jax.random.split(
                rng, 5
            )

            # optimisers
            critic_optimiser = optax.adam(learning_rate=critic_lr)
            actor_optimiser = optax.adam(learning_rate=actor_lr)
            temp_optimiser = optax.adam(learning_rate=temp_lr)
            rnd_optimiser = optax.adam(learning_rate=rnd_lr)

            # actors
            actor_def = NormalTanhPolicy(hidden_dims, action_dim)
            actor_explore = TrainState.create(
                apply_fn=actor_def.apply,
                params=actor_def.init(actor_key, dummy_observations),
                tx=actor_optimiser,
            )
            actor = TrainState.create(
                apply_fn=actor_def.apply,
                params=actor_def.init(actor_key, dummy_observations),
                tx=actor_optimiser,
            )

            # critics
            critic_def = DoubleCritic(hidden_dims)
            critic_explore = TrainState.create(
                apply_fn=critic_def.apply,
                params=critic_def.init(
                    critic_key, dummy_observations, dummy_actions
                ),
                tx=critic_optimiser,
            )
            target_critic_explore = TrainState.create(
                apply_fn=critic_def.apply,
                params=critic_def.init(
                    critic_key, dummy_observations, dummy_actions
                ),
                tx=critic_optimiser,
            )
            critic = TrainState.create(
                apply_fn=critic_def.apply,
                params=critic_def.init(
                    critic_key, dummy_observations, dummy_actions
                ),
                tx=critic_optimiser,
            )
            target_critic = TrainState.create(
                apply_fn=critic_def.apply,
                params=critic_def.init(
                    critic_key, dummy_observations, dummy_actions
                ),
                tx=critic_optimiser,
            )

            # temperature
            temp_def = Temperature(init_temperature)
            temp_explore = TrainState.create(
                apply_fn=temp_def.apply,
                params=temp_def.init(temp_key),
                tx=temp_optimiser,
            )
            temp = TrainState.create(
                apply_fn=temp_def.apply,
                params=temp_def.init(temp_key),
                tx=temp_optimiser,
            )

            # RND
            rnd_def = RND(hidden_dims)
            rnd = RNDTrainState.create(
                apply_fn=rnd_def.apply,
                params=rnd_def.init(
                    rnd_key, dummy_observations, dummy_actions
                ),
                tx=rnd_optimiser,
                rms=RunningMeanStd.create(),
            )

            return (
                actor_explore,
                actor,
                critic_explore,
                target_critic_explore,
                critic,
                target_critic,
                temp_explore,
                temp,
                rnd,
                rng,
            )

        self.init_models = jax.jit(jax.vmap(_init_models))

        (
            self.actor_explore,
            self.actor,
            self.critic_explore,
            self.target_critic_explore,
            self.critic,
            self.target_critic,
            self.temp_explore,
            self.temp,
            self.rnd,
            self.rng,
        ) = self.init_models(self.seeds)
        self.trainable_models = ["actor", "critic"]
        self.step = 0

    def sample_actions(
        self, observations: np.ndarray, temperature: float = 1.0
    ) -> jax.Array:

        # use temperature as tool to determine to "explore" or not
        if temperature > 0:
            rng, actions = sample_actions(
                self.rng,
                self.actor_explore.apply_fn,
                self.actor_explore.params,
                observations,
                temperature,
                distribution="log_prob",
            )

        else:
            rng, actions = sample_actions(
                self.rng,
                self.actor.apply_fn,
                self.actor.params,
                observations,
                temperature,
                distribution="log_prob",
            )

        self.rng = rng
        actions = np.asarray(actions)
        return np.clip(actions, -1, 1)

    def update(self, batch: Batch, num_updates: int = 1) -> InfoDict:
        (
            self.rng,
            self.step,
            self.actor_explore,
            self.actor,
            self.critic_explore,
            self.target_critic_explore,
            self.critic,
            self.target_critic,
            self.temp_explore,
            self.temp,
            self.rnd,
            info,
        ) = _do_multiple_updates(
            self.rng,
            self.step,
            self.actor_explore,
            self.actor,
            self.critic_explore,
            self.target_critic_explore,
            self.critic,
            self.target_critic,
            self.temp_explore,
            self.temp,
            self.rnd,
            batch,
            self.discount,
            self.tau,
            num_updates,
            self.args,
        )
        return info

    def reset(self):
        (
            self.actor_explore,
            self.actor,
            self.critic_explore,
            self.target_critic_explore,
            self.critic,
            self.target_critic,
            self.temp_explore,
            self.rnd,
            self.rng,
        ) = self.init_models(self.seeds)

    def apply_once(self, batch: Batch):
        raise NotImplementedError
