from typing import Dict, Tuple

import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from flax.training.train_state import TrainState

from jaxOfflineRL.types import Params, PRNGKey


def update_state_action_reward(
    key: PRNGKey,
    critic: TrainState,
    batch: FrozenDict,
    mean_rewards: float,
) -> Tuple[TrainState, Dict[str, float]]:

    def rm_loss_fn(critic_params: Params) -> Tuple[jnp.ndarray, Dict[str, float]]:
        rewards = critic.apply_fn(
            {"params": critic_params},
            batch["observations"],
            batch["actions"],
            training=True,
            rngs={"dropout": key},
        )
        # print(rewards.shape)
        critic_loss = ((rewards - batch["rewards"])**2).mean()

        log_dict = {}
        log_dict["action_reward_loss"] = critic_loss

        log_dict["R_divides_r"] = (rewards/batch["rewards"]).mean()
        log_dict["R_divides_meanR"] = (rewards/mean_rewards).mean()
        log_dict["var"] = rewards.var(axis=0).mean()

        return critic_loss, log_dict

    grads, info = jax.grad(rm_loss_fn, has_aux=True)(critic.params)
    new_critic = critic.apply_gradients(grads=grads)

    return new_critic, info
