from typing import Dict
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, reduce
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusers.schedulers.scheduling_ddim import DDIMScheduler

from diffusion_policy.model.common.normalizer import LinearNormalizer
from diffusion_policy.policy.base_image_policy import BaseImagePolicy
from diffusion_policy.model.diffusion.conditional_unet1d import ConditionalUnet1D
from diffusion_policy.model.diffusion.mask_generator import LowdimMaskGenerator
from diffusion_policy.common.robomimic_config_util import get_robomimic_config
from diffusion_policy.model.common.rotation_transformer import RotationTransformer

from robomimic.algo import algo_factory
from robomimic.algo.algo import PolicyAlgo
import robomimic.utils.obs_utils as ObsUtils
import robomimic.models.base_nets as rmbn
import diffusion_policy.model.vision.crop_randomizer as dmvc
from diffusion_policy.common.pytorch_util import dict_apply, replace_submodules
from time import perf_counter
import pytorch3d.transforms as pt
import numpy as np


class DiffusionUnetHybridImagePolicy(BaseImagePolicy):
    def __init__(self, 
            shape_meta: dict,
            noise_scheduler: DDPMScheduler,
            horizon, 
            n_action_steps, 
            n_obs_steps,
            num_inference_steps=None,
            obs_as_global_cond=True,
            crop_shape=(76, 76),
            diffusion_step_embed_dim=256,
            down_dims=(256,512,1024),
            kernel_size=5,
            n_groups=8,
            cond_predict_scale=True,
            obs_encoder_group_norm=False,
            eval_fixed_crop=False,
            # parameters passed to step
            **kwargs):
        super().__init__()

        # parse shape_meta
        action_shape = shape_meta['action']['shape']
        assert len(action_shape) == 1
        action_dim = action_shape[0]
        obs_shape_meta = shape_meta['obs']
        obs_config = {
            'low_dim': [],
            'rgb': [],
            'depth': [],
            'scan': []
        }
        obs_key_shapes = dict()
        for key, attr in obs_shape_meta.items():
            shape = attr['shape']
            obs_key_shapes[key] = list(shape)

            type = attr.get('type', 'low_dim')
            if type == 'rgb':
                obs_config['rgb'].append(key)
            elif type == 'low_dim':
                obs_config['low_dim'].append(key)
            else:
                raise RuntimeError(f"Unsupported obs type: {type}")

        # get raw robomimic config
        config = get_robomimic_config(
            algo_name='bc_rnn',
            hdf5_type='image',
            task_name='square',
            dataset_type='ph')
        
        with config.unlocked():
            # set config with shape_meta
            config.observation.modalities.obs = obs_config

            if crop_shape is None:
                for key, modality in config.observation.encoder.items():
                    if modality.obs_randomizer_class == 'CropRandomizer':
                        modality['obs_randomizer_class'] = None
            else:
                # set random crop parameter
                ch, cw = crop_shape
                for key, modality in config.observation.encoder.items():
                    if modality.obs_randomizer_class == 'CropRandomizer':
                        modality.obs_randomizer_kwargs.crop_height = ch
                        modality.obs_randomizer_kwargs.crop_width = cw

        # init global state
        ObsUtils.initialize_obs_utils_with_config(config)

        # load model
        policy: PolicyAlgo = algo_factory(
                algo_name=config.algo_name,
                config=config,
                obs_key_shapes=obs_key_shapes,
                ac_dim=action_dim,
                device='cpu',
            )

        obs_encoder = policy.nets['policy'].nets['encoder'].nets['obs']
        
        if obs_encoder_group_norm:
            # replace batch norm with group norm
            replace_submodules(
                root_module=obs_encoder,
                predicate=lambda x: isinstance(x, nn.BatchNorm2d),
                func=lambda x: nn.GroupNorm(
                    num_groups=x.num_features//16, 
                    num_channels=x.num_features)
            )
            # obs_encoder.obs_nets['sideview_image'].nets[0].nets
        
        # obs_encoder.obs_randomizers['sideview_image']
        if eval_fixed_crop:
            replace_submodules(
                root_module=obs_encoder,
                predicate=lambda x: isinstance(x, rmbn.CropRandomizer),
                func=lambda x: dmvc.CropRandomizer(
                    input_shape=x.input_shape,
                    crop_height=x.crop_height,
                    crop_width=x.crop_width,
                    num_crops=x.num_crops,
                    pos_enc=x.pos_enc
                )
            )

        # create diffusion model
        obs_feature_dim = obs_encoder.output_shape()[0]
        input_dim = action_dim + obs_feature_dim
        global_cond_dim = None
        if obs_as_global_cond:
            input_dim = action_dim
            global_cond_dim = obs_feature_dim * n_obs_steps

        model = ConditionalUnet1D(
            input_dim=input_dim,
            local_cond_dim=None,
            global_cond_dim=global_cond_dim,
            diffusion_step_embed_dim=diffusion_step_embed_dim,
            down_dims=down_dims,
            kernel_size=kernel_size,
            n_groups=n_groups,
            cond_predict_scale=cond_predict_scale
        )

        self.obs_encoder = obs_encoder
        self.model = model
        self.noise_scheduler = noise_scheduler
        self.mask_generator = LowdimMaskGenerator(
            action_dim=action_dim,
            obs_dim=0 if obs_as_global_cond else obs_feature_dim,
            max_n_obs_steps=n_obs_steps,
            fix_obs_steps=True,
            action_visible=False
        )
        self.normalizer = LinearNormalizer()
        self.horizon = horizon
        self.obs_feature_dim = obs_feature_dim
        self.action_dim = action_dim
        self.n_action_steps = n_action_steps
        self.n_obs_steps = n_obs_steps
        self.obs_as_global_cond = obs_as_global_cond
        self.kwargs = kwargs

        if num_inference_steps is None:
            num_inference_steps = noise_scheduler.config.num_train_timesteps
        self.num_inference_steps = num_inference_steps

        print("Diffusion params: %e" % sum(p.numel() for p in self.model.parameters()))
        print("Vision params: %e" % sum(p.numel() for p in self.obs_encoder.parameters()))
    
    # ========= inference  ============
    def conditional_sample(self, 
            condition_data, condition_mask,
            local_cond=None, global_cond=None,
            generator=None,
            # keyword arguments to scheduler.step
            **kwargs
            ):
        model = self.model
        scheduler = self.noise_scheduler

        trajectory = torch.randn(
            size=condition_data.shape, 
            dtype=condition_data.dtype,
            device=condition_data.device,
            generator=generator)
    
        # set step values
        scheduler.set_timesteps(self.num_inference_steps)

        for t in scheduler.timesteps:
            # 1. apply conditioning
            trajectory[condition_mask] = condition_data[condition_mask]

            # 2. predict model output
            model_output = model(trajectory, t, 
                local_cond=local_cond, global_cond=global_cond)

            # 3. compute previous image: x_t -> x_t-1
            trajectory = scheduler.step(
                model_output, t, trajectory, 
                generator=generator,
                **kwargs
                ).prev_sample
        
        # finally make sure conditioning is enforced
        trajectory[condition_mask] = condition_data[condition_mask]        

        return trajectory
    
    def conditional_sample_with_scheduler(self, 
            condition_data, condition_mask, scheduler: DDIMScheduler, steps=-1,
            local_cond=None, global_cond=None,
            generator=None, 
            # keyword arguments to scheduler.step
            **kwargs
            ):
        model = self.model

        trajectory = torch.randn(
            size=condition_data.shape, 
            dtype=condition_data.dtype,
            device=condition_data.device,
            generator=generator)
    
        # set step values
        if steps == -1:
            steps = self.num_inference_steps
        scheduler.set_timesteps(steps)

        for t in scheduler.timesteps:
            # 1. apply conditioning
            trajectory[condition_mask] = condition_data[condition_mask]

            # 2. predict model output
            model_output = model(trajectory, t, 
                local_cond=local_cond, global_cond=global_cond)

            # 3. compute previous image: x_t -> x_t-1
            trajectory = scheduler.step(
                model_output, t, trajectory, 
                generator=generator,
                **kwargs
                ).prev_sample
        
        # finally make sure conditioning is enforced
        trajectory[condition_mask] = condition_data[condition_mask]        

        return trajectory


    def predict_action(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        obs_dict: must include "obs" key
        result: must include "action" key
        """
        assert 'past_action' not in obs_dict # not implemented yet
        # normalize input
        nobs = self.normalizer.normalize(obs_dict)
        value = next(iter(nobs.values()))
        B, To = value.shape[:2]
        T = self.horizon
        Da = self.action_dim
        Do = self.obs_feature_dim
        To = self.n_obs_steps

        # build input
        device = self.device
        dtype = self.dtype

        # handle different ways of passing observation
        local_cond = None
        global_cond = None
        if self.obs_as_global_cond:
            # condition through global feature
            this_nobs = dict_apply(nobs, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:]))
            nobs_features = self.obs_encoder(this_nobs)
            # reshape back to B, Do
            global_cond = nobs_features.reshape(B, -1)
            # empty data for action
            cond_data = torch.zeros(size=(B, T, Da), device=device, dtype=dtype)
            cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
        else:
            # condition through impainting
            this_nobs = dict_apply(nobs, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:]))
            nobs_features = self.obs_encoder(this_nobs)
            # reshape back to B, To, Do
            nobs_features = nobs_features.reshape(B, To, -1)
            cond_data = torch.zeros(size=(B, T, Da+Do), device=device, dtype=dtype)
            cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
            cond_data[:,:To,Da:] = nobs_features
            cond_mask[:,:To,Da:] = True

        # run sampling
        nsample = self.conditional_sample(
            cond_data, 
            cond_mask,
            local_cond=local_cond,
            global_cond=global_cond,
            **self.kwargs)
        
        # unnormalize prediction
        naction_pred = nsample[...,:Da]
        action_pred = self.normalizer['action'].unnormalize(naction_pred)

        # get action
        start = To - 1
        end = start + self.n_action_steps
        action = action_pred[:,start:end]
        
        result = {
            'action': action,
            'action_pred': action_pred
        }
        return result

    def predict_action_with_scheduler(self, obs_dict: Dict[str, torch.Tensor], scheduler: DDIMScheduler, steps=-1) -> Dict[str, torch.Tensor]:
        """
        obs_dict: must include "obs" key
        scheduler: the scheduler to use
        result: must include "action" key
        """
        assert 'past_action' not in obs_dict # not implemented yet
        # normalize input
        nobs = self.normalizer.normalize(obs_dict)
        value = next(iter(nobs.values()))
        B, To = value.shape[:2]
        T = self.horizon
        Da = self.action_dim
        Do = self.obs_feature_dim
        To = self.n_obs_steps

        # build input
        device = self.device
        dtype = self.dtype

        # handle different ways of passing observation
        local_cond = None
        global_cond = None
        if self.obs_as_global_cond:
            # condition through global feature
            this_nobs = dict_apply(nobs, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:]))
            nobs_features = self.obs_encoder(this_nobs)
            # reshape back to B, Do
            global_cond = nobs_features.reshape(B, -1)
            # empty data for action
            cond_data = torch.zeros(size=(B, T, Da), device=device, dtype=dtype)
            cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
        else:
            # condition through impainting
            this_nobs = dict_apply(nobs, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:]))
            nobs_features = self.obs_encoder(this_nobs)
            # reshape back to B, To, Do
            nobs_features = nobs_features.reshape(B, To, -1)
            cond_data = torch.zeros(size=(B, T, Da+Do), device=device, dtype=dtype)
            cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
            cond_data[:,:To,Da:] = nobs_features
            cond_mask[:,:To,Da:] = True

        # run sampling
        nsample = self.conditional_sample_with_scheduler(
            cond_data, 
            cond_mask,
            scheduler=scheduler,
            steps=steps,
            local_cond=local_cond,
            global_cond=global_cond,
            **self.kwargs)
        
        # unnormalize prediction
        naction_pred = nsample[...,:Da]
        action_pred = self.normalizer['action'].unnormalize(naction_pred)

        # get action
        start = To - 1
        end = start + self.n_action_steps
        action = action_pred[:,start:end]
        
        result = {
            'action': action,
            'action_pred': action_pred
        }
        return result
    
    # ========= attacking ============
    def predict_action_attacked(self, obs_dict: Dict[str, torch.Tensor], eps=.03, num_steps=50, alpha=.001875, ntarget: torch.Tensor=None) -> Dict[str, torch.Tensor]:
        """
        obs_dict: must include "obs" key
        result: must include "action" key
        eps: epsilon for the attack
        step: number of steps for the attack
        alpha: step size for the attack
        ntarget: the normalized target action, shape B, T, Da
        """
        start_time = perf_counter()
        assert 'past_action' not in obs_dict # not implemented yet
        # normalize input
        nobs = self.normalizer.normalize(obs_dict)
        value = next(iter(nobs.values()))
        B, To = value.shape[:2]
        T = self.horizon
        Da = self.action_dim
        Do = self.obs_feature_dim
        To = self.n_obs_steps

        # build input
        device = self.device
        dtype = self.dtype

        # handle different ways of passing observation
        local_cond = None
        global_cond = None
        if self.obs_as_global_cond:
            # condition through global feature
            this_nobs = dict_apply(nobs, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:]))
            # the above collapse the first two dimensions into batch
            
            nobs_features = self.obs_encoder(this_nobs)
            # reshape back to B, Do
            global_cond = nobs_features.reshape(B, -1)
            # empty data for action
            cond_data = torch.zeros(size=(B, T, Da), device=device, dtype=dtype)
            cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
        else:
            # condition through impainting
            this_nobs = dict_apply(nobs, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:]))
            nobs_features = self.obs_encoder(this_nobs)
            # reshape back to B, To, Do
            nobs_features = nobs_features.reshape(B, To, -1)
            cond_data = torch.zeros(size=(B, T, Da+Do), device=device, dtype=dtype)
            cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
            cond_data[:,:To,Da:] = nobs_features
            cond_mask[:,:To,Da:] = True



        # TODO: Note all the current Attacks assume that we are using obs_as_global_cond

        if ntarget is None:
            # Non-targeted attack, this method requires a ground truth action
            # Here we assume the GT action is the action that the model would have predicted

            # run sampling
            nsample = self.conditional_sample(
                cond_data, 
                cond_mask,
                local_cond=local_cond,
                global_cond=global_cond,
                **self.kwargs)
            
            # unnormalize prediction
            naction_pred = nsample[...,:Da]
            action_pred = self.normalizer['action'].unnormalize(naction_pred)

            # get action
            start = To - 1
            end = start + self.n_action_steps
            action = action_pred[:,start:end]
            
            # store the gt action for reference
            result = {
                'action_gt': action,
                'action_pred_gt': action_pred
            }
            naction_pred = naction_pred.detach().clone()
        else:
            result = {
                'action_target': self.normalizer['action'].unnormalize(ntarget)
            }
            ntarget = ntarget.detach().clone()

        # figure out the attack observation keys
        attack_images_keys = []
        for key in obs_dict.keys():
            if 'image' in key:
                attack_images_keys.append(key)


        # conduct the attack

        # clone the observation
        temp_obs_dict = dict()
        for key, value in obs_dict.items():
            temp_obs_dict[key] = value.detach().clone()

        loss_list = []

        with torch.enable_grad():
            for _ in range(num_steps):
                obs_clone = dict()
                for key, value in temp_obs_dict.items():
                    obs_clone[key] = value.detach().clone()
                for key in attack_images_keys:
                    # make the image requires grad
                    obs_clone[key].requires_grad = True
                # # images are: 'sideview_image', 'robot0_eye_in_hand_image'
                # nobs_clone['sideview_image'].requires_grad = True
                # nobs_clone['robot0_eye_in_hand_image'].requires_grad = True

                # get image features
                nobs_clone = self.normalizer.normalize(obs_clone)
                this_nobs =  dict_apply(nobs_clone, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:]))
                nobs_features = self.obs_encoder(this_nobs)
                # reshape back to B, Do
                global_cond = nobs_features.reshape(B, -1)

                # forward sample
                # with torch.no_grad():
                noise = torch.randn(B, T, Da, device=device)
                timesteps = torch.randint(
                    0, self.noise_scheduler.config.num_train_timesteps, 
                    (B,), device=device
                ).long()

                if ntarget is None:
                    # Non-targeted attack
                    noisy_trajectory = self.noise_scheduler.add_noise(
                        naction_pred, noise, timesteps)
                else:
                    # Targeted attack
                    noisy_trajectory = self.noise_scheduler.add_noise(
                        ntarget, noise, timesteps)
                    
                # predict noise
                noise_pred = self.model(noisy_trajectory, timesteps, local_cond=local_cond, global_cond=global_cond)

                # l2 loss
                if ntarget is None:
                    # non-targeted attack is to maximize the loss between new noise and the GT noise
                    loss = -F.mse_loss(noise_pred, noise)
                else:
                    # targeted attack is to minimize the loss between new noise and the target noise
                    loss = F.mse_loss(noise_pred, noise)
                # loss_list.append(loss.cpu().item())
                # print(loss.shape)
                # B, Ta, (Da=10)
                # add more weight to x,y,z (first 3 dimensions)
                # loss = -F.mse_loss(noise_pred, noise, reduction='none')
                # loss = (loss * torch.tensor([1, 1, 1, 0, 0, 0, 0, 0, 0, 0], device=device, dtype=dtype)).mean()
                loss.backward()

                for keys in attack_images_keys:
                    # Get the gradient
                    new_gradient = obs_clone[keys].grad.data.sign()
                    # update the image
                    obs_clone[keys] = obs_clone[keys] - alpha * new_gradient
                    # clip the image within the epsilon ball
                    obs_clone[keys] = torch.max(torch.min(obs_clone[keys], obs_dict[keys] + eps), obs_dict[keys] - eps)
                # new_gradient = nobs_clone['sideview_image'].grad.data.sign()
                # # update the image
                # nobs_clone['sideview_image'] = nobs_clone['sideview_image'] - alpha * new_gradient
                # # clip the image within the epsilon ball
                # nobs_clone['sideview_image'] = torch.max(torch.min(nobs_clone['sideview_image'], nobs['sideview_image'] + eps), nobs['sideview_image'] - eps)
                temp_obs_dict = obs_clone

        #         with torch.no_grad():
        #             # find the expectation of the loss
        #             eval_loss = 0
        #             for i in range(self.num_inference_steps):
        #                 timesteps = torch.ones(B, device=device, dtype=torch.long) * i
        #                 noise = torch.randn(nsample.shape, device=device)
        #                 noisy_trajectory = self.noise_scheduler.add_noise(
        #                     naction_pred, noise, timesteps)
                        
        #                 # get image features
        #                 nobs_clone = self.normalizer.normalize(obs_clone)
        #                 this_nobs =  dict_apply(nobs_clone, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:]))
        #                 nobs_features = self.obs_encoder(this_nobs)
        #                 # reshape back to B, Do
        #                 global_cond = nobs_features.reshape(B, -1)

        #                 # predict noise
        #                 noise_pred = self.model(noisy_trajectory, timesteps, local_cond=local_cond, global_cond=global_cond)
                        
        #                 # find the loss
        #                 loss = -F.mse_loss(noise_pred, noise)
        #                 eval_loss += loss
        #             eval_loss /= self.num_inference_steps
        #             loss_list.append(eval_loss.cpu().item())
        #             # print('eval_loss:', eval_loss.cpu().item())
                    
        # print('loss_list:', loss_list)

        # clip the image within 0, 1
        for key in attack_images_keys:
            temp_obs_dict[key] = torch.clamp(temp_obs_dict[key], 0, 1)
        # nobs['sideview_image'] = torch.clamp(nobs['sideview_image'], -1, 1)
        # nobs['robot0_eye_in_hand_image'] = torch.clamp(nobs['robot0_eye_in_hand_image'], -1, 1)

        result['attacked obs'] = temp_obs_dict
        end_time = perf_counter()
        used_time = end_time - start_time
        result['attack_time'] = used_time
        # recompute action
        result.update(self.predict_action(temp_obs_dict))

        return result
    
    def predict_action_full_chain_attacked(self, obs_dict: Dict[str, torch.Tensor], eps=.03, num_steps=50, alpha=.001875, scheduler: DDIMScheduler=None, ntarget: torch.Tensor=None, scheduler_steps=-1) -> Dict[str, torch.Tensor]:
        """
        obs_dict: must include "obs" key
        result: must include "action" key
        eps: epsilon for the attack
        step: number of steps for the attack
        alpha: step size for the attack
        scheduler: the scheduler to use
        ntarget: the normalized target action, shape B, T, Da
        """
        start_time = perf_counter()
        # TODO: Note all the current Attacks assume that we are using obs_as_global_cond

        if ntarget is None:
            action_gt_result = self.predict_action(obs_dict)
            
            # store the gt action for reference
            result = {
                'action_gt': action_gt_result['action'],
                'action_pred_gt': action_gt_result['action_pred']
            }
        else:
            result = {
                'action_target': self.normalizer['action'].unnormalize(ntarget)
            }
            ntarget = ntarget.detach().clone()

        # figure out the attack observation keys
        attack_images_keys = []
        for key in obs_dict.keys():
            if 'image' in key:
                attack_images_keys.append(key)


        # conduct the attack

        # clone the observation
        temp_obs_dict = dict()
        for key, value in obs_dict.items():
            temp_obs_dict[key] = value.detach().clone()

        loss_list = []

        with torch.enable_grad():
            for _ in range(num_steps):
                obs_clone = dict()
                for key, value in temp_obs_dict.items():
                    obs_clone[key] = value.detach().clone()
                for key in attack_images_keys:
                    # make the image requires grad
                    obs_clone[key].requires_grad = True
                # # images are: 'sideview_image', 'robot0_eye_in_hand_image'
                # nobs_clone['sideview_image'].requires_grad = True
                # nobs_clone['robot0_eye_in_hand_image'].requires_grad = True

                # full-chain attack
                if scheduler is None:
                    scheduler = self.noise_scheduler # use default scheduler for attack
                    scheduler_steps = self.num_inference_steps

                temp_action_pred = self.predict_action_with_scheduler(obs_clone, scheduler=scheduler, steps=scheduler_steps)
                if ntarget is None:
                    # Non-targeted attack
                    loss = -F.mse_loss(result['action_gt'], temp_action_pred['action'])
                else:
                    # Targeted attack
                    loss = F.mse_loss(result['action_target'], temp_action_pred['action_pred'])
                loss.backward()

                for keys in attack_images_keys:
                    # Get the gradient
                    new_gradient = obs_clone[keys].grad.data.sign()
                    # update the image
                    obs_clone[keys] = obs_clone[keys] - alpha * new_gradient
                    # clip the image within the epsilon ball
                    obs_clone[keys] = torch.max(torch.min(obs_clone[keys], obs_dict[keys] + eps), obs_dict[keys] - eps)
                # new_gradient = nobs_clone['sideview_image'].grad.data.sign()
                # # update the image
                # nobs_clone['sideview_image'] = nobs_clone['sideview_image'] - alpha * new_gradient
                # # clip the image within the epsilon ball
                # nobs_clone['sideview_image'] = torch.max(torch.min(nobs_clone['sideview_image'], nobs['sideview_image'] + eps), nobs['sideview_image'] - eps)
                temp_obs_dict = obs_clone
                loss_list.append(loss.cpu().item())
            print('Loss history:', loss_list)

        # clip the image within 0, 1
        for key in attack_images_keys:
            temp_obs_dict[key] = torch.clamp(temp_obs_dict[key], 0, 1)
        # nobs['sideview_image'] = torch.clamp(nobs['sideview_image'], -1, 1)
        # nobs['robot0_eye_in_hand_image'] = torch.clamp(nobs['robot0_eye_in_hand_image'], -1, 1)

        result['attacked obs'] = temp_obs_dict
        end_time = perf_counter()
        used_time = end_time - start_time
        result['attack_time'] = used_time
        # recompute action
        result.update(self.predict_action(temp_obs_dict))

        return result

    def get_encoded_obs(self, obs_dict: Dict[str, torch.Tensor]) -> torch.Tensor:
        """
        obs_dict: must include "obs" key
        """
        assert 'past_action' not in obs_dict # not implemented yet
        # first clone the obs
        obs_dict = dict_apply(obs_dict, lambda x: x.detach().clone())
        # normalize input
        nobs = self.normalizer.normalize(obs_dict)
        value = next(iter(nobs.values()))
        B, To = value.shape[:2]
        T = self.horizon
        Da = self.action_dim
        Do = self.obs_feature_dim
        To = self.n_obs_steps

        # build input
        device = self.device
        dtype = self.dtype

        # handle different ways of passing observation
        local_cond = None
        global_cond = None
        if self.obs_as_global_cond:
            # condition through global feature
            this_nobs = dict_apply(nobs, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:]))
            nobs_features = self.obs_encoder(this_nobs)
            # reshape back to B, Do
            global_cond = nobs_features.reshape(B, -1)
            # empty data for action
            cond_data = torch.zeros(size=(B, T, Da), device=device, dtype=dtype)
            cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
        else:
            # condition through impainting
            this_nobs = dict_apply(nobs, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:]))
            nobs_features = self.obs_encoder(this_nobs)
            # reshape back to B, To, Do
            nobs_features = nobs_features.reshape(B, To, -1)
            cond_data = torch.zeros(size=(B, T, Da+Do), device=device, dtype=dtype)
            cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
            cond_data[:,:To,Da:] = nobs_features
            cond_mask[:,:To,Da:] = True

        return global_cond

    def get_attacked_encoded_obs(self, obs_dict: Dict[str, torch.Tensor], eps=.03, num_steps=50, alpha=.001875, ntarget: torch.Tensor=None) -> torch.Tensor:
        """
        obs_dict: must include "obs" key
        eps: epsilon for the attack
        step: number of steps for the attack
        alpha: step size for the attack
        """
        """
        obs_dict: must include "obs" key
        result: must include "action" key
        eps: epsilon for the attack
        step: number of steps for the attack
        alpha: step size for the attack
        ntarget: the normalized target action, shape B, T, Da
        """
        start_time = perf_counter()
        assert 'past_action' not in obs_dict # not implemented yet
        # normalize input
        nobs = self.normalizer.normalize(obs_dict)
        value = next(iter(nobs.values()))
        B, To = value.shape[:2]
        T = self.horizon
        Da = self.action_dim
        Do = self.obs_feature_dim
        To = self.n_obs_steps

        # build input
        device = self.device
        dtype = self.dtype

        # handle different ways of passing observation
        local_cond = None
        global_cond = None
        if self.obs_as_global_cond:
            # condition through global feature
            this_nobs = dict_apply(nobs, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:]))
            # the above collapse the first two dimensions into batch
            
            nobs_features = self.obs_encoder(this_nobs)
            # reshape back to B, Do
            global_cond = nobs_features.reshape(B, -1)
            # empty data for action
            cond_data = torch.zeros(size=(B, T, Da), device=device, dtype=dtype)
            cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
        else:
            # condition through impainting
            this_nobs = dict_apply(nobs, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:]))
            nobs_features = self.obs_encoder(this_nobs)
            # reshape back to B, To, Do
            nobs_features = nobs_features.reshape(B, To, -1)
            cond_data = torch.zeros(size=(B, T, Da+Do), device=device, dtype=dtype)
            cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
            cond_data[:,:To,Da:] = nobs_features
            cond_mask[:,:To,Da:] = True



        # TODO: Note all the current Attacks assume that we are using obs_as_global_cond

        if ntarget is None:
            # Non-targeted attack, this method requires a ground truth action
            # Here we assume the GT action is the action that the model would have predicted

            # run sampling
            nsample = self.conditional_sample(
                cond_data, 
                cond_mask,
                local_cond=local_cond,
                global_cond=global_cond,
                **self.kwargs)
            
            # unnormalize prediction
            naction_pred = nsample[...,:Da]
            action_pred = self.normalizer['action'].unnormalize(naction_pred)

            # get action
            start = To - 1
            end = start + self.n_action_steps
            action = action_pred[:,start:end]
            
            # store the gt action for reference
            result = {
                'action_gt': action,
                'action_pred_gt': action_pred
            }
            naction_pred = naction_pred.detach().clone()
        else:

            if isinstance(ntarget, int):
                ntarget = torch.ones(B, T, Da, device=device, dtype=dtype) * ntarget

            result = {
                'action_target': self.normalizer['action'].unnormalize(ntarget)
            }
            ntarget = ntarget.detach().clone()

        # figure out the attack observation keys
        attack_images_keys = []
        for key in obs_dict.keys():
            if 'image' in key:
                attack_images_keys.append(key)


        # conduct the attack

        # clone the observation
        temp_obs_dict = dict()
        for key, value in obs_dict.items():
            temp_obs_dict[key] = value.detach().clone()

        loss_list = []

        with torch.enable_grad():
            for _ in range(num_steps):
                obs_clone = dict()
                for key, value in temp_obs_dict.items():
                    obs_clone[key] = value.detach().clone()
                for key in attack_images_keys:
                    # make the image requires grad
                    obs_clone[key].requires_grad = True
                # # images are: 'sideview_image', 'robot0_eye_in_hand_image'
                # nobs_clone['sideview_image'].requires_grad = True
                # nobs_clone['robot0_eye_in_hand_image'].requires_grad = True

                # get image features
                nobs_clone = self.normalizer.normalize(obs_clone)
                this_nobs =  dict_apply(nobs_clone, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:]))
                nobs_features = self.obs_encoder(this_nobs)
                # reshape back to B, Do
                global_cond = nobs_features.reshape(B, -1)

                # forward sample
                # with torch.no_grad():
                noise = torch.randn(B, T, Da, device=device)
                timesteps = torch.randint(
                    0, self.noise_scheduler.config.num_train_timesteps, 
                    (B,), device=device
                ).long()

                if ntarget is None:
                    # Non-targeted attack
                    noisy_trajectory = self.noise_scheduler.add_noise(
                        naction_pred, noise, timesteps)
                else:
                    # Targeted attack
                    print('targeted attack')
                    noisy_trajectory = self.noise_scheduler.add_noise(
                        ntarget, noise, timesteps)
                    
                # predict noise
                noise_pred = self.model(noisy_trajectory, timesteps, local_cond=local_cond, global_cond=global_cond)

                # l2 loss
                if ntarget is None:
                    # non-targeted attack is to maximize the loss between new noise and the GT noise
                    loss = -F.mse_loss(noise_pred, noise)
                else:
                    # targeted attack is to minimize the loss between new noise and the target noise
                    loss = F.mse_loss(noise_pred, noise)
                # loss_list.append(loss.cpu().item())
                # print(loss.shape)
                # B, Ta, (Da=10)
                # add more weight to x,y,z (first 3 dimensions)
                # loss = -F.mse_loss(noise_pred, noise, reduction='none')
                # loss = (loss * torch.tensor([1, 1, 1, 0, 0, 0, 0, 0, 0, 0], device=device, dtype=dtype)).mean()
                loss.backward()

                for keys in attack_images_keys:
                    # Get the gradient
                    new_gradient = obs_clone[keys].grad.data.sign()
                    # update the image
                    obs_clone[keys] = obs_clone[keys] - alpha * new_gradient
                    # clip the image within the epsilon ball
                    obs_clone[keys] = torch.max(torch.min(obs_clone[keys], obs_dict[keys] + eps), obs_dict[keys] - eps)
                # new_gradient = nobs_clone['sideview_image'].grad.data.sign()
                # # update the image
                # nobs_clone['sideview_image'] = nobs_clone['sideview_image'] - alpha * new_gradient
                # # clip the image within the epsilon ball
                # nobs_clone['sideview_image'] = torch.max(torch.min(nobs_clone['sideview_image'], nobs['sideview_image'] + eps), nobs['sideview_image'] - eps)
                temp_obs_dict = obs_clone

        #         with torch.no_grad():
        #             # find the expectation of the loss
        #             eval_loss = 0
        #             for i in range(self.num_inference_steps):
        #                 timesteps = torch.ones(B, device=device, dtype=torch.long) * i
        #                 noise = torch.randn(nsample.shape, device=device)
        #                 noisy_trajectory = self.noise_scheduler.add_noise(
        #                     naction_pred, noise, timesteps)
                        
        #                 # get image features
        #                 nobs_clone = self.normalizer.normalize(obs_clone)
        #                 this_nobs =  dict_apply(nobs_clone, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:]))
        #                 nobs_features = self.obs_encoder(this_nobs)
        #                 # reshape back to B, Do
        #                 global_cond = nobs_features.reshape(B, -1)

        #                 # predict noise
        #                 noise_pred = self.model(noisy_trajectory, timesteps, local_cond=local_cond, global_cond=global_cond)
                        
        #                 # find the loss
        #                 loss = -F.mse_loss(noise_pred, noise)
        #                 eval_loss += loss
        #             eval_loss /= self.num_inference_steps
        #             loss_list.append(eval_loss.cpu().item())
        #             # print('eval_loss:', eval_loss.cpu().item())
                    
        # print('loss_list:', loss_list)

        # clip the image within 0, 1
        for key in attack_images_keys:
            temp_obs_dict[key] = torch.clamp(temp_obs_dict[key], 0, 1)
        # nobs['sideview_image'] = torch.clamp(nobs['sideview_image'], -1, 1)
        # nobs['robot0_eye_in_hand_image'] = torch.clamp(nobs['robot0_eye_in_hand_image'], -1, 1)

        return self.get_encoded_obs(temp_obs_dict)

    def get_random_attacked_feature(self, obs_dict: Dict[str, torch.Tensor], eps=.03) -> torch.Tensor:
        temp_obs_dict = dict()
        for key, value in obs_dict.items():
            temp_obs_dict[key] = value.detach().clone()
        attack_images_keys = []
        for key in obs_dict.keys():
            if 'image' in key:
                attack_images_keys.append(key)
        for key in attack_images_keys:
            attacked_img = (torch.rand_like(temp_obs_dict[key]) * eps).clamp(-eps, eps)
            temp_obs_dict[key] = temp_obs_dict[key] + attacked_img
            temp_obs_dict[key] = torch.clamp(temp_obs_dict[key], 0, 1)
        return self.get_encoded_obs(temp_obs_dict)
    
    # ========= training  ============
    def set_normalizer(self, normalizer: LinearNormalizer):
        self.normalizer.load_state_dict(normalizer.state_dict())

    def compute_loss(self, batch):
        # normalize input
        assert 'valid_mask' not in batch
        nobs = self.normalizer.normalize(batch['obs'])
        nactions = self.normalizer['action'].normalize(batch['action'])
        batch_size = nactions.shape[0]
        horizon = nactions.shape[1]

        # handle different ways of passing observation
        local_cond = None
        global_cond = None
        trajectory = nactions
        cond_data = trajectory
        if self.obs_as_global_cond:
            # reshape B, T, ... to B*T
            this_nobs = dict_apply(nobs, 
                lambda x: x[:,:self.n_obs_steps,...].reshape(-1,*x.shape[2:]))
            nobs_features = self.obs_encoder(this_nobs)
            # reshape back to B, Do
            global_cond = nobs_features.reshape(batch_size, -1)
        else:
            # reshape B, T, ... to B*T
            this_nobs = dict_apply(nobs, lambda x: x.reshape(-1, *x.shape[2:]))
            nobs_features = self.obs_encoder(this_nobs)
            # reshape back to B, T, Do
            nobs_features = nobs_features.reshape(batch_size, horizon, -1)
            cond_data = torch.cat([nactions, nobs_features], dim=-1)
            trajectory = cond_data.detach()

        # generate impainting mask
        condition_mask = self.mask_generator(trajectory.shape)

        # Sample noise that we'll add to the images
        noise = torch.randn(trajectory.shape, device=trajectory.device)
        bsz = trajectory.shape[0]
        # Sample a random timestep for each image
        timesteps = torch.randint(
            0, self.noise_scheduler.config.num_train_timesteps, 
            (bsz,), device=trajectory.device
        ).long()
        # Add noise to the clean images according to the noise magnitude at each timestep
        # (this is the forward diffusion process)
        noisy_trajectory = self.noise_scheduler.add_noise(
            trajectory, noise, timesteps)
        
        # compute loss mask
        loss_mask = ~condition_mask

        # apply conditioning
        noisy_trajectory[condition_mask] = cond_data[condition_mask]
        
        # Predict the noise residual
        pred = self.model(noisy_trajectory, timesteps, 
            local_cond=local_cond, global_cond=global_cond)

        pred_type = self.noise_scheduler.config.prediction_type 
        if pred_type == 'epsilon':
            target = noise
        elif pred_type == 'sample':
            target = trajectory
        else:
            raise ValueError(f"Unsupported prediction type {pred_type}")

        loss = F.mse_loss(pred, target, reduction='none')
        loss = loss * loss_mask.type(loss.dtype)
        loss = reduce(loss, 'b ... -> b (...)', 'mean')
        loss = loss.mean()
        return loss

    def compute_targeted_loss(self, batch, target):
        # normalize input
        assert 'valid_mask' not in batch
        nobs = self.normalizer.normalize(batch['obs'])
        nactions = target * torch.ones_like(batch['action']) # use target action
        batch_size = nactions.shape[0]
        horizon = nactions.shape[1]

        # handle different ways of passing observation
        local_cond = None
        global_cond = None
        trajectory = nactions
        cond_data = trajectory
        if self.obs_as_global_cond:
            # reshape B, T, ... to B*T
            this_nobs = dict_apply(nobs, 
                lambda x: x[:,:self.n_obs_steps,...].reshape(-1,*x.shape[2:]))
            nobs_features = self.obs_encoder(this_nobs)
            # reshape back to B, Do
            global_cond = nobs_features.reshape(batch_size, -1)
        else:
            # reshape B, T, ... to B*T
            this_nobs = dict_apply(nobs, lambda x: x.reshape(-1, *x.shape[2:]))
            nobs_features = self.obs_encoder(this_nobs)
            # reshape back to B, T, Do
            nobs_features = nobs_features.reshape(batch_size, horizon, -1)
            cond_data = torch.cat([nactions, nobs_features], dim=-1)
            trajectory = cond_data.detach()

        # generate impainting mask
        condition_mask = self.mask_generator(trajectory.shape)

        # Sample noise that we'll add to the images
        noise = torch.randn(trajectory.shape, device=trajectory.device)
        bsz = trajectory.shape[0]
        # Sample a random timestep for each image
        timesteps = torch.randint(
            0, self.noise_scheduler.config.num_train_timesteps, 
            (bsz,), device=trajectory.device
        ).long()
        # Add noise to the clean images according to the noise magnitude at each timestep
        # (this is the forward diffusion process)
        noisy_trajectory = self.noise_scheduler.add_noise(
            trajectory, noise, timesteps)
        
        # compute loss mask
        loss_mask = ~condition_mask

        # apply conditioning
        noisy_trajectory[condition_mask] = cond_data[condition_mask]
        
        # Predict the noise residual
        pred = self.model(noisy_trajectory, timesteps, 
            local_cond=local_cond, global_cond=global_cond)

        pred_type = self.noise_scheduler.config.prediction_type 
        if pred_type == 'epsilon':
            target = noise
        elif pred_type == 'sample':
            target = trajectory
        else:
            raise ValueError(f"Unsupported prediction type {pred_type}")

        loss = F.mse_loss(pred, target, reduction='none')
        loss = loss * loss_mask.type(loss.dtype)
        loss = reduce(loss, 'b ... -> b (...)', 'mean')
        loss = loss.mean()
        return loss

    def compute_targeted_stay_loss(self, batch):
        # normalize input
        assert 'valid_mask' not in batch
        nobs = self.normalizer.normalize(batch['obs'])


        value = next(iter(batch['obs'].values()))
        B, To = value.shape[:2]
        Da = self.action_dim

        ee_positions = batch['obs']['robot0_eef_pos'][:, 0, :] # B, 3
        ee_quats = torch.zeros((B, 4)) # B, 4

        # set target as current position, note the convention difference, obs is in xyzw, pytorch3d expects wxyz
        ee_quats[:, 0] = batch['obs']['robot0_eef_quat'][:, 0, -1] # B, 4
        ee_quats[:, 1:] = batch['obs']['robot0_eef_quat'][:, 0, :-1] # B, 3

        # ee_quats[:, 0] = np.sqrt(2)/2
        # ee_quats[:, 2] = -np.sqrt(2)/2

        # somehow the action requires an offset rotation
        offset_quat = torch.zeros((B, 4))
        offset_quat[:, 1] = -np.sqrt(2)/2
        offset_quat[:, 2] = -np.sqrt(2)/2

        ee_quats = pt.quaternion_multiply(ee_quats, offset_quat)
        action = torch.zeros((B, self.n_action_steps, 10))
        tf_quat_to_rot6 = RotationTransformer('rotation_6d', 'quaternion')
        # print(f'{action.shape=}')
        # return
        action[..., 3:3+6] = tf_quat_to_rot6.inverse(ee_quats[:, None, :])
        action[..., :3] = ee_positions[:, None, :]

        nactions = self.normalizer['action'].normalize(action)

        # nactions = target * torch.ones_like(batch['action']) # use target action
        batch_size = nactions.shape[0]
        horizon = nactions.shape[1]

        # handle different ways of passing observation
        local_cond = None
        global_cond = None
        trajectory = nactions
        cond_data = trajectory
        if self.obs_as_global_cond:
            # reshape B, T, ... to B*T
            this_nobs = dict_apply(nobs, 
                lambda x: x[:,:self.n_obs_steps,...].reshape(-1,*x.shape[2:]))
            nobs_features = self.obs_encoder(this_nobs)
            # reshape back to B, Do
            global_cond = nobs_features.reshape(batch_size, -1)
        else:
            # reshape B, T, ... to B*T
            this_nobs = dict_apply(nobs, lambda x: x.reshape(-1, *x.shape[2:]))
            nobs_features = self.obs_encoder(this_nobs)
            # reshape back to B, T, Do
            nobs_features = nobs_features.reshape(batch_size, horizon, -1)
            cond_data = torch.cat([nactions, nobs_features], dim=-1)
            trajectory = cond_data.detach()

        # generate impainting mask
        condition_mask = self.mask_generator(trajectory.shape)

        # Sample noise that we'll add to the images
        noise = torch.randn(trajectory.shape, device=trajectory.device)
        bsz = trajectory.shape[0]
        # Sample a random timestep for each image
        timesteps = torch.randint(
            0, self.noise_scheduler.config.num_train_timesteps, 
            (bsz,), device=trajectory.device
        ).long()
        # Add noise to the clean images according to the noise magnitude at each timestep
        # (this is the forward diffusion process)
        noisy_trajectory = self.noise_scheduler.add_noise(
            trajectory, noise, timesteps)
        
        # compute loss mask
        loss_mask = ~condition_mask

        # apply conditioning
        noisy_trajectory[condition_mask] = cond_data[condition_mask]
        
        # Predict the noise residual
        pred = self.model(noisy_trajectory, timesteps, 
            local_cond=local_cond, global_cond=global_cond)

        pred_type = self.noise_scheduler.config.prediction_type 
        if pred_type == 'epsilon':
            target = noise
        elif pred_type == 'sample':
            target = trajectory
        else:
            raise ValueError(f"Unsupported prediction type {pred_type}")

        loss = F.mse_loss(pred, target, reduction='none')
        loss = loss * loss_mask.type(loss.dtype)
        loss = reduce(loss, 'b ... -> b (...)', 'mean')
        loss = loss.mean()
        return loss
