"""Implementations of algorithms for continuous control."""

import copy
import functools
from typing import Dict, Optional, Sequence, Tuple

import gym
import numpy as np
import jax
import jax.numpy as jnp
import optax
from flax.core.frozen_dict import FrozenDict
from flax.training.train_state import TrainState
from jaxOfflineRL.agents.common import eval_actions_jit

from jaxOfflineRL.networks import DeterministicPolicy
from jaxOfflineRL.types import Params, PRNGKey
from jaxOfflineRL.networks.values import ActionValue, ActionValueEnsemble
from jaxOfflineRL.reward_models.action_reward.action_reward_updater import update_action_reward

#@functools.partial(jax.jit, static_argnames=("critic_reduction","update_target_actor", "mean_reward", "min_reward"))
@functools.partial(jax.jit, static_argnames=("mean_reward"))
def _update_jit(
    rng: PRNGKey,
    rm: TrainState,
    batch: TrainState,
    mean_rewards: float,
) -> Tuple[PRNGKey, TrainState, Params, TrainState, Params, Dict[str, float]]:

    key, rng = jax.random.split(rng)

    new_rm, info = update_action_reward(
        key,
        rm,
        batch,
        mean_rewards
    )

    return (
        rng,
        new_rm,
        {**info},
    )


class ActionRewardLearner():
    def __init__(
        self,
        seed: int,
        observation_space: gym.Space,
        action_space: gym.Space,
        hidden_dims: Sequence[int] = (256, 256),
        rm_lr: float = 3e-4,
    ):
        """
        An implementation of the version of Soft-Actor-Critic described in https://arxiv.org/abs/1801.01290
        """

        observations = observation_space.sample()
        actions = action_space.sample()

        rng = jax.random.PRNGKey(seed)
        rng, rm_key = jax.random.split(rng, 2)

        #action_dim = actions.shape[-1]
        #rm_def = ActionValue(hidden_dims)
        rm_def = ActionValueEnsemble(hidden_dims)

        rm_params = rm_def.init(rm_key, actions)["params"]
        rm = TrainState.create(
            apply_fn=rm_def.apply,
            params=rm_params,
            tx=optax.adam(learning_rate=rm_lr),
        )

        self._rng = rng
        self._rm = rm

    def update(self, batch: FrozenDict, mean_rewards: float) -> Dict[str, float]:
        (
            new_rng,
            new_rm,
            info,
        ) = _update_jit(
            self._rng,
            self._rm,
            batch,
            mean_rewards,
        )
        self._rng = new_rng
        self._rm = new_rm

        return info