from typing import Dict
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Union, Tuple

from einops import rearrange, reduce
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler

from diffusion_policy.model.common.normalizer import LinearNormalizer
from diffusion_policy.policy.base_image_policy import BaseImagePolicy
from diffusion_policy.model.diffusion.conditional_unet1d import ConditionalUnet1D
from diffusion_policy.model.diffusion.mask_generator import LowdimMaskGenerator
from diffusion_policy.common.robomimic_config_util import get_robomimic_config
from robomimic.algo import algo_factory
from robomimic.algo.algo import PolicyAlgo
import robomimic.utils.obs_utils as ObsUtils
import robomimic.models.base_nets as rmbn
import diffusion_policy.model.vision.crop_randomizer as dmvc
from diffusion_policy.common.pytorch_util import dict_apply, replace_submodules

import math


def progressive_function(x, rollout_every):
    if x <= rollout_every * 3:
        return x / (rollout_every * 3)
    else:
        return 1


# 示例用法

def sigmoid_growth(epoch, total_epochs, start=0.1, end=1.0):
    mid = total_epochs / 2  # 中点
    k = 10  # 控制陡峭程度的系数
    value = 1 / (1 + math.exp(-k * (epoch - mid) / total_epochs))  # Sigmoid 函数
    return start + (end - start) * value


class DiffusionUnetHybridImagePolicy(BaseImagePolicy):
    def __init__(self,
                 shape_meta: dict,
                 noise_scheduler: DDPMScheduler,
                 horizon,
                 n_action_steps,
                 n_obs_steps,
                 num_inference_steps=None,
                 obs_as_global_cond=True,
                 crop_shape=(76, 76),
                 diffusion_step_embed_dim=256,
                 down_dims=(256, 512, 1024),
                 kernel_size=5,
                 n_groups=8,
                 cond_predict_scale=True,
                 obs_encoder_group_norm=False,
                 eval_fixed_crop=False,
                 # parameters passed to step
                 **kwargs):
        super().__init__()
        self.total_epochs = 100
        # parse shape_meta
        action_shape = shape_meta['action']['shape']
        assert len(action_shape) == 1
        action_dim = action_shape[0]
        obs_shape_meta = shape_meta['obs']
        obs_config = {
            'low_dim': [],
            'rgb': [],
            'depth': [],
            'scan': []
        }
        obs_key_shapes = dict()
        for key, attr in obs_shape_meta.items():
            shape = attr['shape']
            obs_key_shapes[key] = list(shape)

            type = attr.get('type', 'low_dim')
            if type == 'rgb':
                obs_config['rgb'].append(key)
            elif type == 'low_dim':
                obs_config['low_dim'].append(key)
            else:
                raise RuntimeError(f"Unsupported obs type: {type}")

        # get raw robomimic config
        config = get_robomimic_config(
            algo_name='bc_rnn',
            hdf5_type='image',
            task_name='square',
            dataset_type='ph')

        with config.unlocked():
            # set config with shape_meta
            config.observation.modalities.obs = obs_config

            if crop_shape is None:
                for key, modality in config.observation.encoder.items():
                    if modality.obs_randomizer_class == 'CropRandomizer':
                        modality['obs_randomizer_class'] = None
            else:
                # set random crop parameter
                ch, cw = crop_shape
                for key, modality in config.observation.encoder.items():
                    if modality.obs_randomizer_class == 'CropRandomizer':
                        modality.obs_randomizer_kwargs.crop_height = ch
                        modality.obs_randomizer_kwargs.crop_width = cw

        # init global state
        ObsUtils.initialize_obs_utils_with_config(config)

        # load model
        policy: PolicyAlgo = algo_factory(
            algo_name=config.algo_name,
            config=config,
            obs_key_shapes=obs_key_shapes,
            ac_dim=action_dim,
            device='cpu',
        )

        obs_encoder = policy.nets['policy'].nets['encoder'].nets['obs']

        if obs_encoder_group_norm:
            # replace batch norm with group norm
            replace_submodules(
                root_module=obs_encoder,
                predicate=lambda x: isinstance(x, nn.BatchNorm2d),
                func=lambda x: nn.GroupNorm(
                    num_groups=x.num_features // 16,
                    num_channels=x.num_features)
            )
            # obs_encoder.obs_nets['agentview_image'].nets[0].nets

        # obs_encoder.obs_randomizers['agentview_image']
        if eval_fixed_crop:
            replace_submodules(
                root_module=obs_encoder,
                predicate=lambda x: isinstance(x, rmbn.CropRandomizer),
                func=lambda x: dmvc.CropRandomizer(
                    input_shape=x.input_shape,
                    crop_height=x.crop_height,
                    crop_width=x.crop_width,
                    num_crops=x.num_crops,
                    pos_enc=x.pos_enc
                )
            )

        # create diffusion model
        obs_feature_dim = obs_encoder.output_shape()[0]
        input_dim = action_dim + obs_feature_dim
        global_cond_dim = None
        if obs_as_global_cond:
            input_dim = action_dim
            global_cond_dim = obs_feature_dim * n_obs_steps

        model = ConditionalUnet1D(
            input_dim=input_dim,
            local_cond_dim=None,
            global_cond_dim=global_cond_dim,
            diffusion_step_embed_dim=diffusion_step_embed_dim,
            down_dims=down_dims,
            kernel_size=kernel_size,
            n_groups=n_groups,
            cond_predict_scale=cond_predict_scale
        )

        self.obs_encoder = obs_encoder
        self.model = model
        self.noise_scheduler = noise_scheduler
        self.mask_generator = LowdimMaskGenerator(
            action_dim=action_dim,
            obs_dim=0 if obs_as_global_cond else obs_feature_dim,
            max_n_obs_steps=n_obs_steps,
            fix_obs_steps=True,
            action_visible=False
        )
        self.normalizer = LinearNormalizer()
        self.horizon = horizon
        self.obs_feature_dim = obs_feature_dim
        self.action_dim = action_dim
        self.n_action_steps = n_action_steps
        self.n_obs_steps = n_obs_steps
        self.obs_as_global_cond = obs_as_global_cond
        self.kwargs = kwargs

        if num_inference_steps is None:
            num_inference_steps = noise_scheduler.config.num_train_timesteps
        self.num_inference_steps = num_inference_steps

        print("Diffusion params: %e" % sum(p.numel() for p in self.model.parameters()))
        print("Vision params: %e" % sum(p.numel() for p in self.obs_encoder.parameters()))

    # ========= inference  ============
    def conditional_sample(self,
                           condition_data, condition_mask,
                           local_cond=None, global_cond=None,
                           generator=None,
                           # keyword arguments to scheduler.step
                           **kwargs
                           ):
        model = self.model
        scheduler = self.noise_scheduler

        trajectory = torch.randn(
            size=condition_data.shape,
            dtype=condition_data.dtype,
            device=condition_data.device,
            generator=generator)

        # set step values
        scheduler.set_timesteps(self.num_inference_steps)

        # import time
        # c_t=time.time()
        for t in scheduler.timesteps:
            # 1. apply conditioning
            trajectory[condition_mask] = condition_data[condition_mask]

            # 2. predict model output
            model_output = model(trajectory, t,
                                 local_cond=local_cond, global_cond=global_cond)

            # 3. compute previous image: x_t -> x_t-1
            # {'cond_predict_scale': True, 'diffusion_step_embed_dim': 128, 'down_dims': [512, 1024, 2048], 'kernel_size': 5,
            #  'n_groups': 8, 'obs_as_global_cond': True}

            trajectory = scheduler.step(
                model_output, t, trajectory,
                generator=generator,
                **kwargs
            ).prev_sample
        # print('time:',time.time()-c_t)
        # finally make sure conditioning is enforced
        trajectory[condition_mask] = condition_data[condition_mask]

        return trajectory


    def guide_mse_conditional_sample_ddpm(self,
                                          condition_data, condition_mask,
                                          local_cond=None, global_cond=None,
                                          stu_traj=None, epoch=0,
                                          generator=None,
                                          # DDPM采样特有的噪声参数
                                          eta: float = 0.0,  # 0表示纯DDPM（固定方差），>0为噪声扰动
                                          **kwargs
                                          ):
        model = self.model
        scheduler = self.noise_scheduler  # 需确保是DDPM调度器（如DDPMScheduler）

        # 1. 初始化噪声轨迹（DDPM从纯噪声x_T开始）
        trajectory = torch.randn(
            size=condition_data.shape,
            dtype=condition_data.dtype,
            device=condition_data.device,
            generator=generator
        )

        # 2. 设置DDPM采样步骤（需与训练时的num_train_timesteps匹配）
        scheduler.set_timesteps(self.num_inference_steps)
        num_train_timesteps = scheduler.config.num_train_timesteps
        step_size = num_train_timesteps // self.num_inference_steps  # 推理步与训练步的映射比例

        # 3. 轨迹片段参数（与原逻辑一致：观测步、学生轨迹步）
        To = self.n_obs_steps
        start = To - 1
        stu_action_step = stu_traj.shape[1] if stu_traj is not None else 0
        end = start + stu_action_step

        # 4. DDPM反向采样循环（从T到1）
        for timestep in scheduler.timesteps:
            # --------------------------
            # 步骤1：强制条件约束（与原逻辑一致）
            # --------------------------
            trajectory[condition_mask] = condition_data[condition_mask]

            # --------------------------
            # 步骤2：预测模型输出（噪声ε_θ(x_t, t)）
            # --------------------------
            model_output = model(
                trajectory, timestep,
                local_cond=local_cond,
                global_cond=global_cond
            )
            # DDPM默认输出为噪声预测（ε_θ），若模型输出为均值需调整，此处假设输出为ε
            pred_epsilon = model_output

            # --------------------------
            # 步骤3：计算MSE引导梯度（核心：参考论文Algorithm 1）
            # --------------------------
            with torch.enable_grad():
                # 3.1 标记轨迹为可求导（仅目标片段stu_traj参与梯度计算）
                trajectory.requires_grad_(True)
                traj_target = trajectory[:, start:end]  # 待引导的轨迹片段

                # 3.2 计算MSE损失（引导目标：让采样轨迹逼近stu_traj）
                # 论文中使用log p(y|x_t)的梯度，此处替换为MSE损失的梯度（等价于最小化轨迹误差）
                mse_loss = F.mse_loss(traj_target, stu_traj, reduction='mean')

                # 3.3 计算损失对轨迹的梯度（∇x_t L_MSE）
                grad = torch.autograd.grad(
                    outputs=mse_loss,
                    inputs=trajectory,
                    create_graph=False,  # DDPM采样无需计算图，加速推理
                    retain_graph=False
                )[0]

                # 3.4 梯度掩码：仅保留目标片段的梯度，其余区域置0（避免干扰条件区域）
                grad_full = torch.zeros_like(trajectory)
                grad_full[:, start:end] = grad[:, start:end]  # 仅目标时间段生效

                # 3.5 释放梯度计算图（避免内存泄漏）
                trajectory.requires_grad_(False)

            # --------------------------
            # 步骤4：DDPM引导项计算（参考论文公式8-10）
            # --------------------------
            # 4.1 获取DDPM关键参数（α累积乘积、方差）
            # 映射推理timestep到训练timestep（如推理步=50，训练步=1000，则step_size=20）
            t = timestep * step_size

            prev_t = scheduler.previous_timestep(t)

            # if model_output.shape[1] == sample.shape[1] * 2 and scheduler.variance_type in ["learned", "learned_range"]:
            #     model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
            # else:
            predicted_variance = None

            sample=trajectory

            # 1. compute alphas, betas
            alpha_prod_t = scheduler.alphas_cumprod[t]
            alpha_prod_t_prev = scheduler.alphas_cumprod[prev_t] if prev_t >= 0 else scheduler.one
            beta_prod_t = 1 - alpha_prod_t
            beta_prod_t_prev = 1 - alpha_prod_t_prev
            current_alpha_t = alpha_prod_t / alpha_prod_t_prev
            current_beta_t = 1 - current_alpha_t

            # 2. compute predicted original sample from predicted noise also called
            # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
            if scheduler.config.prediction_type == "epsilon":
                pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
            elif scheduler.config.prediction_type == "sample":
                pred_original_sample = model_output
            elif scheduler.config.prediction_type == "v_prediction":
                pred_original_sample = (alpha_prod_t ** 0.5) * sample - (beta_prod_t ** 0.5) * model_output
            else:
                raise ValueError(
                    f"prediction_type given as {scheduler.config.prediction_type} must be one of `epsilon`, `sample` or"
                    " `v_prediction`  for the DDPMScheduler."
                )

            # 3. Clip or threshold "predicted x_0"
            if scheduler.config.thresholding:
                pred_original_sample = scheduler._threshold_sample(pred_original_sample)
            elif scheduler.config.clip_sample:
                pred_original_sample = pred_original_sample.clamp(
                    -scheduler.config.clip_sample_range, scheduler.config.clip_sample_range
                )

            # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
            # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
            pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t
            current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t

            # 5. Compute predicted previous sample µ_t
            # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
            pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample


            variance = scheduler._get_variance(t, variance_type='fixed_small')  # 对应论文中的β̃_t或固定方差
            sigma_theta = torch.sqrt(variance)  # Σ_θ的平方根（对角高斯的标准差）
            # sigma_theta=variance
            # 2. 梯度缩放系数（动态调整：随epoch线性增长，避免初期引导过强）
            # milestone = 100  # 梯度缩放的epoch阈值
            # grad_scalar = 50.0 * epoch / milestone if epoch < milestone else 50.0
            grad_scalar=10
            # 3. 计算最终引导项（均值偏移量）
            # 注：MSE损失的梯度方向与论文中log p(y|x_t)相反，需加负号确保引导方向正确
            guide_term = grad_scalar * sigma_theta * (-grad_full)

            # 4. 应用引导项到均值（得到引导后的均值μ_guided）
            mu_guided = pred_prev_sample + guide_term

            # --------------------------
            # 步骤5：DDPM随机采样（添加噪声，区别于DDIM的确定性采样）
            # --------------------------
            # 5.1 生成采样噪声（DDPM必须，DDIM可无）
            if timestep > 0:
                if generator is not None:
                    noise = torch.randn(
                        trajectory.size(),
                        dtype=trajectory.dtype,
                        device=trajectory.device,
                        generator=generator
                    )
                else:
                    noise = torch.randn(
                        trajectory.size(),
                        dtype=trajectory.dtype,
                        device=trajectory.device
                    )
                # 可选：用eta调整噪声强度（eta=0为纯DDPM，eta>0增加随机性）
                # noise = eta * noise

            else:
                noise = torch.zeros_like(trajectory)  # 最后一步（t=0）无噪声

            # 5.2 计算x_{t-1}：N(μ_guided, Σ_θ)
            # 论文公式：x_{t-1} = μ_guided + sigma_theta * noise
            prev_sample = mu_guided + sigma_theta * noise

            # --------------------------
            # 步骤6：更新轨迹（x_t → x_{t-1}）
            # --------------------------
            trajectory = prev_sample

        # --------------------------
        # 最终步骤：再次强制条件约束（确保输出符合条件）
        # --------------------------
        trajectory[condition_mask] = condition_data[condition_mask]

        return trajectory
    def guide_mse_conditional_sample_ddim(self,
                                     condition_data, condition_mask,
                                     local_cond=None, global_cond=None,
                                     stu_traj=None, epoch=0,
                                     generator=None,
                                     # keyword arguments to scheduler.step
                                     **kwargs
                                     ):
        model = self.model
        scheduler = self.noise_scheduler

        trajectory = torch.randn(
            size=condition_data.shape,
            dtype=condition_data.dtype,
            device=condition_data.device,
            generator=generator)

        # set step values
        scheduler.set_timesteps(self.num_inference_steps)

        # import time
        # c_t=time.time()

        To = self.n_obs_steps
        start = To - 1
        stu_action_step = stu_traj.shape[1]
        # end = start + stu_action_step
        end = start + stu_action_step

        for timestep in scheduler.timesteps:
            # 1. apply conditioning
            trajectory[condition_mask] = condition_data[condition_mask]

            # 2. predict model output
            model_output = model(trajectory, timestep,
                                 local_cond=local_cond, global_cond=global_cond)

            # Note: Guide gradient should be added into the model_output, see formulas in algorithm(2) in https://arxiv.org/pdf/2105.05233

            # 1. get previous step value (=t-1)
            prev_timestep = timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps

            # 2. compute alphas, betas
            alpha_prod_t = scheduler.alphas_cumprod[timestep]
            alpha_prod_t_prev = scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.final_alpha_cumprod

            beta_prod_t = 1 - alpha_prod_t
            sample = trajectory

            with torch.enable_grad():
                sample.requires_grad_()
                traj = sample[:, start:end]

                dist = F.mse_loss(traj, stu_traj, reduction='mean')

                grad = torch.autograd.grad([dist.sum()], [sample])[0]


                # 创建全零梯度，只填充目标区域
                grad_full = torch.zeros_like(sample)
                grad_full[:, start:end] = grad[:,start:end]

                gradient_guied = beta_prod_t ** (0.5) * grad_full

                pred_epsilon = model_output - gradient_guied*25
                # pred_epsilon = model_output - gradient_guied * grad_scalar
                pred_epsilon.detach()
            eta = 0
            variance = scheduler._get_variance(timestep, prev_timestep)
            std_dev_t = eta * variance ** (0.5)

            pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t ** 2) ** (0.5) * pred_epsilon

            pred_original_sample = (sample - beta_prod_t ** (0.5) * pred_epsilon) / alpha_prod_t ** (0.5)

            # 4. Clip or threshold "predicted x_0"
            if scheduler.config.thresholding:
                pred_original_sample = scheduler._threshold_sample(pred_original_sample)
            elif scheduler.config.clip_sample:
                pred_original_sample = pred_original_sample.clamp(
                    -scheduler.clip_sample_range, scheduler.clip_sample_range
                )

            prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction

            trajectory = prev_sample

            #
        # print('time:',time.time()-c_t)
        # finally make sure conditioning is enforced
        trajectory[condition_mask] = condition_data[condition_mask]

        return trajectory

    def predict_action(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        obs_dict: must include "obs" key
        result: must include "action" key
        """
        assert 'past_action' not in obs_dict  # not implemented yet
        # normalize input
        nobs = self.normalizer.normalize(obs_dict)
        value = next(iter(nobs.values()))
        B, To = value.shape[:2]
        T = self.horizon
        Da = self.action_dim
        Do = self.obs_feature_dim
        To = self.n_obs_steps

        # build input
        device = self.device
        dtype = self.dtype

        # handle different ways of passing observation
        local_cond = None
        global_cond = None
        if self.obs_as_global_cond:
            # condition through global feature
            this_nobs = dict_apply(nobs, lambda x: x[:, :To, ...].reshape(-1, *x.shape[2:]))
            nobs_features = self.obs_encoder(this_nobs)
            # reshape back to B, Do
            global_cond = nobs_features.reshape(B, -1)
            # empty data for action
            cond_data = torch.zeros(size=(B, T, Da), device=device, dtype=dtype)
            cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
        else:
            # condition through impainting
            this_nobs = dict_apply(nobs, lambda x: x[:, :To, ...].reshape(-1, *x.shape[2:]))
            nobs_features = self.obs_encoder(this_nobs)
            # reshape back to B, To, Do
            nobs_features = nobs_features.reshape(B, To, -1)
            cond_data = torch.zeros(size=(B, T, Da + Do), device=device, dtype=dtype)
            cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
            cond_data[:, :To, Da:] = nobs_features
            cond_mask[:, :To, Da:] = True

        # run sampling
        nsample = self.conditional_sample(
            cond_data,
            cond_mask,
            local_cond=local_cond,
            global_cond=global_cond,
            **self.kwargs)

        # unnormalize prediction
        naction_pred = nsample[..., :Da]
        action_pred = self.normalizer['action'].unnormalize(naction_pred)
        # action_pred    batch, self.n_action_steps,action_dim
        # get action
        start = To - 1
        end = start + self.n_action_steps
        action = action_pred[:, start:end]

        result = {
            'action': action,
            'action_pred': action_pred,
            'obs_features': nobs_features
        }
        return result

    def predict_action_mse_guide(self, obs_dict: Dict[str, torch.Tensor], stu_sample, epoch=1) -> Dict[str, torch.Tensor]:
        """
        obs_dict: must include "obs" key
        result: must include "action" key
        """
        assert 'past_action' not in obs_dict  # not implemented yet
        # normalize input
        nobs = self.normalizer.normalize(obs_dict)
        value = next(iter(nobs.values()))
        B, To = value.shape[:2]
        T = self.horizon
        Da = self.action_dim
        Do = self.obs_feature_dim
        To = self.n_obs_steps

        # build input
        device = self.device
        dtype = self.dtype

        # handle different ways of passing observation
        local_cond = None
        global_cond = None

        if self.obs_as_global_cond:
            # condition through global feature
            this_nobs = dict_apply(nobs, lambda x: x[:, :To, ...].reshape(-1, *x.shape[2:]))
            nobs_features = self.obs_encoder(this_nobs)
            # reshape back to B, Do
            global_cond = nobs_features.reshape(B, -1)
            # empty data for action
            cond_data = torch.zeros(size=(B, T, Da), device=device, dtype=dtype)
            cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
        else:
            # condition through impainting
            this_nobs = dict_apply(nobs, lambda x: x[:, :To, ...].reshape(-1, *x.shape[2:]))
            nobs_features = self.obs_encoder(this_nobs)
            # reshape back to B, To, Do
            nobs_features = nobs_features.reshape(B, To, -1)
            cond_data = torch.zeros(size=(B, T, Da + Do), device=device, dtype=dtype)
            cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
            cond_data[:, :To, Da:] = nobs_features
            cond_mask[:, :To, Da:] = True

        # run sampling
        nsample = self.guide_mse_conditional_sample_ddpm(
            cond_data,
            cond_mask,
            local_cond=local_cond,
            global_cond=global_cond,
            stu_traj=stu_sample,
            epoch=epoch, **self.kwargs)

        # unnormalize prediction
        naction_pred = nsample[..., :Da]
        action_pred = self.normalizer['action'].unnormalize(naction_pred)

        # get action
        # start = To - 1
        # end = start + s_policy_horizon
        # action = action_pred[:, start:end]

        start = To - 1
        end = start + self.n_action_steps
        action = action_pred[:, start:end]

        result = {
            'action': action,
            'action_pred': action_pred,
            'obs_features': nobs_features
        }
        return result

    # ========= training  ============
    def set_normalizer(self, normalizer: LinearNormalizer):
        self.normalizer.load_state_dict(normalizer.state_dict())

    def compute_loss(self, batch):
        # normalize input
        assert 'valid_mask' not in batch
        nobs = self.normalizer.normalize(batch['obs'])
        nactions = self.normalizer['action'].normalize(batch['action'])
        batch_size = nactions.shape[0]
        horizon = nactions.shape[1]

        # handle different ways of passing observation
        local_cond = None
        global_cond = None
        trajectory = nactions
        cond_data = trajectory
        if self.obs_as_global_cond:
            # reshape B, T, ... to B*T
            this_nobs = dict_apply(nobs,
                                   lambda x: x[:, :self.n_obs_steps, ...].reshape(-1, *x.shape[2:]))
            nobs_features = self.obs_encoder(this_nobs)
            # reshape back to B, Do
            global_cond = nobs_features.reshape(batch_size, -1)
            output_obs_features = global_cond

        else:
            # reshape B, T, ... to B*T
            this_nobs = dict_apply(nobs, lambda x: x.reshape(-1, *x.shape[2:]))
            nobs_features = self.obs_encoder(this_nobs)
            # reshape back to B, T, Do
            nobs_features = nobs_features.reshape(batch_size, horizon, -1)
            cond_data = torch.cat([nactions, nobs_features], dim=-1)
            trajectory = cond_data.detach()

        # generate impainting mask
        condition_mask = self.mask_generator(trajectory.shape)

        # Sample noise that we'll add to the images
        noise = torch.randn(trajectory.shape, device=trajectory.device)
        bsz = trajectory.shape[0]
        # Sample a random timestep for each image
        timesteps = torch.randint(
            0, self.noise_scheduler.config.num_train_timesteps,
            (bsz,), device=trajectory.device
        ).long()
        # Add noise to the clean images according to the noise magnitude at each timestep
        # (this is the forward diffusion process)
        noisy_trajectory = self.noise_scheduler.add_noise(
            trajectory, noise, timesteps)

        # compute loss mask
        loss_mask = ~condition_mask

        # apply conditioning
        noisy_trajectory[condition_mask] = cond_data[condition_mask]

        # Predict the noise residual
        pred = self.model(noisy_trajectory, timesteps,
                          local_cond=local_cond, global_cond=global_cond)

        pred_type = self.noise_scheduler.config.prediction_type
        if pred_type == 'epsilon':
            target = noise
        elif pred_type == 'sample':
            target = trajectory
        else:
            raise ValueError(f"Unsupported prediction type {pred_type}")

        loss = F.mse_loss(pred, target, reduction='none')

        # loss_mask = ~condition_mask

        loss = loss * loss_mask.type(loss.dtype)

        loss = reduce(loss, 'b ... -> b (...)', 'mean')
        loss = loss.mean()
        return loss, pred, output_obs_features



