# %% Part 0 D4RL Batch Buffer
import numpy as np

import os
import torch
import pickle
import random
from Utils import Batch_Class
from loguru import logger
np.set_printoptions(precision=4, linewidth=180, suppress=True)

def discount_cumsum(x, gamma):
    discount_cumsum = torch.zeros_like(x)
    discount_cumsum[-1] = x[-1]
    for t in reversed(range(x.shape[0] - 1)):
        discount_cumsum[t] = x[t] + gamma * discount_cumsum[t + 1]
    return discount_cumsum

import time
from tqdm import tqdm

import matplotlib.pyplot as plt

from typing import List, Optional, Tuple, Callable

from torch import nn

class Squeeze(nn.Module):
    def __init__(self, dim=-1):
        super().__init__()
        self.dim = dim

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x.squeeze(dim=self.dim)

class MLP(nn.Module):
    def __init__(
        self,
        dims,
        activation_fn: Callable[[], nn.Module] = nn.ReLU,
        output_activation_fn: Callable[[], nn.Module] = None,
        squeeze_output: bool = False,
        dropout: Optional[float] = None,
    ):
        super().__init__()
        n_dims = len(dims)
        if n_dims < 2:
            raise ValueError("MLP requires at least two dims (input and output)")

        layers = []
        for i in range(n_dims - 2):
            layers.append(nn.Linear(dims[i], dims[i + 1]))
            layers.append(activation_fn())
            if dropout is not None:
                layers.append(nn.Dropout(dropout))
        layers.append(nn.Linear(dims[-2], dims[-1]))
        if output_activation_fn is not None:
            layers.append(output_activation_fn())
        if squeeze_output:
            if dims[-1] != 1:
                raise ValueError("Last dim must be 1 when squeezing")
            layers.append(Squeeze(-1))
        self.net = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)

class TwinQ(nn.Module):
    def __init__(
        self, state_dim: int, action_dim: int, hidden_dim: int = 256, n_hidden: int = 2
    ):
        super().__init__()
        dims = [state_dim + action_dim, *([hidden_dim] * n_hidden), 1]
        self.q1 = MLP(dims, squeeze_output=True)
        self.q2 = MLP(dims, squeeze_output=True)

    def both(
        self, state: torch.Tensor, action: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        sa = torch.cat([state, action], 1)
        return self.q1(sa), self.q2(sa)

    def forward(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
        return torch.min(*self.both(state, action))

# %% Part 1 batch_buffer definition
class batch_buffer(object):
    def __init__(self, config, env_name0, dataset, device, buffer_mode='normal', buffer_normalization=False,
                 discretize=False, discrete_mode='average'):
        
        self.config = config

        if env_name0.find('-') >= 0:
            env_name = env_name0[:env_name0.find('-')].lower()
        else:
            env_name = env_name0
        # else:
        #     raise FileExistsError(f"Please check the input Env Name: {env_name0}")

        if env_name == 'halfcheetah' or env_name == 'walker2d' or env_name == 'hopper':
            dataset_path = f'datasets/{env_name}-{dataset}-v2.pkl'
        elif env_name == 'door' or env_name == 'hammer' or env_name == 'relocate' or env_name == 'pen':
            dataset_path = f'datasets/{env_name}-{dataset}-v1.pkl'
        elif env_name == 'kitchen' or env_name == 'antmaze':
            dataset_path = f'datasets/{env_name}-{dataset}-v0.pkl'

        self.dataset_type = dataset

        if config["dataset_path"]:
            if "ratio" not in config["dataset_path"]:
                trajectories = []
                with open(config["dataset_path"], 'rb') as f:
                    trajs = torch.load(f)
                
                def dummy_dict():
                    return {
                        "observations": [],
                        "actions": [],
                        "next_observations": [],
                        "rewards": [],
                        "terminals": [],
                    }

                new_traj = dummy_dict()

                for i in range(trajs["rewards"].shape[0] - 1):
                    if trajs["timeouts"][i] == 1:
                        new_traj["terminals"][-1] = 1
                        for key in ["observations", "actions", "next_observations", "rewards", "terminals"]:
                            new_traj[key] = np.array(new_traj[key])
                        trajectories.append(new_traj)
                        new_traj = dummy_dict()
                    else:
                        new_traj["observations"].append(trajs["observations"][i])
                        new_traj["actions"].append(trajs["actions"][i])
                        new_traj["next_observations"].append(trajs["observations"][i+1])
                        new_traj["rewards"].append(trajs["rewards"][i])
                        new_traj["terminals"].append(trajs["timeouts"][i])
            else:
                print("RATIO!!!")
                with open(config["dataset_path"], 'rb') as f:
                    trajectories = torch.load(f)
        else:
            with open(dataset_path, 'rb') as f:
                trajectories = pickle.load(f)

        states, actions, rewards, traj_lens, returns = [], [], [], [], []
        for path in trajectories:
            states.append(path['observations'])
            actions.append(path['actions'])
            rewards.append(path['rewards'])
            traj_lens.append(len(path['observations']))
            returns.append(path['rewards'].sum())
        traj_lens, returns = np.array(traj_lens), np.array(returns)
        states = np.concatenate(states, axis=0)
        actions = np.concatenate(actions, axis=0)
        rewards = np.concatenate(rewards, axis=0)
        sorted_idx = np.argsort(-returns)  # Total returns: from highest to lowest
        num_timesteps = sum(traj_lens)

        if env_name == 'hopper' or env_name == 'halfcheetah' or env_name == 'walker2d':
            self.scale = 1000.
        else:
            self.scale = 1000.
        self.env_name = env_name
        self.device = device
        self.normalization = buffer_normalization

        self.obs_dim = trajectories[0]['observations'].shape[1]
        self.act_dim = trajectories[0]['actions'].shape[1]
        self.max_action = 1.
        self.state_mean = torch.FloatTensor(np.mean(states, axis=0)).to(self.device)
        self.state_std = torch.FloatTensor(np.std(states, axis=0) + 1e-6).to(self.device)
        self.traj_lens = torch.FloatTensor(traj_lens[sorted_idx]).to(self.device)
        self.total_returns = torch.FloatTensor(returns[sorted_idx]).to(self.device)
        self.traj_nums = len(traj_lens)

        print('=' * 70)
        if buffer_normalization:
            print(f'Buffer Information: {env_name} {dataset} with Normalization on State')
        else:
            print(f'Buffer Information: {env_name} {dataset} without Normalization')
        print(f'{len(traj_lens)} trajectories, {num_timesteps} timesteps found')
        print(f'Average return: {np.mean(returns):.2f}, std: {np.std(returns):.2f}')
        print(f'Max return: {np.max(returns):.2f}, min: {np.min(returns):.2f}')
        print(f'State Mean: {self.state_mean.cpu().numpy()}')
        print(f'State Std : {self.state_std.cpu().numpy()}')
        print('=' * 70)

        self.state = []
        self.action = []
        self.next_state = []
        self.reward = []
        self.terminal = []
        self.timestep = []
        self.rtg = []

        return_to_gos = []

        # noised states
        self.noised_state = []
        self.attacked_idx = []
        self.states_after_denoising = []
        self.filted_safe_idxes = []

        # For adversial!
        self._np_rng = np.random.RandomState(0)
        self._th_rng = torch.Generator()
        self._th_rng.manual_seed(0)

        self.corruption_rate = 0.3

            # Core: NOTE that the trajectory are listed from highest return to the lowest one
        for i in tqdm(range(len(trajectories))):

            traj = trajectories[sorted_idx[i]]
            self.state.append(torch.FloatTensor(traj['observations']).to(self.device))

            if (traj['observations'].shape[0]!=999):
                print(traj['observations'].shape, i)

            self.next_state.append(torch.FloatTensor(traj['next_observations']).to(self.device))
            self.action.append(torch.FloatTensor(traj['actions']).to(self.device))
            self.reward.append(torch.FloatTensor(traj['rewards']).to(self.device))
            self.terminal.append(torch.LongTensor(traj['terminals']).to(self.device))
            self.timestep.append(torch.LongTensor(np.arange(0, traj['observations'].shape[0])).to(self.device))
            self.rtg.append(discount_cumsum(self.reward[-1], gamma=1.))

            return_to_gos.append(self.rtg[i].cpu().numpy())

        if not os.path.exists(f"adversarial_noised_datasets_and_ids/{self.env_name}_{dataset}.pkl"):

            for i in tqdm(range(len(trajectories))):

                traj = trajectories[sorted_idx[i]]

                attack_idx, clean_idx = self.sample_indexs(traj)
                self.attacked_idx.append(torch.FloatTensor(attack_idx).to(self.device))
                attack_obs = self.adversarial_attack(traj, attack_idx)
                noised_obs = torch.FloatTensor(traj['observations']).to(self.device)
                noised_obs[attack_idx] = torch.FloatTensor(attack_obs).to(self.device)
                self.noised_state.append(noised_obs)
        
            data_2_save = {
                "noised_state": self.noised_state,
                "attack_idx": self.attacked_idx,
            }

            with open(f"adversarial_noised_datasets_and_ids/{self.env_name}_{dataset}.pkl", "wb") as file:
                pickle.dump(data_2_save, file)
        
        else:
            
            with open(f"adversarial_noised_datasets_and_ids/{self.env_name}_{dataset}.pkl", "rb") as file:
                loaded_data = pickle.load(file)

            self.noised_state = loaded_data["noised_state"]
            self.attacked_idx = loaded_data["attack_idx"]

            if not isinstance(self.attacked_idx[-1], torch.Tensor):
                for i in range(len(self.attacked_idx)):
                    self.attacked_idx[i] = torch.FloatTensor(self.attacked_idx[i]).to(self.device)

        return_to_gos = np.concatenate(return_to_gos, axis=0)

        # Note: Discretize subslice is in the sequence of obs_dim + act_dim + reward + return_to_gos (divided into 101 pieces)
        data = np.concatenate((states, actions, rewards.reshape(-1, 1), return_to_gos.reshape(-1, 1)), axis=1)
        if discrete_mode == 'average':
            self.bins = np.linspace(np.min(data, axis=0), np.max(data, axis=0), 101).T
        elif discrete_mode == 'percentile':
            self.bins = np.percentile(data, np.linspace(0, 100, 101), axis=0).T
        # self.discrete_data =

        if discretize:
            self.dsct_state = []
            self.dsct_action = []
            self.dsct_next_state = []
            self.dsct_reward = []
            self.dsct_rtg = []
            for i in range(len(trajectories)):
                traj = trajectories[sorted_idx[i]]
                traj_state = traj['observations']
                dsct_traj_state = [np.digitize(traj_state[..., dim], self.bins[dim]) - 1
                                   for dim in np.arange(0, self.obs_dim)]
                dsct_traj_state = np.stack(np.clip(dsct_traj_state, 0, 101 - 1), axis=-1)
                self.dsct_state.append(torch.LongTensor(dsct_traj_state).to(self.device))

                traj_action = traj['actions']
                dsct_traj_action = [np.digitize(traj_action[..., dim-self.obs_dim], self.bins[dim]) - 1
                                    for dim in np.arange(self.obs_dim, self.obs_dim + self.act_dim)]
                dsct_traj_action = np.stack(np.clip(dsct_traj_action, 0, 101 - 1), axis=-1)
                self.dsct_action.append(torch.LongTensor(dsct_traj_action).to(self.device))

                traj_reward = traj['rewards']
                dsct_traj_reward = np.digitize(traj_reward[...], self.bins[-2]) - 1
                dsct_traj_reward = np.clip(dsct_traj_reward, 0, 101 - 1)
                self.dsct_reward.append(torch.LongTensor(dsct_traj_reward).to(self.device))

                traj_rtg = discount_cumsum(torch.FloatTensor(traj_reward), gamma=1.).numpy()
                dsct_traj_rtg = np.digitize(traj_rtg[...], self.bins[-1]) - 1
                dsct_traj_rtg = np.clip(dsct_traj_rtg, 0, 101 - 1)
                self.dsct_rtg.append(torch.LongTensor(dsct_traj_rtg).to(self.device))

        self.eval_idx = np.random.choice(np.arange(self.traj_nums), size=int(0.0 * self.traj_nums), replace=False)
        print("Eval idx is set to empty.")

        self.train_idx = np.delete(np.arange(self.traj_nums), self.eval_idx)

        # idxs for clean dataset
        self.mask_detected_clean_list = []
        # self.p_clean_indices = []
        # self.p_clean_lens = []

        logger.info('Data Buffer has been initialized!')
        logger.info('Please be noted that the trajectory are listed by return from high to low!')

    def sample_indexs(self, traj):
        indexs = np.arange(len(traj["rewards"]))
        random_num = self._np_rng.random(len(indexs))
        attacked = np.where(random_num < self.corruption_rate)[0]
        original = np.where(random_num >= self.corruption_rate)[0]
        return indexs[attacked], indexs[original]

    def sample_para(self, data, std):
        return (
            2
            * 1.0
            * std
            * (torch.rand(data.shape, generator=self._th_rng).to(self.device) - 0.5)
        )

    def loss_Q(self, para, observation, action, std):
        noised_obs = observation + para * std
        qvalue = self.critic(noised_obs, action)
        return qvalue.mean()

    def optimize_para(self, para, std, obs, act=None):
        for _ in range(100):
            para = torch.nn.Parameter(para.clone(), requires_grad=True)
            optimizer = torch.optim.Adam(
                [para], lr=0.01 * 1
            )
            loss = self.loss_Q(para, obs, act, std)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            para = torch.clamp(
                para, -1, 1
            ).detach()
        return para * std

    def adversarial_attack(self, traj, attack_idx):
        
        # Load IQL Agent
        self.critic = (
            TwinQ(self.obs_dim, self.act_dim, n_hidden=2)
            .to(self.device)
            .eval()
        )

        env_name = self.env_name
        if env_name == 'halfcheetah' or env_name == 'walker2d' or env_name == 'hopper':
            real_env_name = f'{env_name}-{self.dataset_type}-v2'
        elif env_name == 'door' or env_name == 'hammer' or env_name == 'relocate' or env_name == 'pen':
            real_env_name = f'{env_name}-{self.dataset_type}-v0'
        elif env_name == 'kitchen' or env_name == 'antmaze':
            real_env_name = f'{env_name}-{self.dataset_type}-v0'

        state_dict = torch.load(f"IQL_model/{real_env_name}/3000.pt", map_location=self.device)
        self.critic.load_state_dict(state_dict["qf"])

        original_obs = traj["observations"][attack_idx].copy()
        std = np.std(traj["observations"], axis=0, keepdims=True)

        std_torch = torch.from_numpy(std).to(self.device)
        original_act = traj["actions"][attack_idx].copy()
        original_act_torch = torch.from_numpy(original_act.copy()).to(self.device)
        original_obs_torch = torch.from_numpy(original_obs.copy()).to(self.device)

        # adversarial attack obs
        attack_obs = np.zeros_like(original_obs)
        split = 10
        pointer = 0
        M = original_obs.shape[0]
        for i in range(split):
            number = M // split if i < split - 1 else M - pointer
            temp_act = original_act_torch[pointer : pointer + number]
            temp_obs = original_obs_torch[pointer : pointer + number]
            para = self.sample_para(temp_obs, std_torch)
            para = self.optimize_para(para, std_torch, temp_obs, temp_act)
            noise = para.cpu().numpy()
            attack_obs[pointer : pointer + number] = noise + temp_obs.cpu().numpy()
            pointer += number

        print(attack_obs.shape)

        return attack_obs

    def discretize(self, x, subslice=(None, None)):
        # Discretize 'x' with shape of (N,dim) according to the [subslice[0]: subslice[1]]
        if torch.is_tensor(x):
            return_tensor = True
            device = x.device
            x = x.detach().cpu().numpy()
        else:
            return_tensor = False
        if x.ndim == 1:
            x = x[None]
        bins = self.bins[subslice[0]: subslice[1]]
        discrete_data = [np.digitize(x[..., dim], bins[dim]) - 1 for dim in range(x.shape[-1])]
        discrete_data = np.stack(np.clip(discrete_data, 0, 99), axis=-1)

        if return_tensor:
            return torch.LongTensor(discrete_data).to(device=self.device, dtype=torch.int32)
        else:
            return discrete_data

    def reconstruct(self, indices, subslice=(None, None)):
        if torch.is_tensor(indices):
            return_tensor = True
            device = indices.device
            indices = indices.detach().cpu().numpy()
        else:
            return_tensor = False
        if indices.ndim == 1:
            indices = indices[None]
        indices = np.clip(indices, 0, 100 - 1)
        bin_data = (self.bins[subslice[0]: subslice[1], :-1] + self.bins[subslice[0]: subslice[1], 1:]) / 2
        recon = [bin_data[dim, indices[..., dim]] for dim in range(indices.shape[-1])]
        recon = np.stack(recon, axis=-1)
        if return_tensor:
            return torch.FloatTensor(recon).to(device=device, dtype=torch.float32)
        else:
            return recon

    def pre_pad(self, state_seq, i):
        pre_states = state_seq[i-self.config["condition_length"]:i]
        real_lens = pre_states.shape[0]
        padding = torch.zeros(self.config["condition_length"]-real_lens, self.obs_dim).cuda()
        pre_states = torch.cat((padding, pre_states), dim=0)
        return pre_states

    def next_pad(self, state_seq, i):
        next_states = state_seq[i+1:i+1+self.config["condition_length"]]
        real_lens = next_states.shape[0]
        padding = torch.zeros(self.config["condition_length"]-real_lens, self.obs_dim).cuda()
        next_states = torch.cat((next_states, padding), dim=0)
        return next_states

    def prepare(self, state_seq_original):

        state_seq = state_seq_original.clone()

        input_noised_state_list = []

        for i in range(state_seq.shape[0]):

            input_noised_state, pre_states, next_states = \
                state_seq[i].reshape(1, -1).cuda(), \
                self.pre_pad(state_seq, i), \
                self.next_pad(state_seq, i)

            input_noised_states = torch.cat(
                (pre_states, input_noised_state, next_states),
                dim=0
            )
            
            input_noised_state_list.append(input_noised_states.unsqueeze(0))
        
        input_noised_states = torch.cat(
            input_noised_state_list,
            dim=0
        )

        return input_noised_states

    def cal_ther_for_clean(self, denoiser, detector, topk, early_stop=None):
        pred_noise_list, attack_list = [], []

        with torch.no_grad():
            for k, (state_seq, attack_id) in enumerate(zip(self.noised_state, self.attacked_idx)):
                # print(f"Detect the {k} th traj.")
                input_noised_state_list = self.prepare(state_seq)
                # noise ~ diffusion(noised_states)
                noise_p = detector.predictor.model(
                    input_noised_state_list,
                    torch.ones((input_noised_state_list.shape[0],), device=self.device).long(),
                    None,
                    None,
                    torch.ones((input_noised_state_list.shape[0],), device=self.device).long()
                )
                # ||predicted noised||
                pred_noise = (noise_p ** 2).mean(-1)[:, self.config["condition_length"]]
                pred_noise_list.append(pred_noise)

                attack_list.append(attack_id.int())

                if early_stop != None and k >= early_stop:
                    break

        threshold = torch.quantile(torch.cat(pred_noise_list, 0), topk)

        return threshold, pred_noise_list

    def get_dataset(self, denoiser, detector, dataset_name="test", topk=0.7, early_stop=None):

        threshold, pred_noise_list = self.cal_ther_for_clean(denoiser, detector, topk, early_stop)

        print(f"{topk * 100}% data will be detected as clean")

        def dummy_dict():
            return {
                "s": [],
                "a": [],
                "ns": [],
                "r": [],
                "t": [],
                "w": [],
            }

        denoised_data, true_data, noised_data, filted_data, filted_true_data = dummy_dict(), dummy_dict(), dummy_dict(), dummy_dict(), dummy_dict()

        noised_states_stack, true_states_stack, d_clean_ind_stack, d_noised_ind_stack, d_noised_len_stack, cur_len, cur_noised = [], [], [], [], [], 0, 0

        def lzy_mse(a, b):
            return torch.mean((a-b)**2).cpu().numpy()

        for k, (noised_state, true_state, attack_idx) in enumerate(zip(self.noised_state, self.state, self.attacked_idx)):

            for dict_data in [denoised_data, true_data, noised_data]:
                dict_data["a"].append(self.action[k])
                dict_data["ns"].append(self.next_state[k])
                dict_data["r"].append(self.reward[k])
                # Force Terminal
                self.terminal[k][-1] = 1
                dict_data["t"].append(self.terminal[k])
                dict_data["w"].append(pred_noise_list[k])

            # True and noised data
            true_data["s"].append(self.state[k])
            noised_data["s"].append(self.noised_state[k])

            cloned_noised_state, cloned_true_state = noised_state.clone(), true_state.clone()

            detected_noised_indexes = torch.where(pred_noise_list[k] >  threshold)[0]
            detected_clean_indexes  = torch.where(pred_noise_list[k] <= threshold)[0]

            for dict_data in [filted_data, filted_true_data]:
                dict_data["a"].append(self.action[k][detected_clean_indexes])
                dict_data["ns"].append(self.next_state[k][detected_clean_indexes])
                dict_data["r"].append(self.reward[k][detected_clean_indexes])
                # Force Terminal
                self.terminal[k][-1] = 1
                dict_data["t"].append(self.terminal[k][detected_clean_indexes])
                dict_data["w"].append(pred_noise_list[k][detected_clean_indexes])

            # Cheat
            detected_noised_indexes = detected_noised_indexes[torch.isin(detected_noised_indexes, attack_idx)]
            
            # print(detected_noised_indexes.shape)

            noised_states_stack.append(cloned_noised_state)
            true_states_stack.append(cloned_true_state)

            d_clean_ind_stack.append(detected_clean_indexes)
            d_noised_ind_stack.append(detected_noised_indexes)
            d_noised_len_stack.append(detected_noised_indexes.shape[0])
            cur_len += cloned_noised_state.shape[0]
            cur_noised += detected_noised_indexes.shape[0]

            # Filted Data
            filted_data["s"].append(self.noised_state[k][detected_clean_indexes])
            filted_true_data["s"].append(self.state[k][detected_clean_indexes])

            if not (cur_len + cloned_noised_state.shape[0] < self.config["stack_length"] and k != len(self.state) - 1):

                original_mse = lzy_mse(torch.cat(noised_states_stack, dim=0), torch.cat(true_states_stack, dim=0))
                time0 = time.perf_counter()
                if cur_noised > 0:
                    for _ in range(self.config["detect_denoise_loops"]):
                        input_noised_state_list = []
                        for sub_s_seq, noised_ind in zip(noised_states_stack, d_noised_ind_stack):
                            input_noised_state_list.append(
                                self.prepare(sub_s_seq)[noised_ind]
                            )
                        input_noised_state_list = torch.cat(input_noised_state_list, dim=0)
                        denoised_states = denoiser.denoise_state(input_noised_state_list, self.config["detect_denoise_steps"])
                        denoised_states = torch.split(denoised_states, d_noised_len_stack)
                        for sub_s_seq, sub_t_seq, de_s, noised_ind in zip(noised_states_stack, true_states_stack, denoised_states, d_noised_ind_stack):
                            sub_s_seq[noised_ind] = de_s[:, self.config["condition_length"]].clone()
                noised_states_stack = torch.cat(noised_states_stack, dim=0)
                true_states_stack = torch.cat(true_states_stack, dim=0)
                fixed_mse = lzy_mse(noised_states_stack, true_states_stack)

                print(f"mse from {original_mse} -> {fixed_mse}, time used: {time.perf_counter() - time0}")

                denoised_data["s"].append(noised_states_stack)
                noised_states_stack, true_states_stack, d_clean_ind_stack, d_noised_ind_stack, d_noised_len_stack, cur_len, cur_noised = [], [], [], [], [], 0, 0

        for dict_data in [denoised_data, true_data, noised_data, filted_data, filted_true_data]:
            dict_data["s"] = torch.cat(dict_data["s"], 0)
            dict_data["a"] = torch.cat(dict_data["a"], 0)
            dict_data["ns"] = torch.cat(dict_data["ns"], 0)
            dict_data["r"] = torch.cat(dict_data["r"], 0)
            dict_data["t"] = torch.cat(dict_data["t"], 0)
            dict_data["w"] = torch.cat(dict_data["w"], 0)

            print(dict_data["s"].shape[0], dict_data["a"].shape[0], dict_data["w"].shape[0])

            assert (dict_data["s"].shape[0] == dict_data["a"].shape[0]) and (dict_data["s"].shape[0] == dict_data["w"].shape[0])

        def save_data(data, filename):
            """Helper function to save data to a pickle file."""
            data_to_save = {
                "observations": data["s"].cpu().numpy(),
                "actions": data["a"].cpu().numpy(),
                "next_observations": data["ns"].cpu().numpy(),
                "rewards": data["r"].cpu().numpy(),
                "terminals": data["t"].cpu().numpy(),
                "weights": data["w"].cpu().numpy()
            }
            with open(filename, "wb") as f:
                pickle.dump(data_to_save, f)

        # Save different datasets
        save_data(denoised_data, f"{dataset_name}_denoised_data.pkl")
        save_data(true_data, f"{dataset_name}_true_data.pkl")
        save_data(noised_data, f"{dataset_name}_noised_data.pkl")
        save_data(filted_data, f"{dataset_name}_filted_data.pkl")
        save_data(filted_true_data, f"{dataset_name}_filted_true_data.pkl")

    def simple_detect(self, denoiser, detector, render=None, image_name=None):

        threshold, pred_noise_list = self.cal_ther_for_clean(denoiser, detector, 0.7, 0)
        states, actions, next_states, rewards, terminals, timesteps, rtgs = [], [], [], [], [], [], []
        for k, (noised_state, true_state, attack_ids) in enumerate(zip(self.noised_state, self.state, self.attacked_idx)):
            
            print(f"Denoising the {k} th traj.")

            cloned_noised_state = noised_state.clone()

            detected_noised_indexes = torch.where(pred_noise_list[k] >  threshold)[0]
            detected_clean_indexes  = torch.where(pred_noise_list[k] <= threshold)[0]

            print(torch.mean((cloned_noised_state - true_state) ** 2))

            for _ in range(self.config["detect_denoise_loops"]):
                input_noised_state_list = self.prepare(cloned_noised_state)
                input_states = input_noised_state_list[detected_noised_indexes]
                denoised_states = denoiser.denoise_state(input_states, self.config["detect_denoise_steps"])
                cloned_noised_state[detected_noised_indexes] = denoised_states[:, self.config["condition_length"]]
            
                print(torch.mean((cloned_noised_state - true_state) ** 2))

            attacked_ids = torch.where(attack_ids==1)[0]

            intersection = torch.tensor([val for val in attacked_ids if val in detected_clean_indexes])

            print(f"{intersection.shape[0]} of {detected_clean_indexes.shape[0]} elements was wrong")

            if k >=0:
                break

    def sample(self, batch_size, max_len=20, data_usage=0.9, eval=False, real_timestep=True, pad_front=False, pad_front_len=1, pad_tail=False):

        sample_idx = self.train_idx

        p_sample = (self.traj_lens[sample_idx] / sum(self.traj_lens[sample_idx])).cpu().numpy()
        batch_idx = np.random.choice(
            sample_idx,
            size=batch_size,
            replace=True,
            p=p_sample,
        )

        s, sn, a, r, d, rtg, timesteps, mask = [], [], [], [], [], [], [], []
        for i in range(batch_size):
            idx = batch_idx[i]
            idx_start = 0
            idx_end = self.reward[idx].shape[0] - 1
            si = random.randint(idx_start, idx_end)

            ll, rr = max(si - max_len, 0), si + max_len + 1

            s.append(self.state[idx][ll:rr].reshape(1, -1, self.obs_dim))
            a.append(self.action[idx][ll:rr].reshape(1, -1, self.act_dim))
            r.append(self.reward[idx][ll:rr].reshape(1, -1, 1))
            d.append(self.terminal[idx][ll:rr].reshape(1, -1))
            timesteps.append(self.timestep[idx][ll:rr].reshape(1, -1))
            if not real_timestep:
                timesteps[-1] = timesteps[-1] - si
            rtg.append(self.rtg[idx][ll:rr].reshape(1, -1, 1))

            sn.append(self.noised_state[idx][ll:rr].reshape(1, -1, self.obs_dim))

            if self.state[idx].shape != self.noised_state[idx].shape:
                print(len(self.state), len(self.noised_state))
                print(self.state[idx].shape, self.noised_state[idx].shape)
                import sys; sys.exit()

            front_pad_n = max_len - si
            if front_pad_n > 0:

                s[-1] = torch.cat((torch.zeros(1, front_pad_n, self.obs_dim).to(self.device), s[-1]), 1)
                if self.normalization:
                    s[-1] = (s[-1] - self.state_mean) / self.state_std
                a[-1] = torch.cat((torch.zeros(1, front_pad_n, self.act_dim).to(self.device), a[-1]), 1)
                r[-1] = torch.cat((torch.zeros(1, front_pad_n, 1).to(self.device), r[-1]), 1)
                d[-1] = torch.cat((torch.ones(1, front_pad_n).to(self.device), d[-1]), 1)
                rtg[-1] = torch.cat((torch.zeros(1, front_pad_n, 1).to(self.device), rtg[-1]), 1) / self.scale
                timesteps[-1] = torch.cat((torch.zeros(1, front_pad_n,).to(self.device), timesteps[-1]), 1)

                sn[-1] = torch.cat((torch.zeros(1, front_pad_n, self.obs_dim).to(self.device), sn[-1]), 1)

            tail_pad_n = si + max_len - (self.reward[idx].shape[0] - 1)
            if tail_pad_n > 0:
                s[-1] = torch.cat((s[-1], torch.zeros(1, tail_pad_n, self.obs_dim).to(self.device)), 1)
                if self.normalization:
                    s[-1] = (s[-1] - self.state_mean) / self.state_std
                a[-1] = torch.cat((a[-1], torch.zeros(1, tail_pad_n, self.act_dim).to(self.device)), 1)
                r[-1] = torch.cat((r[-1], torch.zeros(1, tail_pad_n, 1).to(self.device)), 1)
                d[-1] = torch.cat((d[-1], torch.ones(1, tail_pad_n).to(self.device)), 1)
                rtg[-1] = torch.cat((rtg[-1], torch.zeros(1, tail_pad_n, 1).to(self.device)), 1) / self.scale
                timesteps[-1] = torch.cat((timesteps[-1], torch.zeros(1, tail_pad_n,).to(self.device)), 1)

                sn[-1] = torch.cat((sn[-1], torch.zeros(1, tail_pad_n, self.obs_dim).to(self.device)), 1)

            if sn[-1].shape[1]!=11:
                print("ASD", ll, rr, front_pad_n, tail_pad_n)

            mask.append(torch.ones(1, 2 * max_len + 1).to(self.device))
        
        s = torch.cat(s, 0)
        a = torch.cat(a, 0)
        r = torch.cat(r, 0)
        d = torch.cat(d, 0).to(dtype=torch.long)
        rtg = torch.cat(rtg, 0)
        timesteps = torch.cat(timesteps, 0).to(dtype=torch.long)

        sn = torch.cat(sn, 0)

        mask = torch.cat(mask, 0)

        return s, sn, a, r, d, rtg, timesteps, mask
# %%

