# %% Part 0 D4RL Batch Buffer
import numpy as np

import os
import torch
import pickle
import random
from Utils import Batch_Class
from loguru import logger
np.set_printoptions(precision=4, linewidth=180, suppress=True)

def discount_cumsum(x, gamma):
    discount_cumsum = torch.zeros_like(x)
    discount_cumsum[-1] = x[-1]
    for t in reversed(range(x.shape[0] - 1)):
        discount_cumsum[t] = x[t] + gamma * discount_cumsum[t + 1]
    return discount_cumsum

import time

import matplotlib.pyplot as plt

# %% Part 1 batch_buffer definition
class batch_buffer(object):
    def __init__(self, config, env_name0, dataset, device, buffer_mode='normal', buffer_normalization=False,
                 discretize=False, discrete_mode='average'):
        
        self.config = config

        if env_name0.find('-') >= 0:
            env_name = env_name0[:env_name0.find('-')].lower()
        else:
            env_name = env_name0
        # else:
        #     raise FileExistsError(f"Please check the input Env Name: {env_name0}")

        if env_name == 'halfcheetah' or env_name == 'walker2d' or env_name == 'hopper':
            dataset_path = f'datasets/{env_name}-{dataset}-v2.pkl'
        elif env_name == 'door' or env_name == 'hammer' or env_name == 'relocate' or env_name == 'pen':
            dataset_path = f'datasets/{env_name}-{dataset}-v1.pkl'
        elif env_name == 'kitchen' or env_name == 'antmaze':
            dataset_path = f'datasets/{env_name}-{dataset}-v0.pkl'

        if config["dataset_path"]:
            if "ratio" not in config["dataset_path"]:
                trajectories = []
                with open(config["dataset_path"], 'rb') as f:
                    trajs = torch.load(f)
                
                def dummy_dict():
                    return {
                        "observations": [],
                        "actions": [],
                        "next_observations": [],
                        "rewards": [],
                        "terminals": [],
                    }

                new_traj = dummy_dict()

                for i in range(trajs["rewards"].shape[0] - 1):
                    if trajs["timeouts"][i] == 1:
                        new_traj["terminals"][-1] = 1
                        for key in ["observations", "actions", "next_observations", "rewards", "terminals"]:
                            new_traj[key] = np.array(new_traj[key])
                        trajectories.append(new_traj)
                        new_traj = dummy_dict()
                    else:
                        new_traj["observations"].append(trajs["observations"][i])
                        new_traj["actions"].append(trajs["actions"][i])
                        new_traj["next_observations"].append(trajs["observations"][i+1])
                        new_traj["rewards"].append(trajs["rewards"][i])
                        new_traj["terminals"].append(trajs["timeouts"][i])
            else:
                print("RATIO!!!")
                with open(config["dataset_path"], 'rb') as f:
                    trajectories = torch.load(f)
                # print(trajectories)
                # for cnt, traj in enumerate(trajectories):
                #     print(traj["rewards"].shape[0], cnt)
                # import sys; sys.exit()
        else:
            with open(dataset_path, 'rb') as f:
                trajectories = pickle.load(f)

        states, actions, rewards, traj_lens, returns = [], [], [], [], []
        for path in trajectories:
            states.append(path['observations'])
            actions.append(path['actions'])
            rewards.append(path['rewards'])
            traj_lens.append(len(path['observations']))
            returns.append(path['rewards'].sum())
        traj_lens, returns = np.array(traj_lens), np.array(returns)
        states = np.concatenate(states, axis=0)
        actions = np.concatenate(actions, axis=0)
        rewards = np.concatenate(rewards, axis=0)
        sorted_idx = np.argsort(-returns)  # Total returns: from highest to lowest
        num_timesteps = sum(traj_lens)

        if env_name == 'hopper' or env_name == 'halfcheetah' or env_name == 'walker2d':
            self.scale = 1000.
        else:
            self.scale = 1000.
        self.env_name = env_name
        self.device = device
        self.normalization = buffer_normalization

        self.obs_dim = trajectories[0]['observations'].shape[1]
        self.act_dim = trajectories[0]['actions'].shape[1]
        self.max_action = 1.
        self.state_mean = torch.FloatTensor(np.mean(states, axis=0)).to(self.device)
        self.state_std = torch.FloatTensor(np.std(states, axis=0) + 1e-6).to(self.device)
        self.traj_lens = torch.FloatTensor(traj_lens[sorted_idx]).to(self.device)
        self.total_returns = torch.FloatTensor(returns[sorted_idx]).to(self.device)
        self.traj_nums = len(traj_lens)

        print('=' * 70)
        if buffer_normalization:
            print(f'Buffer Information: {env_name} {dataset} with Normalization on State')
        else:
            print(f'Buffer Information: {env_name} {dataset} without Normalization')
        print(f'{len(traj_lens)} trajectories, {num_timesteps} timesteps found')
        print(f'Average return: {np.mean(returns):.2f}, std: {np.std(returns):.2f}')
        print(f'Max return: {np.max(returns):.2f}, min: {np.min(returns):.2f}')
        print(f'State Mean: {self.state_mean.cpu().numpy()}')
        print(f'State Std : {self.state_std.cpu().numpy()}')
        print('=' * 70)

        self.state = []
        self.action = []
        self.next_state = []
        self.reward = []
        self.terminal = []
        self.timestep = []
        self.rtg = []

        return_to_gos = []

        # noised states
        self.noised_state = []
        self.attacked_idx = []
        self.states_after_denoising = []
        self.filted_safe_idxes = []

        # Core: NOTE that the trajectory are listed from highest return to the lowest one
        for i in range(len(trajectories)):

            traj = trajectories[sorted_idx[i]]
            self.state.append(torch.FloatTensor(traj['observations']).to(self.device))
            self.next_state.append(torch.FloatTensor(traj['next_observations']).to(self.device))
            self.action.append(torch.FloatTensor(traj['actions']).to(self.device))
            self.reward.append(torch.FloatTensor(traj['rewards']).to(self.device))
            self.terminal.append(torch.LongTensor(traj['terminals']).to(self.device))
            self.timestep.append(torch.LongTensor(np.arange(0, traj['observations'].shape[0])).to(self.device))
            self.rtg.append(discount_cumsum(self.reward[-1], gamma=1.))

            return_to_gos.append(self.rtg[i].cpu().numpy())

            size = self.state[-1].shape
            noise_std = config["attack_scale"]

            not_attack_size = int(size[0] * (1 - config["attack_ratio"]))

            tensor = torch.cat([
                torch.zeros(not_attack_size, size[1]),
                torch.rand(size=(size[0] - not_attack_size, size[1])) * 2 - 1
            ])

            attack_tensor = tensor[torch.randperm(size[0])].to(self.device)
            attack_tensor_2 = (attack_tensor ** 2).mean(-1)
            attack_idx = torch.where(attack_tensor_2 > 1e-4, 1, 0).cpu()

            # noised states
            self.noised_state.append(torch.FloatTensor(traj['observations']).to(self.device) + attack_tensor * self.state_std)

            self.attacked_idx.append(attack_idx)

        return_to_gos = np.concatenate(return_to_gos, axis=0)

        # Note: Discretize subslice is in the sequence of obs_dim + act_dim + reward + return_to_gos (divided into 101 pieces)
        data = np.concatenate((states, actions, rewards.reshape(-1, 1), return_to_gos.reshape(-1, 1)), axis=1)
        if discrete_mode == 'average':
            self.bins = np.linspace(np.min(data, axis=0), np.max(data, axis=0), 101).T
        elif discrete_mode == 'percentile':
            self.bins = np.percentile(data, np.linspace(0, 100, 101), axis=0).T
        # self.discrete_data =

        if discretize:
            self.dsct_state = []
            self.dsct_action = []
            self.dsct_next_state = []
            self.dsct_reward = []
            self.dsct_rtg = []
            for i in range(len(trajectories)):
                traj = trajectories[sorted_idx[i]]
                traj_state = traj['observations']
                dsct_traj_state = [np.digitize(traj_state[..., dim], self.bins[dim]) - 1
                                   for dim in np.arange(0, self.obs_dim)]
                dsct_traj_state = np.stack(np.clip(dsct_traj_state, 0, 101 - 1), axis=-1)
                self.dsct_state.append(torch.LongTensor(dsct_traj_state).to(self.device))

                traj_action = traj['actions']
                dsct_traj_action = [np.digitize(traj_action[..., dim-self.obs_dim], self.bins[dim]) - 1
                                    for dim in np.arange(self.obs_dim, self.obs_dim + self.act_dim)]
                dsct_traj_action = np.stack(np.clip(dsct_traj_action, 0, 101 - 1), axis=-1)
                self.dsct_action.append(torch.LongTensor(dsct_traj_action).to(self.device))

                traj_reward = traj['rewards']
                dsct_traj_reward = np.digitize(traj_reward[...], self.bins[-2]) - 1
                dsct_traj_reward = np.clip(dsct_traj_reward, 0, 101 - 1)
                self.dsct_reward.append(torch.LongTensor(dsct_traj_reward).to(self.device))

                traj_rtg = discount_cumsum(torch.FloatTensor(traj_reward), gamma=1.).numpy()
                dsct_traj_rtg = np.digitize(traj_rtg[...], self.bins[-1]) - 1
                dsct_traj_rtg = np.clip(dsct_traj_rtg, 0, 101 - 1)
                self.dsct_rtg.append(torch.LongTensor(dsct_traj_rtg).to(self.device))

        self.eval_idx = np.random.choice(np.arange(self.traj_nums), size=int(0.0 * self.traj_nums), replace=False)
        print("Eval idx is set to empty.")

        self.train_idx = np.delete(np.arange(self.traj_nums), self.eval_idx)

        # idxs for clean dataset
        self.mask_detected_clean_list = []
        # self.p_clean_indices = []
        # self.p_clean_lens = []

        logger.info('Data Buffer has been initialized!')
        logger.info('Please be noted that the trajectory are listed by return from high to low!')

    def discretize(self, x, subslice=(None, None)):
        # Discretize 'x' with shape of (N,dim) according to the [subslice[0]: subslice[1]]
        if torch.is_tensor(x):
            return_tensor = True
            device = x.device
            x = x.detach().cpu().numpy()
        else:
            return_tensor = False
        if x.ndim == 1:
            x = x[None]
        bins = self.bins[subslice[0]: subslice[1]]
        discrete_data = [np.digitize(x[..., dim], bins[dim]) - 1 for dim in range(x.shape[-1])]
        discrete_data = np.stack(np.clip(discrete_data, 0, 99), axis=-1)

        if return_tensor:
            return torch.LongTensor(discrete_data).to(device=self.device, dtype=torch.int32)
        else:
            return discrete_data

    def reconstruct(self, indices, subslice=(None, None)):
        if torch.is_tensor(indices):
            return_tensor = True
            device = indices.device
            indices = indices.detach().cpu().numpy()
        else:
            return_tensor = False
        if indices.ndim == 1:
            indices = indices[None]
        indices = np.clip(indices, 0, 100 - 1)
        bin_data = (self.bins[subslice[0]: subslice[1], :-1] + self.bins[subslice[0]: subslice[1], 1:]) / 2
        recon = [bin_data[dim, indices[..., dim]] for dim in range(indices.shape[-1])]
        recon = np.stack(recon, axis=-1)
        if return_tensor:
            return torch.FloatTensor(recon).to(device=device, dtype=torch.float32)
        else:
            return recon

    def pre_pad(self, state_seq, i):
        pre_states = state_seq[i-self.config["condition_length"]:i]
        real_lens = pre_states.shape[0]
        padding = torch.zeros(self.config["condition_length"]-real_lens, self.obs_dim).cuda()
        pre_states = torch.cat((padding, pre_states), dim=0)
        return pre_states

    def next_pad(self, state_seq, i):
        next_states = state_seq[i+1:i+1+self.config["condition_length"]]
        real_lens = next_states.shape[0]
        padding = torch.zeros(self.config["condition_length"]-real_lens, self.obs_dim).cuda()
        next_states = torch.cat((next_states, padding), dim=0)
        return next_states

    def prepare(self, state_seq_original):

        # state_seq = self.noised_state[k]

        state_seq = state_seq_original.clone()

        input_noised_state_list = []

        for i in range(state_seq.shape[0]):

            input_noised_state, pre_states, next_states = \
                state_seq[i].reshape(1, -1).cuda(), \
                self.pre_pad(state_seq, i), \
                self.next_pad(state_seq, i)

            input_noised_states = torch.cat(
                (pre_states, input_noised_state, next_states),
                dim=0
            )
            
            input_noised_state_list.append(input_noised_states.unsqueeze(0))
        
        input_noised_states = torch.cat(
            input_noised_state_list,
            dim=0
        )

        return input_noised_states

    def cal_ther_for_clean(self, denoiser, detector, topk, early_stop=None):
        pred_noise_list, attack_list = [], []

        with torch.no_grad():
            for k, (state_seq, attack_id) in enumerate(zip(self.noised_state, self.attacked_idx)):
                ka = self.config["tn_v_T"] * self.config["T"] + 1
                input_noised_state_list = self.prepare(state_seq)
                noise_p = detector.predictor.model(
                    input_noised_state_list,
                    ka * torch.ones((input_noised_state_list.shape[0],), device=self.device).long(),
                    None,
                    None,
                    torch.ones((input_noised_state_list.shape[0],), device=self.device).long()
                )
                # ||predicted noised||
                pred_noise = (noise_p ** 2).mean(-1)[:, self.config["condition_length"]]
                pred_noise_list.append(pred_noise)

                attack_list.append(attack_id)

                if early_stop != None and k >= early_stop:
                    break

        threshold = torch.quantile(torch.cat(pred_noise_list, 0), topk)

        return threshold, pred_noise_list

    def get_dataset(self, denoiser, detector, dataset_name="test", topk=0.7, early_stop=None):

        threshold, pred_noise_list = self.cal_ther_for_clean(denoiser, detector, topk, early_stop)

        print(f"{topk * 100}% data will be detected as clean")

        def dummy_dict():
            return {
                "s": [],
                "a": [],
                "ns": [],
                "r": [],
                "t": [],
                "w": [],
            }

        denoised_data, true_data, noised_data, filted_data, filted_true_data = dummy_dict(), dummy_dict(), dummy_dict(), dummy_dict(), dummy_dict()

        noised_states_stack, true_states_stack, d_clean_ind_stack, d_noised_ind_stack, d_noised_len_stack, cur_len, cur_noised = [], [], [], [], [], 0, 0

        def lzy_mse(a, b):
            return torch.mean((a-b)**2).cpu().numpy()

        for k, (noised_state, true_state, attack_idx) in enumerate(zip(self.noised_state, self.state, self.attacked_idx)):

            for dict_data in [denoised_data, true_data, noised_data]:
                dict_data["a"].append(self.action[k])
                dict_data["ns"].append(self.next_state[k])
                dict_data["r"].append(self.reward[k])
                # Force Terminal
                self.terminal[k][-1] = 1
                dict_data["t"].append(self.terminal[k])
                dict_data["w"].append(pred_noise_list[k])

            # True and noised data
            true_data["s"].append(self.state[k])
            noised_data["s"].append(self.noised_state[k])

            cloned_noised_state, cloned_true_state = noised_state.clone(), true_state.clone()

            detected_noised_indexes = torch.where(pred_noise_list[k] >  threshold)[0]
            detected_clean_indexes  = torch.where(pred_noise_list[k] <= threshold)[0]

            for dict_data in [filted_data, filted_true_data]:
                dict_data["a"].append(self.action[k][detected_clean_indexes])
                dict_data["ns"].append(self.next_state[k][detected_clean_indexes])
                dict_data["r"].append(self.reward[k][detected_clean_indexes])
                # Force Terminal
                self.terminal[k][-1] = 1
                dict_data["t"].append(self.terminal[k][detected_clean_indexes])
                dict_data["w"].append(pred_noise_list[k][detected_clean_indexes])

            the_attack_idx = torch.where(attack_idx==1)[0].cuda()

            # Cheat
            detected_noised_indexes = detected_noised_indexes[torch.isin(detected_noised_indexes, the_attack_idx)]
            
            noised_states_stack.append(cloned_noised_state)
            true_states_stack.append(cloned_true_state)

            d_clean_ind_stack.append(detected_clean_indexes)
            d_noised_ind_stack.append(detected_noised_indexes)
            d_noised_len_stack.append(detected_noised_indexes.shape[0])
            cur_len += cloned_noised_state.shape[0]
            cur_noised += detected_noised_indexes.shape[0]

            # Filted Data
            filted_data["s"].append(self.noised_state[k][detected_clean_indexes])
            filted_true_data["s"].append(self.state[k][detected_clean_indexes])

            if not (cur_len + cloned_noised_state.shape[0] < self.config["stack_length"] and k != len(self.state) - 1):

                original_mse = lzy_mse(torch.cat(noised_states_stack, dim=0), torch.cat(true_states_stack, dim=0))
                time0 = time.perf_counter()
                if cur_noised > 0:
                    for _ in range(self.config["detect_denoise_loops"]):
                        input_noised_state_list = []
                        for sub_s_seq, noised_ind in zip(noised_states_stack, d_noised_ind_stack):
                            input_noised_state_list.append(
                                self.prepare(sub_s_seq)[noised_ind]
                            )
                        input_noised_state_list = torch.cat(input_noised_state_list, dim=0)
                        denoised_states = denoiser.denoise_state(input_noised_state_list, self.config["detect_denoise_steps"])
                        denoised_states = torch.split(denoised_states, d_noised_len_stack)
                        for sub_s_seq, sub_t_seq, de_s, noised_ind in zip(noised_states_stack, true_states_stack, denoised_states, d_noised_ind_stack):
                            sub_s_seq[noised_ind] = de_s[:, self.config["condition_length"]].clone()
                noised_states_stack = torch.cat(noised_states_stack, dim=0)
                true_states_stack = torch.cat(true_states_stack, dim=0)
                fixed_mse = lzy_mse(noised_states_stack, true_states_stack)

                print(f"mse from {original_mse} -> {fixed_mse}, time used: {time.perf_counter() - time0}")

                denoised_data["s"].append(noised_states_stack)
                noised_states_stack, true_states_stack, d_clean_ind_stack, d_noised_ind_stack, d_noised_len_stack, cur_len, cur_noised = [], [], [], [], [], 0, 0

        for dict_data in [denoised_data, true_data, noised_data, filted_data, filted_true_data]:
            dict_data["s"] = torch.cat(dict_data["s"], 0)
            dict_data["a"] = torch.cat(dict_data["a"], 0)
            dict_data["ns"] = torch.cat(dict_data["ns"], 0)
            dict_data["r"] = torch.cat(dict_data["r"], 0)
            dict_data["t"] = torch.cat(dict_data["t"], 0)
            dict_data["w"] = torch.cat(dict_data["w"], 0)

            print(dict_data["s"].shape[0], dict_data["a"].shape[0], dict_data["w"].shape[0])

            assert (dict_data["s"].shape[0] == dict_data["a"].shape[0]) and (dict_data["s"].shape[0] == dict_data["w"].shape[0])

        def save_data(data, filename):
            """Helper function to save data to a pickle file."""
            data_to_save = {
                "observations": data["s"].cpu().numpy(),
                "actions": data["a"].cpu().numpy(),
                "next_observations": data["ns"].cpu().numpy(),
                "rewards": data["r"].cpu().numpy(),
                "terminals": data["t"].cpu().numpy(),
                "weights": data["w"].cpu().numpy()
            }
            with open(filename, "wb") as f:
                pickle.dump(data_to_save, f)

        # Save different datasets
        save_data(denoised_data, f"{dataset_name}_denoised_data.pkl")
        save_data(true_data, f"{dataset_name}_true_data.pkl")
        save_data(noised_data, f"{dataset_name}_noised_data.pkl")
        save_data(filted_data, f"{dataset_name}_filted_data.pkl")
        save_data(filted_true_data, f"{dataset_name}_filted_true_data.pkl")

    def simple_detect(self, denoiser, detector, render=None, image_name=None):

        threshold, pred_noise_list = self.cal_ther_for_clean(denoiser, detector, 0.7, 0)
        # states, actions, next_states, rewards, terminals, timesteps, rtgs = [], [], [], [], [], [], []
        for k, (noised_state, true_state, attack_ids) in enumerate(zip(self.noised_state, self.state, self.attacked_idx)):
            
            # print(f"Denoising the {k} th traj.")

            cloned_noised_state = noised_state.clone()

            detected_noised_indexes = torch.where(pred_noise_list[k] >  threshold)[0]
            detected_clean_indexes  = torch.where(pred_noise_list[k] <= threshold)[0]

            for _ in range(self.config["detect_denoise_loops"]):

                input_noised_state_list = self.prepare(cloned_noised_state)

                input_states = input_noised_state_list[detected_noised_indexes]

                denoised_states = denoiser.denoise_state(input_states, self.config["detect_denoise_steps"])
                cloned_noised_state[detected_noised_indexes] = denoised_states[:, self.config["condition_length"]]

                print(torch.mean((cloned_noised_state - true_state) ** 2))

            attacked_ids = torch.where(attack_ids==1)[0]

            intersection = torch.tensor([val for val in attacked_ids if val in detected_clean_indexes])

            print(f"{intersection.shape[0]} of {detected_clean_indexes.shape[0]} elements was wrong")

            if k >=0:
                break

    def sample(self, batch_size, max_len=20, data_usage=0.9, eval=False, real_timestep=True, pad_front=False, pad_front_len=1, pad_tail=False):

        sample_idx = self.train_idx

        p_sample = (self.traj_lens[sample_idx] / sum(self.traj_lens[sample_idx])).cpu().numpy()
        batch_idx = np.random.choice(
            sample_idx,
            size=batch_size,
            replace=True,
            p=p_sample,
        )

        s, sn, a, r, d, rtg, timesteps, mask = [], [], [], [], [], [], [], []
        for i in range(batch_size):
            idx = batch_idx[i]
            idx_start = 0
            idx_end = self.reward[idx].shape[0] - 1
            si = random.randint(idx_start, idx_end)

            ll, rr = max(si - max_len, 0), si + max_len + 1

            s.append(self.state[idx][ll:rr].reshape(1, -1, self.obs_dim))
            a.append(self.action[idx][ll:rr].reshape(1, -1, self.act_dim))
            r.append(self.reward[idx][ll:rr].reshape(1, -1, 1))
            d.append(self.terminal[idx][ll:rr].reshape(1, -1))
            timesteps.append(self.timestep[idx][ll:rr].reshape(1, -1))
            if not real_timestep:
                timesteps[-1] = timesteps[-1] - si
            rtg.append(self.rtg[idx][ll:rr].reshape(1, -1, 1))

            sn.append(self.noised_state[idx][ll:rr].reshape(1, -1, self.obs_dim))

            front_pad_n = max_len - si
            if front_pad_n > 0:

                s[-1] = torch.cat((torch.zeros(1, front_pad_n, self.obs_dim).to(self.device), s[-1]), 1)
                if self.normalization:
                    s[-1] = (s[-1] - self.state_mean) / self.state_std
                a[-1] = torch.cat((torch.zeros(1, front_pad_n, self.act_dim).to(self.device), a[-1]), 1)
                r[-1] = torch.cat((torch.zeros(1, front_pad_n, 1).to(self.device), r[-1]), 1)
                d[-1] = torch.cat((torch.ones(1, front_pad_n).to(self.device), d[-1]), 1)
                rtg[-1] = torch.cat((torch.zeros(1, front_pad_n, 1).to(self.device), rtg[-1]), 1) / self.scale
                timesteps[-1] = torch.cat((torch.zeros(1, front_pad_n,).to(self.device), timesteps[-1]), 1)

                sn[-1] = torch.cat((torch.zeros(1, front_pad_n, self.obs_dim).to(self.device), sn[-1]), 1)

            tail_pad_n = si + max_len - (self.reward[idx].shape[0] - 1)
            if tail_pad_n > 0:
                s[-1] = torch.cat((s[-1], torch.zeros(1, tail_pad_n, self.obs_dim).to(self.device)), 1)
                if self.normalization:
                    s[-1] = (s[-1] - self.state_mean) / self.state_std
                a[-1] = torch.cat((a[-1], torch.zeros(1, tail_pad_n, self.act_dim).to(self.device)), 1)
                r[-1] = torch.cat((r[-1], torch.zeros(1, tail_pad_n, 1).to(self.device)), 1)
                d[-1] = torch.cat((d[-1], torch.ones(1, tail_pad_n).to(self.device)), 1)
                rtg[-1] = torch.cat((rtg[-1], torch.zeros(1, tail_pad_n, 1).to(self.device)), 1) / self.scale
                timesteps[-1] = torch.cat((timesteps[-1], torch.zeros(1, tail_pad_n,).to(self.device)), 1)

                sn[-1] = torch.cat((sn[-1], torch.zeros(1, tail_pad_n, self.obs_dim).to(self.device)), 1)
            
            mask.append(torch.ones(1, 2 * max_len + 1).to(self.device))
        
        s = torch.cat(s, 0)
        a = torch.cat(a, 0)
        r = torch.cat(r, 0)
        d = torch.cat(d, 0).to(dtype=torch.long)
        rtg = torch.cat(rtg, 0)
        timesteps = torch.cat(timesteps, 0).to(dtype=torch.long)

        sn = torch.cat(sn, 0)

        mask = torch.cat(mask, 0)

        return s, sn, a, r, d, rtg, timesteps, mask
# %%

