"""Critics for gym environments."""

import torch
import torch.nn.functional as F
from torch import nn

from collections import namedtuple

import numpy as np


class MLPCritic(nn.Module):
    """MLP critic for continuous control.

    Adapted from TD3 code by Scott Fujimoto:
    https://github.com/sfujim/TD3/blob/master/TD3.py
    """

    def __init__(self, state_dim, action_dim):
        super().__init__()

        # Q1 architecture
        self.l1 = nn.Linear(state_dim + action_dim, 256)
        self.l2 = nn.Linear(256, 256)
        self.l3 = nn.Linear(256, 1)

        # Q2 architecture
        self.l4 = nn.Linear(state_dim + action_dim, 256)
        self.l5 = nn.Linear(256, 256)
        self.l6 = nn.Linear(256, 1)

    def forward(self, state, action):
        sa = torch.cat([state, action], 1)

        q1 = F.relu(self.l1(sa))
        q1 = F.relu(self.l2(q1))
        q1 = self.l3(q1)

        q2 = F.relu(self.l4(sa))
        q2 = F.relu(self.l5(q2))
        q2 = self.l6(q2)
        return q1, q2

    def q1(self, state, action):
        sa = torch.cat([state, action], 1)

        q1 = F.relu(self.l1(sa))
        q1 = F.relu(self.l2(q1))
        q1 = self.l3(q1)
        return q1


Experience = namedtuple("Experience", ["obs", "action", "reward", "next_obs", "done"])

# Used for batches of items, e.g. a batch of obs, a batch of actions.
BatchExperience = namedtuple("BatchExperience", Experience._fields)


class ReplayBuffer:
    """Stores experience for training RL algorithms.

    Based on TD3 implementation and PGA-MAP-Elites implementation:
    https://github.com/sfujim/TD3/blob/master/utils.py
    https://github.com/ollenilsson19/PGA-MAP-Elites/blob/master/utils.py
    """

    def __init__(
        self,
        capacity: int,
        state_dim: int,
        action_dim: int,
        seed: int = None,
    ):
        self.capacity = capacity
        self.ptr = 0
        self.size = 0
        self.additions = 0

        self.obs = np.empty((capacity, state_dim), dtype=np.float32)
        self.action = np.empty((capacity, action_dim), dtype=np.float32)
        self.reward = np.empty(capacity, dtype=np.float32)
        self.next_obs = np.empty((capacity, state_dim), dtype=np.float32)
        self.done = np.empty(capacity, dtype=np.float32)

        self.rng = np.random.default_rng(seed)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def add(self, e: Experience):
        """Adds experience to the buffer."""
        self.obs[self.ptr] = e.obs
        self.action[self.ptr] = e.action
        self.next_obs[self.ptr] = e.next_obs
        self.reward[self.ptr] = e.reward
        self.done[self.ptr] = e.done

        self.ptr = (self.ptr + 1) % self.capacity
        self.size = min(self.size + 1, self.capacity)
        self.additions += 1

    def add_batch(
        self,
        obs_batch: np.ndarray,
        action_batch: np.ndarray,
        next_obs_batch: np.ndarray,
        reward_batch: np.ndarray,
        done_batch: np.ndarray,
    ):
        batch_size = obs_batch.shape[0]

        if batch_size <= self.capacity - self.ptr:
            # Everything fits
            l = batch_size
        else:
            # Batch size bigger than remaining space
            l = self.capacity - self.ptr

        self.obs[self.ptr : self.ptr + l] = obs_batch[:l]
        self.action[self.ptr : self.ptr + l] = action_batch[:l]
        self.next_obs[self.ptr : self.ptr + l] = next_obs_batch[:l]
        self.reward[self.ptr : self.ptr + l] = reward_batch[:l]
        self.done[self.ptr : self.ptr + l] = done_batch[:l]

        self.ptr = (self.ptr + l) % self.capacity
        self.size = min(self.size + l, self.capacity)
        self.additions += l

        if l < batch_size:
            self.add_batch(
                obs_batch[l:],
                action_batch[l:],
                next_obs_batch[l:],
                reward_batch[l:],
                done_batch[l:],
            )

    def __len__(self):
        """Number of Experience in the buffer."""
        return self.size

    def sample_tensors(self, n: int):
        """Same as sample() but returns tensors with each item in batch."""
        if len(self) == 0:
            raise ValueError("No entries currently in ReplayBuffer.")

        # This is used in the TD3 and PGA-ME implementation - if the buffer is
        # big enough, sampling duplicates is not an issue.
        # https://github.com/sfujim/TD3/blob/385b33ac7de4767bab17eb02ade4a268d3e4e24f/utils.py#L32
        # One concern with this is that the indices are out of range. However,
        # since we are only adding to the buffer, indices in the range [ 0,
        # len(self) ) will always be occupied. If we also remove from the
        # buffer, then we would want to have some offset here.
        indices = self.rng.integers(len(self), size=n)

        return BatchExperience(
            torch.as_tensor(self.obs[indices], device=self.device),
            torch.as_tensor(self.action[indices], device=self.device),
            torch.as_tensor(self.reward[indices], device=self.device),
            torch.as_tensor(self.next_obs[indices], device=self.device),
            torch.as_tensor(self.done[indices], device=self.device),
        )
