"""Provides the PGAEmitter."""

import logging
from typing import List

import numpy as np
from ribs.emitters import EmitterBase
from ribs.archives import ArchiveBase
from src.qd.td3 import TD3
from src.utils import Trajectory

logger = logging.getLogger(__name__)


class PGAEmitter(EmitterBase):
    """Emitter based on the PG variation from PGA-ME."""

    def __init__(
        self,
        archive: ArchiveBase,
        *,
        td3: TD3,
        x0: np.ndarray,
        sigma0: float,
        batch_size: int,
        init_iters: int,
        bounds=None,
        seed=None,
    ):
        self._rng = np.random.default_rng(seed)
        self._batch_size = batch_size
        self._x0 = np.array(x0, dtype=archive.dtype)
        self._sigma0 = sigma0
        self.td3 = td3
        self.init_iters = init_iters
        self.num_asks = 0

        EmitterBase.__init__(
            self,
            archive,
            solution_dim=len(x0),
            bounds=bounds,
        )

        self._greedy_eval = None

    @property
    def batch_size(self):
        return self._batch_size

    @property
    def greedy_eval(self):
        """Performance of the last evaluated greedy solution."""
        return self._greedy_eval

    def ask(self):
        """Returns batch_size solutions.

        One of the solutions is the greedy solution. The other batch_size - 1
        solutions are created by randomly choosing elites from the archive and
        applying gradient ascent to them with TD3.

        When the archive is empty or for the first few iterations, we sample solutions
        from a Gaussian distribution centered at x0 with std sigma0.

        WARNING: Bounds are currently not enforced.
        """
        self.num_asks += 1
        if self.num_asks <= self.init_iters or self.archive.empty:
            logger.info("Sampling solutions from Gaussian distribution")
            sols = np.expand_dims(self._x0, axis=0) + self._rng.normal(
                scale=self._sigma0,
                size=(self._batch_size, self.solution_dim),
            ).astype(self.archive.dtype)
        else:
            logger.info("Sampling solutions with PG variation")
            sols = self.archive.sample_elites(self._batch_size)["solution"]
            for i in range(1, self.batch_size):
                sols[i] = self.td3.gradient_ascent(sols[i])
            logger.info("Solutions with PG variation: %d", len(sols))

        sols[0] = self.td3.get_actor()
        return sols

    def tell(self, solution, objective, measures, add_info, **fields):
        self._greedy_eval = objective[0]
        super().tell(solution, objective, measures, add_info, **fields)

    def add_experience(self, trajs: List[Trajectory]):
        # Adds the collected experience to td3's buffer and trains it's critics
        buffer = self.td3.buffer
        obs_batch = [traj.states[:-1] for traj in trajs]
        action_batch = [traj.actions[:-1] for traj in trajs]
        next_obs_batch = [traj.states[1:] for traj in trajs]
        reward_batch = [traj.rewards[:-1] for traj in trajs]
        done_batch = [np.zeros_like(reward_traj) for reward_traj in reward_batch]
        for done_traj in done_batch:
            done_traj[-1] = 1

        self.td3.buffer.add_batch(
            np.concatenate(obs_batch),
            np.concatenate(action_batch),
            np.concatenate(next_obs_batch),
            np.concatenate(reward_batch),
            np.concatenate(done_batch),
        )

    def train_td3(self):
        self.td3.train_td3()
