import random
from typing import List, NamedTuple

import numpy as np
import torch


class BatchOutput(NamedTuple):
    obs: torch.Tensor
    action: torch.Tensor
    action_repeat: torch.Tensor
    action_repeat_prob: torch.Tensor
    reward: torch.Tensor
    non_terminal: torch.Tensor


class ReplayMemory:
    def __init__(self, capacity: int, action_size: int, action_repeat_size: int, observation_size: int):
        self.capacity = int(capacity)
        self.observations = np.empty((capacity, observation_size), dtype=np.float32)
        self.actions = np.empty((capacity, action_size), dtype=np.float32)
        self.actions_repeat = np.empty((capacity,), dtype=np.float32)
        self.actions_repeat_prob = np.empty((capacity, action_repeat_size), dtype=np.float32)
        self.rewards = np.empty((capacity,), dtype=np.float32)
        self.non_terminals = np.empty((capacity, 1), dtype=np.float32)

        self.full = False
        self.idx = 0
        self.episodes = 0
        self.steps = 0
        self.action_repeat_size = action_repeat_size

    def _sample_idx(self, L):
        """ returns an index for a valid single sequence chunk uniformly sampled from the memory"""
        valid_idx = False
        while not valid_idx:
            idx = np.random.randint(0, self.capacity if self.full else self.idx - L)
            idxs = np.arange(idx, idx + L) % self.capacity
            valid_idx = not self.idx in idxs[1:]  # Make sure data does not cross the memory index
        return idxs

    def push(self, observation, action, repeat, repeat_prob, reward, done):
        self.observations[self.idx] = observation
        self.actions[self.idx] = action
        self.actions_repeat[self.idx] = repeat
        self.actions_repeat_prob[self.idx] = repeat_prob
        self.rewards[self.idx] = reward
        self.non_terminals[self.idx] = not done

        self.idx = (self.idx + 1) % self.capacity
        self.full = self.full or self.idx == 0
        self.steps, self.episodes = self.steps + 1, self.episodes + (1 if done else 0)

    def sample(self, batch_size, chunk_size, device='cpu'):
        start_idxs = np.asarray([self._sample_idx(chunk_size) for _ in range(batch_size)])

        start_idxs = start_idxs.transpose().reshape(-1)  # Unroll indices
        observations = torch.as_tensor(self.observations[start_idxs].astype(np.float32)).to(device)
        observations = observations.reshape(chunk_size, batch_size, *observations.shape[1:])
        actions = self.actions[start_idxs].reshape(chunk_size, batch_size, -1)
        actions_repeat = self.actions_repeat[start_idxs].reshape(chunk_size, batch_size, -1)
        actions_repeat_prob = self.actions_repeat_prob[start_idxs].reshape(chunk_size, batch_size, -1)
        rewards = self.rewards[start_idxs].reshape(chunk_size, batch_size)
        non_terminals = self.non_terminals[start_idxs].reshape(chunk_size, batch_size, 1)

        batch = BatchOutput(observations,
                            torch.as_tensor(actions).to(device),
                            torch.as_tensor(actions_repeat).to(device),
                            torch.as_tensor(actions_repeat_prob).to(device),
                            torch.as_tensor(rewards).to(device),
                            torch.as_tensor(non_terminals).to(device))
        return batch

    def __len__(self):
        return self.capacity if self.full else (self.idx + 1)
