"""Implementations of algorithms for continuous control."""

from typing import Sequence, Tuple
import functools
import numpy as np
from dataclasses import dataclass

import jax
import jax.numpy as jnp
import optax
from flax.training.train_state import TrainState

from src.common import InfoDict, Batch
from src.models import DoubleCritic, QuadCritic, MSEPolicy, sample_actions
from src.agents.common import target_update
from src.agents.td3_boxd4.actor import update as update_actor
from src.agents.td3_boxd4.critic import (
    update_max as update_critic_max,
    update as update_critic,
)


@functools.partial(
    jax.vmap, in_axes=(0, None, 0, 0, 0, 0, 0, 0, 0, 0, 0, None, None, None)
)
def _update(
    rng: jax.random.PRNGKey,
    step: int,
    actor_explore: TrainState,
    target_actor_explore: TrainState,
    actor: TrainState,
    target_actor: TrainState,
    critic_explore: TrainState,
    target_critic_explore: TrainState,
    critic: TrainState,
    target_critic: TrainState,
    batch: Batch,
    discount: float,
    tau: float,
    args,
) -> Tuple[
    jax.random.PRNGKey,
    TrainState,
    TrainState,
    TrainState,
    TrainState,
    TrainState,
    TrainState,
    TrainState,
    TrainState,
    InfoDict,
]:

    rng, critic_key, actor_key = jax.random.split(rng, 3)

    new_critic, critic_info = update_critic(
        critic_key,
        target_actor,
        critic,
        target_critic,
        batch,
        discount,
        policy_noise=args.policy_noise,
        noise_clip=args.noise_clip,
        max_action=args.max_action,
        k_samples=1,
    )
    new_critic_explore, critic_info_explore = update_critic_max(
        critic_key,
        target_actor_explore,
        critic_explore,
        target_critic_explore,
        batch,
        discount,
        policy_noise=args.policy_noise,
        noise_clip=args.noise_clip,
        max_action=args.max_action,
        k_samples=args.critic_k_samples,
    )
    critic_info = {**critic_info, **critic_info_explore}

    def _apply_actor(_):
        new_actor, info = update_actor(
            actor_key,
            actor,
            new_critic,
            batch,
            k_samples=1,
        )
        new_actor_explore, info_explore = update_actor(
            actor_key,
            actor_explore,
            new_critic_explore,
            batch,
            k_samples=args.actor_k_samples,
            optimistic=True,
        )
        info_explore = {f"{k}_explore": v for k, v in info_explore.items()}
        return new_actor_explore, new_actor, {**info, **info_explore}

    def _skip_actor(_):
        dummy_info = {
            "actor_loss": jnp.array(0.0),
            "actor_q1_var": jnp.array(0.0),
            "actor_q2_var": jnp.array(0.0),
            "actor_q3_var": jnp.array(0.0),
            "actor_q4_var": jnp.array(0.0),
            "actor_loss_explore": jnp.array(0.0),
            "actor_q1_var_explore": jnp.array(0.0),
            "actor_q2_var_explore": jnp.array(0.0),
            "actor_q3_var_explore": jnp.array(0.0),
            "actor_q4_var_explore": jnp.array(0.0),
        }
        return actor_explore, actor, dummy_info

    new_actor_explore, new_actor, actor_info = jax.lax.cond(
        step % args.policy_update_freq == 0,
        _apply_actor,
        _skip_actor,
        operand=None,
    )
    new_target_critic_explore = target_update(
        new_critic_explore, target_critic_explore, tau
    )
    new_target_critic = target_update(new_critic, target_critic, tau)
    new_target_actor_explore = target_update(
        new_actor_explore, target_actor_explore, tau
    )
    new_target_actor = target_update(new_actor, target_actor, tau)
    info = {**critic_info, **actor_info}

    return (
        rng,
        new_actor_explore,
        new_target_actor_explore,
        new_actor,
        new_target_actor,
        new_critic_explore,
        new_target_critic_explore,
        new_critic,
        new_target_critic,
        info,
    )


@functools.partial(
    jax.jit, static_argnames=("discount", "tau", "num_updates", "args")
)
def _do_multiple_updates(
    rng: jax.random.PRNGKey,
    step: int,
    actor_explore: TrainState,
    target_actor_explore: TrainState,
    actor: TrainState,
    target_actor: TrainState,
    critic_explore: TrainState,
    target_critic_explore: TrainState,
    critic: TrainState,
    target_critic: TrainState,
    batches: Batch,
    discount: float,
    tau: float,
    num_updates: int,
    args,
) -> Tuple[
    jax.random.PRNGKey,
    TrainState,
    TrainState,
    TrainState,
    TrainState,
    TrainState,
    TrainState,
    TrainState,
    TrainState,
    InfoDict,
]:

    def one_step(i, state):

        (
            rng,
            step,
            actor_explore,
            target_actor_explore,
            actor,
            target_actor,
            critic_explore,
            target_critic_explore,
            critic,
            target_critic,
            info,
        ) = state
        (
            new_rng,
            new_actor_explore,
            new_target_actor_explore,
            new_actor,
            new_target_actor,
            new_critic_explore,
            new_target_critic_explore,
            new_critic,
            new_target_critic,
            info,
        ) = _update(
            rng,
            step,
            actor_explore,
            target_actor_explore,
            actor,
            target_actor,
            critic_explore,
            target_critic_explore,
            critic,
            target_critic,
            jax.tree.map(lambda x: jnp.take(x, i, axis=1), batches),
            discount,
            tau,
            args,
        )
        step = step + 1
        return (
            new_rng,
            step,
            new_actor_explore,
            new_target_actor_explore,
            new_actor,
            new_target_actor,
            new_critic_explore,
            new_target_critic_explore,
            new_critic,
            new_target_critic,
            info,
        )

    (
        rng,
        step,
        actor_explore,
        target_actor_explore,
        actor,
        target_actor,
        critic_explore,
        target_critic_explore,
        critic,
        target_critic,
        info,
    ) = one_step(
        0,
        (
            rng,
            step,
            actor_explore,
            target_actor_explore,
            actor,
            target_actor,
            critic_explore,
            target_critic_explore,
            critic,
            target_critic,
            {},
        ),
    )

    return jax.lax.fori_loop(
        1,
        num_updates,
        one_step,
        (
            rng,
            step,
            actor_explore,
            target_actor_explore,
            actor,
            target_actor,
            critic_explore,
            target_critic_explore,
            critic,
            target_critic,
            info,
        ),
    )


@functools.partial(jax.jit)
@functools.partial(jax.vmap, in_axes=(0))
def _random_set_temp(
    rng: jax.random.PRNGKey,
) -> Tuple[jax.random.PRNGKey, jax.Array]:

    rng, temp_key = jax.random.split(rng)
    return rng, jax.random.uniform(temp_key, shape=())


def random_set_temp(
    rng: jax.random.PRNGKey, thres: float, temperature: float = 1.0
) -> Tuple[jax.random.PRNGKey, jax.Array]:

    # original temp -> use pi_explore
    # temp == 0 -> use pi_task
    # return original temperature if mean(randomness) is below threshold
    rng, temp = _random_set_temp(rng)
    return rng, temperature if temp.mean() > thres else 0


def annealing_set_temp(
    rng: jax.random.PRNGKey, step: int, temperature: float = 1.0
) -> Tuple[jax.random.PRNGKey, jax.Array]:

    # original temp -> use pi_explore
    # temp == 0 -> use pi_task
    # return original temperature if mean(randomness) is below threshold
    # this threshold stepwise increases makes it more likely for
    # mean(randomness) < threshold, so we use pi_task more (explore less)
    # as training progress
    rng, temp = _random_set_temp(rng)
    step_thres = (step // 3e5) / 10

    return rng, temperature if temp.mean() > step_thres else 0


@dataclass(frozen=True)
class ConfigArgs:
    policy_update_freq: int
    policy_noise: float
    noise_clip: float
    max_action: float
    max_steps: int
    dropout_rate: float
    actor_k_samples: int
    critic_k_samples: int
    set_explore_type: str


class TD3Boxd4Learner(object):
    def __init__(
        self,
        seed: int,
        # env settings
        state_dim: int,
        action_dim: int,
        # common RL settings
        actor_lr: float = 3e-4,
        critic_lr: float = 3e-4,
        hidden_dims: Sequence[int] = (256, 256),
        discount: float = 0.99,
        tau: float = 0.005,
        num_parallel_seeds: int = 1,
        max_steps: int = 3e6,
        # td3
        exploration_noise: float = 0.1,
        policy_update_freq: int = 2,
        policy_noise: float = 0.2,
        noise_clip: float = 0.5,
        max_action: float = 1.0,
        # droq
        dropout_rate: float = 0.001,
        layernorm: bool = True,
        # boxd
        actor_k_samples: int = 1,
        critic_k_samples: int = 1,
        set_explore_type: str = "annealing",
    ):

        self.seeds = jnp.arange(seed, seed + num_parallel_seeds)
        self.tau = tau
        self.discount = discount
        self.exploration_noise = exploration_noise
        self.max_steps = max_steps

        self.args = ConfigArgs(
            policy_update_freq,
            policy_noise,
            noise_clip,
            max_action,
            max_steps,
            dropout_rate,
            actor_k_samples,
            critic_k_samples,
            set_explore_type,
        )

        dummy_observations = jnp.ones([1, state_dim], dtype=jnp.float32)
        dummy_actions = jnp.ones([1, action_dim], dtype=jnp.float32)

        def _init_models(seed):
            rng = jax.random.PRNGKey(seed)
            rng, actor_key, critic_key = jax.random.split(rng, 3)

            # optimisers
            critic_optimiser = optax.adam(learning_rate=critic_lr)
            actor_optimiser = optax.adam(learning_rate=actor_lr)

            # actors
            actor_def = MSEPolicy(hidden_dims, action_dim)
            actor_explore = TrainState.create(
                apply_fn=actor_def.apply,
                params=actor_def.init(actor_key, dummy_observations),
                tx=actor_optimiser,
            )
            target_actor_explore = TrainState.create(
                apply_fn=actor_def.apply,
                params=actor_def.init(actor_key, dummy_observations),
                tx=actor_optimiser,
            )
            actor = TrainState.create(
                apply_fn=actor_def.apply,
                params=actor_def.init(actor_key, dummy_observations),
                tx=actor_optimiser,
            )
            target_actor = TrainState.create(
                apply_fn=actor_def.apply,
                params=actor_def.init(actor_key, dummy_observations),
                tx=actor_optimiser,
            )

            # critics
            critic_def_explore = QuadCritic(
                hidden_dims, layernorm=layernorm, dropout_rate=dropout_rate
            )
            critic_explore = TrainState.create(
                apply_fn=critic_def_explore.apply,
                params=critic_def_explore.init(
                    critic_key, dummy_observations, dummy_actions
                ),
                tx=critic_optimiser,
            )
            target_critic_explore = TrainState.create(
                apply_fn=critic_def_explore.apply,
                params=critic_def_explore.init(
                    critic_key, dummy_observations, dummy_actions
                ),
                tx=critic_optimiser,
            )
            critic_def = DoubleCritic(
                hidden_dims, layernorm=layernorm, dropout_rate=dropout_rate
            )
            critic = TrainState.create(
                apply_fn=critic_def.apply,
                params=critic_def.init(
                    critic_key, dummy_observations, dummy_actions
                ),
                tx=critic_optimiser,
            )
            target_critic = TrainState.create(
                apply_fn=critic_def.apply,
                params=critic_def.init(
                    critic_key, dummy_observations, dummy_actions
                ),
                tx=critic_optimiser,
            )

            return (
                actor_explore,
                target_actor_explore,
                actor,
                target_actor,
                critic_explore,
                target_critic_explore,
                critic,
                target_critic,
                rng,
            )

        self.init_models = jax.jit(jax.vmap(_init_models))

        (
            self.actor_explore,
            self.target_actor_explore,
            self.actor,
            self.target_actor,
            self.critic_explore,
            self.target_critic_explore,
            self.critic,
            self.target_critic,
            self.rng,
        ) = self.init_models(self.seeds)
        self.trainable_models = ["actor", "critic"]
        self.step = 0

    def sample_actions(
        self, observations: np.ndarray, temperature: float = 1.0
    ) -> jax.Array:

        # use temperature as tool to determine to "explore" or not
        # sometimes set temperature=zero for more conservation "explore"
        if temperature > 0:
            if self.args.set_explore_type == "annealing":
                # use a annealing scheduler to determine threshold for setting
                # temperature=zero
                self.rng, temperature = annealing_set_temp(
                    self.rng, self.step, temperature
                )
            elif self.args.set_explore_type == "fixed":
                # use a fixed scheduler
                self.rng, temperature = random_set_temp(
                    self.rng, 0.1, temperature
                )
            elif self.args.set_explore_type == "half":
                # use a fixed scheduler
                self.rng, temperature = random_set_temp(
                    self.rng, 0.5, temperature
                )
            else:
                temperature = temperature

        if temperature > 0:
            rng, actions = sample_actions(
                self.rng,
                self.actor_explore.apply_fn,
                self.actor_explore.params,
                observations,
                temperature,
                distribution="det",
            )

        else:
            rng, actions = sample_actions(
                self.rng,
                self.actor.apply_fn,
                self.actor.params,
                observations,
                temperature,
                distribution="det",
            )

        self.rng = rng
        actions = np.asarray(actions)
        actions = (
            actions
            + np.random.normal(size=actions.shape)
            * self.exploration_noise
            * temperature
        )
        return np.clip(actions, -1, 1)

    def update(self, batch: Batch, num_updates: int = 1) -> InfoDict:
        (
            self.rng,
            self.step,
            self.actor_explore,
            self.target_actor_explore,
            self.actor,
            self.target_actor,
            self.critic_explore,
            self.target_critic_explore,
            self.critic,
            self.target_critic,
            info,
        ) = _do_multiple_updates(
            self.rng,
            self.step,
            self.actor_explore,
            self.target_actor_explore,
            self.actor,
            self.target_actor,
            self.critic_explore,
            self.target_critic_explore,
            self.critic,
            self.target_critic,
            batch,
            self.discount,
            self.tau,
            num_updates,
            self.args,
        )
        return info

    def reset(self):
        (
            self.actor_explore,
            self.target_actor_explore,
            self.actor,
            self.target_actor,
            self.critic_explore,
            self.target_critic_explore,
            self.critic,
            self.target_critic,
            self.rng,
        ) = self.init_models(self.seeds)

    def apply_once(self, batch: Batch):
        raise NotImplementedError
