from collections import deque
import numpy as np
import torch


class LAP(object):
    def __init__(
            self,
            device,
            max_size: int = 1e6,
            batch_size: int = 256,
            max_action: float = 1,
            normalize_actions: bool = True,
    ):
        max_size = int(max_size)
        self.max_size = max_size

        self.device = device
        self.batch_size = batch_size

        self.buffer = deque(maxlen=self.max_size)

        self.normalize_actions = max_action if normalize_actions else 1

        self.size = 0

    def add(self, state, action, prev_action, next_state, reward, done, hidden_states, next_hidden_states):
        self.buffer.appendleft((state,
                                action / self.normalize_actions,
                                prev_action / self.normalize_actions,
                                next_state,
                                reward,
                                1. - done,
                                np.squeeze(hidden_states.detach().cpu().numpy(), axis=1),
                                np.squeeze(next_hidden_states.detach().cpu().numpy(), axis=1)
                                ))

    def sample(self, c_k: int):
        # indexing purposes
        c_k = min(c_k, len(self.buffer))

        indices = np.random.choice(c_k, self.batch_size, replace=False)
        batch = [self.buffer[idx] for idx in indices]
        state, action, prev_action, next_state, reward, not_done, gru_hx, gru_nhx = map(np.stack, zip(*batch))

        return (
            torch.from_numpy(state).float().to(self.device),
            torch.from_numpy(action).float().to(self.device),
            torch.from_numpy(prev_action).float().to(self.device),
            torch.from_numpy(next_state).float().to(self.device),
            torch.from_numpy(reward).float().to(self.device),
            torch.from_numpy(not_done).float().to(self.device),
            torch.from_numpy(gru_hx).float().to(self.device).permute(1, 0, 2).contiguous(),
            torch.from_numpy(gru_nhx).float().to(self.device).permute(1, 0, 2).contiguous())
