"""Implementations of algorithms for continuous control."""

from typing import Optional, Sequence, Tuple
import functools
import numpy as np
from dataclasses import dataclass

import jax
import jax.numpy as jnp
import optax
from flax.training.train_state import TrainState

from src.common import InfoDict, Batch
from src.models import (
    DoubleCritic,
    NormalTanhPolicy,
    sample_actions,
)
from src.agents.common import target_update
from src.agents.sac_boxd2.actor import update as update_actor
from src.agents.sac_boxd2.critic import (
    update_max as update_critic_max,
    update as update_critic,
)
from src.agents.sac_boxd2.temperature import (
    Temperature,
    update as update_temp,
)


@functools.partial(
    jax.vmap, in_axes=(0, None, 0, 0, 0, 0, 0, 0, 0, 0, 0, None, None, None)
)
def _update(
    rng: jax.random.PRNGKey,
    step: int,
    actor_explore: TrainState,
    actor: TrainState,
    critic_explore: TrainState,
    target_critic_explore: TrainState,
    critic: TrainState,
    target_critic: TrainState,
    temp_explore: TrainState,
    temp: TrainState,
    batch: Batch,
    discount: float,
    tau: float,
    args,
) -> Tuple[
    jax.random.PRNGKey,
    TrainState,
    TrainState,
    TrainState,
    TrainState,
    TrainState,
    TrainState,
    TrainState,
    TrainState,
    InfoDict,
]:

    rng, critic_key, actor_key = jax.random.split(rng, 3)

    new_critic, critic_info = update_critic(
        critic_key,
        actor,
        critic,
        target_critic,
        temp,
        batch,
        discount,
        soft_critic=True,
        k_samples=1,
    )
    new_critic_explore, critic_info_explore = update_critic_max(
        critic_key,
        actor_explore,
        critic_explore,
        target_critic_explore,
        temp_explore,
        batch,
        discount,
        soft_critic=True,
        k_samples=args.critic_k_samples,
    )
    critic_info = {**critic_info, **critic_info_explore}

    def _apply_actor(_):
        new_actor, info = update_actor(
            actor_key,
            actor,
            new_critic,
            temp,
            batch,
            k_samples=1,
        )
        new_actor_explore, info_explore = update_actor(
            actor_key,
            actor_explore,
            new_critic_explore,
            temp_explore,
            batch,
            k_samples=args.actor_k_samples,
        )
        info_explore = {f"{k}_explore": v for k, v in info_explore.items()}
        return new_actor_explore, new_actor, {**info, **info_explore}

    def _skip_actor(_):
        dummy_info = {
            "actor_loss": jnp.array(0.0),
            "actor_entropy": jnp.array(0.0),
            "actor_q1_var": jnp.array(0.0),
            "actor_q2_var": jnp.array(0.0),
            "actor_loss_explore": jnp.array(0.0),
            "actor_entropy_explore": jnp.array(0.0),
            "actor_q1_var_explore": jnp.array(0.0),
            "actor_q2_var_explore": jnp.array(0.0),
        }
        return actor_explore, actor, dummy_info

    new_actor_explore, new_actor, actor_info = jax.lax.cond(
        step % args.policy_update_freq == 0,
        _apply_actor,
        _skip_actor,
        operand=None,
    )
    new_target_critic_explore = target_update(
        new_critic_explore, target_critic_explore, tau
    )
    new_target_critic = target_update(new_critic, target_critic, tau)
    new_temp, alpha_info = update_temp(
        temp, actor_info["actor_entropy"], args.target_entropy
    )
    new_temp_explore, alpha_info_explore = update_temp(
        temp_explore, actor_info["actor_entropy_explore"], args.target_entropy
    )
    alpha_info_explore = {
        f"{k}_explore": v for k, v in alpha_info_explore.items()
    }
    info = {**critic_info, **actor_info, **alpha_info, **alpha_info_explore}

    return (
        rng,
        new_actor_explore,
        new_actor,
        new_critic_explore,
        new_target_critic_explore,
        new_critic,
        new_target_critic,
        new_temp_explore,
        new_temp,
        info,
    )


@functools.partial(
    jax.jit, static_argnames=("discount", "tau", "num_updates", "args")
)
def _do_multiple_updates(
    rng: jax.random.PRNGKey,
    step: int,
    actor_explore: TrainState,
    actor: TrainState,
    critic_explore: TrainState,
    target_critic_explore: TrainState,
    critic: TrainState,
    target_critic: TrainState,
    temp_explore: TrainState,
    temp: TrainState,
    batches: Batch,
    discount: float,
    tau: float,
    num_updates: int,
    args,
) -> Tuple[
    jax.random.PRNGKey,
    TrainState,
    TrainState,
    TrainState,
    TrainState,
    TrainState,
    TrainState,
    TrainState,
    TrainState,
    InfoDict,
]:

    def one_step(i, state):

        (
            rng,
            step,
            actor_explore,
            actor,
            critic_explore,
            target_critic_explore,
            critic,
            target_critic,
            temp_explore,
            temp,
            info,
        ) = state
        (
            new_rng,
            new_actor_explore,
            new_actor,
            new_critic_explore,
            new_target_critic_explore,
            new_critic,
            new_target_critic,
            new_temp_explore,
            new_temp,
            info,
        ) = _update(
            rng,
            step,
            actor_explore,
            actor,
            critic_explore,
            target_critic_explore,
            critic,
            target_critic,
            temp_explore,
            temp,
            jax.tree.map(lambda x: jnp.take(x, i, axis=1), batches),
            discount,
            tau,
            args,
        )
        step = step + 1
        return (
            new_rng,
            step,
            new_actor_explore,
            new_actor,
            new_critic_explore,
            new_target_critic_explore,
            new_critic,
            new_target_critic,
            new_temp_explore,
            new_temp,
            info,
        )

    (
        rng,
        step,
        actor_explore,
        actor,
        critic_explore,
        target_critic_explore,
        critic,
        target_critic,
        temp_explore,
        temp,
        info,
    ) = one_step(
        0,
        (
            rng,
            step,
            actor_explore,
            actor,
            critic_explore,
            target_critic_explore,
            critic,
            target_critic,
            temp_explore,
            temp,
            {},
        ),
    )

    return jax.lax.fori_loop(
        1,
        num_updates,
        one_step,
        (
            rng,
            step,
            actor_explore,
            actor,
            critic_explore,
            target_critic_explore,
            critic,
            target_critic,
            temp_explore,
            temp,
            info,
        ),
    )


@functools.partial(jax.jit)
@functools.partial(jax.vmap, in_axes=(0))
def _random_set_temp(
    rng: jax.random.PRNGKey,
) -> Tuple[jax.random.PRNGKey, jax.Array]:

    rng, temp_key = jax.random.split(rng)
    return rng, jax.random.uniform(temp_key, shape=())


def random_set_temp(
    rng: jax.random.PRNGKey, thres: float, temperature: float = 1.0
) -> Tuple[jax.random.PRNGKey, jax.Array]:

    # original temp -> use pi_explore
    # temp == 0 -> use pi_task
    # return original temperature if mean(randomness) is below threshold
    rng, temp = _random_set_temp(rng)
    return rng, temperature if temp.mean() > thres else 0


def annealing_set_temp(
    rng: jax.random.PRNGKey, step: int, temperature: float = 1.0
) -> Tuple[jax.random.PRNGKey, jax.Array]:

    # original temp -> use pi_explore
    # temp == 0 -> use pi_task
    # return original temperature if mean(randomness) is below threshold
    # this threshold stepwise increases makes it more likely for
    # mean(randomness) < threshold, so we use pi_task more (explore less)
    # as training progress
    rng, temp = _random_set_temp(rng)
    step_thres = (step // 3e5) / 10

    return rng, temperature if temp.mean() > step_thres else 0


@dataclass(frozen=True)
class ConfigArgs:
    policy_update_freq: int
    target_entropy: float
    dropout_rate: float
    actor_k_samples: int
    critic_k_samples: int
    set_explore_type: str


class SACBoxd2Learner(object):
    def __init__(
        self,
        seed: int,
        # env settings
        state_dim: int,
        action_dim: int,
        # common RL settings
        actor_lr: float = 3e-4,
        critic_lr: float = 3e-4,
        hidden_dims: Sequence[int] = (256, 256),
        discount: float = 0.99,
        tau: float = 0.005,
        num_parallel_seeds: int = 1,
        max_steps: int = 3e6,
        # sac
        policy_update_freq: int = 1,
        target_entropy: Optional[float] = None,
        init_temperature: float = 1.0,
        temp_lr: float = 3e-4,
        # droq
        dropout_rate: float = 0.001,
        layernorm: bool = True,
        # mc
        actor_k_samples: int = 1,
        critic_k_samples: int = 1,
        set_explore_type: str = "annealing",
    ):

        self.seeds = jnp.arange(seed, seed + num_parallel_seeds)
        self.tau = tau
        self.discount = discount
        self.max_steps = max_steps
        target_entropy = (
            -action_dim / 2 if target_entropy is None else target_entropy
        )

        self.args = ConfigArgs(
            policy_update_freq,
            target_entropy,
            dropout_rate,
            actor_k_samples,
            critic_k_samples,
            set_explore_type,
        )

        dummy_observations = jnp.ones([1, state_dim], dtype=jnp.float32)
        dummy_actions = jnp.ones([1, action_dim], dtype=jnp.float32)

        def _init_models(seed):
            rng = jax.random.PRNGKey(seed)
            rng, actor_key, critic_key, temp_key = jax.random.split(rng, 4)

            # optimisers
            critic_optimiser = optax.adam(learning_rate=critic_lr)
            actor_optimiser = optax.adam(learning_rate=actor_lr)
            temp_optimiser = optax.adam(learning_rate=temp_lr)

            # actors
            actor_def = NormalTanhPolicy(hidden_dims, action_dim)
            actor_explore = TrainState.create(
                apply_fn=actor_def.apply,
                params=actor_def.init(actor_key, dummy_observations),
                tx=actor_optimiser,
            )
            actor = TrainState.create(
                apply_fn=actor_def.apply,
                params=actor_def.init(actor_key, dummy_observations),
                tx=actor_optimiser,
            )

            # critics
            critic_def_explore = DoubleCritic(
                hidden_dims, layernorm=layernorm, dropout_rate=dropout_rate
            )
            critic_explore = TrainState.create(
                apply_fn=critic_def_explore.apply,
                params=critic_def_explore.init(
                    critic_key, dummy_observations, dummy_actions
                ),
                tx=critic_optimiser,
            )
            target_critic_explore = TrainState.create(
                apply_fn=critic_def_explore.apply,
                params=critic_def_explore.init(
                    critic_key, dummy_observations, dummy_actions
                ),
                tx=critic_optimiser,
            )
            critic_def = DoubleCritic(
                hidden_dims, layernorm=layernorm, dropout_rate=dropout_rate
            )
            critic = TrainState.create(
                apply_fn=critic_def.apply,
                params=critic_def.init(
                    critic_key, dummy_observations, dummy_actions
                ),
                tx=critic_optimiser,
            )
            target_critic = TrainState.create(
                apply_fn=critic_def.apply,
                params=critic_def.init(
                    critic_key, dummy_observations, dummy_actions
                ),
                tx=critic_optimiser,
            )

            # temperature
            temp_def = Temperature(init_temperature)
            temp_explore = TrainState.create(
                apply_fn=temp_def.apply,
                params=temp_def.init(temp_key),
                tx=temp_optimiser,
            )
            temp = TrainState.create(
                apply_fn=temp_def.apply,
                params=temp_def.init(temp_key),
                tx=temp_optimiser,
            )

            return (
                actor_explore,
                actor,
                critic_explore,
                target_critic_explore,
                critic,
                target_critic,
                temp_explore,
                temp,
                rng,
            )

        self.init_models = jax.jit(jax.vmap(_init_models))

        (
            self.actor_explore,
            self.actor,
            self.critic_explore,
            self.target_critic_explore,
            self.critic,
            self.target_critic,
            self.temp_explore,
            self.temp,
            self.rng,
        ) = self.init_models(self.seeds)
        self.trainable_models = ["actor", "critic"]
        self.step = 0

    def sample_actions(
        self, observations: np.ndarray, temperature: float = 1.0
    ) -> jax.Array:

        # use temperature as tool to determine to "explore" or not
        # sometimes set temperature=zero for more conservation "explore"
        if temperature > 0:
            if self.args.set_explore_type == "annealing":
                # use a annealing scheduler to determine threshold for setting
                # temperature=zero
                self.rng, temperature = annealing_set_temp(
                    self.rng, self.step, temperature
                )
            elif self.args.set_explore_type == "fixed":
                # use a fixed scheduler
                self.rng, temperature = random_set_temp(
                    self.rng, 0.1, temperature
                )
            elif self.args.set_explore_type == "half":
                # use a fixed scheduler
                self.rng, temperature = random_set_temp(
                    self.rng, 0.5, temperature
                )
            else:
                temperature = temperature

        if temperature > 0:
            rng, actions = sample_actions(
                self.rng,
                self.actor_explore.apply_fn,
                self.actor_explore.params,
                observations,
                temperature,
                distribution="log_prob",
            )

        else:
            rng, actions = sample_actions(
                self.rng,
                self.actor.apply_fn,
                self.actor.params,
                observations,
                temperature,
                distribution="log_prob",
            )

        self.rng = rng
        actions = np.asarray(actions)
        return np.clip(actions, -1, 1)

    def update(self, batch: Batch, num_updates: int = 1) -> InfoDict:
        (
            self.rng,
            self.step,
            self.actor_explore,
            self.actor,
            self.critic_explore,
            self.target_critic_explore,
            self.critic,
            self.target_critic,
            self.temp_explore,
            self.temp,
            info,
        ) = _do_multiple_updates(
            self.rng,
            self.step,
            self.actor_explore,
            self.actor,
            self.critic_explore,
            self.target_critic_explore,
            self.critic,
            self.target_critic,
            self.temp_explore,
            self.temp,
            batch,
            self.discount,
            self.tau,
            num_updates,
            self.args,
        )
        return info

    def reset(self):
        (
            self.actor_explore,
            self.actor,
            self.critic_explore,
            self.target_critic_explore,
            self.critic,
            self.target_critic,
            self.temp_explore,
            self.rng,
        ) = self.init_models(self.seeds)

    def apply_once(self, batch: Batch):
        raise NotImplementedError
