import math

import numpy as np
import torch
from torch import multiprocessing as mp


class ReplayBufferStorage:
    def __init__(
        self, size, obs_shape, act_shape, action_repeat, obs_dtype=torch.float32
    ):
        self.s_dtype = obs_dtype

        # buffer arrays
        self.s_stack = torch.zeros((size,) + obs_shape, dtype=self.s_dtype)
        self.action_stack = torch.zeros((size,) + act_shape, dtype=torch.float32)
        self.reward_stack = torch.zeros((size, action_repeat), dtype=torch.float32)
        self.s1_stack = torch.zeros((size,) + obs_shape, dtype=self.s_dtype)
        self.done_stack = torch.zeros((size, 1), dtype=torch.int)

        self.obs_shape = obs_shape
        self.size = size
        self._next_idx = 0
        self._max_filled = 0

        self._shared = False

    def __len__(self):
        return self.max_filled

    @property
    def next_idx(self):
        if self._shared:
            return self._next_idx.value
        else:
            return self._next_idx

    @next_idx.setter
    def next_idx(self, v):
        if self._shared:
            self._next_idx.value = v
        else:
            self._next_idx = v

    @property
    def max_filled(self):
        if self._shared:
            return self._max_filled.value
        else:
            return self._max_filled

    @max_filled.setter
    def max_filled(self, v):
        if self._shared:
            self._max_filled.value = v
        else:
            self._max_filled = v

    def add(self, s, a, r, s_1, d):
        # this buffer supports batched experience
        if len(s.shape) > len(self.obs_shape):
            # there must be a batch dimension
            num_samples = len(s)
        else:
            num_samples = 1
            d = [d]

        if not isinstance(s, torch.Tensor):
            # convert states to numpy (checking for LazyFrames)
            if not isinstance(s, np.ndarray):
                s = np.asarray(s)
            if not isinstance(s_1, np.ndarray):
                s_1 = np.asarray(s_1)

            # convert to torch tensors
            s = torch.from_numpy(s)
            a = torch.from_numpy(a).float()
            r = torch.Tensor(r).float().unsqueeze(0)

            steps_short = self.reward_stack.shape[1] - r.shape[1]
            if steps_short > 0:
                r = torch.cat((r, torch.zeros(1, steps_short)), dim=1)
            s_1 = torch.from_numpy(s_1)
            d = torch.Tensor(d).int()

            # make sure tensors are floats not doubles
            if self.s_dtype is torch.float32:
                s = s.float()
                s_1 = s_1.float()

        else:
            # move to cpu
            s = s.cpu()
            a = a.cpu()
            r = r.cpu()
            s_1 = s_1.cpu()
            d = d.int().cpu()

        # Store at end of buffer. Wrap around if past end.
        R = np.arange(self.next_idx, self.next_idx + num_samples) % self.size
        self.s_stack[R] = s
        self.action_stack[R] = a
        self.reward_stack[R] = r
        self.s1_stack[R] = s_1
        self.done_stack[R] = d
        # Advance index.
        self.max_filled = min(
            max(self.next_idx + num_samples, self.max_filled), self.size
        )
        self.next_idx = (self.next_idx + num_samples) % self.size
        return R

    def __getitem__(self, indices):
        try:
            iter(indices)
        except ValueError:
            raise IndexError(
                "ReplayBufferStorage getitem called with indices object that is not iterable"
            )

        # converting states and actions to float here instead of inside the learning loop
        # of each agent seems fine for now.
        state = self.s_stack[indices].float()
        action = self.action_stack[indices].float()
        reward = self.reward_stack[indices]
        next_state = self.s1_stack[indices].float()
        done = self.done_stack[indices]
        return (state, action, reward, next_state, done)

    def __setitem__(self, indices, experience):
        s, a, r, s1, d = experience
        self.s_stack[indices] = s.float()
        self.action_stack[indices] = a.float()
        self.reward_stack[indices] = r
        self.s1_stack[indices] = s1.float()
        self.done_stack[indices] = d

    def get_all_transitions(self):
        return (
            self.s_stack[: self.max_filled],
            self.action_stack[: self.max_filled],
            self.reward_stack[: self.max_filled],
            self.s1_stack[: self.max_filled],
            self.done_stack[: self.max_filled],
        )

    def share_memory_(self):
        if self._shared:
            return

        self._shared = True
        self.s_stack.share_memory_()
        self.action_stack.share_memory_()
        self.reward_stack.share_memory_()
        self.s1_stack.share_memory_()
        self.done_stack.share_memory_()
        self._max_filled = mp.Value("i", self._max_filled)
        self._next_idx = mp.Value("i", self._next_idx)


class ReplayBuffer:
    def __init__(
        self,
        size,
        state_shape=None,
        action_shape=None,
        action_repeat=1,
        state_dtype=float,
    ):
        self._maxsize = size
        self.state_shape = state_shape
        self.state_dtype = self._convert_dtype(state_dtype)
        self.action_shape = action_shape
        self._storage = None
        self.action_repeat = action_repeat
        assert self.state_shape, "Must provide shape of state space to ReplayBuffer"
        assert self.action_shape, "Must provide shape of action space to ReplayBuffer"

    def _convert_dtype(self, dtype):
        if dtype in [int, np.uint8, torch.uint8]:
            return torch.uint8
        elif dtype in [float, np.float32, np.float64, torch.float32, torch.float64]:
            return torch.float32
        elif dtype in ["int32", np.int32]:
            return torch.int32
        else:
            raise ValueError(f"Uncreocgnized replay buffer dtype: {dtype}")

    def __len__(self):
        return len(self._storage) if self._storage is not None else 0

    def push(self, state, action, reward, next_state, done):
        if self._storage is None:
            self._storage = ReplayBufferStorage(
                self._maxsize,
                obs_shape=self.state_shape,
                act_shape=self.action_shape,
                action_repeat=self.action_repeat,
                obs_dtype=self.state_dtype,
            )
        return self._storage.add(state, action, reward, next_state, done)

    def sample(self, batch_size, get_idxs=False):
        random_idxs = torch.randint(len(self._storage), (batch_size,))
        if get_idxs:
            return self._storage[random_idxs], random_idxs.cpu().numpy()
        else:
            return self._storage[random_idxs]

    def get_all_transitions(self):
        return self._storage.get_all_transitions()

    def load_experience(self, s, a, r, s1, d):
        assert (
            s.shape[0] <= self._maxsize
        ), "Experience dataset is larger than the buffer."
        if len(r.shape) < 2:
            r = np.expand_dims(r, 1)
        if len(d.shape) < 2:
            d = np.expand_dims(d, 1)
        self.push(s, a, r, s1, d)

    def share_memory_(self):
        self._storage.share_memory_()
