# %% Part 0 D4RL Batch Buffer
import numpy as np

import os
import torch
import pickle
import random
from Utils import Batch_Class
from loguru import logger
np.set_printoptions(precision=4, linewidth=180, suppress=True)

def discount_cumsum(x, gamma):
    discount_cumsum = torch.zeros_like(x)
    discount_cumsum[-1] = x[-1]
    for t in reversed(range(x.shape[0] - 1)):
        discount_cumsum[t] = x[t] + gamma * discount_cumsum[t + 1]
    return discount_cumsum

import time

import matplotlib.pyplot as plt

# %% Part 1 batch_buffer definition
class batch_buffer(object):
    def __init__(self, config, env_name0, dataset, device, input_type, buffer_mode='normal', buffer_normalization=False,
                 discretize=False, discrete_mode='average'):
        
        self.config = config

        if env_name0.find('-') >= 0:
            env_name = env_name0[:env_name0.find('-')].lower()
        else:
            env_name = env_name0
        # else:
        #     raise FileExistsError(f"Please check the input Env Name: {env_name0}")

        if env_name == 'halfcheetah' or env_name == 'walker2d' or env_name == 'hopper':
            dataset_path = f'datasets/{env_name}-{dataset}-v2.pkl'
        elif env_name == 'door' or env_name == 'hammer' or env_name == 'relocate' or env_name == 'pen':
            dataset_path = f'datasets/{env_name}-{dataset}-v1.pkl'
        elif env_name == 'kitchen' or env_name == 'antmaze':
            dataset_path = f'datasets/{env_name}-{dataset}-v0.pkl'

        if config["dataset_path"]:
            if "ratio" not in config["dataset_path"]:
                trajectories = []
                with open(config["dataset_path"], 'rb') as f:
                    trajs = torch.load(f)
                
                def dummy_dict():
                    return {
                        "observations": [],
                        "actions": [],
                        "next_observations": [],
                        "rewards": [],
                        "terminals": [],
                    }

                new_traj = dummy_dict()

                for i in range(trajs["rewards"].shape[0] - 1):
                    if trajs["timeouts"][i] == 1:
                        new_traj["terminals"][-1] = 1
                        for key in ["observations", "actions", "next_observations", "rewards", "terminals"]:
                            new_traj[key] = np.array(new_traj[key])
                        trajectories.append(new_traj)
                        new_traj = dummy_dict()
                    else:
                        new_traj["observations"].append(trajs["observations"][i])
                        new_traj["actions"].append(trajs["actions"][i])
                        new_traj["next_observations"].append(trajs["observations"][i+1])
                        new_traj["rewards"].append(trajs["rewards"][i])
                        new_traj["terminals"].append(trajs["timeouts"][i])
            else:
                print("RATIO!!!")
                with open(config["dataset_path"], 'rb') as f:
                    trajectories = torch.load(f)
                # print(trajectories)
                # for cnt, traj in enumerate(trajectories):
                #     print(traj["rewards"].shape[0], cnt)
                # import sys; sys.exit()
        else:
            with open(dataset_path, 'rb') as f:
                trajectories = pickle.load(f)

        states, actions, rewards, traj_lens, returns = [], [], [], [], []
        for path in trajectories:
            states.append(path['observations'])
            actions.append(path['actions'])
            rewards.append(path['rewards'])
            traj_lens.append(len(path['observations']))
            returns.append(path['rewards'].sum())
        traj_lens, returns = np.array(traj_lens), np.array(returns)
        states = np.concatenate(states, axis=0)
        actions = np.concatenate(actions, axis=0)
        rewards = np.concatenate(rewards, axis=0)
        sorted_idx = np.argsort(-returns)  # Total returns: from highest to lowest
        num_timesteps = sum(traj_lens)

        if env_name == 'hopper' or env_name == 'halfcheetah' or env_name == 'walker2d':
            self.scale = 1000.
        else:
            self.scale = 1000.
        self.env_name = env_name
        self.device = device
        self.normalization = buffer_normalization

        self.obs_dim = trajectories[0]['observations'].shape[1]
        self.act_dim = trajectories[0]['actions'].shape[1]

        self.max_action = 1.

        self.state_mean = torch.FloatTensor(np.mean(states, axis=0)).to(self.device)
        self.state_std = torch.FloatTensor(np.std(states, axis=0) + 1e-6).to(self.device)

        self.action_mean = torch.FloatTensor(np.mean(actions, axis=0)).to(self.device)
        self.action_std = torch.FloatTensor(np.std(actions, axis=0) + 1e-6).to(self.device)

        self.reward_mean = torch.FloatTensor(np.mean(rewards.reshape(-1, 1), axis=0)).to(self.device)
        self.reward_std = torch.FloatTensor(np.std(rewards.reshape(-1, 1), axis=0) + 1e-6).to(self.device)

        self.traj_lens = torch.FloatTensor(traj_lens[sorted_idx]).to(self.device)
        self.total_returns = torch.FloatTensor(returns[sorted_idx]).to(self.device)
        self.traj_nums = len(traj_lens)

        print('=' * 70)
        if buffer_normalization:
            print(f'Buffer Information: {env_name} {dataset} with Normalization on State')
        else:
            print(f'Buffer Information: {env_name} {dataset} without Normalization')
        print(f'{len(traj_lens)} trajectories, {num_timesteps} timesteps found')
        print(f'Average return: {np.mean(returns):.2f}, std: {np.std(returns):.2f}')
        print(f'Max return: {np.max(returns):.2f}, min: {np.min(returns):.2f}')
        print(f'State Mean: {self.state_mean.cpu().numpy()}')
        print(f'State Std : {self.state_std.cpu().numpy()}')
        print('=' * 70)

        self.state = []
        self.action = []
        self.next_state = []
        self.reward = []
        self.terminal = []
        self.timestep = []
        self.rtg = []

        return_to_gos = []

        self.noised_state = []
        self.noised_action = []
        self.noised_reward = []

        self.attacked_idx = []
        self.states_after_denoising = []
        self.filted_safe_idxes = []

        # Core: NOTE that the trajectory are listed from highest return to the lowest one
        for i in range(len(trajectories)):

            traj = trajectories[sorted_idx[i]]
            self.state.append(torch.FloatTensor(traj['observations']).to(self.device))
            self.next_state.append(torch.FloatTensor(traj['next_observations']).to(self.device))
            self.action.append(torch.FloatTensor(traj['actions']).to(self.device))
            self.reward.append(torch.FloatTensor(traj['rewards']).to(self.device))
            self.terminal.append(torch.LongTensor(traj['terminals']).to(self.device))
            self.timestep.append(torch.LongTensor(np.arange(0, traj['observations'].shape[0])).to(self.device))
            self.rtg.append(discount_cumsum(self.reward[-1], gamma=1.))

            return_to_gos.append(self.rtg[i].cpu().numpy())

            state_size = self.state[-1].shape
            action_size = self.action[-1].shape
            reward_size = self.reward[-1].shape

            not_attack_size = int(state_size[0] * (1 - config["attack_ratio"]))

            tensor_state = torch.cat([
                torch.zeros(not_attack_size, state_size[1]),
                torch.rand(size=(state_size[0] - not_attack_size, state_size[1])) * 2 - 1
            ])
            tensor_action = torch.cat([
                torch.zeros(not_attack_size, action_size[1]),
                torch.rand(size=(action_size[0] - not_attack_size, action_size[1])) * 2 - 1
            ])
            tensor_reward = torch.cat([
                torch.zeros(not_attack_size, 1),
                torch.rand(size=(reward_size[0] - not_attack_size, 1)) * 2 - 1
            ])

            attack_indexes = torch.randperm(state_size[0])

            attack_tensor_state = tensor_state[attack_indexes].to(self.device)
            
            # Get the attacked index. This part is not that elegant
            attack_tensor_2 = (attack_tensor_state ** 2).mean(-1)
            attack_idx = torch.where(attack_tensor_2 > 1e-4, 1, 0).cpu()

            attack_tensor_action = tensor_action[attack_indexes].to(self.device)
            attack_tensor_reward = tensor_reward[attack_indexes].to(self.device)

            if config["attack_element"] in ["state", "transition"]:
                # print("Attack State")
                self.noised_state.append(torch.FloatTensor(traj['observations']).to(self.device) + attack_tensor_state * self.state_std)
            else:
                self.noised_state.append(torch.FloatTensor(traj['observations']).to(self.device))
            
            if config["attack_element"] in ["action", "transition"]:
                # print("Attack Action")
                self.noised_action.append(torch.FloatTensor(traj['actions']).to(self.device) + attack_tensor_action * self.action_std)
            else:
                self.noised_action.append(torch.FloatTensor(traj['actions']).to(self.device))
            
            if config["attack_element"] in ["reward", "transition"]:
                # print("Attack Reward")
                self.noised_reward.append(torch.FloatTensor(traj['rewards']).to(self.device).reshape(-1, 1) + attack_tensor_reward * self.reward_std)
            else:
                self.noised_reward.append(torch.FloatTensor(traj['rewards']).to(self.device).reshape(-1, 1))

            self.attacked_idx.append(attack_idx)

        return_to_gos = np.concatenate(return_to_gos, axis=0)

        # Note: Discretize subslice is in the sequence of obs_dim + act_dim + reward + return_to_gos (divided into 101 pieces)
        data = np.concatenate((states, actions, rewards.reshape(-1, 1), return_to_gos.reshape(-1, 1)), axis=1)
        if discrete_mode == 'average':
            self.bins = np.linspace(np.min(data, axis=0), np.max(data, axis=0), 101).T
        elif discrete_mode == 'percentile':
            self.bins = np.percentile(data, np.linspace(0, 100, 101), axis=0).T
        # self.discrete_data =

        if discretize:
            self.dsct_state = []
            self.dsct_action = []
            self.dsct_next_state = []
            self.dsct_reward = []
            self.dsct_rtg = []
            for i in range(len(trajectories)):
                traj = trajectories[sorted_idx[i]]
                traj_state = traj['observations']
                dsct_traj_state = [np.digitize(traj_state[..., dim], self.bins[dim]) - 1
                                   for dim in np.arange(0, self.obs_dim)]
                dsct_traj_state = np.stack(np.clip(dsct_traj_state, 0, 101 - 1), axis=-1)
                self.dsct_state.append(torch.LongTensor(dsct_traj_state).to(self.device))

                traj_action = traj['actions']
                dsct_traj_action = [np.digitize(traj_action[..., dim-self.obs_dim], self.bins[dim]) - 1
                                    for dim in np.arange(self.obs_dim, self.obs_dim + self.act_dim)]
                dsct_traj_action = np.stack(np.clip(dsct_traj_action, 0, 101 - 1), axis=-1)
                self.dsct_action.append(torch.LongTensor(dsct_traj_action).to(self.device))

                traj_reward = traj['rewards']
                dsct_traj_reward = np.digitize(traj_reward[...], self.bins[-2]) - 1
                dsct_traj_reward = np.clip(dsct_traj_reward, 0, 101 - 1)
                self.dsct_reward.append(torch.LongTensor(dsct_traj_reward).to(self.device))

                traj_rtg = discount_cumsum(torch.FloatTensor(traj_reward), gamma=1.).numpy()
                dsct_traj_rtg = np.digitize(traj_rtg[...], self.bins[-1]) - 1
                dsct_traj_rtg = np.clip(dsct_traj_rtg, 0, 101 - 1)
                self.dsct_rtg.append(torch.LongTensor(dsct_traj_rtg).to(self.device))

        self.eval_idx = np.random.choice(np.arange(self.traj_nums), size=int(0.0 * self.traj_nums), replace=False)
        print("Eval idx is set to empty.")

        self.train_idx = np.delete(np.arange(self.traj_nums), self.eval_idx)

        # idxs for clean dataset
        self.mask_detected_clean_list = []
        # self.p_clean_indices = []
        # self.p_clean_lens = []

        logger.info('Data Buffer has been initialized!')
        logger.info('Please be noted that the trajectory are listed by return from high to low!')

        self.input_type = input_type

    def discretize(self, x, subslice=(None, None)):
        # Discretize 'x' with shape of (N,dim) according to the [subslice[0]: subslice[1]]
        if torch.is_tensor(x):
            return_tensor = True
            device = x.device
            x = x.detach().cpu().numpy()
        else:
            return_tensor = False
        if x.ndim == 1:
            x = x[None]
        bins = self.bins[subslice[0]: subslice[1]]
        discrete_data = [np.digitize(x[..., dim], bins[dim]) - 1 for dim in range(x.shape[-1])]
        discrete_data = np.stack(np.clip(discrete_data, 0, 99), axis=-1)

        if return_tensor:
            return torch.LongTensor(discrete_data).to(device=self.device, dtype=torch.int32)
        else:
            return discrete_data

    def reconstruct(self, indices, subslice=(None, None)):
        if torch.is_tensor(indices):
            return_tensor = True
            device = indices.device
            indices = indices.detach().cpu().numpy()
        else:
            return_tensor = False
        if indices.ndim == 1:
            indices = indices[None]
        indices = np.clip(indices, 0, 100 - 1)
        bin_data = (self.bins[subslice[0]: subslice[1], :-1] + self.bins[subslice[0]: subslice[1], 1:]) / 2
        recon = [bin_data[dim, indices[..., dim]] for dim in range(indices.shape[-1])]
        recon = np.stack(recon, axis=-1)
        if return_tensor:
            return torch.FloatTensor(recon).to(device=device, dtype=torch.float32)
        else:
            return recon

    def pre_pad(self, element_seq, i):
        input_dim = self.obs_dim + self.act_dim + 1

        pre_elements = element_seq[i - self.config["condition_length"]:i]
        real_lens = pre_elements.shape[0]
        
        padding = torch.zeros(max(self.config["condition_length"]-real_lens, 0), input_dim).cuda()
        pre_elements = torch.cat((padding, pre_elements), dim=0)

        return pre_elements

    def next_pad(self, element_seq, i):
        input_dim = self.obs_dim + self.act_dim + 1

        next_elements = element_seq[i + 1:i + 1 + self.config["condition_length"]]
        real_lens = next_elements.shape[0]

        padding = torch.zeros(max(self.config["condition_length"]-real_lens, 0), input_dim).cuda()
        next_elements = torch.cat((next_elements, padding), dim=0)

        return next_elements

    def prepare(self, element_seq_original):
        element_seq = element_seq_original.clone()
        input_noised_elements_list = []
        for i in range(element_seq.shape[0]):
            input_noised_element, pre_elements, next_elements = \
                element_seq[i].reshape(1, -1).cuda(), \
                self.pre_pad(element_seq, i,), \
                self.next_pad(element_seq, i,)
            input_noised_elements = torch.cat(
                (pre_elements, input_noised_element, next_elements),
                dim=0
            )
            input_noised_elements_list.append(input_noised_elements.unsqueeze(0))
        input_noised_elements = torch.cat(
            input_noised_elements_list,
            dim=0
        )
        return input_noised_elements

    def cal_ther_for_clean(self, denoiser, detector, topk, early_stop=None):
        pred_noise_list, attack_list = [], []

        with torch.no_grad():
            for k, (state_seq, action_seq, reward_seq, t_state_seq, t_action_seq, t_reward_seq, attack_id) in enumerate(zip(self.noised_state, self.noised_action, self.noised_reward, self.state, self.action, self.reward, self.attacked_idx)):
                input_noised_list = torch.cat((state_seq, action_seq, reward_seq), dim=-1)

                input_noised_list = self.prepare(input_noised_list)

                # For filting, timesteps is always 1
                noise_p = detector.predictor.model(
                    input_noised_list,
                    # Step = 1
                    torch.ones((input_noised_list.shape[0],), device=self.device).long(),
                )

                # ||predicted noised||
                pred_noise = (noise_p ** 2).mean(-1)[:, self.config["condition_length"]]
                pred_noise_list.append(pred_noise)

                attack_list.append(attack_id)

                if early_stop != None and k >= early_stop:
                    break
        
        threshold = torch.quantile(torch.cat(pred_noise_list, 0), topk)

        return threshold, pred_noise_list

    def get_dataset(self, denoiser, detector, dataset_name="test", topk=0.7, early_stop=None):

        threshold, pred_noise_list = self.cal_ther_for_clean(denoiser, detector, topk, early_stop)

        print(f"{topk * 100}% data will be detected as clean")

        def dummy_dict():
            return {
                "s": [],
                "a": [],
                "ns": [],
                "r": [],
                "t": [],
                "w": [],
            }

        # denoised_data, true_data, noised_data, filted_data, filted_true_data = dummy_dict(), dummy_dict(), dummy_dict(), dummy_dict(), dummy_dict()

        denoised_data, noised_data = dummy_dict(), dummy_dict()

        noised_states_stack, true_states_stack = [], []
        noised_actions_stack, true_actions_stack = [], []
        noised_rewards_stack, true_rewards_stack = [], []
        noised_transition_stack, true_transition_stack = [], []

        d_clean_ind_stack, d_noised_ind_stack, d_noised_len_stack = [], [], []
        
        cur_len, cur_noised = 0, 0        

        def lzy_mse(a, b):
            return torch.mean((a-b)**2).cpu().numpy()

        for k, (noised_state, noised_action, noised_reward, true_state, true_action, true_reward, attack_idx) in enumerate(zip(self.noised_state, self.noised_action, self.noised_reward, self.state, self.action, self.reward, self.attacked_idx)):

            for dict_data in [denoised_data, noised_data]:
                # s, a, r will not be included
                dict_data["ns"].append(self.next_state[k])
                # Force Terminal
                self.terminal[k][-1] = 1
                dict_data["t"].append(self.terminal[k])
                dict_data["w"].append(pred_noise_list[k])

            cloned_noised_state,  cloned_true_state  = noised_state.clone(),  true_state.clone()
            cloned_noised_action, cloned_true_action = noised_action.clone(), true_action.clone()
            cloned_noised_reward, cloned_true_reward = noised_reward.clone(), true_reward.clone()

            detected_noised_indexes = torch.where(pred_noise_list[k] >  threshold)[0]
            detected_clean_indexes  = torch.where(pred_noise_list[k] <= threshold)[0]

            the_attack_idx = torch.where(attack_idx==1)[0].cuda()

            # Clean the noised indexes
            detected_noised_indexes = detected_noised_indexes[torch.isin(detected_noised_indexes, the_attack_idx)]
            
            noised_states_stack.append(cloned_noised_state)
            true_states_stack.append(cloned_true_state)

            noised_actions_stack.append(cloned_noised_action)
            true_actions_stack.append(cloned_true_action)

            noised_rewards_stack.append(cloned_noised_reward)
            true_rewards_stack.append(cloned_true_reward)

            # Noised dataset
            noised_data["s"].append(cloned_noised_state.clone())
            noised_data["a"].append(cloned_noised_action.clone())
            noised_data["r"].append(cloned_noised_reward.clone())

            noised_transition_stack.append(
                torch.cat(
                    (
                        cloned_noised_state, cloned_noised_action, cloned_noised_reward
                    ),
                    dim=-1
                )
            )

            true_transition_stack.append(
                torch.cat(
                    (
                        cloned_true_state,
                        cloned_true_action,
                        cloned_true_reward.reshape(-1, 1)
                    ),
                    dim=-1
                )
            )

            d_clean_ind_stack.append(detected_clean_indexes)
            d_noised_ind_stack.append(detected_noised_indexes)
            d_noised_len_stack.append(detected_noised_indexes.shape[0])

            cur_len += cloned_noised_state.shape[0]
            cur_noised += detected_noised_indexes.shape[0]

            if not (cur_len + cloned_noised_state.shape[0] < self.config["stack_length"] and k != len(self.state) - 1):

                original_mse = lzy_mse(
                    torch.cat(noised_transition_stack, dim=0), torch.cat(true_transition_stack, dim=0)
                )

                time0 = time.perf_counter()

                if cur_noised > 0:
                    for _ in range(self.config["detect_denoise_loops"]):
                        input_noised_element_list = []
                        transition_conditions_list = []
                        for sub_s_seq, sub_a_seq, sub_r_seq, sub_t_s_seq, sub_t_a_seq, sub_t_r_seq, sub_transition_seq, sub_t_transition_seq, noised_ind in zip(noised_states_stack, noised_actions_stack, noised_rewards_stack, true_states_stack, true_actions_stack, true_rewards_stack, noised_transition_stack, true_transition_stack, d_noised_ind_stack):
                            input_noised_element_list.append(
                                self.prepare(sub_transition_seq)[noised_ind]
                            )
                        input_noised_element_list = torch.cat(input_noised_element_list, dim=0)

                        denoised_elements = denoiser.denoise_element(
                            input_noised_element_list,
                            self.config["detect_denoise_steps"]
                        )

                        denoised_elements = torch.split(denoised_elements, d_noised_len_stack)

                        for sub_s_seq, sub_a_seq, sub_r_seq, sub_t_seq, de_s, noised_ind in zip(noised_states_stack, noised_actions_stack, noised_rewards_stack, true_states_stack, denoised_elements, d_noised_ind_stack):
                            sub_s_seq[noised_ind] = de_s[:, self.config["condition_length"]][:,:self.obs_dim].clone()
                            sub_a_seq[noised_ind] = de_s[:, self.config["condition_length"]][:,self.obs_dim:self.obs_dim+self.act_dim].clone()
                            sub_r_seq[noised_ind] = de_s[:, self.config["condition_length"]][:,self.obs_dim+self.act_dim:].clone()

                noised_states_stack = torch.cat(noised_states_stack, dim=0)
                true_states_stack = torch.cat(true_states_stack, dim=0)

                noised_actions_stack = torch.cat(noised_actions_stack, dim=0)
                true_actions_stack = torch.cat(true_actions_stack, dim=0)

                noised_rewards_stack = torch.cat(noised_rewards_stack, dim=0)
                true_rewards_stack = torch.cat(true_rewards_stack, dim=0)

                noised_transition_stack = torch.cat(
                    (
                        noised_states_stack,
                        noised_actions_stack,
                        noised_rewards_stack,
                    ),
                    dim=-1
                )

                denoised_data["s"].append(noised_states_stack)
                denoised_data["a"].append(noised_actions_stack)
                denoised_data["r"].append(noised_rewards_stack)

                fixed_mse = lzy_mse(
                    noised_transition_stack, 
                    torch.cat(true_transition_stack, dim=0)
                )

                print(f"mse from {original_mse} -> {fixed_mse}, time used: {time.perf_counter() - time0}")

                noised_actions_stack, true_actions_stack = [], []
                noised_states_stack, true_states_stack = [], []
                noised_rewards_stack, true_rewards_stack = [], []

                d_clean_ind_stack, d_noised_ind_stack, d_noised_len_stack = [], [], []
                cur_len, cur_noised = 0, 0

                noised_transition_stack, true_transition_stack = [], []

        for dict_data in [denoised_data, noised_data]:

            dict_data["s"] = torch.cat(dict_data["s"], 0)
            dict_data["a"] = torch.cat(dict_data["a"], 0)
            dict_data["ns"] = torch.cat(dict_data["ns"], 0)
            dict_data["r"] = torch.cat(dict_data["r"], 0)
            dict_data["t"] = torch.cat(dict_data["t"], 0)
            dict_data["w"] = torch.cat(dict_data["w"], 0)

            print(dict_data["s"].shape[0], dict_data["a"].shape[0], dict_data["w"].shape[0])

            assert (dict_data["s"].shape[0] == dict_data["a"].shape[0]) and (dict_data["s"].shape[0] == dict_data["w"].shape[0])

        def save_data(data, filename):
            """Helper function to save data to a pickle file."""
            data_to_save = {
                "observations": data["s"].cpu().numpy(),
                "actions": data["a"].cpu().numpy(),
                "next_observations": data["ns"].cpu().numpy(),
                "rewards": data["r"].cpu().numpy(),
                "terminals": data["t"].cpu().numpy(),
                "weights": data["w"].cpu().numpy()
            }
            with open(filename, "wb") as f:
                pickle.dump(data_to_save, f)

        # Save different datasets
        save_data(denoised_data, f"{dataset_name}_denoised_data.pkl")
        save_data(noised_data, f"{dataset_name}_noised_data.pkl")

    def simple_detect(self, denoiser, detector, render=None, image_name=None):

        threshold, pred_noise_list = self.cal_ther_for_clean(denoiser, detector, 0.7, 0)

        states, actions, next_states, rewards, terminals, timesteps, rtgs = [], [], [], [], [], [], []
        for k, (noised_state, noised_action, noised_reward, true_state, true_action, true_reward, attack_ids) in enumerate(zip(self.noised_state, self.noised_action, self.noised_reward, self.state, self.action, self.reward, self.attacked_idx)):
            
            print(f"Denoising the {k} th traj.")

            cloned_noised_state = noised_state.clone()
            cloned_noised_action = noised_action.clone()
            cloned_noised_reward = noised_reward.clone()

            input_noised_elements = torch.cat(
                (
                    cloned_noised_state,
                    cloned_noised_action,
                    cloned_noised_reward,
                ),
                dim=-1
            )

            true_elements = torch.cat(
                (
                    true_state,
                    true_action,
                    true_reward.reshape(-1, 1),
                ),
                dim=-1
            )

            original_mse = torch.mean((input_noised_elements - true_elements)**2)

            print(f"Original MSE: {original_mse}")

            detected_noised_indexes = torch.where(pred_noise_list[k] >  threshold)[0]
            detected_clean_indexes  = torch.where(pred_noise_list[k] <= threshold)[0]

            the_attack_idx = torch.where(attack_ids==1)[0].cuda()
            detected_noised_indexes = detected_noised_indexes[torch.isin(detected_noised_indexes, the_attack_idx)]

            for loop_cnt in range(self.config["detect_denoise_loops"]):

                input_noised_elements = torch.cat(
                    (
                        cloned_noised_state,
                        cloned_noised_action,
                        cloned_noised_reward,
                    ),
                    dim=-1
                )

                input_noised_elements_list = self.prepare(input_noised_elements)
                
                input_elements = input_noised_elements_list[detected_noised_indexes]

                denoised_elements = denoiser.denoise_element(
                    input_elements,
                    self.config["detect_denoise_steps"]
                )
                input_noised_elements[detected_noised_indexes] = denoised_elements[:, self.config["condition_length"]]

                cloned_noised_state[detected_noised_indexes] = input_noised_elements[detected_noised_indexes][:,:self.obs_dim]
                cloned_noised_action[detected_noised_indexes] = input_noised_elements[detected_noised_indexes][:,self.obs_dim:self.obs_dim+self.act_dim]
                cloned_noised_reward[detected_noised_indexes] = input_noised_elements[detected_noised_indexes][:,self.obs_dim+self.act_dim:]

                mse = torch.mean((input_noised_elements - true_elements)**2)

                print(f"Fixed MSE in loop: {mse}")

            attacked_ids = torch.where(attack_ids==1)[0]

            intersection = torch.tensor([val for val in attacked_ids if val in detected_clean_indexes])

            print(f"{intersection.shape[0]} of {detected_clean_indexes.shape[0]} elements was wrong")

            if k >=0:
                break

    def sample(self, batch_size, max_len=20, data_usage=0.9, eval=False, real_timestep=True, pad_front=False, pad_front_len=1, pad_tail=False):

        sample_idx = self.train_idx

        p_sample = (self.traj_lens[sample_idx] / sum(self.traj_lens[sample_idx])).cpu().numpy()
        batch_idx = np.random.choice(
            sample_idx,
            size=batch_size,
            replace=True,
            p=p_sample,
        )

        s, sn, a, an, r, rn, d, rtg, timesteps, mask = [], [], [], [], [], [], [], [], [], []
        for i in range(batch_size):
            idx = batch_idx[i]
            idx_start = 0
            idx_end = self.reward[idx].shape[0] - 1
            si = random.randint(idx_start, idx_end)

            ll, rr = max(si - max_len, 0), si + max_len + 1

            s.append(self.state[idx][ll:rr].reshape(1, -1, self.obs_dim))
            a.append(self.action[idx][ll:rr].reshape(1, -1, self.act_dim))
            r.append(self.reward[idx][ll:rr].reshape(1, -1, 1))
            d.append(self.terminal[idx][ll:rr].reshape(1, -1))
            timesteps.append(self.timestep[idx][ll:rr].reshape(1, -1))
            if not real_timestep:
                timesteps[-1] = timesteps[-1] - si
            rtg.append(self.rtg[idx][ll:rr].reshape(1, -1, 1))

            sn.append(self.noised_state[idx][ll:rr].reshape(1, -1, self.obs_dim))
            an.append(self.noised_action[idx][ll:rr].reshape(1, -1, self.act_dim))
            rn.append(self.noised_reward[idx][ll:rr].reshape(1, -1, 1))

            front_pad_n = max_len - si
            if front_pad_n > 0:

                s[-1] = torch.cat((torch.zeros(1, front_pad_n, self.obs_dim).to(self.device), s[-1]), 1)
                if self.normalization:
                    s[-1] = (s[-1] - self.state_mean) / self.state_std
                a[-1] = torch.cat((torch.zeros(1, front_pad_n, self.act_dim).to(self.device), a[-1]), 1)
                r[-1] = torch.cat((torch.zeros(1, front_pad_n, 1).to(self.device), r[-1]), 1)
                d[-1] = torch.cat((torch.ones(1, front_pad_n).to(self.device), d[-1]), 1)
                rtg[-1] = torch.cat((torch.zeros(1, front_pad_n, 1).to(self.device), rtg[-1]), 1) / self.scale
                timesteps[-1] = torch.cat((torch.zeros(1, front_pad_n,).to(self.device), timesteps[-1]), 1)

                sn[-1] = torch.cat((torch.zeros(1, front_pad_n, self.obs_dim).to(self.device), sn[-1]), 1)
                an[-1] = torch.cat((torch.zeros(1, front_pad_n, self.act_dim).to(self.device), an[-1]), 1)
                rn[-1] = torch.cat((torch.zeros(1, front_pad_n, 1).to(self.device), rn[-1]), 1)

            tail_pad_n = si + max_len - (self.reward[idx].shape[0] - 1)
            if tail_pad_n > 0:
                s[-1] = torch.cat((s[-1], torch.zeros(1, tail_pad_n, self.obs_dim).to(self.device)), 1)
                if self.normalization:
                    s[-1] = (s[-1] - self.state_mean) / self.state_std
                a[-1] = torch.cat((a[-1], torch.zeros(1, tail_pad_n, self.act_dim).to(self.device)), 1)
                r[-1] = torch.cat((r[-1], torch.zeros(1, tail_pad_n, 1).to(self.device)), 1)
                d[-1] = torch.cat((d[-1], torch.ones(1, tail_pad_n).to(self.device)), 1)
                rtg[-1] = torch.cat((rtg[-1], torch.zeros(1, tail_pad_n, 1).to(self.device)), 1) / self.scale
                timesteps[-1] = torch.cat((timesteps[-1], torch.zeros(1, tail_pad_n,).to(self.device)), 1)

                sn[-1] = torch.cat((sn[-1], torch.zeros(1, tail_pad_n, self.obs_dim).to(self.device)), 1)
                an[-1] = torch.cat((an[-1], torch.zeros(1, tail_pad_n, self.act_dim).to(self.device)), 1)
                rn[-1] = torch.cat((rn[-1], torch.zeros(1, tail_pad_n, 1).to(self.device)), 1)

            mask.append(torch.ones(1, 2 * max_len + 1).to(self.device))
        
        s = torch.cat(s, 0)
        a = torch.cat(a, 0)
        r = torch.cat(r, 0)
        d = torch.cat(d, 0).to(dtype=torch.long)
        rtg = torch.cat(rtg, 0)
        timesteps = torch.cat(timesteps, 0).to(dtype=torch.long)

        sn = torch.cat(sn, 0)
        an = torch.cat(an, 0)
        rn = torch.cat(rn, 0)

        mask = torch.cat(mask, 0)

        return s, sn, a, an, r, rn, d, rtg, timesteps, mask
# %%

