# Shared functions for the CORL algorithms.
import os
import random
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, DefaultDict, Any

import gin
import tqdm
# import gym
import numpy as np
import torch
from torch import Tensor
from torch.utils.data import Dataset


from diffusion.norm import MinMaxNormalizer
# from diffusion.utils import make_inputs, construct_diffusion_model, split_diffusion_samples 

TensorBatch = List[torch.Tensor]


def combine_two_tensors(tensor1, tensor2):
    return torch.cat([tensor1, tensor2], dim=0)


def zero_after_first_zero(inputs, dones):
    # 处理一维张量（转换为二维处理后再恢复）
    original_dim = inputs.dim()
    if original_dim == 1:
        x = inputs.unsqueeze(0)  # 升为二维
    
    # 核心逻辑
    mask = (dones == 0).int()                # 将布尔掩码转为整数
    cumulative_mask = mask.cummax(dim=1)[0].bool()  # 计算累积最大值并转为布尔掩码
    x[cumulative_mask] = 0               # 置零操作
    
    # 恢复原始维度
    if original_dim == 1:
        x = x.squeeze(0)
    return x


def get_return_to_go(dataset: Dict, env, config, max_ep_len=None) -> np.ndarray:
    returns = []
    ep_returns = []
    rtgs = []
    ep_ret, ep_len = 0.0, 0
    rtg, rtg_len = 0.0, 0
    cur_rewards = []
    terminals = []
    N = len(dataset["rewards"])
    for t, (r, d) in enumerate(zip(dataset["rewards"], dataset["terminals"])):
        ep_ret += float(r)
        rtg += float(r)
        cur_rewards.append(float(r))
        terminals.append(float(d))
        ep_len += 1
        rtg_len += 0
        is_last_step = (
            (t == N - 1)
            or (
                # np.linalg.norm(
                torch.linalg.norm(
                    dataset["observations"][t + 1] - dataset["next_observations"][t]
                )
                > 1e-6
            ).item()
            or ep_len == max_ep_len if max_ep_len is not None else env._max_episode_steps
        )

        # if d or ep_len == max_ep_len:
        #     rtgs += [rtg] * ep_len
        #     rtg_len = 0
        #     rtg = 0.0

        if d or is_last_step:
            discounted_returns = [0] * ep_len
            prev_return = 0
            if (
                config.is_sparse_reward
                and r
                == env.ref_min_score * config.reward_scale + config.reward_bias
            ):
                discounted_returns = [r / (1 - config.discount)] * ep_len
            else:
                for i in reversed(range(ep_len)):
                    discounted_returns[i] = cur_rewards[
                        i
                    ] + config.discount * prev_return * (1 - terminals[i])
                    prev_return = discounted_returns[i]
            returns += discounted_returns
            ep_returns += [returns[0]] * ep_len

            rtgs += [rtg] * ep_len
            rtg_len = 0
            rtg = 0.0

            ep_ret, ep_len = 0.0, 0
            cur_rewards = []
            terminals = []
    # rtgs += [rtg] * rtg_len
    return returns, ep_returns, rtgs


@dataclass
class DiffusionConfig:
    path: Optional[str] = None  # Path to model checkpoints or .npz file with diffusion samples
    num_steps: int = 128  # Number of diffusion steps
    sample_limit: int = -1  # If not -1, limit the number of diffusion samples to this number


def return_reward_range(dataset, max_episode_steps):
    returns, lengths = [], []
    ep_ret, ep_len = 0.0, 0
    for r, d in zip(dataset["rewards"], dataset["terminals"]):
        ep_ret += float(r)
        ep_len += 1
        if d or ep_len == max_episode_steps:
            returns.append(ep_ret)
            lengths.append(ep_len)
            ep_ret, ep_len = 0.0, 0
    lengths.append(ep_len)  # but still keep track of number of steps
    assert sum(lengths) == len(dataset["rewards"])
    return min(returns), max(returns)


class RewardNormalizer:
    def __init__(self, dataset, env_name, max_episode_steps=1000):
        self.env_name = env_name
        self.scale = 1.
        self.shift = 0.
        if any(s in env_name for s in ("halfcheetah", "hopper", "walker2d")):
            min_ret, max_ret = return_reward_range(dataset, max_episode_steps)
            self.scale = max_episode_steps / (max_ret - min_ret)
        elif "antmaze" in env_name:
            self.shift = -1.

    def __call__(self, reward):
        return (reward + self.shift) * self.scale


class StateNormalizer:
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def to_torch(self, device: str):
        self.mean = torch.tensor(self.mean, device=device)
        self.std = torch.tensor(self.std, device=device)

    def __call__(self, state):
        return (state - self.mean) / self.std


class ReplayBufferBase:
    def __init__(
            self,
            device: str = "cpu",
            reward_normalizer: Optional[RewardNormalizer] = None,
            state_normalizer: Optional[StateNormalizer] = None,
    ):
        self.reward_normalizer = reward_normalizer
        self.state_normalizer = state_normalizer
        if self.state_normalizer is not None:
            self.state_normalizer.to_torch(device)
        self._device = device

    # Un-normalized samples.
    def _sample(self, batch_size: int, **kwargs) -> TensorBatch:
        raise NotImplementedError

    def sample(self, batch_size: int, **kwargs) -> TensorBatch:
        states, actions, rewards, next_states, dones = self._sample(batch_size, **kwargs)
        if self.reward_normalizer is not None:
            rewards = self.reward_normalizer(rewards)
        if self.state_normalizer is not None:
            states = self.state_normalizer(states)
            next_states = self.state_normalizer(next_states)

        return [states, actions, rewards, next_states, dones]

class calq_ReplayBuffer:
    def __init__(
        self,
        state_dim: int,
        action_dim: int,
        buffer_size: int,
        device: str = "cpu",
        discount: float = 1.0,
        top_curi: float = 1.0,
    ):
        self._buffer_size = buffer_size
        self._pointer = 0
        self._size = 0

        self._states = torch.zeros(
            (buffer_size, state_dim), dtype=torch.float32, device=device
        )
        self._actions = torch.zeros(
            (buffer_size, action_dim), dtype=torch.float32, device=device
        )
        self._rewards = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device)
        self._next_states = torch.zeros(
            (buffer_size, state_dim), dtype=torch.float32, device=device
        )
        self._dones = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device)
        self._mc_returns = torch.zeros(
            (buffer_size, 1), dtype=torch.float32, device=device
        )
        self._rtgs = torch.zeros(
            (buffer_size, 1), dtype=torch.float32, device=device
        )
        self.top_curi = top_curi

        self._reset_tmp_cache()
        self._device = device
        self._discount = discount
        
    @property
    def empty(self):
        return self._pointer == 0

    @property
    def full(self):
        return self._pointer == self._buffer_size

    def __len__(self):
        return self._pointer

    def _to_tensor(self, data) -> torch.Tensor:
        if isinstance(data, torch.Tensor):
            return data
        else:
            return torch.tensor(data, dtype=torch.float32, device=self._device)

    def mc_returns_max(self):
        return self._mc_returns.max().detach().cpu().numpy()

    def compute_grpo(self, use_rtg: bool = False):
        if use_rtg:
            mean = torch.mean(self._rtgs)
            std = torch.std(self._rtgs)
        else:
            mean = torch.mean(self._rewards)
            std = torch.std(self._rewards)
        return mean, std

    # Loads data in d4rl format, i.e. from Dict[str, np.array].
    def load_d4rl_dataset(self, data: Dict[str, np.ndarray], online: bool = False, max_ep_len: int = None):
        if self._size != 0:
            raise ValueError("Trying to load data into non-empty replay buffer")
        n_transitions = data["observations"].shape[0]
        if n_transitions > self._buffer_size:
            raise ValueError(
                "Replay buffer is smaller than the dataset you are trying to load!"
            )
        flag = 0
        # if online:
        # 	flag = 1
        # 	assert max_ep_len is not None

        self._states[flag:flag+n_transitions] = self._to_tensor(data["observations"])
        self._actions[flag:flag+n_transitions] = self._to_tensor(data["actions"])
        self._rewards[flag:flag+n_transitions] = self._to_tensor(data["rewards"][..., None])
        self._next_states[flag:flag+n_transitions] = self._to_tensor(data["next_observations"])
        self._dones[flag:flag+n_transitions] = self._to_tensor(data["terminals"][..., None])
        self._mc_returns[flag:flag+n_transitions] = self._to_tensor(data["mc_returns"][..., None])
        self._rtgs[flag:flag+n_transitions] = self._to_tensor(data["rtgs"][..., None])
        self._size += n_transitions+flag
        self._pointer = min(self._size, n_transitions)

        print(f"Dataset size: {n_transitions+flag}, "
        f"Average trajectory returns: {self._rtgs[:self._size].mean().item() :4f}")
        # mean = rew_sum.mean()
        # std = rew_sum.std()
        # adv = (rew_sum - mean) / std
        # print(adv.reshape(-1, max_ep_len).mean(dim=-1))
    
    def _reset_tmp_cache(self):
        self._tmp_transition = dict(
            observations = [],
            actions = [],
            rewards = [],
            next_observations = [],
            terminals = [],
            mc_returns = [],
            rtgs = [],
        )

    def sample(
        self, 
        batch_size: int, 
        next_step: bool = False, 
        t_len: int = 1, 
        use_rtg: bool = False, 
        cond: Optional[np.ndarray] = None,
        curiosity: bool = False,
     ) -> TensorBatch:
        assert t_len == 1 or (not next_step and t_len > 1), \
            "sampling next step batch does not support seq data."

        if next_step:
            indices = np.random.randint(0, self._size-1, size=batch_size)
            states = self._states[indices]
            actions = self._actions[indices]
            rewards = self._rewards[indices]
            next_states = self._next_states[indices]
            dones = self._dones[indices]
            mc_returns = self._mc_returns[indices]

            n_states = self._states[indices+1]
            n_actions = self._actions[indices+1]
            n_rewards = self._rewards[indices+1]
            n_next_states = self._next_states[indices+1]
            n_dones = self._dones[indices+1]
            n_mc_returns = self._mc_returns[indices+1]

            return [states, actions, rewards, next_states, dones, mc_returns], \
                [n_states, n_actions, n_rewards, n_next_states, n_dones, n_mc_returns]
        
        elif t_len > 1:
            indices = np.random.randint(0, self._size-t_len, size=batch_size)
            states = self._states[indices: indices+t_len]
            actions = self._actions[indices: indices+t_len]
            rewards = self._rewards[indices: indices+t_len]
            next_states = self._next_states[indices: indices+t_len]
            dones = self._dones[indices: indices+t_len]
            mc_returns = self._mc_returns[indices:indices+t_len]
            ### mask transitions after terminal step
            data = [states, actions, rewards, next_states, dones, mc_returns]
            for i in range(len(data)):
                data[i] = zero_after_first_zero(data[i], dones)

            return data

        else:
            if cond is None:
                indices = np.random.randint(0, self._size, size=batch_size)
            else:
                top_frac_indices = np.argsort(cond, axis=0)[-int(self.top_curi * self._size):].reshape(-1)
                indices = np.random.choice(top_frac_indices, batch_size, replace=True)
                # p_sample = cond / (cond.sum() + 1e-10)
                # indices = np.random.choice(np.arange(self._size),
                #                             size=batch_size,
                #                             replace=True,
                #                             p=p_sample.reshape(-1))  # reweights
            states = self._states[indices]
            actions = self._actions[indices]
            rewards = self._rewards[indices]
            next_states = self._next_states[indices]
            dones = self._dones[indices]
            mc_returns = self._mc_returns[indices]

            data = [states, actions, rewards, next_states, dones, mc_returns]
            if use_rtg:
                rtgs = self._rtgs[indices]
                data.append(rtgs)

            return data

    def get_ind(self, indices, use_rtg: bool = False) -> TensorBatch:
        states = self._states[indices]
        actions = self._actions[indices]
        rewards = self._rewards[indices]
        next_states = self._next_states[indices]
        dones = self._dones[indices]
        mc_returns = self._mc_returns[indices]

        data = [states, actions, rewards, next_states, dones, mc_returns]
        if use_rtg:
            rtgs = self._rtgs[indices]
            data.append(rtgs)
        return data
    
    def add_transition(
        self,
        state: np.ndarray,
        action: np.ndarray,
        reward: float,
        next_state: np.ndarray,
        done: bool,
    ):
        # Use this method to add new data into the replay buffer during fine-tuning.
        self._states[self._pointer] = self._to_tensor(state)
        self._actions[self._pointer] = self._to_tensor(action)
        self._rewards[self._pointer] = self._to_tensor(reward)
        self._next_states[self._pointer] = self._to_tensor(next_state)
        self._dones[self._pointer] = self._to_tensor(done)
        self._mc_returns[self._pointer] = 0.0

        self._pointer = (self._pointer + 1) % self._buffer_size
        self._size = min(self._size + 1, self._buffer_size)

    def add_transition_batch(self, batch: TensorBatch):
        ep_returns, mc_returns, rtgs = None, None, None
        if len(batch) == 5:
            states, actions, rewards, next_states, dones = batch
        elif len(batch) == 6:
            states, actions, rewards, next_states, dones, mc_returns = batch
        elif len(batch) == 7:
            states, actions, rewards, next_states, dones, mc_returns, rtgs = batch
        batch_size = states.shape[0]

        states = self._to_tensor(states)
        actions = self._to_tensor(actions)
        rewards = self._to_tensor(rewards).reshape(-1, 1)
        next_states = self._to_tensor(next_states)
        dones = self._to_tensor(dones).reshape(-1, 1)
        if mc_returns is not None:
            mc_returns = self._to_tensor(mc_returns).reshape(-1, 1)
        if rtgs is not None:
            rtgs = self._to_tensor(rtgs).reshape(-1, 1)

        # If the buffer is full, do nothing.
        if self.full:
            return
        if self._pointer + batch_size > self._buffer_size:
            # # Trim the samples to fit the buffer size.
            # states = states[: self._buffer_size - self._pointer]
            # actions = actions[: self._buffer_size - self._pointer]
            # rewards = rewards[: self._buffer_size - self._pointer]
            # next_states = next_states[: self._buffer_size - self._pointer]
            # dones = dones[: self._buffer_size - self._pointer]
   
            self._states[: batch_size] = torch.zeros_like(states)
            self._actions[: batch_size] = torch.zeros_like(actions)
            self._rewards[: batch_size] = torch.zeros_like(rewards)
            self._next_states[: batch_size] = torch.zeros_like(next_states)
            self._dones[: batch_size] = torch.zeros_like(dones)
            if len(batch) > 5 and mc_returns is not None:
                # mc_returns = mc_returns[: self._buffer_size - self._pointer]
                self._mc_returns[: batch_size] = torch.zeros_like(mc_returns)
            if rtgs is not None:
                # rtgs = rtgs[: self._buffer_size - self._pointer]
                self._rtgs[: batch_size] = torch.zeros_like(rtgs)
            batch_size = states.shape[0]
            self._pointer = 0

        self._states[self._pointer: self._pointer + batch_size] = states
        self._actions[self._pointer: self._pointer + batch_size] = actions
        self._rewards[self._pointer: self._pointer + batch_size] = rewards
        self._next_states[self._pointer: self._pointer + batch_size] = next_states
        self._dones[self._pointer: self._pointer + batch_size] = dones
        if mc_returns is not None:
            self._mc_returns[self._pointer: self._pointer + batch_size] = mc_returns 
        if rtgs is not None:
            self._rtgs[self._pointer: self._pointer + batch_size] = rtgs
        
        self._pointer += batch_size
        self._size = min(self._size + batch_size, self._buffer_size)

    def add_transition_to_tmp(
        self,
        state: np.ndarray,
        action: np.ndarray,
        reward: float,
        next_state: np.ndarray,
        done: bool,
    ):
        # Use this method to add new data into the replay buffer during fine-tuning.
        self._tmp_transition["observations"].append(self._to_tensor(state))
        self._tmp_transition["actions"].append(self._to_tensor(action))
        self._tmp_transition["rewards"].append(self._to_tensor(reward))
        self._tmp_transition["next_observations"].append(self._to_tensor(next_state))
        self._tmp_transition["terminals"].append(self._to_tensor(done))
        self._tmp_transition["mc_returns"].append(torch.zeros_like(self._to_tensor(reward)))
        self._tmp_transition["rtgs"].append(torch.zeros_like(self._to_tensor(reward)))

    def merge_tmp_transitions(self, env, config, max_ep_len):
        for k in self._tmp_transition.keys():
            self._tmp_transition[k] = torch.stack(self._tmp_transition[k], dim=0)
        # self._tmp_transition["mc_returns"] = torch.tensor(mc_returns, dtype=torch.float32, device=self._device)
        # self._tmp_transition["rtgs"] = torch.tensor(rtgs, dtype=torch.float32, device=self._device)
        # print(self._tmp_transition["mc_returns"].mean(), self._tmp_transition["ep_returns"].mean())

        # ep_len = self._tmp_transition["rewards"].shape[0]
        mc_returns, ep_returns, rtgs = get_return_to_go(self._tmp_transition, env, config, max_ep_len)
        # self._tmp_transition["mc_returns"] = torch.tensor(mc_returns, dtype=torch.float32, device=self._device)
        self._tmp_transition["mc_returns"] = torch.zeros_like(self._tmp_transition["rewards"], dtype=torch.float32, device=self._device)
        self._tmp_transition["rtgs"] = torch.tensor(mc_returns, dtype=torch.float32, device=self._device)
    #     self._tmp_transition["rtgs"] = self._tmp_transition["rewards"].reshape(ep_len).sum().reshape(
    #   -1, 1).repeat_interleave(ep_len, dim=0)
        # print(self._tmp_transition["rtgs"].mean(), self._tmp_transition["rtgs"].std())

        data = [
            self._tmp_transition["observations"],
            self._tmp_transition["actions"],
            self._tmp_transition["rewards"],
            self._tmp_transition["next_observations"],
            self._tmp_transition["terminals"],
            self._tmp_transition["mc_returns"],
            self._tmp_transition["rtgs"]
        ]
        self.add_transition_batch(data)
        self._reset_tmp_cache()

    def compute_return_to_go(self, env, config, max_ep_len):
        current_data = {
            "observations": self._states[:self._size],
            "actions": self._actions[:self._size],
            "rewards": self._rewards[:self._size],
            "next_observations": self._next_states[:self._size],
            "terminals": self._dones[:self._size],
        }
        # print("*** current data dict ***")
        self._mc_returns, _ = self.get_return_to_go(current_data, env, config, max_ep_len)
        self._mc_returns = torch.tensor(self._mc_returns, dtype=torch.float32, device=self._device)

    def combine_replay_buffer(
            self,
            diffusion_replay_buffer,
            offline_replay_buffer,
            batch_size_offline,
            batch_size_online,
            device,
        ):
        if diffusion_replay_buffer.empty:
            diffusion_batch = offline_replay_buffer.sample(batch_size=batch_size_offline)
            # diffusion_batch = self.sample(batch_size=batch_size_offline)
        else:
            diffusion_batch = diffusion_replay_buffer.sample(batch_size=batch_size_offline)
        online_batch = self.sample(batch_size=batch_size_online)
        (diffusion_obs1, diffusion_acts, diffusion_rews, diffusion_obs2, 
         diffusion_done, diffusion_mc_return) = diffusion_batch
        online_obs1, online_acts, online_rews, online_obs2, online_done, online_mc_return = online_batch
        obs_tensor = combine_two_tensors(online_obs1, diffusion_obs1).to(device)
        obs_next_tensor = combine_two_tensors(online_obs2, diffusion_obs2).to(device)
        acts_tensor = combine_two_tensors(online_acts, diffusion_acts).to(device)
        rews_tensor = combine_two_tensors(online_rews, diffusion_rews).to(device)
        done_tensor = combine_two_tensors(online_done, diffusion_done).to(device)
        mc_returns_tensor = combine_two_tensors(online_mc_return, diffusion_mc_return).to(device)

        return [obs_tensor, acts_tensor, rews_tensor, obs_next_tensor, done_tensor, mc_returns_tensor]


class ReplayBuffer(ReplayBufferBase):
    def __init__(
            self,
            state_dim: int,
            action_dim: int,
            buffer_size: int,
            device: str = "cpu",
            reward_normalizer: Optional[RewardNormalizer] = None,
            state_normalizer: Optional[StateNormalizer] = None,
    ):
        super().__init__(
            device, reward_normalizer, state_normalizer,
        )
        self._buffer_size = buffer_size
        self._pointer = 0
        self._size = 0

        self._states = torch.zeros(
            (buffer_size, state_dim), dtype=torch.float32, device=device
        )
        self._actions = torch.zeros(
            (buffer_size, action_dim), dtype=torch.float32, device=device
        )
        self._rewards = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device)
        self._next_states = torch.zeros(
            (buffer_size, state_dim), dtype=torch.float32, device=device
        )
        self._dones = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device)

    @property
    def empty(self):
        return self._pointer == 0

    @property
    def full(self):
        return self._pointer == self._buffer_size

    def __len__(self):
        return self._pointer

    def _to_tensor(self, data: np.ndarray) -> torch.Tensor:
        return torch.tensor(data, dtype=torch.float32, device=self._device)
    
    # Loads data in d4rl format, i.e. from Dict[str, np.array].
    def load_d4rl_dataset(self, data: Dict[str, np.ndarray]):
        if not self.empty:
            raise ValueError("Trying to load data into non-empty replay buffer")
        n_transitions = data["observations"].shape[0]
        if n_transitions > self._buffer_size:
            raise ValueError(
                "Replay buffer is smaller than the dataset you are trying to load!"
            )
        self._states[1:1+n_transitions] = self._to_tensor(data["observations"])
        self._actions[1:1+n_transitions] = self._to_tensor(data["actions"])
        self._rewards[1:1+n_transitions] = self._to_tensor(data["rewards"][..., None])
        self._next_states[1:1+n_transitions] = self._to_tensor(data["next_observations"])
        self._dones[1:1+n_transitions] = self._to_tensor(data["terminals"][..., None])
        self._size += n_transitions+1
        self._pointer = min(self._size, n_transitions)

        print(f"Dataset size: {n_transitions}")

    # def _sample(self, batch_size: int, **kwargs) -> TensorBatch:
    #     indices = np.random.randint(0, self._pointer, size=batch_size)
    #     states = self._states[indices]
    #     actions = self._actions[indices]
    #     rewards = self._rewards[indices]
    #     next_states = self._next_states[indices]
    #     dones = self._dones[indices]
    #     return [states, actions, rewards, next_states, dones]

    def add_transition_batch(self, batch: TensorBatch):
        states, actions, rewards, next_states, dones = batch
        batch_size = states.shape[0]

        # If the buffer is full, do nothing.
        if self.full:
            return
        if self._pointer + batch_size > self._buffer_size:
            # Trim the samples to fit the buffer size.
            states = states[: self._buffer_size - self._pointer]
            actions = actions[: self._buffer_size - self._pointer]
            rewards = rewards[: self._buffer_size - self._pointer]
            next_states = next_states[: self._buffer_size - self._pointer]
            dones = dones[: self._buffer_size - self._pointer]
            batch_size = states.shape[0]

        self._states[self._pointer: self._pointer + batch_size] = self._to_tensor(states)
        self._actions[self._pointer: self._pointer + batch_size] = self._to_tensor(actions)
        self._rewards[self._pointer: self._pointer + batch_size] = self._to_tensor(rewards).unsqueeze(dim=-1)
        self._next_states[self._pointer: self._pointer + batch_size] = self._to_tensor(next_states)
        self._dones[self._pointer: self._pointer + batch_size] = self._to_tensor(dones).unsqueeze(dim=-1)
        self._pointer += batch_size
        self._size = min(self._size + batch_size, self._buffer_size)


    def sample(self, batch_size: int) -> TensorBatch:
        indices = np.random.randint(0, self._size, size=batch_size)
        states = self._states[indices]
        actions = self._actions[indices]
        rewards = self._rewards[indices]
        next_states = self._next_states[indices]
        dones = self._dones[indices]
        return [states, actions, rewards, next_states, dones]

    def add_transition(
        self,
        state: np.ndarray,
        action: np.ndarray,
        reward: float,
        next_state: np.ndarray,
        done: bool,
    ):
        # Use this method to add new data into the replay buffer during fine-tuning.
        self._states[self._pointer] = self._to_tensor(state)
        self._actions[self._pointer] = self._to_tensor(action)
        self._rewards[self._pointer] = self._to_tensor(reward)
        self._next_states[self._pointer] = self._to_tensor(next_state)
        self._dones[self._pointer] = self._to_tensor(done)

        self._pointer = (self._pointer + 1) % self._buffer_size
        self._size = min(self._size + 1, self._buffer_size)
        # raise NotImplementedError
   
    def combine_replay_buffer(
            self,
            diffusion_replay_buffer,
            batch_size,
            device
        ):
        diffusion_batch = diffusion_replay_buffer.sample(batch_size=batch_size)
        online_batch = self.sample(batch_size=batch_size)
        diffusion_obs1, diffusion_acts, diffusion_rews, diffusion_obs2, diffusion_done = diffusion_batch
        online_obs1, online_acts, online_rews, online_obs2, online_done = online_batch
        obs_tensor = combine_two_tensors(online_obs1, diffusion_obs1).to(device)
        obs_next_tensor = combine_two_tensors(online_obs2, diffusion_obs2).to(device)
        acts_tensor = combine_two_tensors(online_acts, diffusion_acts).to(device)
        rews_tensor = combine_two_tensors(online_rews, diffusion_rews).to(device)
        done_tensor = combine_two_tensors(online_done, diffusion_done).to(device)
        return [obs_tensor, acts_tensor, rews_tensor, obs_next_tensor, done_tensor]


def prepare_replay_buffer(
        state_dim: int,
        action_dim: int,
        buffer_size: int,
        dataset: Dict[str, np.ndarray],
        env_name: str,
        diffusion_config: DiffusionConfig,
        device: str = "cpu",
        reward_normalizer: Optional[RewardNormalizer] = None,
        state_normalizer: Optional[StateNormalizer] = None,
):
    buffer_args = {
        'reward_normalizer': reward_normalizer,
        'state_normalizer': state_normalizer,
        'device': device,
    }
    if diffusion_config.path is None:
        print('Loading standard D4RL dataset.')
        replay_buffer = ReplayBuffer(
            state_dim=state_dim,
            action_dim=action_dim,
            buffer_size=buffer_size,
            **buffer_args,
        )
        replay_buffer.load_d4rl_dataset(dataset)
    elif diffusion_config.path.endswith(".npz"):
        print('Loading diffusion dataset.')
        diffusion_dataset = np.load(diffusion_config.path)
        diffusion_dataset = {key: diffusion_dataset[key] for key in diffusion_dataset.files}

        if diffusion_config.sample_limit != -1:
            # Limit the number of samples
            for key in diffusion_dataset.keys():
                diffusion_dataset[key] = diffusion_dataset[key][:diffusion_config.sample_limit]
            print('Limited diffusion dataset to {} samples'.format(diffusion_config.sample_limit))

        replay_buffer = ReplayBuffer(
            state_dim=state_dim,
            action_dim=action_dim,
            buffer_size=diffusion_dataset['rewards'].shape[0],
            **buffer_args,
        )
        replay_buffer.load_d4rl_dataset(diffusion_dataset)
    elif diffusion_config.path.endswith(".pt"):
        print('Loading diffusion model.')
        # Load gin config from the same directory.
        gin_path = os.path.join(os.path.dirname(diffusion_config.path), 'config.gin')
        gin.parse_config_file(gin_path, skip_unknown=True)

        replay_buffer = DiffusionGenerator(
            env_name=env_name,
            diffusion_path=diffusion_config.path,
            use_ema=True,
            num_steps=diffusion_config.num_steps,
            max_samples=diffusion_config.sample_limit,
            **buffer_args,
        )
    else:
        raise ValueError("Unknown diffusion_path format")

    return replay_buffer
