import jax
import jax.numpy as jnp
import chex
from typing import Tuple
from ..strategy import Strategy
from ..utils import GradientOptimizer, OptState, OptParams
from flax import struct


@struct.dataclass
class EvoState:
    mean: chex.Array
    sigma: float
    best_member: chex.Array
    best_fitness: float = jnp.finfo(jnp.float32).max
    gen_counter: int = 0
    e: float = 1.0


@struct.dataclass
class EvoParams:
    sigma_init: float = 0.04
    e_decay : float = 0.99
    min_e = 0.1
    grad_popsize = 10


class SoGradES(Strategy):
    def __init__(self, num_dims: int, popsize: int, part_size: int, padding: int = 0):
        super().__init__(num_dims, popsize)
        assert not self.popsize & 1, "Population size must be even"
        self.strategy_name = "SoGradES"
        self.grads = jnp.zeros((100, self.num_dims))
        self.cursor = 0
        self.part_size = part_size
        self.padding = padding

    @property
    def params_strategy(self) -> EvoParams:
        """Return default parameters of evolution strategy."""
        return EvoParams()

    @staticmethod
    def update_params(params: EvoParams, args) -> EvoParams:
        """Update parameters of evolution strategy."""
        for key, value in args.__dict__.items():
            if key in params.__dict__:
                params = params.replace(**{key: value})

        return params

    def initialize_strategy(
            self, rng: chex.PRNGKey, params: EvoParams
    ) -> EvoState:
        """`initialize` the evolution strategy."""
        initialization = jnp.zeros((self.num_dims,))
        state = EvoState(
            mean=initialization,
            sigma=params.sigma_init,
            best_member=initialization,
        )
        return state

    def update_grads(self, grad):
        if self.cursor >= int(100):
            return
            self.cursor = 0
            # self.all_grads = True
        self.grads = self.grads.at[self.cursor].set(grad)
        self.cursor += 1

    def ask_strategy(
            self, rng: chex.PRNGKey, state: EvoState, params: EvoParams
    ) -> Tuple[chex.Array, EvoState]:
        """`ask` for new parameter candidates to evaluate next."""
        # Antithetic sampling of noise
        # z = jax.random.uniform(rng,(self.popsize, self.num_dims), minval=-1, maxval=1)
        z = jax.random.normal(rng, (self.part_size, self.popsize))
        g = self.grads[:self.cursor]
        x = state.sigma * z
        # Create boolean masks
        x_mask = jnp.arange(self.popsize) < state.e * x.shape[1]
        x = jnp.where(x_mask, x, 0)
        # g_mask = jnp.arange(self.popsize) < self.cursor #(1 - state.e) * x.shape[1]
        # g = jnp.concatenate([g, jnp.zeros((self.popsize - self.cursor, self.num_dims))]).T
        # g = jnp.where(g_mask, g, 0)
        g = self.grads.T
        return x, g, state

    def tell_strategy(
            self,
            x: chex.Array,
            g: chex.Array,
            fitness: chex.Array,
            g_fitness: chex.Array,
            state: EvoState,
            params: EvoParams,
    ) -> EvoState:
        """`tell` performance data for strategy state update."""
        # Reconstruct noise from last mean/std estimates
        # Treat fitness as weights and compute dot product with noise
        weighted_fitness_dot = jnp.dot(x, fitness).reshape(-1)[:-self.padding if self.padding > 0 else None]
        weighted_fitness_dot += jnp.dot(g, g_fitness).reshape(-1)[:-self.padding if self.padding > 0 else None]
        # Update the mean using the dot product result and bias term
        # mean = weighted_fitness_dot + fitness[-1]
        mean = weighted_fitness_dot
        # sigma = jnp.maximum(sigma, params.sigma_limit)
        if self.cursor == 1:
            new_e = state.e * params.e_decay
            if new_e < params.min_e:
                new_e = params.min_e
            state = state.replace(e=new_e)
        return state.replace(mean=mean)