import jax, flax
import jax.numpy as jnp
import numpy as np
from jaxrl_m.dataset import Dataset, ReplayBuffer, get_size
from jaxrl_m.typing import Data, Batch, PRNGKey
import jax.tree_util as tree_util
import jaxrl_m.examples.mujoco.d4rl_utils as d4rl_utils


class OfflineReplayBuffer(ReplayBuffer):
    def __init__(self, *args, **kwargs):
        # 调用ReplayBuffer的初始化逻辑，同时Dataset的初始化也会被调用
        super().__init__(*args, **kwargs)

    @classmethod
    def create_from_existing_dataset(
        cls,
        dataset: Dataset,
        max_size_ratio: float = 1.0,
    ):
        data = dataset._dict
        size = int(get_size(data) * max_size_ratio)
        return cls.create_from_initial_dataset(data, size)

    def add_batch(self, batch: Batch):
        batch_size = get_size(batch)
        new_pointer = self.pointer + batch_size

        if new_pointer <= self.max_size:

            def insert_batch(buffer, new_data):
                buffer[self.pointer : new_pointer] = new_data

            tree_util.tree_map(insert_batch, self._dict, batch)
        else:
            first_part = self.max_size - self.pointer
            second_part = batch_size - first_part

            def insert_batch(buffer, new_data):
                buffer[self.pointer :] = new_data[:first_part]
                buffer[:second_part] = new_data[first_part:]

            tree_util.tree_map(insert_batch, self._dict, batch)

        self.pointer = (self.pointer + batch_size) % self.max_size
        self.size = min(self.size + batch_size, self.max_size)


def main():
    from pprint import pprint

    env = d4rl_utils.make_env("halfcheetah-medium-v2")
    dataset = d4rl_utils.get_dataset(env)
    print(dataset)
    real_buffer = OfflineReplayBuffer.create_from_existing_dataset(
        dataset=dataset,
        max_size_ratio=1.1,
    )
    print(real_buffer.size)
    real_buffer.add_batch(dataset.sample(256000))
    print(real_buffer.size)
    real_buffer.add_batch(dataset.sample(256000))
    print(real_buffer.size)
    pprint(real_buffer.sample(2))


if __name__ == "__main__":
    main()
