import jax
import jax.numpy as jnp
import chex
from typing import Tuple
from ..strategy import Strategy
from flax import struct


@struct.dataclass
class EvoState:
    mean: chex.Array
    sigma: float
    best_member: chex.Array
    best_fitness: float = jnp.finfo(jnp.float32).max
    gen_counter: int = 0


@struct.dataclass
class EvoParams:
    sigma_init: float = 0.04

class SoES(Strategy):
    def __init__(self, num_dims: int, popsize: int, part_size: int, padding: int = 0):
        super().__init__(num_dims, popsize)
        # assert not self.popsize & 1, "Population size must be even"
        self.strategy_name = "SoES"
        self.part_size = part_size
        self.padding = padding

    @property
    def params_strategy(self) -> EvoParams:
        """Return default parameters of evolution strategy."""
        return EvoParams()

    @staticmethod
    def update_params(params: EvoParams, args) -> EvoParams:
        """Update parameters of evolution strategy."""
        for key, value in args.__dict__.items():
            if key in params.__dict__:
                params = params.replace(**{key: value})

        return params

    def initialize_strategy(
            self, rng: chex.PRNGKey, params: EvoParams
    ) -> EvoState:
        """`initialize` the evolution strategy."""
        initialization = jnp.zeros((self.num_dims,))
        state = EvoState(
            mean=initialization,
            sigma=params.sigma_init,
            best_member=initialization,
        )
        return state

    def ask_strategy(
            self, rng: chex.PRNGKey, state: EvoState, params: EvoParams
    ) -> Tuple[chex.Array, EvoState]:
        """`ask` for new parameter candidates to evaluate next."""
        # Antithetic sampling of noise
        # z_plus = jax.random.uniform(rng,(self.popsize, self.num_dims), minval=-1, maxval=1)
        # rng = jax.random.PRNGKey(42)
        z = jax.random.normal(rng,(self.part_size, self.popsize))
        # z = jnp.concatenate([z_plus, -1.0 * z_plus])
        x = state.sigma * z
        # Print the rank of x


        return x, state

    def tell_strategy(
            self,
            x: chex.Array,
            fitness: chex.Array,
            state: EvoState,
            params: EvoParams,
    ) -> EvoState:
        """`tell` performance data for strategy state update."""
        # Reconstruct noise from last mean/std estimates
        # Treat fitness as weights and compute dot product with noise
        weighted_fitness_dot = jnp.dot(x, fitness).reshape(-1)[:-self.padding if self.padding > 0 else None]

        # Update the mean using the dot product result and bias term
        # mean = weighted_fitness_dot + fitness[-1]
        mean = weighted_fitness_dot
        # sigma = jnp.maximum(sigma, params.sigma_limit)

        return state.replace(mean=mean)