from typing import Tuple
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training.train_state import TrainState

from src.common import InfoDict


class Temperature(nn.Module):
    initial_temperature: float = 1.0

    @nn.compact
    def __call__(self) -> jax.Array:
        log_temp = self.param(
            "log_temp",
            init_fn=lambda key: jnp.full(
                (), jnp.log(self.initial_temperature)
            ),
        )
        return jnp.exp(log_temp)


def update(
    temp: TrainState, entropy: float, target_entropy: float
) -> Tuple[TrainState, InfoDict]:

    def temperature_loss_fn(temp_params):
        temperature = temp.apply_fn(temp_params)
        temp_loss = temperature * (entropy - target_entropy).mean()
        info = {"temperature": temperature, "temp_loss": temp_loss}

        return temp_loss, info

    grads, info = jax.grad(temperature_loss_fn, has_aux=True)(temp.params)
    new_temp = temp.apply_gradients(grads=grads)

    return new_temp, info
