"""Replay buffer class to save data then use it to train a value function"""
import random
from collections import namedtuple, deque

import numpy as np
import torch


class TorchReplayMemory:
    def __init__(self, max_size: int, device: str = None):
        self.max_size = max_size
        self.device = device
        self.buffer_size = 0
        self.memory = deque([], maxlen=self.max_size)
        self.transition = namedtuple("Transition", ("obs", "action", "next_obs", "reward"))

    def push(self, *args):
        """Save a transition"""
        self.memory.append(self.transition(*args))
        self.buffer_size += 1

    def sample(self, n_samples):
        # check if the buffer is empty
        if self.buffer_size == 0:
            raise ValueError("The replay buffer is empty")

        # check if the number of sample is bigger than the buffer size
        if self.buffer_size <= n_samples:
            raise ValueError("the number of sample is bigger than the buffer size")
        return random.sample(self.memory, n_samples)

    def __len__(self):
        return len(self.memory)

    def get_transition(self):
        return self.transition

    def save(self, log_dir: str, seed: int):
        np.save(log_dir + "/replay_buffer_{}.npy".format(seed), self.memory)

    def load(self, log_dir: str, seed: int = 1):
        self.memory = np.load(log_dir + "replay_buffer_{}.npy".format(seed), allow_pickle=True)[()]
        self.memory = deque([list(i) for i in self.memory])
        self.buffer_size = len(self.memory)

    def process_batch(self, batch, device: str):
        batch = self.transition(*zip(*batch))
        real_reward_mask = [not np.isnan(p) for p in [r["mdp"] for r in batch.reward]]
        if len(real_reward_mask) == 0:
            raise ValueError("All nan")
        real_reward_idx = [i for i, x in enumerate(real_reward_mask) if x]
        non_final_mask = torch.tensor(tuple(map(lambda s: s["mdp"] is not None, batch.next_obs)), dtype=torch.bool)
        non_final_next_states = torch.tensor(
            np.array([s["mdp"] for s in batch.next_obs if s["mdp"] is not None]),
            device=device,
            dtype=torch.float,
        )

        non_final_next_monitor_states = torch.tensor(
            np.array([s["monitor"] for s in batch.next_obs if s["monitor"] is not None]),
            device=device,
            dtype=torch.float,
        )

        mdp_obs = torch.tensor(np.array([state["mdp"] for state in batch.obs]), device=device, dtype=torch.float)
        mon_obs = torch.tensor(
            np.array([state["monitor"] for state in batch.obs]).astype(np.float32), device=device, dtype=torch.float
        )

        mdp_reward = torch.tensor(np.array([r["mdp"] for r in batch.reward]), device=device, dtype=torch.float)
        mon_reward = torch.tensor(np.array([r["monitor"] for r in batch.reward]), device=device, dtype=torch.float)

        mdp_action = torch.tensor(np.array([a["mdp"] for a in batch.action]), device=device)
        mon_action = torch.tensor(np.array([a["monitor"] for a in batch.action]), device=device)
        process_batch = {
            "real_reward_idx": real_reward_idx,
            "mdp_obs": mdp_obs,
            "mon_obs": mon_obs.unsqueeze(1),
            "non_final_mask": non_final_mask,
            "non_final_next_states": non_final_next_states,
            "non_final_next_monitor_states": non_final_next_monitor_states.unsqueeze(1),
            "mdp_reward": mdp_reward.unsqueeze(1),
            "mon_reward": mon_reward.unsqueeze(1),
            "mdp_action": mdp_action.unsqueeze(1),
            "mon_action": mon_action.unsqueeze(1),
        }
        return process_batch
