from typing import Any

import numpy as np
from tianshou.data import PrioritizedVectorReplayBuffer

from _tianshou_custom.data.buffer.CustomPrioritizedReplayBufferManager import CustomPrioritizedReplayBufferManager
from _tianshou_custom.data.buffer.CustomReplayBuffer import CustomReplayBuffer


class CustomPrioritizedVectorReplayBuffer(PrioritizedVectorReplayBuffer, CustomPrioritizedReplayBufferManager):

    def __init__(self, total_size: int, buffer_num: int, **kwargs: Any) -> None:
        assert buffer_num > 0
        size = int(np.ceil(total_size / buffer_num))
        buffer_list = [CustomReplayBuffer(size, **kwargs) for _ in range(buffer_num)]
        super().__init__(buffer_list)
