"""Implementations of algorithms for continuous control."""

from typing import Sequence, Tuple
import functools
import numpy as np
from dataclasses import dataclass

import jax
import jax.numpy as jnp
import optax
from flax.training.train_state import TrainState

from src.common import InfoDict, Batch
from src.models import DoubleCritic, MSEPolicy, sample_actions, RND
from src.agents.common import target_update
from src.agents.td3_derl.actor import update as update_actor
from src.agents.td3_derl.critic import update as update_critic
from src.agents.td3_derl.intrinsicrewards import update_rnd
from src.agents.td3_derl.rnd_states import RNDTrainState, RunningMeanStd


@functools.partial(
    jax.vmap, in_axes=(0, None, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, None, None, None)
)
def _update(
    rng: jax.random.PRNGKey,
    step: int,
    actor_explore: TrainState,
    target_actor_explore: TrainState,
    actor: TrainState,
    target_actor: TrainState,
    critic_explore: TrainState,
    target_critic_explore: TrainState,
    critic: TrainState,
    target_critic: TrainState,
    rnd: RNDTrainState,
    batch: Batch,
    discount: float,
    tau: float,
    args,
) -> Tuple[
    jax.random.PRNGKey,
    TrainState,
    TrainState,
    TrainState,
    TrainState,
    TrainState,
    TrainState,
    TrainState,
    TrainState,
    TrainState,
    InfoDict,
]:

    rng, critic_key, rnd_key = jax.random.split(rng, 3)

    new_critic, critic_info = update_critic(
        critic_key,
        target_actor,
        critic,
        target_critic,
        rnd,
        batch,
        discount,
        policy_noise=args.policy_noise,
        noise_clip=args.noise_clip,
        max_action=args.max_action,
        rnd_coeff=0,
    )
    new_critic_explore, critic_info_explore = update_critic(
        critic_key,
        target_actor_explore,
        critic_explore,
        target_critic_explore,
        rnd,
        batch,
        discount,
        policy_noise=args.policy_noise,
        noise_clip=args.noise_clip,
        max_action=args.max_action,
        rnd_coeff=args.rnd_coeff,
    )

    critic_info_explore = {
        f"{k}_explore": v for k, v in critic_info_explore.items()
    }
    critic_info = {**critic_info, **critic_info_explore}

    def _apply_actor(_):
        new_actor, info = update_actor(
            actor,
            new_critic,
            rnd,
            batch,
            rnd_coeff=0.0,
        )
        new_actor_explore, info_explore = update_actor(
            actor_explore,
            new_critic,
            rnd,
            batch,
            rnd_coeff=args.rnd_coeff,
        )
        info_explore = {f"{k}_explore": v for k, v in info_explore.items()}
        return new_actor_explore, new_actor, {**info, **info_explore}

    def _skip_actor(_):
        dummy_info = {
            "actor_loss": jnp.array(0.0),
            # "actor_rnd_bonus": jnp.array(0.0),
            "actor_loss_explore": jnp.array(0.0),
            # "actor_rnd_bonus_explore": jnp.array(0.0),
        }
        return actor_explore, actor, dummy_info

    new_actor_explore, new_actor, actor_info = jax.lax.cond(
        step % args.policy_update_freq == 0,
        _apply_actor,
        _skip_actor,
        operand=None,
    )

    new_rnd, rnd_info = update_rnd(rnd_key, rnd, batch)

    new_target_critic_explore = target_update(
        new_critic_explore, target_critic_explore, tau
    )
    new_target_critic = target_update(new_critic, target_critic, tau)
    new_target_actor_explore = target_update(
        new_actor_explore, target_actor_explore, tau
    )
    new_target_actor = target_update(new_actor, target_actor, tau)
    info = {**critic_info, **actor_info, **rnd_info}

    return (
        rng,
        new_actor_explore,
        new_target_actor_explore,
        new_actor,
        new_target_actor,
        new_critic_explore,
        new_target_critic_explore,
        new_critic,
        new_target_critic,
        new_rnd,
        info,
    )


@functools.partial(
    jax.jit, static_argnames=("discount", "tau", "num_updates", "args")
)
def _do_multiple_updates(
    rng: jax.random.PRNGKey,
    step: int,
    actor_explore: TrainState,
    target_actor_explore: TrainState,
    actor: TrainState,
    target_actor: TrainState,
    critic_explore: TrainState,
    target_critic_explore: TrainState,
    critic: TrainState,
    target_critic: TrainState,
    rnd: TrainState,
    batches: Batch,
    discount: float,
    tau: float,
    num_updates: int,
    args,
) -> Tuple[
    jax.random.PRNGKey,
    TrainState,
    TrainState,
    TrainState,
    TrainState,
    TrainState,
    TrainState,
    TrainState,
    TrainState,
    TrainState,
    TrainState,
    InfoDict,
]:

    def one_step(i, state):

        (
            rng,
            step,
            actor_explore,
            target_actor_explore,
            actor,
            target_actor,
            critic_explore,
            target_critic_explore,
            critic,
            target_critic,
            rnd,
            info,
        ) = state
        (
            new_rng,
            new_actor_explore,
            new_target_actor_explore,
            new_actor,
            new_target_actor,
            new_critic_explore,
            new_target_critic_explore,
            new_critic,
            new_target_critic,
            new_rnd,
            info,
        ) = _update(
            rng,
            step,
            actor_explore,
            target_actor_explore,
            actor,
            target_actor,
            critic_explore,
            target_critic_explore,
            critic,
            target_critic,
            rnd,
            jax.tree.map(lambda x: jnp.take(x, i, axis=1), batches),
            discount,
            tau,
            args,
        )
        step = step + 1
        return (
            new_rng,
            step,
            new_actor_explore,
            new_target_actor_explore,
            new_actor,
            new_target_actor,
            new_critic_explore,
            new_target_critic_explore,
            new_critic,
            new_target_critic,
            new_rnd,
            info,
        )

    (
        rng,
        step,
        actor_explore,
        target_actor_explore,
        actor,
        target_actor,
        critic_explore,
        target_critic_explore,
        critic,
        target_critic,
        rnd,
        info,
    ) = one_step(
        0,
        (
            rng,
            step,
            actor_explore,
            target_actor_explore,
            actor,
            target_actor,
            critic_explore,
            target_critic_explore,
            critic,
            target_critic,
            rnd,
            {},
        ),
    )

    return jax.lax.fori_loop(
        1,
        num_updates,
        one_step,
        (
            rng,
            step,
            actor_explore,
            target_actor_explore,
            actor,
            target_actor,
            critic_explore,
            target_critic_explore,
            critic,
            target_critic,
            rnd,
            info,
        ),
    )


@dataclass(frozen=True)
class ConfigArgs:
    policy_update_freq: int
    policy_noise: float
    noise_clip: float
    max_action: float
    max_steps: int
    rnd_coeff: float


class TD3DERLLearner(object):
    def __init__(
        self,
        seed: int,
        # env settings
        state_dim: int,
        action_dim: int,
        # common RL settings
        actor_lr: float = 3e-4,
        critic_lr: float = 3e-4,
        hidden_dims: Sequence[int] = (256, 256),
        discount: float = 0.99,
        tau: float = 0.005,
        num_parallel_seeds: int = 1,
        max_steps: int = 3e6,
        # td3
        exploration_noise: float = 0.1,
        policy_update_freq: int = 2,
        policy_noise: float = 0.2,
        noise_clip: float = 0.5,
        max_action: float = 1.0,
        # intrinsic rewards
        rnd_lr: float = 3e-4,
        rnd_coeff: float = 1.0,
    ):

        self.seeds = jnp.arange(seed, seed + num_parallel_seeds)
        self.tau = tau
        self.discount = discount
        self.exploration_noise = exploration_noise
        self.max_steps = max_steps

        self.args = ConfigArgs(
            policy_update_freq,
            policy_noise,
            noise_clip,
            max_action,
            max_steps,
            rnd_coeff,
        )

        dummy_observations = jnp.ones([1, state_dim], dtype=jnp.float32)
        dummy_actions = jnp.ones([1, action_dim], dtype=jnp.float32)

        def _init_models(seed):
            rng = jax.random.PRNGKey(seed)
            rng, actor_key, critic_key, rnd_key = jax.random.split(rng, 4)

            # optimisers
            critic_optimiser = optax.adam(learning_rate=critic_lr)
            actor_optimiser = optax.adam(learning_rate=actor_lr)
            rnd_optimiser = optax.adam(learning_rate=rnd_lr)

            # actors
            actor_def = MSEPolicy(hidden_dims, action_dim)
            actor_explore = TrainState.create(
                apply_fn=actor_def.apply,
                params=actor_def.init(actor_key, dummy_observations),
                tx=actor_optimiser,
            )
            target_actor_explore = TrainState.create(
                apply_fn=actor_def.apply,
                params=actor_def.init(actor_key, dummy_observations),
                tx=actor_optimiser,
            )
            actor = TrainState.create(
                apply_fn=actor_def.apply,
                params=actor_def.init(actor_key, dummy_observations),
                tx=actor_optimiser,
            )
            target_actor = TrainState.create(
                apply_fn=actor_def.apply,
                params=actor_def.init(actor_key, dummy_observations),
                tx=actor_optimiser,
            )

            # critics
            critic_def = DoubleCritic(hidden_dims)
            critic_explore = TrainState.create(
                apply_fn=critic_def.apply,
                params=critic_def.init(
                    critic_key, dummy_observations, dummy_actions
                ),
                tx=critic_optimiser,
            )
            target_critic_explore = TrainState.create(
                apply_fn=critic_def.apply,
                params=critic_def.init(
                    critic_key, dummy_observations, dummy_actions
                ),
                tx=critic_optimiser,
            )
            critic = TrainState.create(
                apply_fn=critic_def.apply,
                params=critic_def.init(
                    critic_key, dummy_observations, dummy_actions
                ),
                tx=critic_optimiser,
            )
            target_critic = TrainState.create(
                apply_fn=critic_def.apply,
                params=critic_def.init(
                    critic_key, dummy_observations, dummy_actions
                ),
                tx=critic_optimiser,
            )

            # RND
            rnd_def = RND(hidden_dims)
            rnd = RNDTrainState.create(
                apply_fn=rnd_def.apply,
                params=rnd_def.init(
                    rnd_key, dummy_observations, dummy_actions
                ),
                tx=rnd_optimiser,
                rms=RunningMeanStd.create(),
            )

            return (
                actor_explore,
                target_actor_explore,
                actor,
                target_actor,
                critic_explore,
                target_critic_explore,
                critic,
                target_critic,
                rnd,
                rng,
            )

        self.init_models = jax.jit(jax.vmap(_init_models))

        (
            self.actor_explore,
            self.target_actor_explore,
            self.actor,
            self.target_actor,
            self.critic_explore,
            self.target_critic_explore,
            self.critic,
            self.target_critic,
            self.rnd,
            self.rng,
        ) = self.init_models(self.seeds)
        self.trainable_models = ["actor", "critic"]
        self.step = 0

    def sample_actions(
        self, observations: np.ndarray, temperature: float = 1.0
    ) -> jax.Array:

        # use temperature as tool to determine to "explore" or not
        if temperature > 0:
            rng, actions = sample_actions(
                self.rng,
                self.actor_explore.apply_fn,
                self.actor_explore.params,
                observations,
                temperature,
                distribution="det",
            )

        else:
            rng, actions = sample_actions(
                self.rng,
                self.actor.apply_fn,
                self.actor.params,
                observations,
                temperature,
                distribution="det",
            )

        self.rng = rng
        actions = np.asarray(actions)
        actions = (
            actions
            + np.random.normal(size=actions.shape)
            * self.exploration_noise
            * temperature
        )
        return np.clip(actions, -1, 1)

    def update(self, batch: Batch, num_updates: int = 1) -> InfoDict:
        (
            self.rng,
            self.step,
            self.actor_explore,
            self.target_actor_explore,
            self.actor,
            self.target_actor,
            self.critic_explore,
            self.target_critic_explore,
            self.critic,
            self.target_critic,
            self.rnd,
            info,
        ) = _do_multiple_updates(
            self.rng,
            self.step,
            self.actor_explore,
            self.target_actor_explore,
            self.actor,
            self.target_actor,
            self.critic_explore,
            self.target_critic_explore,
            self.critic,
            self.target_critic,
            self.rnd,
            batch,
            self.discount,
            self.tau,
            num_updates,
            self.args,
        )
        return info

    def reset(self):
        (
            self.actor_explore,
            self.target_actor_explore,
            self.actor,
            self.target_actor,
            self.critic_explore,
            self.target_critic_explore,
            self.critic,
            self.target_critic,
            self.rnd,
            self.rng,
        ) = self.init_models(self.seeds)

    def apply_once(self, batch: Batch):
        raise NotImplementedError
