from jax import random #,grad
import jax.numpy as jnp
from jax.config import config
config.update("jax_enable_x64", True)
import flax
import optax
import numpy as np


# An simplified JAX implementation based on https://github.com/hardmaru/estool/blob/master/es.py which is itself based 
# on the work of https://github.com/openai/evolution-strategies-starter

class OpenES:
  def __init__(self, num_params,             # number of model parameters
               sigma_init=0.5,               # initial standard deviation
               popsize=10,                   # population size
               learning_rate=0.1,            # learning rate for standard deviation
               seed = 0,                     # starting seed for normal distribution
              ):
    self.sigma_decay = 0.999
    self.sigma_limit=0.005#0.01
    self.num_params = num_params
    self.sigma = sigma_init
    self.popsize = popsize
    self.mu = jnp.zeros(self.num_params,dtype="float64")
    self.first_interation = True
    self.learning_rate = learning_rate

    self.optimizer = optax.adam(learning_rate = self.learning_rate, b1=0.99)
    #self.optimizer = optax.radam(learning_rate = self.learning_rate, b1=0.99)
    
    self.state = self.optimizer.init(self.mu)

    self.rnd_key = random.PRNGKey(0)

    self.step = 0
  
  def ask(self):
    self.epsilon =  random.normal(self.rnd_key,(self.popsize, self.num_params),dtype="float64")
    _, self.rnd_key = random.split(self.rnd_key)
    self.solutions = self.mu.reshape(1, self.num_params) + self.epsilon * self.sigma

    return self.solutions

  def tell(self,reward):
    idx = jnp.argsort(reward)[::-1]

    best_reward = reward[idx[0]]
    best_mu = self.solutions[idx[0]]

    self.curr_best_reward = best_reward
    self.curr_best_mu = best_mu

    if self.first_interation:
      self.first_interation = False
      self.best_reward = self.curr_best_reward
      self.best_mu = best_mu
    else:
      if (self.curr_best_reward > self.best_reward):
        self.best_mu = best_mu
        self.best_reward = self.curr_best_reward


    normalized_reward = (reward - jnp.mean(reward)) / jnp.std(reward)
    change_mu = 1./(self.popsize*self.sigma)*jnp.dot(self.epsilon.T, normalized_reward)
    
    updates, self.state = self.optimizer.update(-change_mu, self.state, self.mu)
    self.mu = optax.apply_updates(self.mu, updates)

    if (self.sigma > self.sigma_limit):
      self.sigma *= self.sigma_decay