from collections.abc import Iterable

import numpy as np
from numpy.random import RandomState

from .misc import (
    make_vector,
    non_negative_float,
    positive_float,
    positive_float_or_none,
    positive_int,
    readonly_view,
)
from .optimizers import Adam


class PGPE:
    """PGPE algorithm from:
    "Policy Gradients with Parameter-based Exploration for Control"
    Sehnke et al., 2010

    Implementation adapted from:
    https://github.com/nnaisense/pgpelib
    """

    def __init__(
        self,
        *,
        solution_length: int,
        popsize: int,
        symmetric_sampling: bool = True,
        update_type: str = "reinforce",
        normalize_fitness: bool = True,
        center_learning_rate: float = 0.15,
        stdev_learning_rate: float = 0.1,
        stdev_init: float | Iterable[float] = 0.1,
        stdev_max_change: float | None = 0.2,
        use_lr_scheduler: bool = True,
        max_generations: int = 1000,
        center_init: np.ndarray | None = None,
        optimizer_config: dict | None = None,
        seed: int | None = None,
        dtype: np.dtype | str = "float32",
    ):
        """Initializes the PGPE instance."""
        # ... (most of the __init__ code is the same) ...
        self._length = positive_int(solution_length)
        self._popsize = positive_int(popsize)
        self._symmetric_sampling = bool(symmetric_sampling)
        self._normalize_fitness = bool(normalize_fitness)

        if self._symmetric_sampling and self._popsize % 2 != 0:
            raise ValueError("For symmetric sampling, popsize must be even.")

        if update_type.lower() not in ("reinforce", "natural"):
            raise ValueError("update_type must be 'reinforce' or 'natural'.")
        self._update_type = update_type.lower()

        self._initial_center_lr = positive_float(center_learning_rate)
        self._initial_stdev_lr = non_negative_float(stdev_learning_rate)
        self._stdev_lr = self._initial_stdev_lr

        self._logstd = np.log(make_vector(stdev_init, self._length, np.dtype(dtype)))
        self._stdev_max_change = positive_float_or_none(stdev_max_change)

        self._optimizer = Adam(
            stepsize=self._initial_center_lr,
            solution_length=self._length,
            dtype=dtype,
            **optimizer_config if optimizer_config else {},
        )
        self._center = make_vector(0.0, self._length, np.dtype(dtype))
        if center_init is not None:
            self._center[:] = center_init

        self._use_lr_scheduler = bool(use_lr_scheduler)
        self._max_generations = positive_int(max_generations)
        self._generation_count = 0

        self._running_mean = 0
        self._running_var = 1

        self._rndgen = RandomState(seed)
        self._noises: np.ndarray
        self._solutions: np.ndarray

    def ask(self) -> np.ndarray:
        if self._symmetric_sampling:
            num_base_noises = self._popsize // 2
            base_noises = self._rndgen.randn(num_base_noises, self._length)
            self._noises = np.concatenate([base_noises, -base_noises], axis=0)
        else:
            self._noises = self._rndgen.randn(self._popsize, self._length)

        self._solutions = self._center + np.exp(self._logstd) * self._noises
        return self._solutions.copy().astype(self._solutions.dtype)

    def tell(self, fitnesses: Iterable[float]) -> None:
        assert len(fitnesses) == self._popsize, "Invalid number of fitnesses."

        stdev = np.exp(self._logstd)
        fitness = np.asarray(fitnesses, dtype=stdev.dtype)

        if self._symmetric_sampling:
            num_pairs = self._popsize // 2
            fitness_pos = fitness[:num_pairs]
            fitness_neg = fitness[num_pairs:]
            fitness = fitness_pos - fitness_neg
            base_noises = self._noises[:num_pairs]
        else:
            base_noises = self._noises

        if self._normalize_fitness:
            self._running_mean += (np.mean(fitness) - self._running_mean) / (
                self._generation_count + 1
            )
            self._running_var += (np.var(fitness) - self._running_var) / (
                self._generation_count + 1
            )
            fitness = (fitness - self._running_mean) / (
                np.sqrt(self._running_var + 1e-8)
            )

        grad_center = np.mean(fitness[:, None] * base_noises / stdev, axis=0)
        grad_log_stdev = np.mean(fitness[:, None] * (base_noises**2 - 1), axis=0)

        if self._update_type == "natural":
            grad_center *= stdev**2
            grad_log_stdev /= 2

        self._center += self._optimizer.ascent(grad_center)

        delta_logstd = self._stdev_lr * grad_log_stdev
        if self._stdev_max_change is not None:
            c = self._stdev_max_change
            delta_logstd = np.clip(delta_logstd, np.log(1.0 - c), np.log(1.0 + c))

        self._logstd += delta_logstd

        self._update_learning_rates()

    def _update_learning_rates(self) -> None:
        """Applies the selected learning rate decay schedule."""
        self._generation_count += 1

        new_center_lr = self._initial_center_lr
        new_stdev_lr = self._initial_stdev_lr

        if self._generation_count <= self._max_generations:
            decay_fraction = self._generation_count / self._max_generations

            multiplier = 1.0 - (0.8 * decay_fraction)

            new_center_lr = self._initial_center_lr * multiplier
            new_stdev_lr = self._initial_stdev_lr * multiplier

        # Update stdev learning rate
        # self._stdev_lr = new_stdev_lr

        # Update the learning rate within the Adam optimizer instance
        if self._optimizer and hasattr(self._optimizer, "stepsize"):
            self._optimizer.stepsize = new_center_lr

    @property
    def stdev(self) -> np.ndarray:
        return readonly_view(np.exp(self._logstd))

    @property
    def center(self) -> np.ndarray:
        return readonly_view(self._center)
