from jax import numpy as jnp
from jax import jit, value_and_grad
from jax import random
import optax


class GreedyAgent:
    """
    Greedily maximize reward signal assuming we know the environment's dynamics
    """

    def __init__(self, env, use_jit=True):
        self.env = env
        self.optimizer, self.optax_step = self.initialize_optimizer(use_jit)

    def initialize_optimizer(self, use_jit):
        """
        Intialize the jitted loss function and gradient step function; we should only create
        a closure one time.
        :return:
        """
        optimizer = optax.adam(1e-1)

        def loss_function(params: optax.Params, state):
            """params should have the action phi in [0, 1] per group"""
            action = params["threshold_action"]
            return -self.env.check_step_reward(state, action)

        # Returning a closure with optimizer. Roughly equivalent to an implicit static arg.
        def optax_step(params_, opt_state_, current_state):
            loss_value, grads = value_and_grad(loss_function, argnums=0)(
                params_, current_state
            )
            updates, opt_state_ = optimizer.update(grads, opt_state_, params_)
            params_ = optax.apply_updates(params_, updates)
            return params_, opt_state_, loss_value, updates

        if use_jit:
            optax_step = jit(optax_step)

        return optimizer, optax_step

    def policy(self, observation, info, seed=None):
        """
        Gradient descent in objective function
        """

        # Initial policy
        if seed is None:
            if info["prev_action"] is not None:
                init_action = info["prev_action"]
            else:
                init_action = jnp.array([0.5, 0.5])
        else:
            key = random.PRNGKey(seed)
            init_action = random.uniform(key, shape=(2,))

        params = {"threshold_action": init_action}

        opt_state = self.optimizer.init(params)
        current_state = info["current_state"]

        for _ in range(100):
            params, opt_state, loss_value, updates = self.optax_step(
                params, opt_state, current_state
            )
            if jnp.linalg.norm(updates["threshold_action"]) <= 1e-10:
                break

        # TODO why do we need to do this
        return jnp.clip(params["threshold_action"], 0, 0.999999)
