from typing import Any

import numpy as np
from tianshou.data import VectorReplayBuffer

from _tianshou_custom.data.buffer.CustomReplayBuffer import CustomReplayBuffer
from _tianshou_custom.data.buffer.CustomReplayBufferManager import CustomReplayBufferManager


class CustomVectorReplayBuffer(VectorReplayBuffer, CustomReplayBufferManager):

    def __init__(self, total_size: int, buffer_num: int, **kwargs: Any) -> None:
        assert buffer_num > 0
        size = int(np.ceil(total_size / buffer_num))
        buffer_list = [CustomReplayBuffer(size, **kwargs) for _ in range(buffer_num)]
        super(VectorReplayBuffer, self).__init__(buffer_list)
