from collections import deque
from typing import Deque, Generic, Optional, TypeVar, Tuple, List, cast, Union
import numpy as np
import torch
import random

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

SARSA = Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]

E = TypeVar("E")


class ReplayBuffer(Generic[E]):
    def __init__(
        self,
        capacity: Optional[int] = int(1e6),
    ):
        assert capacity is not None
        self.buffer = cast(List[E], [None for _ in range(capacity)])
        self.empty = True
        self.start = 0
        self.end = 0
        self.capacity = capacity

    def append(self, e: E):
        assert self.end == self.start or self.end > self.start
        if self.end == self.start and not self.empty:
            self.start += 1
        self.buffer[self.end] = e
        self.end += 1
        if self.start == self.capacity:
            self.start = 0
        if self.end == self.capacity:
            self.end = 0

        self.empty = False
        return self

    def clear(self):
        self.start = 0
        self.end = 0
        self.empty = True
        return self

    @property
    def size(self) -> int:
        return self.capacity

    @property
    def len(self) -> int:
        if self.empty:
            return 0
        if self.end == self.start:
            return self.size
        return self.end - self.start

    def sample(self, size: int, p: Optional[List[float]] = None) -> List[E]:
        assert self.len > 0
        r: List[E] = [None for _ in range(size)]
        idx = np.random.choice(self.len, size=size, p=p)
        for i, j in enumerate(idx):
            r[i] = self.buffer[j]
        
        assert not any([_r is None for _r in r])

        return r
