from typing import Dict, Tuple
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, reduce
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler

from diffusion_policy.model.common.normalizer import LinearNormalizer
from diffusion_policy.policy.base_image_policy import BaseImagePolicy
from diffusion_policy.model.diffusion.transformer_for_diffusion import TransformerForDiffusion
from diffusion_policy.model.diffusion.mask_generator import LowdimMaskGenerator
from diffusion_policy.common.robomimic_config_util import get_robomimic_config
from robomimic.algo import algo_factory
from robomimic.algo.algo import PolicyAlgo
import robomimic.utils.obs_utils as ObsUtils
import robomimic.models.base_nets as rmbn
import diffusion_policy.model.vision.crop_randomizer as dmvc
from diffusion_policy.common.pytorch_util import dict_apply, replace_submodules
from time import perf_counter
from diffusers.schedulers.scheduling_ddim import DDIMScheduler


class DiffusionTransformerHybridImagePolicy(BaseImagePolicy):
    def __init__(self, 
            shape_meta: dict,
            noise_scheduler: DDPMScheduler,
            # task params
            horizon, 
            n_action_steps, 
            n_obs_steps,
            num_inference_steps=None,
            # image
            crop_shape=(76, 76),
            obs_encoder_group_norm=False,
            eval_fixed_crop=False,
            # arch
            n_layer=8,
            n_cond_layers=0,
            n_head=4,
            n_emb=256,
            p_drop_emb=0.0,
            p_drop_attn=0.3,
            causal_attn=True,
            time_as_cond=True,
            obs_as_cond=True,
            pred_action_steps_only=False,
            # parameters passed to step
            **kwargs):
        super().__init__()

        # parse shape_meta
        action_shape = shape_meta['action']['shape']
        assert len(action_shape) == 1
        action_dim = action_shape[0]
        obs_shape_meta = shape_meta['obs']
        obs_config = {
            'low_dim': [],
            'rgb': [],
            'depth': [],
            'scan': []
        }
        obs_key_shapes = dict()
        for key, attr in obs_shape_meta.items():
            shape = attr['shape']
            obs_key_shapes[key] = list(shape)

            type = attr.get('type', 'low_dim')
            if type == 'rgb':
                obs_config['rgb'].append(key)
            elif type == 'low_dim':
                obs_config['low_dim'].append(key)
            else:
                raise RuntimeError(f"Unsupported obs type: {type}")

        # get raw robomimic config
        config = get_robomimic_config(
            algo_name='bc_rnn',
            hdf5_type='image',
            task_name='square',
            dataset_type='ph')
        
        with config.unlocked():
            # set config with shape_meta
            config.observation.modalities.obs = obs_config

            if crop_shape is None:
                for key, modality in config.observation.encoder.items():
                    if modality.obs_randomizer_class == 'CropRandomizer':
                        modality['obs_randomizer_class'] = None
            else:
                # set random crop parameter
                ch, cw = crop_shape
                for key, modality in config.observation.encoder.items():
                    if modality.obs_randomizer_class == 'CropRandomizer':
                        modality.obs_randomizer_kwargs.crop_height = ch
                        modality.obs_randomizer_kwargs.crop_width = cw

        # init global state
        ObsUtils.initialize_obs_utils_with_config(config)

        # load model
        policy: PolicyAlgo = algo_factory(
                algo_name=config.algo_name,
                config=config,
                obs_key_shapes=obs_key_shapes,
                ac_dim=action_dim,
                device='cpu',
            )

        obs_encoder = policy.nets['policy'].nets['encoder'].nets['obs']
        
        if obs_encoder_group_norm:
            # replace batch norm with group norm
            replace_submodules(
                root_module=obs_encoder,
                predicate=lambda x: isinstance(x, nn.BatchNorm2d),
                func=lambda x: nn.GroupNorm(
                    num_groups=x.num_features//16, 
                    num_channels=x.num_features)
            )
            # obs_encoder.obs_nets['agentview_image'].nets[0].nets
        
        # obs_encoder.obs_randomizers['agentview_image']
        if eval_fixed_crop:
            replace_submodules(
                root_module=obs_encoder,
                predicate=lambda x: isinstance(x, rmbn.CropRandomizer),
                func=lambda x: dmvc.CropRandomizer(
                    input_shape=x.input_shape,
                    crop_height=x.crop_height,
                    crop_width=x.crop_width,
                    num_crops=x.num_crops,
                    pos_enc=x.pos_enc
                )
            )

        # create diffusion model
        obs_feature_dim = obs_encoder.output_shape()[0]
        input_dim = action_dim if obs_as_cond else (obs_feature_dim + action_dim)
        output_dim = input_dim
        cond_dim = obs_feature_dim if obs_as_cond else 0

        model = TransformerForDiffusion(
            input_dim=input_dim,
            output_dim=output_dim,
            horizon=horizon,
            n_obs_steps=n_obs_steps,
            cond_dim=cond_dim,
            n_layer=n_layer,
            n_head=n_head,
            n_emb=n_emb,
            p_drop_emb=p_drop_emb,
            p_drop_attn=p_drop_attn,
            causal_attn=causal_attn,
            time_as_cond=time_as_cond,
            obs_as_cond=obs_as_cond,
            n_cond_layers=n_cond_layers
        )

        self.obs_encoder = obs_encoder
        self.model = model
        self.noise_scheduler = noise_scheduler
        self.mask_generator = LowdimMaskGenerator(
            action_dim=action_dim,
            obs_dim=0 if (obs_as_cond) else obs_feature_dim,
            max_n_obs_steps=n_obs_steps,
            fix_obs_steps=True,
            action_visible=False
        )
        self.normalizer = LinearNormalizer()
        self.horizon = horizon
        self.obs_feature_dim = obs_feature_dim
        self.action_dim = action_dim
        self.n_action_steps = n_action_steps
        self.n_obs_steps = n_obs_steps
        self.obs_as_cond = obs_as_cond
        self.pred_action_steps_only = pred_action_steps_only
        self.kwargs = kwargs

        if num_inference_steps is None:
            num_inference_steps = noise_scheduler.config.num_train_timesteps
        self.num_inference_steps = num_inference_steps

        print("Diffusion params: %e" % sum(p.numel() for p in self.model.parameters()))
        print("Vision params: %e" % sum(p.numel() for p in self.obs_encoder.parameters()))
    
    # ========= inference  ============
    def conditional_sample(self, 
            condition_data, condition_mask,
            cond=None, generator=None,
            # keyword arguments to scheduler.step
            **kwargs
            ):
        model = self.model
        scheduler = self.noise_scheduler

        trajectory = torch.randn(
            size=condition_data.shape, 
            dtype=condition_data.dtype,
            device=condition_data.device,
            generator=generator)
    
        # set step values
        scheduler.set_timesteps(self.num_inference_steps)

        for t in scheduler.timesteps:
            # 1. apply conditioning
            trajectory[condition_mask] = condition_data[condition_mask]

            # 2. predict model output
            model_output = model(trajectory, t, cond)

            # 3. compute previous image: x_t -> x_t-1
            trajectory = scheduler.step(
                model_output, t, trajectory, 
                generator=generator,
                **kwargs
                ).prev_sample
        
        # finally make sure conditioning is enforced
        trajectory[condition_mask] = condition_data[condition_mask]        

        return trajectory

    def conditional_sample_with_scheduler(self, 
            condition_data, condition_mask, scheduler: DDIMScheduler, steps=-1,
            cond=None, generator=None,
            # keyword arguments to scheduler.step
            **kwargs
            ):
        model = self.model

        trajectory = torch.randn(
            size=condition_data.shape, 
            dtype=condition_data.dtype,
            device=condition_data.device,
            generator=generator)
    
        # set step values
        if steps == -1:
            steps = self.num_inference_steps
        scheduler.set_timesteps(steps)

        for t in scheduler.timesteps:
            # 1. apply conditioning
            trajectory[condition_mask] = condition_data[condition_mask]

            # 2. predict model output
            model_output = model(trajectory, t, cond)

            # 3. compute previous image: x_t -> x_t-1
            trajectory = scheduler.step(
                model_output, t, trajectory, 
                generator=generator,
                **kwargs
                ).prev_sample
        
        # finally make sure conditioning is enforced
        trajectory[condition_mask] = condition_data[condition_mask]        

        return trajectory


    def predict_action(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        obs_dict: must include "obs" key
        result: must include "action" key
        """
        assert 'past_action' not in obs_dict # not implemented yet
        # normalize input
        nobs = self.normalizer.normalize(obs_dict)
        value = next(iter(nobs.values()))
        B, To = value.shape[:2]
        T = self.horizon
        Da = self.action_dim
        Do = self.obs_feature_dim
        To = self.n_obs_steps

        # build input
        device = self.device
        dtype = self.dtype

        # handle different ways of passing observation
        cond = None
        cond_data = None
        cond_mask = None
        if self.obs_as_cond:
            this_nobs = dict_apply(nobs, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:]))
            nobs_features = self.obs_encoder(this_nobs)
            # reshape back to B, To, Do
            cond = nobs_features.reshape(B, To, -1)
            shape = (B, T, Da)
            if self.pred_action_steps_only:
                shape = (B, self.n_action_steps, Da)
            cond_data = torch.zeros(size=shape, device=device, dtype=dtype)
            cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
        else:
            # condition through impainting
            this_nobs = dict_apply(nobs, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:]))
            nobs_features = self.obs_encoder(this_nobs)
            # reshape back to B, To, Do
            nobs_features = nobs_features.reshape(B, To, -1)
            shape = (B, T, Da+Do)
            cond_data = torch.zeros(size=shape, device=device, dtype=dtype)
            cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
            cond_data[:,:To,Da:] = nobs_features
            cond_mask[:,:To,Da:] = True

        # run sampling
        nsample = self.conditional_sample(
            cond_data, 
            cond_mask,
            cond=cond,
            **self.kwargs)
        
        # unnormalize prediction
        naction_pred = nsample[...,:Da]
        action_pred = self.normalizer['action'].unnormalize(naction_pred)

        # get action
        if self.pred_action_steps_only:
            action = action_pred
        else:
            start = To - 1
            end = start + self.n_action_steps
            action = action_pred[:,start:end]
        
        result = {
            'action': action,
            'action_pred': action_pred
        }
        return result

    def predict_action_with_scheduler(self, obs_dict: Dict[str, torch.Tensor], scheduler: DDIMScheduler, steps=-1) -> Dict[str, torch.Tensor]:
        """
        obs_dict: must include "obs" key
        result: must include "action" key
        """
        assert 'past_action' not in obs_dict # not implemented yet
        # normalize input
        nobs = self.normalizer.normalize(obs_dict)
        value = next(iter(nobs.values()))
        B, To = value.shape[:2]
        T = self.horizon
        Da = self.action_dim
        Do = self.obs_feature_dim
        To = self.n_obs_steps

        # build input
        device = self.device
        dtype = self.dtype

        # handle different ways of passing observation
        cond = None
        cond_data = None
        cond_mask = None
        if self.obs_as_cond:
            this_nobs = dict_apply(nobs, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:]))
            nobs_features = self.obs_encoder(this_nobs)
            # reshape back to B, To, Do
            cond = nobs_features.reshape(B, To, -1)
            shape = (B, T, Da)
            if self.pred_action_steps_only:
                shape = (B, self.n_action_steps, Da)
            cond_data = torch.zeros(size=shape, device=device, dtype=dtype)
            cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
        else:
            # condition through impainting
            this_nobs = dict_apply(nobs, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:]))
            nobs_features = self.obs_encoder(this_nobs)
            # reshape back to B, To, Do
            nobs_features = nobs_features.reshape(B, To, -1)
            shape = (B, T, Da+Do)
            cond_data = torch.zeros(size=shape, device=device, dtype=dtype)
            cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
            cond_data[:,:To,Da:] = nobs_features
            cond_mask[:,:To,Da:] = True

        # run sampling
        nsample = self.conditional_sample_with_scheduler(
            cond_data, 
            cond_mask,
            cond=cond,
            scheduler=scheduler,
            steps=steps,
            **self.kwargs)
        
        # unnormalize prediction
        naction_pred = nsample[...,:Da]
        action_pred = self.normalizer['action'].unnormalize(naction_pred)

        # get action
        if self.pred_action_steps_only:
            action = action_pred
        else:
            start = To - 1
            end = start + self.n_action_steps
            action = action_pred[:,start:end]
        
        result = {
            'action': action,
            'action_pred': action_pred
        }
        return result
    
    # ========= attacking ============
    def predict_action_attacked(self, obs_dict: Dict[str, torch.Tensor], eps=.03, num_steps=50, alpha=.001875, ntarget: torch.Tensor=None) -> Dict[str, torch.Tensor]:
        """
        obs_dict: must include "obs" key
        result: must include "action" key
        eps: epsilon for the attack
        step: number of steps for the attack
        alpha: step size for the attack
        ntarget: the normalized target action, shape B, T, Da
        """
        start_time = perf_counter()
        assert 'past_action' not in obs_dict
        # normalize input
        nobs = self.normalizer.normalize(obs_dict)
        value = next(iter(nobs.values()))
        B, To = value.shape[:2]
        T = self.horizon
        Da = self.action_dim
        Do = self.obs_feature_dim
        To = self.n_obs_steps

        # build input
        device = self.device
        dtype = self.dtype

        # handle different ways of passing observation
        cond = None
        cond_data = None
        cond_mask = None
        if self.obs_as_cond:
            this_nobs = dict_apply(nobs, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:]))
            nobs_features = self.obs_encoder(this_nobs)
            # reshape back to B, To, Do
            cond = nobs_features.reshape(B, To, -1)
            shape = (B, T, Da)
            if self.pred_action_steps_only:
                shape = (B, self.n_action_steps, Da)
            cond_data = torch.zeros(size=shape, device=device, dtype=dtype)
            cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
        else:
            # condition through impainting
            this_nobs = dict_apply(nobs, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:]))
            nobs_features = self.obs_encoder(this_nobs)
            # reshape back to B, To, Do
            nobs_features = nobs_features.reshape(B, To, -1)
            shape = (B, T, Da+Do)
            cond_data = torch.zeros(size=shape, device=device, dtype=dtype)
            cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
            cond_data[:,:To,Da:] = nobs_features
            cond_mask[:,:To,Da:] = True

        if ntarget is None:

            # run sampling
            nsample = self.conditional_sample(
                cond_data, 
                cond_mask,
                cond=cond,
                **self.kwargs)
            
            # unnormalize prediction
            naction_pred = nsample[...,:Da]
            action_pred = self.normalizer['action'].unnormalize(naction_pred)

            # get action
            if self.pred_action_steps_only:
                action = action_pred
            else:
                start = To - 1
                end = start + self.n_action_steps
                action = action_pred[:,start:end]

            result = {
                'action_gt': action,
                'action_pred': action_pred,
            }
            naction_pred = naction_pred.detach().clone()
        else:
            result = {
                'action_target': self.normalizer['action'].unnormalize(ntarget),
            }
            ntarget = ntarget.detach().clone()

        # figure out the attack observation keys
        attack_images_keys = []
        for key in obs_dict.keys():
            if 'image' in key:
                attack_images_keys.append(key)
        
        # conduct the attack

        # clone the observation
        temp_obs_dict = dict()
        for key, value in obs_dict.items():
            temp_obs_dict[key] = value.detach().clone()

        with torch.enable_grad():
            for _ in range(num_steps):
                obs_clone = dict()
                for key, value in temp_obs_dict.items():
                    obs_clone[key] = value.detach().clone()
                for key in attack_images_keys:
                    # make the image requires grad
                    obs_clone[key].requires_grad = True
                # # images are: 'sideview_image', 'robot0_eye_in_hand_image'
                # nobs_clone['sideview_image'].requires_grad = True
                # nobs_clone['robot0_eye_in_hand_image'].requires_grad = True

                # get image features
                nobs_clone = self.normalizer.normalize(obs_clone)
                this_nobs =  dict_apply(nobs_clone, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:]))
                nobs_features = self.obs_encoder(this_nobs)
                # reshape back to B, To, Do
                cond = nobs_features.reshape(B, To, -1)

                # forward sample
                # with torch.no_grad():
                noise = torch.randn(B, T, Da, device=device)
                timesteps = torch.randint(
                    0, self.noise_scheduler.config.num_train_timesteps, 
                    (B,), device=device
                ).long()

                if ntarget is None:
                    # Non-targeted attack
                    noisy_trajectory = self.noise_scheduler.add_noise(
                        naction_pred, noise, timesteps)
                else:
                    # Targeted attack
                    noisy_trajectory = self.noise_scheduler.add_noise(
                        ntarget, noise, timesteps)
                    
                # predict noise
                noise_pred = self.model(noisy_trajectory, timesteps, cond)

                # l2 loss
                if ntarget is None:
                    # non-targeted attack is to maximize the loss between new noise and the GT noise
                    loss = -F.mse_loss(noise_pred, noise)
                else:
                    # targeted attack is to minimize the loss between new noise and the target noise
                    loss = F.mse_loss(noise_pred, noise)
                # print(loss.shape)
                # B, Ta, (Da=10)
                # add more weight to x,y,z (first 3 dimensions)
                # loss = -F.mse_loss(noise_pred, noise, reduction='none')
                # loss = (loss * torch.tensor([1, 1, 1, 0, 0, 0, 0, 0, 0, 0], device=device, dtype=dtype)).mean()
                loss.backward()

                for keys in attack_images_keys:
                    # Get the gradient
                    new_gradient = obs_clone[keys].grad.data.sign()
                    # update the image
                    obs_clone[keys] = obs_clone[keys] - alpha * new_gradient
                    # clip the image within the epsilon ball
                    obs_clone[keys] = torch.max(torch.min(obs_clone[keys], obs_dict[keys] + eps), obs_dict[keys] - eps)
                # new_gradient = nobs_clone['sideview_image'].grad.data.sign()
                # # update the image
                # nobs_clone['sideview_image'] = nobs_clone['sideview_image'] - alpha * new_gradient
                # # clip the image within the epsilon ball
                # nobs_clone['sideview_image'] = torch.max(torch.min(nobs_clone['sideview_image'], nobs['sideview_image'] + eps), nobs['sideview_image'] - eps)
                temp_obs_dict = obs_clone

        # clip the image within 0, 1
        for key in attack_images_keys:
            temp_obs_dict[key] = torch.clamp(temp_obs_dict[key], 0, 1)
        # nobs['sideview_image'] = torch.clamp(nobs['sideview_image'], -1, 1)
        # nobs['robot0_eye_in_hand_image'] = torch.clamp(nobs['robot0_eye_in_hand_image'], -1, 1)

        result['attacked obs'] = temp_obs_dict

        end_time = perf_counter()
        used_time = end_time - start_time
        result['attack_time'] = used_time
        
        # recompute action
        result.update(self.predict_action(temp_obs_dict))

        return result

    def predict_action_full_chain_attacked(self, obs_dict: Dict[str, torch.Tensor], eps=.03, num_steps=50, alpha=.001875, scheduler: DDIMScheduler=None, ntarget: torch.Tensor=None, scheduler_steps=-1) -> Dict[str, torch.Tensor]:
        """
        obs_dict: must include "obs" key
        result: must include "action" key
        eps: epsilon for the attack
        step: number of steps for the attack
        alpha: step size for the attack
        scheduler: the scheduler for the attack
        ntarget: the normalized target action, shape B, T, Da
        """
        start_time = perf_counter()

        if ntarget is None:
            action_gt_result = self.predict_action(obs_dict)

            # store the ground truth action
            result = {
                'action_gt': action_gt_result['action'],
                'action_pred': action_gt_result['action_pred'],
            }
        else:
            result = {
                'action_target': self.normalizer['action'].unnormalize(ntarget),
            }
            ntarget = ntarget.detach().clone()

        # figure out the attack observation keys
        attack_images_keys = []
        for key in obs_dict.keys():
            if 'image' in key:
                attack_images_keys.append(key)
        
        # conduct the attack

        # clone the observation
        temp_obs_dict = dict()
        for key, value in obs_dict.items():
            temp_obs_dict[key] = value.detach().clone()

        loss_list = []

        with torch.enable_grad():
            for _ in range(num_steps):
                obs_clone = dict()
                for key, value in temp_obs_dict.items():
                    obs_clone[key] = value.detach().clone()
                for key in attack_images_keys:
                    # make the image requires grad
                    obs_clone[key].requires_grad = True
                # # images are: 'sideview_image', 'robot0_eye_in_hand_image'
                # nobs_clone['sideview_image'].requires_grad = True
                # nobs_clone['robot0_eye_in_hand_image'].requires_grad = True

                if scheduler is None:
                    scheduler = self.noise_scheduler # use default scheduler for attack

                temp_action_pred = self.predict_action_with_scheduler(obs_clone, scheduler, steps=scheduler_steps)

                # l2 loss
                if ntarget is None:
                    loss = -F.mse_loss(result['action_gt'], temp_action_pred['action'])
                else:
                    loss = F.mse_loss(result['action_target'], temp_action_pred['action_pred'])
                loss.backward()

                for keys in attack_images_keys:
                    # Get the gradient
                    new_gradient = obs_clone[keys].grad.data.sign()
                    # update the image
                    obs_clone[keys] = obs_clone[keys] - alpha * new_gradient
                    # clip the image within the epsilon ball
                    obs_clone[keys] = torch.max(torch.min(obs_clone[keys], obs_dict[keys] + eps), obs_dict[keys] - eps)
                # new_gradient = nobs_clone['sideview_image'].grad.data.sign()
                # # update the image
                # nobs_clone['sideview_image'] = nobs_clone['sideview_image'] - alpha * new_gradient
                # # clip the image within the epsilon ball
                # nobs_clone['sideview_image'] = torch.max(torch.min(nobs_clone['sideview_image'], nobs['sideview_image'] + eps), nobs['sideview_image'] - eps)
                temp_obs_dict = obs_clone
            #     loss_list.append(loss.cpu().item())
            # print(f'Loss history: {loss_list}')

        # clip the image within 0, 1
        for key in attack_images_keys:
            temp_obs_dict[key] = torch.clamp(temp_obs_dict[key], 0, 1)
        # nobs['sideview_image'] = torch.clamp(nobs['sideview_image'], -1, 1)
        # nobs['robot0_eye_in_hand_image'] = torch.clamp(nobs['robot0_eye_in_hand_image'], -1, 1)

        result['attacked obs'] = temp_obs_dict

        end_time = perf_counter()
        used_time = end_time - start_time
        result['attack_time'] = used_time
        
        # recompute action
        result.update(self.predict_action(temp_obs_dict))

        return result


    # ========= training  ============
    def set_normalizer(self, normalizer: LinearNormalizer):
        self.normalizer.load_state_dict(normalizer.state_dict())

    def get_optimizer(
            self, 
            transformer_weight_decay: float, 
            obs_encoder_weight_decay: float,
            learning_rate: float, 
            betas: Tuple[float, float]
        ) -> torch.optim.Optimizer:
        optim_groups = self.model.get_optim_groups(
            weight_decay=transformer_weight_decay)
        optim_groups.append({
            "params": self.obs_encoder.parameters(),
            "weight_decay": obs_encoder_weight_decay
        })
        optimizer = torch.optim.AdamW(
            optim_groups, lr=learning_rate, betas=betas
        )
        return optimizer

    def compute_loss(self, batch):
        # normalize input
        assert 'valid_mask' not in batch
        nobs = self.normalizer.normalize(batch['obs'])
        nactions = self.normalizer['action'].normalize(batch['action'])
        batch_size = nactions.shape[0]
        horizon = nactions.shape[1]
        To = self.n_obs_steps

        # handle different ways of passing observation
        cond = None
        trajectory = nactions
        if self.obs_as_cond:
            # reshape B, T, ... to B*T
            this_nobs = dict_apply(nobs, 
                lambda x: x[:,:To,...].reshape(-1,*x.shape[2:]))
            nobs_features = self.obs_encoder(this_nobs)
            # reshape back to B, T, Do
            cond = nobs_features.reshape(batch_size, To, -1)
            if self.pred_action_steps_only:
                start = To - 1
                end = start + self.n_action_steps
                trajectory = nactions[:,start:end]
        else:
            # reshape B, T, ... to B*T
            this_nobs = dict_apply(nobs, lambda x: x.reshape(-1, *x.shape[2:]))
            nobs_features = self.obs_encoder(this_nobs)
            # reshape back to B, T, Do
            nobs_features = nobs_features.reshape(batch_size, horizon, -1)
            trajectory = torch.cat([nactions, nobs_features], dim=-1).detach()

        # generate impainting mask
        if self.pred_action_steps_only:
            condition_mask = torch.zeros_like(trajectory, dtype=torch.bool)
        else:
            condition_mask = self.mask_generator(trajectory.shape)

        # Sample noise that we'll add to the images
        noise = torch.randn(trajectory.shape, device=trajectory.device)
        bsz = trajectory.shape[0]
        # Sample a random timestep for each image
        timesteps = torch.randint(
            0, self.noise_scheduler.config.num_train_timesteps, 
            (bsz,), device=trajectory.device
        ).long()
        # Add noise to the clean images according to the noise magnitude at each timestep
        # (this is the forward diffusion process)
        noisy_trajectory = self.noise_scheduler.add_noise(
            trajectory, noise, timesteps)

        # compute loss mask
        loss_mask = ~condition_mask

        # apply conditioning
        noisy_trajectory[condition_mask] = trajectory[condition_mask]
        
        # Predict the noise residual
        pred = self.model(noisy_trajectory, timesteps, cond)

        pred_type = self.noise_scheduler.config.prediction_type 
        if pred_type == 'epsilon':
            target = noise
        elif pred_type == 'sample':
            target = trajectory
        else:
            raise ValueError(f"Unsupported prediction type {pred_type}")

        loss = F.mse_loss(pred, target, reduction='none')
        loss = loss * loss_mask.type(loss.dtype)
        loss = reduce(loss, 'b ... -> b (...)', 'mean')
        loss = loss.mean()
        return loss
    

    def compute_targeted_loss(self, batch, target):
        # normalize input
        assert 'valid_mask' not in batch
        nobs = self.normalizer.normalize(batch['obs'])
        nactions = torch.ones_like(batch['action']) * target
        batch_size = nactions.shape[0]
        horizon = nactions.shape[1]
        To = self.n_obs_steps

        # handle different ways of passing observation
        cond = None
        trajectory = nactions
        if self.obs_as_cond:
            # reshape B, T, ... to B*T
            this_nobs = dict_apply(nobs, 
                lambda x: x[:,:To,...].reshape(-1,*x.shape[2:]))
            nobs_features = self.obs_encoder(this_nobs)
            # reshape back to B, T, Do
            cond = nobs_features.reshape(batch_size, To, -1)
            if self.pred_action_steps_only:
                start = To - 1
                end = start + self.n_action_steps
                trajectory = nactions[:,start:end]
        else:
            # reshape B, T, ... to B*T
            this_nobs = dict_apply(nobs, lambda x: x.reshape(-1, *x.shape[2:]))
            nobs_features = self.obs_encoder(this_nobs)
            # reshape back to B, T, Do
            nobs_features = nobs_features.reshape(batch_size, horizon, -1)
            trajectory = torch.cat([nactions, nobs_features], dim=-1).detach()

        # generate impainting mask
        if self.pred_action_steps_only:
            condition_mask = torch.zeros_like(trajectory, dtype=torch.bool)
        else:
            condition_mask = self.mask_generator(trajectory.shape)

        # Sample noise that we'll add to the images
        noise = torch.randn(trajectory.shape, device=trajectory.device)
        bsz = trajectory.shape[0]
        # Sample a random timestep for each image
        timesteps = torch.randint(
            0, self.noise_scheduler.config.num_train_timesteps, 
            (bsz,), device=trajectory.device
        ).long()
        # Add noise to the clean images according to the noise magnitude at each timestep
        # (this is the forward diffusion process)
        noisy_trajectory = self.noise_scheduler.add_noise(
            trajectory, noise, timesteps)

        # compute loss mask
        loss_mask = ~condition_mask

        # apply conditioning
        noisy_trajectory[condition_mask] = trajectory[condition_mask]
        
        # Predict the noise residual
        pred = self.model(noisy_trajectory, timesteps, cond)

        pred_type = self.noise_scheduler.config.prediction_type 
        if pred_type == 'epsilon':
            target = noise
        elif pred_type == 'sample':
            target = trajectory
        else:
            raise ValueError(f"Unsupported prediction type {pred_type}")

        loss = F.mse_loss(pred, target, reduction='none')
        loss = loss * loss_mask.type(loss.dtype)
        loss = reduce(loss, 'b ... -> b (...)', 'mean')
        loss = loss.mean()
        return loss
