# %% Part 0 D4RL Batch Buffer
import numpy as np

import os
import torch
import pickle
import random
from Utils import Batch_Class
from loguru import logger
np.set_printoptions(precision=4, linewidth=180, suppress=True)

def discount_cumsum(x, gamma):
    discount_cumsum = torch.zeros_like(x)
    discount_cumsum[-1] = x[-1]
    for t in reversed(range(x.shape[0] - 1)):
        discount_cumsum[t] = x[t] + gamma * discount_cumsum[t + 1]
    return discount_cumsum

import time
from tqdm import tqdm

import matplotlib.pyplot as plt

from typing import List, Optional, Tuple, Callable

from torch import nn

class Squeeze(nn.Module):
    def __init__(self, dim=-1):
        super().__init__()
        self.dim = dim

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x.squeeze(dim=self.dim)

class MLP(nn.Module):
    def __init__(
        self,
        dims,
        activation_fn: Callable[[], nn.Module] = nn.ReLU,
        output_activation_fn: Callable[[], nn.Module] = None,
        squeeze_output: bool = False,
        dropout: Optional[float] = None,
    ):
        super().__init__()
        n_dims = len(dims)
        if n_dims < 2:
            raise ValueError("MLP requires at least two dims (input and output)")

        layers = []
        for i in range(n_dims - 2):
            layers.append(nn.Linear(dims[i], dims[i + 1]))
            layers.append(activation_fn())
            if dropout is not None:
                layers.append(nn.Dropout(dropout))
        layers.append(nn.Linear(dims[-2], dims[-1]))
        if output_activation_fn is not None:
            layers.append(output_activation_fn())
        if squeeze_output:
            if dims[-1] != 1:
                raise ValueError("Last dim must be 1 when squeezing")
            layers.append(Squeeze(-1))
        self.net = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)

class TwinQ(nn.Module):
    def __init__(
        self, state_dim: int, action_dim: int, hidden_dim: int = 256, n_hidden: int = 2
    ):
        super().__init__()
        dims = [state_dim + action_dim, *([hidden_dim] * n_hidden), 1]
        self.q1 = MLP(dims, squeeze_output=True)
        self.q2 = MLP(dims, squeeze_output=True)

    def both(
        self, state: torch.Tensor, action: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        sa = torch.cat([state, action], 1)
        return self.q1(sa), self.q2(sa)

    def forward(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
        return torch.min(*self.both(state, action))

# %% Part 1 batch_buffer definition
class batch_buffer(object):
    def __init__(self, config, env_name0, dataset, device, input_type, buffer_mode='normal', buffer_normalization=False,
                 discretize=False, discrete_mode='average'):
        
        self.config = config

        if env_name0.find('-') >= 0:
            env_name = env_name0[:env_name0.find('-')].lower()
        else:
            env_name = env_name0

        version="v2"
        filter=""
        if env_name in ["relocate", "hammer", "door"]:
            version = "v0"
            filter = "_ratio_0.01"
        if env_name == "kitchen":
            version = "v0"

        dataset_path = f'adv_attack_element_full/{env_name0}-{dataset}-{version}_adv_all_0.3_1.0{filter}.pth'
            
        trajectories = []
        trajs = torch.load(dataset_path)
        
        def dummy_dict():
            return {
                "observations": [],
                "actions": [],
                "next_observations": [],
                "rewards": [],
                "terminals": [],
                "noised_observations": [],
                "noised_actions": [],
                "noised_rewards": [],
            }

        new_traj = dummy_dict()

        for i in range(trajs["original"]["timeouts"].shape[0] - 1):
            # print(i)
            if trajs["original"]["terminals"][i] == 1 or len(new_traj["terminals"]) >= 999:
                print(len(new_traj["terminals"]))
                new_traj["terminals"][-1] = 1
                for key in ["observations", "actions", "next_observations", "rewards", "terminals", "noised_observations", "noised_actions", "noised_rewards"]:
                    new_traj[key] = np.array(new_traj[key]) if "reward" not in key else np.array(new_traj[key]).reshape(-1, 1)
                trajectories.append(new_traj)
                new_traj = dummy_dict()
            else:
                new_traj["observations"].append(trajs["original"]["observations"][i])
                new_traj["actions"].append(trajs["original"]["actions"][i])
                new_traj["rewards"].append(trajs["original"]["rewards"][i])#.reshape(-1, 1))
                new_traj["next_observations"].append(trajs["original"]["observations"][i+1])

                new_traj["noised_observations"].append(trajs["corrupted"]["observations"][i])
                new_traj["noised_actions"].append(trajs["corrupted"]["actions"][i])
                new_traj["noised_rewards"].append(trajs["corrupted"]["rewards"][i])#.reshape(-1, 1))
                # new_traj["noised_next_observations"].append(trajs["corrupted"]["observations"][i+1])

                new_traj["terminals"].append(trajs["original"]["terminals"][i])
        states, actions, rewards, traj_lens, returns = [], [], [], [], []

        noised_states, noised_actions, noised_rewards = [], [], []

        attack_indexs = []

        for path in trajectories:
            states.append(path['observations'])
            actions.append(path['actions'])
            rewards.append(path['rewards'])

            traj_lens.append(len(path['observations']))

            returns.append(path['rewards'].sum())

            noised_states.append(path['noised_observations'])
            noised_actions.append(path['noised_actions'])
            noised_rewards.append(path['noised_rewards'])
            
            error = path['noised_observations'] - path['observations']
            error = error ** 2

            error = np.mean(error, -1)
            attack_indexs.append(np.where(error>0)[0])

        traj_lens, returns = np.array(traj_lens), np.array(returns)
        states = np.concatenate(states, axis=0)
        actions = np.concatenate(actions, axis=0)
        rewards = np.concatenate(rewards, axis=0)

        noised_states = np.concatenate(noised_states, axis=0)
        noised_actions = np.concatenate(noised_actions, axis=0)
        noised_rewards = np.concatenate(noised_rewards, axis=0)

        # attack_indexs = np.concatenate(attack_indexs, axis=0)

        sorted_idx = np.argsort(-returns)  # Total returns: from highest to lowest
        num_timesteps = sum(traj_lens)

        if env_name == 'hopper' or env_name == 'halfcheetah' or env_name == 'walker2d':
            self.scale = 1000.
        else:
            self.scale = 1000.
        self.env_name = env_name
        self.device = device
        self.normalization = buffer_normalization

        self.obs_dim = trajectories[0]['observations'].shape[1]
        self.act_dim = trajectories[0]['actions'].shape[1]

        self.max_action = 1.

        self.state_mean = torch.FloatTensor(np.mean(states, axis=0)).to(self.device)
        self.state_std = torch.FloatTensor(np.std(states, axis=0) + 1e-6).to(self.device)

        self.action_mean = torch.FloatTensor(np.mean(actions, axis=0)).to(self.device)
        self.action_std = torch.FloatTensor(np.std(actions, axis=0) + 1e-6).to(self.device)

        self.reward_mean = torch.FloatTensor(np.mean(rewards.reshape(-1, 1), axis=0)).to(self.device)
        self.reward_std = torch.FloatTensor(np.std(rewards.reshape(-1, 1), axis=0) + 1e-6).to(self.device)

        self.traj_lens = torch.FloatTensor(traj_lens[sorted_idx]).to(self.device)
        self.total_returns = torch.FloatTensor(returns[sorted_idx]).to(self.device)
        self.traj_nums = len(traj_lens)

        print('=' * 70)
        if buffer_normalization:
            print(f'Buffer Information: {env_name} {dataset} with Normalization on State')
        else:
            print(f'Buffer Information: {env_name} {dataset} without Normalization')
        print(f'{len(traj_lens)} trajectories, {num_timesteps} timesteps found')
        print(f'Average return: {np.mean(returns):.2f}, std: {np.std(returns):.2f}')
        print(f'Max return: {np.max(returns):.2f}, min: {np.min(returns):.2f}')
        print(f'State Mean: {self.state_mean.cpu().numpy()}')
        print(f'State Std : {self.state_std.cpu().numpy()}')
        print('=' * 70)

        self.state = []
        self.action = []
        self.next_state = []
        self.reward = []
        self.terminal = []
        self.timestep = []
        self.rtg = []

        return_to_gos = []

        self.noised_state = []
        self.noised_action = []
        self.noised_reward = []

        self.attacked_idx = []
        self.states_after_denoising = []
        self.filted_safe_idxes = []

        self._np_rng = np.random.RandomState(0)
        self._th_rng = torch.Generator()
        self._th_rng.manual_seed(0)

        self.dataset_type = dataset

        # Core: NOTE that the trajectory are listed from highest return to the lowest one
        for i in tqdm(range(len(trajectories))):

            traj = trajectories[sorted_idx[i]]
            self.state.append(torch.FloatTensor(traj['observations']).to(self.device))
            self.next_state.append(torch.FloatTensor(traj['next_observations']).to(self.device))
            self.action.append(torch.FloatTensor(traj['actions']).to(self.device))
            self.reward.append(torch.FloatTensor(traj['rewards']).to(self.device))
            self.terminal.append(torch.LongTensor(traj['terminals']).to(self.device))
            self.timestep.append(torch.LongTensor(np.arange(0, traj['observations'].shape[0])).to(self.device))
            self.rtg.append(discount_cumsum(self.reward[-1], gamma=1.))

            return_to_gos.append(self.rtg[i].cpu().numpy())

            self.noised_state.append(torch.FloatTensor(traj['noised_observations']).to(self.device))
            self.noised_action.append(torch.FloatTensor(traj['noised_actions']).to(self.device))
            self.noised_reward.append(torch.FloatTensor(traj['noised_rewards']).to(self.device))

            self.attacked_idx.append(torch.FloatTensor(attack_indexs[sorted_idx[i]]).to(self.device))

        return_to_gos = np.concatenate(return_to_gos, axis=0)

        # Note: Discretize subslice is in the sequence of obs_dim + act_dim + reward + return_to_gos (divided into 101 pieces)
        data = np.concatenate((states, actions, rewards.reshape(-1, 1), return_to_gos.reshape(-1, 1)), axis=1)
        if discrete_mode == 'average':
            self.bins = np.linspace(np.min(data, axis=0), np.max(data, axis=0), 101).T
        elif discrete_mode == 'percentile':
            self.bins = np.percentile(data, np.linspace(0, 100, 101), axis=0).T

        self.eval_idx = np.random.choice(np.arange(self.traj_nums), size=int(0.0 * self.traj_nums), replace=False)
        self.train_idx = np.delete(np.arange(self.traj_nums), self.eval_idx)
        print("Eval idx is set to empty.")

        self.mask_detected_clean_list = []

        self.input_type = input_type

    def sample_indexs(self, traj):
        indexs = np.arange(len(traj["rewards"]))
        random_num = self._np_rng.random(len(indexs))
        attacked = np.where(random_num < self.corruption_rate)[0]
        original = np.where(random_num >= self.corruption_rate)[0]
        return indexs[attacked], indexs[original]

    def loss_Q(self, para0, para1, observation, action, std_obs, std_act):
        noised_obs = observation + para0 * std_obs
        noised_act = action + para1 * std_act
        qvalue = self.critic(noised_obs, noised_act)
        return qvalue.mean()

    def discretize(self, x, subslice=(None, None)):
        # Discretize 'x' with shape of (N,dim) according to the [subslice[0]: subslice[1]]
        if torch.is_tensor(x):
            return_tensor = True
            device = x.device
            x = x.detach().cpu().numpy()
        else:
            return_tensor = False
        if x.ndim == 1:
            x = x[None]
        bins = self.bins[subslice[0]: subslice[1]]
        discrete_data = [np.digitize(x[..., dim], bins[dim]) - 1 for dim in range(x.shape[-1])]
        discrete_data = np.stack(np.clip(discrete_data, 0, 99), axis=-1)

        if return_tensor:
            return torch.LongTensor(discrete_data).to(device=self.device, dtype=torch.int32)
        else:
            return discrete_data

    def reconstruct(self, indices, subslice=(None, None)):
        if torch.is_tensor(indices):
            return_tensor = True
            device = indices.device
            indices = indices.detach().cpu().numpy()
        else:
            return_tensor = False
        if indices.ndim == 1:
            indices = indices[None]
        indices = np.clip(indices, 0, 100 - 1)
        bin_data = (self.bins[subslice[0]: subslice[1], :-1] + self.bins[subslice[0]: subslice[1], 1:]) / 2
        recon = [bin_data[dim, indices[..., dim]] for dim in range(indices.shape[-1])]
        recon = np.stack(recon, axis=-1)
        if return_tensor:
            return torch.FloatTensor(recon).to(device=device, dtype=torch.float32)
        else:
            return recon

    def pre_pad(self, element_seq, i):
        input_dim = self.obs_dim + self.act_dim + 1

        pre_elements = element_seq[i - self.config["condition_length"]:i]
        real_lens = pre_elements.shape[0]
        
        padding = torch.zeros(max(self.config["condition_length"]-real_lens, 0), input_dim).cuda()
        pre_elements = torch.cat((padding, pre_elements), dim=0)

        return pre_elements

    def next_pad(self, element_seq, i):
        input_dim = self.obs_dim + self.act_dim + 1

        next_elements = element_seq[i + 1:i + 1 + self.config["condition_length"]]
        real_lens = next_elements.shape[0]

        padding = torch.zeros(max(self.config["condition_length"]-real_lens, 0), input_dim).cuda()
        next_elements = torch.cat((next_elements, padding), dim=0)

        return next_elements

    def prepare(self, element_seq_original):
        element_seq = element_seq_original.clone()
        input_noised_elements_list = []
        for i in range(element_seq.shape[0]):
            input_noised_element, pre_elements, next_elements = \
                element_seq[i].reshape(1, -1).cuda(), \
                self.pre_pad(element_seq, i,), \
                self.next_pad(element_seq, i,)
            input_noised_elements = torch.cat(
                (pre_elements, input_noised_element, next_elements),
                dim=0
            )
            input_noised_elements_list.append(input_noised_elements.unsqueeze(0))
        input_noised_elements = torch.cat(
            input_noised_elements_list,
            dim=0
        )
        return input_noised_elements

    def cal_ther_for_clean(self, denoiser, detector, topk, early_stop=None):
        pred_noise_list = []

        with torch.no_grad():
            for k, (state_seq, action_seq, reward_seq) in enumerate(zip(self.noised_state, self.noised_action, self.noised_reward)):

                # print(state_seq.shape, action_seq.shape, reward_seq.shape)

                input_noised_list = torch.cat((state_seq, action_seq, reward_seq), dim=-1)
                input_noised_list = self.prepare(input_noised_list)

                # For filting, timesteps is always 1
                noise_p = detector.predictor.model(
                    input_noised_list,
                    # Step = 1
                    31 * torch.ones((input_noised_list.shape[0],), device=self.device).long(),
                )

                # ||predicted noised||
                pred_noise = (noise_p ** 2).mean(-1)[:, self.config["condition_length"]]
                pred_noise_list.append(pred_noise)

                if early_stop != None and k >= early_stop:
                    break

        threshold = torch.quantile(torch.cat(pred_noise_list, 0), topk)

        return threshold, pred_noise_list

    def get_dataset(self, denoiser, detector, dataset_name="test", topk=0.7, early_stop=None):

        threshold, pred_noise_list = self.cal_ther_for_clean(denoiser, detector, topk, early_stop)

        print(f"{topk * 100}% data will be detected as clean")

        def dummy_dict():
            return {
                "s": [],
                "a": [],
                "ns": [],
                "r": [],
                "t": [],
                "w": [],
            }

        denoised_data, noised_data = dummy_dict(), dummy_dict()

        noised_states_stack,     true_states_stack = [], []
        noised_actions_stack,    true_actions_stack = [], []
        noised_rewards_stack,    true_rewards_stack = [], []
        noised_transition_stack, true_transition_stack = [], []

        d_clean_ind_stack, d_noised_ind_stack, d_noised_len_stack = [], [], []
        
        cur_len, cur_noised = 0, 0        

        def lzy_mse(a, b):
            return torch.mean((a-b)**2).cpu().numpy()

        for k, (noised_state, noised_action, noised_reward, true_state, true_action, true_reward, attack_idx) in enumerate(zip(self.noised_state, self.noised_action, self.noised_reward, self.state, self.action, self.reward, self.attacked_idx)):

            for dict_data in [denoised_data, noised_data]:
                # s, a, r will not be included
                dict_data["ns"].append(self.next_state[k])

                # Force Terminal
                self.terminal[k][-1] = 1

                dict_data["t"].append(self.terminal[k])
                dict_data["w"].append(pred_noise_list[k])

            cloned_noised_state,  cloned_true_state  = noised_state.clone(),  true_state.clone()
            cloned_noised_action, cloned_true_action = noised_action.clone(), true_action.clone()
            cloned_noised_reward, cloned_true_reward = noised_reward.clone(), true_reward.clone()

            detected_noised_indexes = torch.where(pred_noise_list[k] >  threshold)[0]
            detected_clean_indexes  = torch.where(pred_noise_list[k] <= threshold)[0]

            the_attack_idx = attack_idx.cuda()

            # Clean the noised indexes
            detected_noised_indexes = detected_noised_indexes[torch.isin(detected_noised_indexes, the_attack_idx)]
            
            noised_states_stack.append(cloned_noised_state)
            true_states_stack.append(cloned_true_state)

            noised_actions_stack.append(cloned_noised_action)
            true_actions_stack.append(cloned_true_action)

            noised_rewards_stack.append(cloned_noised_reward)
            true_rewards_stack.append(cloned_true_reward)

            # Noised dataset
            noised_data["s"].append(cloned_noised_state.clone())
            noised_data["a"].append(cloned_noised_action.clone())
            noised_data["r"].append(cloned_noised_reward.clone())

            noised_transition_stack.append(torch.cat((cloned_noised_state, cloned_noised_action, cloned_noised_reward), dim=-1))

            true_transition_stack.append(torch.cat((cloned_true_state, cloned_true_action, cloned_true_reward.reshape(-1, 1)),dim=-1))

            d_clean_ind_stack.append(detected_clean_indexes)
            d_noised_ind_stack.append(detected_noised_indexes)
            d_noised_len_stack.append(detected_noised_indexes.shape[0])

            cur_len += cloned_noised_state.shape[0]
            cur_noised += detected_noised_indexes.shape[0]

            if not (cur_len + cloned_noised_state.shape[0] < self.config["stack_length"] and k != len(self.state) - 1):

                original_mse = lzy_mse(
                    torch.cat(noised_transition_stack, dim=0), torch.cat(true_transition_stack, dim=0)
                )

                time0 = time.perf_counter()

                if cur_noised > 0:
                    for _ in range(self.config["detect_denoise_loops"]):
                        input_noised_element_list = []
                        transition_conditions_list = []
                        for sub_s_seq, sub_a_seq, sub_r_seq, sub_t_s_seq, sub_t_a_seq, sub_t_r_seq, sub_transition_seq, sub_t_transition_seq, noised_ind in zip(noised_states_stack, noised_actions_stack, noised_rewards_stack, true_states_stack, true_actions_stack, true_rewards_stack, noised_transition_stack, true_transition_stack, d_noised_ind_stack):
                            input_noised_element_list.append(
                                self.prepare(sub_transition_seq)[noised_ind]
                            )
                        input_noised_element_list = torch.cat(input_noised_element_list, dim=0)

                        denoised_elements = denoiser.denoise_element(
                            input_noised_element_list,
                            self.config["detect_denoise_steps"]
                        )

                        denoised_elements = torch.split(denoised_elements, d_noised_len_stack)

                        for sub_s_seq, sub_a_seq, sub_r_seq, sub_t_seq, de_s, noised_ind in zip(noised_states_stack, noised_actions_stack, noised_rewards_stack, true_states_stack, denoised_elements, d_noised_ind_stack):
                            sub_s_seq[noised_ind] = de_s[:, self.config["condition_length"]][:,:self.obs_dim].clone()
                            sub_a_seq[noised_ind] = de_s[:, self.config["condition_length"]][:,self.obs_dim:self.obs_dim+self.act_dim].clone()
                            sub_r_seq[noised_ind] = de_s[:, self.config["condition_length"]][:,self.obs_dim+self.act_dim:].clone()

                noised_states_stack = torch.cat(noised_states_stack, dim=0)
                true_states_stack = torch.cat(true_states_stack, dim=0)

                noised_actions_stack = torch.cat(noised_actions_stack, dim=0)
                true_actions_stack = torch.cat(true_actions_stack, dim=0)

                noised_rewards_stack = torch.cat(noised_rewards_stack, dim=0)
                true_rewards_stack = torch.cat(true_rewards_stack, dim=0)

                noised_transition_stack = torch.cat(
                    (
                        noised_states_stack,
                        noised_actions_stack,
                        noised_rewards_stack,
                    ),
                    dim=-1
                )

                denoised_data["s"].append(noised_states_stack)
                denoised_data["a"].append(noised_actions_stack)
                denoised_data["r"].append(noised_rewards_stack)

                fixed_mse = lzy_mse(
                    noised_transition_stack, 
                    torch.cat(true_transition_stack, dim=0)
                )

                print(f"mse from {original_mse} -> {fixed_mse}, time used: {time.perf_counter() - time0}")

                noised_actions_stack, true_actions_stack = [], []
                noised_states_stack, true_states_stack = [], []
                noised_rewards_stack, true_rewards_stack = [], []

                d_clean_ind_stack, d_noised_ind_stack, d_noised_len_stack = [], [], []
                cur_len, cur_noised = 0, 0

                noised_transition_stack, true_transition_stack = [], []

        for dict_data in [denoised_data, noised_data]:

            dict_data["s"] = torch.cat(dict_data["s"], 0)
            dict_data["a"] = torch.cat(dict_data["a"], 0)
            dict_data["ns"] = torch.cat(dict_data["ns"], 0)
            dict_data["r"] = torch.cat(dict_data["r"], 0)
            dict_data["t"] = torch.cat(dict_data["t"], 0)
            dict_data["w"] = torch.cat(dict_data["w"], 0)

            print(dict_data["s"].shape[0], dict_data["a"].shape[0], dict_data["w"].shape[0])

            assert (dict_data["s"].shape[0] == dict_data["a"].shape[0]) and (dict_data["s"].shape[0] == dict_data["w"].shape[0])

        def save_data(data, filename):
            """Helper function to save data to a pickle file."""
            data_to_save = {
                "observations": data["s"].cpu().numpy(),
                "actions": data["a"].cpu().numpy(),
                "next_observations": data["ns"].cpu().numpy(),
                "rewards": data["r"].cpu().numpy(),
                "terminals": data["t"].cpu().numpy(),
                "weights": data["w"].cpu().numpy()
            }
            with open(filename, "wb") as f:
                pickle.dump(data_to_save, f)

        # Save different datasets
        save_data(denoised_data, f"{dataset_name}_denoised_data.pkl")
        save_data(noised_data, f"{dataset_name}_noised_data.pkl")

    def simple_detect(self, denoiser, detector, render=None, image_name=None):

        threshold, pred_noise_list = self.cal_ther_for_clean(denoiser, detector, 0.7, 0)

        for k, (noised_state, noised_action, noised_reward, true_state, true_action, true_reward, attack_ids) in enumerate(zip(self.noised_state, self.noised_action, self.noised_reward, self.state, self.action, self.reward, self.attacked_idx)):

            cloned_noised_state, cloned_noised_action, cloned_noised_reward = noised_state.clone(), noised_action.clone(), noised_reward.clone()

            input_noised_elements = torch.cat((cloned_noised_state, cloned_noised_action, cloned_noised_reward), dim=-1)

            true_elements = torch.cat((true_state, true_action, true_reward.reshape(-1, 1)), dim=-1)
            original_mse = torch.mean((input_noised_elements - true_elements)**2)

            print(f"Original MSE: {original_mse}")

            detected_noised_indexes = torch.where(pred_noise_list[k] >  threshold)[0]
            detected_clean_indexes  = torch.where(pred_noise_list[k] <= threshold)[0]

            the_attack_idx = attack_ids.cuda()

            detected_noised_indexes = detected_noised_indexes[torch.isin(detected_noised_indexes, the_attack_idx)]

            for _ in range(self.config["detect_denoise_loops"]):

                input_noised_elements = torch.cat((cloned_noised_state, cloned_noised_action, cloned_noised_reward), dim=-1)
                input_noised_elements_list = self.prepare(input_noised_elements)
                input_elements = input_noised_elements_list[detected_noised_indexes]

                denoised_elements = denoiser.denoise_element(
                    input_elements,
                    self.config["detect_denoise_steps"]
                )
                input_noised_elements[detected_noised_indexes] = denoised_elements[:, self.config["condition_length"]]

                cloned_noised_state[detected_noised_indexes]  = input_noised_elements[detected_noised_indexes][:,:self.obs_dim]
                cloned_noised_action[detected_noised_indexes] = input_noised_elements[detected_noised_indexes][:,self.obs_dim:self.obs_dim+self.act_dim]
                cloned_noised_reward[detected_noised_indexes] = input_noised_elements[detected_noised_indexes][:,self.obs_dim+self.act_dim:]

                input_noised_elements = torch.cat((cloned_noised_state, cloned_noised_action, cloned_noised_reward), dim=-1)

                mse = torch.mean((input_noised_elements - true_elements)**2)

                print(f"Fixed MSE in loop: {mse}")

            attacked_ids = attack_ids

            intersection = torch.tensor([val for val in attacked_ids if val in detected_clean_indexes])
            print(f"{intersection.shape[0]} of {detected_clean_indexes.shape[0]} elements was wrong")

            if k >=0:
                break

    def sample(self, batch_size, max_len=20, data_usage=0.9, eval=False, real_timestep=True, pad_front=False, pad_front_len=1, pad_tail=False):

        sample_idx = self.train_idx

        p_sample = (self.traj_lens[sample_idx] / sum(self.traj_lens[sample_idx])).cpu().numpy()
        batch_idx = np.random.choice(
            sample_idx,
            size=batch_size,
            replace=True,
            p=p_sample,
        )

        s, sn, a, an, r, rn, d, rtg, timesteps, mask = [], [], [], [], [], [], [], [], [], []
        for i in range(batch_size):
            idx = batch_idx[i]
            idx_start = 0
            idx_end = self.reward[idx].shape[0] - 1
            si = random.randint(idx_start, idx_end)

            ll, rr = max(si - max_len, 0), si + max_len + 1

            s.append(self.state[idx][ll:rr].reshape(1, -1, self.obs_dim))
            a.append(self.action[idx][ll:rr].reshape(1, -1, self.act_dim))
            r.append(self.reward[idx][ll:rr].reshape(1, -1, 1))
            d.append(self.terminal[idx][ll:rr].reshape(1, -1))
            timesteps.append(self.timestep[idx][ll:rr].reshape(1, -1))
            if not real_timestep:
                timesteps[-1] = timesteps[-1] - si
            rtg.append(self.rtg[idx][ll:rr].reshape(1, -1, 1))

            sn.append(self.noised_state[idx][ll:rr].reshape(1, -1, self.obs_dim))
            an.append(self.noised_action[idx][ll:rr].reshape(1, -1, self.act_dim))
            rn.append(self.noised_reward[idx][ll:rr].reshape(1, -1, 1))

            front_pad_n = max_len - si
            if front_pad_n > 0:

                s[-1] = torch.cat((torch.zeros(1, front_pad_n, self.obs_dim).to(self.device), s[-1]), 1)
                if self.normalization:
                    s[-1] = (s[-1] - self.state_mean) / self.state_std
                a[-1] = torch.cat((torch.zeros(1, front_pad_n, self.act_dim).to(self.device), a[-1]), 1)
                r[-1] = torch.cat((torch.zeros(1, front_pad_n, 1).to(self.device), r[-1]), 1)
                d[-1] = torch.cat((torch.ones(1, front_pad_n).to(self.device), d[-1]), 1)
                rtg[-1] = torch.cat((torch.zeros(1, front_pad_n, 1).to(self.device), rtg[-1]), 1) / self.scale
                timesteps[-1] = torch.cat((torch.zeros(1, front_pad_n,).to(self.device), timesteps[-1]), 1)

                sn[-1] = torch.cat((torch.zeros(1, front_pad_n, self.obs_dim).to(self.device), sn[-1]), 1)
                an[-1] = torch.cat((torch.zeros(1, front_pad_n, self.act_dim).to(self.device), an[-1]), 1)
                rn[-1] = torch.cat((torch.zeros(1, front_pad_n, 1).to(self.device), rn[-1]), 1)

            tail_pad_n = si + max_len - (self.reward[idx].shape[0] - 1)
            if tail_pad_n > 0:
                s[-1] = torch.cat((s[-1], torch.zeros(1, tail_pad_n, self.obs_dim).to(self.device)), 1)
                if self.normalization:
                    s[-1] = (s[-1] - self.state_mean) / self.state_std
                a[-1] = torch.cat((a[-1], torch.zeros(1, tail_pad_n, self.act_dim).to(self.device)), 1)
                r[-1] = torch.cat((r[-1], torch.zeros(1, tail_pad_n, 1).to(self.device)), 1)
                d[-1] = torch.cat((d[-1], torch.ones(1, tail_pad_n).to(self.device)), 1)
                rtg[-1] = torch.cat((rtg[-1], torch.zeros(1, tail_pad_n, 1).to(self.device)), 1) / self.scale
                timesteps[-1] = torch.cat((timesteps[-1], torch.zeros(1, tail_pad_n,).to(self.device)), 1)

                sn[-1] = torch.cat((sn[-1], torch.zeros(1, tail_pad_n, self.obs_dim).to(self.device)), 1)
                an[-1] = torch.cat((an[-1], torch.zeros(1, tail_pad_n, self.act_dim).to(self.device)), 1)
                rn[-1] = torch.cat((rn[-1], torch.zeros(1, tail_pad_n, 1).to(self.device)), 1)

            mask.append(torch.ones(1, 2 * max_len + 1).to(self.device))
        
        s = torch.cat(s, 0)
        a = torch.cat(a, 0)
        r = torch.cat(r, 0)
        d = torch.cat(d, 0).to(dtype=torch.long)
        rtg = torch.cat(rtg, 0)
        timesteps = torch.cat(timesteps, 0).to(dtype=torch.long)

        sn = torch.cat(sn, 0)
        an = torch.cat(an, 0)
        rn = torch.cat(rn, 0)

        mask = torch.cat(mask, 0)

        return s, sn, a, an, r, rn, d, rtg, timesteps, mask
# %%

