import os
from typing import Dict
import math

import yaml
from dino_wm.planner import Planner
from dino_wm.planner_pusht import PushTPlanner
import hydra
import cv2
from diffusion_policy.dataset.base_dataset import BaseImageDataset
from diffusion_policy.dataset.robomimic_image_dynamics_dataset import RobomimicMultiStepWithHistoryImageDynamicsModelDataset
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.optim as optim

from einops import rearrange, reduce
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler

from diffusion_policy.model.common.normalizer import LinearNormalizer
from diffusion_policy.policy.base_image_policy import BaseImagePolicy
from diffusion_policy.model.diffusion.conditional_unet1d import ConditionalUnet1D
from diffusion_policy.model.diffusion.mask_generator import LowdimMaskGenerator
from diffusion_policy.common.robomimic_config_util import get_robomimic_config
from robomimic.algo import algo_factory
from robomimic.algo.algo import PolicyAlgo
import robomimic.utils.obs_utils as ObsUtils
import robomimic.models.obs_core as rmbn
import diffusion_policy.model.vision.crop_randomizer as dmvc
from diffusion_policy.common.pytorch_util import dict_apply, replace_submodules
from diffusion_policy.dataset.pusht_image_dynamics_dataset import DynamicsModelDataset, MultiStepDynamicsModelDataset, MultiStepImageDynamicsModelDataset, MultiStepWithHistoryImageDynamicsModelDataset
from diffusion_policy.model.OOD.knn import KNN_torch, MahalanobisOODModule
from diffusion_policy.model.OOD.svm import SVMOODModule, train_ocsvm
# from diffusion_policy.model.dynamics.dynamics_model import CVAE, HybridDynamicsModel, ImageBasedDynamicsModel, StateBasedDynamicsModel, StateBasedEncoderDecoderDynamicsModel, StateEncodeBasedDynamicsModel
# from diffusion_policy.model.dynamics.dynamics_model_transformer import FullModel, FullModelHistory
# from vit_dino_wm_skip_multiview import WorldModel

def boundary_penalty(action, lower_bound=-1.0, upper_bound=1.0):
    penalty = torch.relu(action - upper_bound) + torch.relu(lower_bound - action)
    return penalty.sum()


class DiffusionUnetHybridImagePolicy(BaseImagePolicy):
    def __init__(self, 
            shape_meta: dict,
            noise_scheduler: DDPMScheduler,
            horizon, 
            n_action_steps, 
            n_obs_steps,
            num_inference_steps=None,
            obs_as_global_cond=True,
            crop_shape=(76, 76),
            diffusion_step_embed_dim=256,
            down_dims=(256,512,1024),
            kernel_size=5,
            n_groups=8,
            cond_predict_scale=True,
            obs_encoder_group_norm=False,
            eval_fixed_crop=False,
            # parameters passed to step
            **kwargs):
        super().__init__()

        # parse shape_meta
        action_shape = shape_meta['action']['shape']
        assert len(action_shape) == 1
        action_dim = action_shape[0]
        obs_shape_meta = shape_meta['obs']
        obs_config = {
            'low_dim': [],
            'rgb': [],
            'depth': [],
            'scan': []
        }
        obs_key_shapes = dict()
        for key, attr in obs_shape_meta.items():
            shape = attr['shape']
            obs_key_shapes[key] = list(shape)

            type = attr.get('type', 'low_dim')
            if type == 'rgb':
                obs_config['rgb'].append(key)
            elif type == 'low_dim':
                obs_config['low_dim'].append(key)
            else:
                raise RuntimeError(f"Unsupported obs type: {type}")

        # get raw robomimic config
        config = get_robomimic_config(
            algo_name='bc_rnn',
            hdf5_type='image',
            task_name='square',
            dataset_type='ph')
        
        with config.unlocked():
            # set config with shape_meta
            config.observation.modalities.obs = obs_config

            if crop_shape is None:
                for key, modality in config.observation.encoder.items():
                    if modality.obs_randomizer_class == 'CropRandomizer':
                        modality['obs_randomizer_class'] = None
            else:
                # set random crop parameter
                ch, cw = crop_shape
                for key, modality in config.observation.encoder.items():
                    if modality.obs_randomizer_class == 'CropRandomizer':
                        modality.obs_randomizer_kwargs.crop_height = ch
                        modality.obs_randomizer_kwargs.crop_width = cw

        # init global state
        ObsUtils.initialize_obs_utils_with_config(config)

        # load model
        policy: PolicyAlgo = algo_factory(
                algo_name=config.algo_name,
                config=config,
                obs_key_shapes=obs_key_shapes,
                ac_dim=action_dim,
                device='cpu',
            )

        obs_encoder = policy.nets['policy'].nets['encoder'].nets['obs']
        
        if obs_encoder_group_norm:
            # replace batch norm with group norm
            replace_submodules(
                root_module=obs_encoder,
                predicate=lambda x: isinstance(x, nn.BatchNorm2d),
                func=lambda x: nn.GroupNorm(
                    num_groups=x.num_features//16, 
                    num_channels=x.num_features)
            )
            # obs_encoder.obs_nets['agentview_image'].nets[0].nets
        
        # obs_encoder.obs_randomizers['agentview_image']
        if eval_fixed_crop:
            replace_submodules(
                root_module=obs_encoder,
                predicate=lambda x: isinstance(x, rmbn.CropRandomizer),
                func=lambda x: dmvc.CropRandomizer(
                    input_shape=x.input_shape,
                    crop_height=x.crop_height,
                    crop_width=x.crop_width,
                    num_crops=x.num_crops,
                    pos_enc=x.pos_enc
                )
            )

        # create diffusion model
        obs_feature_dim = obs_encoder.output_shape()[0]
        input_dim = action_dim + obs_feature_dim
        global_cond_dim = None
        if obs_as_global_cond:
            input_dim = action_dim
            global_cond_dim = obs_feature_dim * n_obs_steps

        model = ConditionalUnet1D(
            input_dim=input_dim,
            local_cond_dim=None,
            global_cond_dim=global_cond_dim,
            diffusion_step_embed_dim=diffusion_step_embed_dim,
            down_dims=down_dims,
            kernel_size=kernel_size,
            n_groups=n_groups,
            cond_predict_scale=cond_predict_scale
        )

        self.obs_encoder = obs_encoder
        self.model = model
        self.noise_scheduler = noise_scheduler
        self.mask_generator = LowdimMaskGenerator(
            action_dim=action_dim,
            obs_dim=0 if obs_as_global_cond else obs_feature_dim,
            max_n_obs_steps=n_obs_steps,
            fix_obs_steps=True,
            action_visible=False
        )
        self.normalizer = LinearNormalizer()
        self.dynamics_model_normalizer = LinearNormalizer()
        self.horizon = horizon
        self.obs_feature_dim = obs_feature_dim
        self.action_dim = action_dim
        self.n_action_steps = n_action_steps
        self.n_obs_steps = n_obs_steps
        self.obs_as_global_cond = obs_as_global_cond
        self.kwargs = kwargs
        if num_inference_steps is None:
            num_inference_steps = noise_scheduler.config.num_train_timesteps
        self.num_inference_steps = num_inference_steps

        print("Diffusion params: %e" % sum(p.numel() for p in self.model.parameters()))
        print("Vision params: %e" % sum(p.numel() for p in self.obs_encoder.parameters()))
    
    # deprecated
    def set_ood_quantification(self, 
                               demo_data_file_path, 
                               demo_data_val_ratio, 
                               play_data_file_path,
                               dynamics_model_type, 
                               dynamics_model_ckpt, 
                               ood_method='vanila_nn'):
        dynamics_dataset = DynamicsModelDataset(play_data_file_path, val_ratio=0.0, random=False)
        demo_dataset = DynamicsModelDataset(demo_data_file_path, val_ratio=demo_data_val_ratio, random=False)
        # for convenience
        self.normalizer = demo_dataset.get_normalizer()
        self.normalizer = self.normalizer.to('cuda') if torch.cuda.is_available() else self.normalizer
        train_loader = DataLoader(demo_dataset, batch_size=32, shuffle=False, num_workers=4)
        train_latents = []
        images = []
        normalizer = dynamics_dataset.get_normalizer()
        print('Loaded dynamics_model_normalizer max', normalizer.params_dict['state']['input_stats']['max'])
        print('Loaded dynamics_model_normalizer min', normalizer.params_dict['state']['input_stats']['min'])
        # exit()
        self.dynamics_model_normalizer = normalizer.to('cuda') if torch.cuda.is_available() else normalizer
        # self.dynamics_model_type = 'state_encoder_decoder'
        directory_path = os.path.dirname(dynamics_model_ckpt)
        dynamics_config_path = os.path.join(directory_path, 'config.yaml')
        # dynamics_config_path = dynamics_model_ckpt.replace('best_val_model.pth', 'config.yaml')
        print('dynamics_config_path ', dynamics_config_path)
        if not os.path.exists(dynamics_config_path):
            raise FileNotFoundError(f"{dynamics_config_path} does not exist!")
        with open(dynamics_config_path, 'r') as f:
            print('debug here')
            dynamics_config = yaml.safe_load(f)
            print('debug here ', dynamics_config)
        hidden_dim = dynamics_config.get('_content', {}).get('value', {}).get('hidden_dim', None)
        name = dynamics_config.get('_content', {}).get('value', {}).get('wandb_dir', None)
        hidden_dim = int(hidden_dim)

        self.dynamics_model_type = dynamics_model_type
        if self.dynamics_model_type == 'image':
            self.dynamics_model = ImageBasedDynamicsModel(action_dim=2, latent_size=32).to('cuda' if torch.cuda.is_available() else 'cpu')
        elif self.dynamics_model_type  == 'state':
            self.dynamics_model = StateBasedDynamicsModel(state_dim=6, action_dim=2, hidden_dim=32).to('cuda' if torch.cuda.is_available() else 'cpu')
        elif self.dynamics_model_type == 'state_encoder_decoder':
            self.dynamics_model = StateBasedEncoderDecoderDynamicsModel(state_dim=6, action_dim=2, hidden_dim=32).to('cuda' if torch.cuda.is_available() else 'cpu')
        elif self.dynamics_model_type  == 'state_encoder':
            # currently i'm using this branch
            if 'vae' not in name:
                self.dynamics_model = StateEncodeBasedDynamicsModel(state_dim=6, action_dim=2, hidden_dim=hidden_dim).to('cuda' if torch.cuda.is_available() else 'cpu')
            else:
                self.dynamics_model = CVAE(state_dim=6, action_dim=2, latent_dim=hidden_dim).to('cuda' if torch.cuda.is_available() else 'cpu')
        
        elif self.dynamics_model_type  == 'hybrid':
            self.dynamics_model = HybridDynamicsModel(state_dim=6, action_dim=2, latent_size=32).to('cuda' if torch.cuda.is_available() else 'cpu')
        else:
            raise ValueError(f"Invalid model_type: {self.dynamics_model_type}")
        
        # state based
        checkpoint_path = "outputs/2024-10-13/00-42-44/wandb/latest-run/files/best_val_model.pth"  # Example path
        # image based (for debugging) 
        # checkpoint_path = 'outputs/2024-10-10/19-15-31/wandb/run-20241010_191535-i9f7n1q7/files/best_train_model.pth'
        checkpoint = torch.load(dynamics_model_ckpt, map_location='cuda')
        self.dynamics_model.load_state_dict(checkpoint['model_state_dict'])
        with torch.no_grad():
            for batch_idx, batch in enumerate(train_loader):
                img_t = batch['o_t']['image'].to(device='cuda')
                images.append(img_t)
                img_t = self.dynamics_model_normalizer['image'].normalize(img_t)

                state_t = batch['o_t']['state'].to(device='cuda')
                state_t = self.dynamics_model_normalizer['state'].normalize(state_t)

                z_t = None
                if self.dynamics_model_type == 'image':
                    z_t = self.dynamics_model.encode(img_t)
                elif 'state' in self.dynamics_model_type:
                    z_t = self.dynamics_model.encode(state_t)
                elif self.dynamics_model_type  == 'hybrid':
                    z_t = self.dynamics_model.encode(img_t, state_t)
                else:
                    raise ValueError(f"Invalid model_type: {self.dynamics_model_type}")
                
                if type(z_t) == tuple:
                    z_t = z_t[0]
                train_latents.append(z_t)
                # ########## saving image ##########
                # Unnormalize (if needed) and convert to PIL image format
                # folder_path = 'visualizations/image/'
                # os.makedirs(folder_path, exist_ok=True)
                # for i in range(img_t.size(0)):
                #     image = img_t[i]  # Select the i-th image in the batch
                #     save_path = os.path.join(folder_path, f'image_{i}.png')
                #     save_image(image, save_path)
                #     print(f'Saved {save_path}')
                # ########## saving finished ##########
                # ########## check model loading ##########
                # criterion = nn.MSELoss()    
                # a_t = batch['action'].to(device='cuda')
                # a_t = self.dynamics_model_normalizer['action'].normalize(a_t)
                # z_t1_pred, z_t1_true = None, None
                # if self.dynamics_model_type == 'image':
                #     img_t1 = batch['o_t1']['image'].to(device='cuda')
                #     img_t1 = self.dynamics_model_normalizer['image'].normalize(img_t1)
                #     z_t, z_t1_pred = self.dynamics_model(img_t, a_t)
                #     z_t1_true = self.dynamics_model.encode(img_t1)
                # elif self.dynamics_model_type  == 'state':
                #     state_t1 = batch['o_t1']['state'].to(device='cuda')
                #     state_t1 = self.dynamics_model_normalizer['state'].normalize(state_t1)
                #     z_t, z_t1_pred = self.dynamics_model(state_t, a_t)
                #     z_t1_true = self.dynamics_model.encode(state_t1)
                # elif self.dynamics_model_type  == 'hybrid':
                #     img_t1 = batch['o_t1']['image'].to(device='cuda')
                #     img_t1 = self.dynamics_model_normalizer['image'].normalize(img_t1)
                #     state_t1 = batch['o_t1']['state'].to(device='cuda')
                #     state_t1 = self.dynamics_model_normalizer['state'].normalize(state_t1)
                #     z_t, z_t1_pred = self.dynamics_model(img_t, state_t, a_t)
                #     z_t1_true = self.dynamics_model.encode(img_t1, state_t1)

                # # Compute loss
                # loss = criterion(z_t1_pred, z_t1_true)
                # print('loss ', loss)
                # ########## check finished ##########
        train_latents = torch.cat(train_latents, dim=0)  # Combine all latents into a single tensor
        images = torch.cat(images, dim=0)  # Combine all images into a single tensor

        self.ood_module = KNN_torch(train_latents=train_latents, images=images, method=ood_method).to('cuda')
        # elif method == 'mahalanobis':
        #     mean_vector = train_latents.mean(dim=0)
        #     covariance_matrix = torch.cov(train_latents.T)
        #     covariance_matrix += torch.eye(covariance_matrix.size(0)).to(device='cuda') * 1e-6
        #     self.ood_module = MahalanobisOODModule(mean_vector, covariance_matrix).to('cuda')

        self.dynamics_model.eval()


    def set_ood_quantification_h_step(self, 
                               demo_data_file_path, 
                               demo_data_val_ratio,
                               dynamics_model_ckpt, 
                               use_embed=False,
                               ood_method='vanila_nn',
                               temp=0.1,
                               use_history=False,
                               early_stop=False,
                               early_stop_threshold=0.08):

        directory_path = os.path.dirname(dynamics_model_ckpt)
        dynamics_config_path = os.path.join(directory_path, 'config.yaml')
        if not os.path.exists(dynamics_config_path):
            raise FileNotFoundError(f"{dynamics_config_path} does not exist!")
        with open(dynamics_config_path, 'r') as f:
            dynamics_config = yaml.safe_load(f)

        dataset_param = dynamics_config.get('dataset', {}).get('value', {})
        hyper = dynamics_config.get('hyperparameters', {}).get('value', {})
        training_cfg = dynamics_config.get('training', {}).get('value', {})
        if not use_history:
            dynamics_dataset = MultiStepImageDynamicsModelDataset(dataset_param['train_zarr_path'], horizon=hyper['horizon'], val_ratio=0.0)
            # NOTE: change horizon to 1 for normal exe, 8 is for debugging and visualizing pred errors
            demo_dataset = MultiStepImageDynamicsModelDataset(demo_data_file_path, horizon=1, val_ratio=demo_data_val_ratio)
        else:
            dynamics_dataset = MultiStepWithHistoryImageDynamicsModelDataset(dataset_param['train_zarr_path'], horizon=hyper['horizon'], val_ratio=0.0, n_neg=dataset_param['n_neg'], neg_sampling=dataset_param['neg_sampling'])
            demo_dataset = MultiStepWithHistoryImageDynamicsModelDataset(demo_data_file_path, horizon=1, val_ratio=demo_data_val_ratio)

        # for convenience
        self.normalizer = demo_dataset.get_normalizer()
        self.normalizer = self.normalizer.to('cuda') if torch.cuda.is_available() else self.normalizer
        train_loader = DataLoader(demo_dataset, batch_size=32, shuffle=False, num_workers=4)
        train_latents = []
        images = []
        normalizer = dynamics_dataset.get_normalizer()
        print('Loaded dynamics_model_normalizer max', normalizer.params_dict['state']['input_stats']['max'])
        print('Loaded dynamics_model_normalizer min', normalizer.params_dict['state']['input_stats']['min'])
        # exit()
        self.dynamics_model_normalizer = normalizer.to('cuda') if torch.cuda.is_available() else normalizer
        del dynamics_dataset

        if not use_history:
            self.dynamics_model = FullModel(
                latent_dim=hyper['latent_dim'],
                action_dim=hyper['action_dim'],
                horizon=hyper['horizon'],
                n_layer=hyper['n_layer'],
                n_head=hyper['n_head'],
                n_emb=hyper['n_emb'],
                p_drop_emb=hyper['p_drop_emb'],
                p_drop_attn=hyper['p_drop_attn'],
                causal=hyper['causal'],
                pretrained_encoder=hyper['use_pretrained'],  # Use pretrained ResNet-18,
                backbone_frozen=hyper['backbone_frozen'],
                project_to_latent=hyper.get('project_to_latent', False),
                decode=hyper.get('decode', False),
                interleave=hyper.get('interleave', False), 
                pred_action=hyper.get('pred_action', False),
            ).to('cuda')
        else:
            self.dynamics_model = FullModelHistory(
                latent_dim=hyper['latent_dim'],
                action_dim=hyper['action_dim'],
                horizon=hyper['horizon'],
                n_layer=hyper['n_layer'],
                n_head=hyper['n_head'],
                n_emb=hyper['n_emb'],
                p_drop_emb=hyper['p_drop_emb'],
                p_drop_attn=hyper['p_drop_attn'],
                causal=hyper['causal'],
                pretrained_encoder=hyper['use_pretrained'],  # Use pretrained ResNet-18,
                backbone_frozen=hyper['backbone_frozen'],
                decode=hyper.get('decode', False),
                history_encoder_only=hyper.get('history_encoder_only', False),
            ).to('cuda')

        checkpoint = torch.load(dynamics_model_ckpt, map_location='cuda')
        self.dynamics_model.load_state_dict(checkpoint['model_state_dict'])
        self.dynamics_model.eval()

        total_dynamics_loss = 0.0
        with torch.no_grad():
            for batch_idx, batch in enumerate(train_loader):
                # NOTE: this is debugging code
                # if batch_idx == 1:
                #     break
                if not use_history:
                    img_t = batch['img_t'].to(device='cuda')
                    images.append(img_t.cpu())
                    # print('img_t range before normalization ', img_t.min(), img_t.max())
                    img_t = self.dynamics_model_normalizer['image'].normalize(img_t)
                    z_t = self.dynamics_model.encode_obs(img_t)
                    if use_embed:
                        z_t = self.dynamics_model.dynamics_model.state_emb(z_t)
                    ######## sanity check ########
                    # mse_loss = nn.MSELoss()
                    # actions = batch['actions_seq'].to(device='cuda')      # (batch_size, h, action_dim)
                    # # print('actions range before normalization ', actions.min(), actions.max())
                    # actions = self.dynamics_model_normalizer['action'].normalize(actions)
                    # # print('img_t range after normalization ', img_t.min(), img_t.max())
                    # # print('actions range after normalization ', actions.min(), actions.max())
                    # z_t, z_hat_future, _, _ = self.dynamics_model(img_t, actions)
                    # img_future = batch['img_future_seq'].to(device='cuda')   # (batch_size, h, 3, 96, 96)
                    # img_future = self.dynamics_model_normalizer['image'].normalize(img_future)
                    # z_future = self.dynamics_model.z_future(img_future)
                    # if training_cfg['dyn_loss_type'] == 'last':
                    #     dynamics_loss = mse_loss(z_hat_future[:, -1, :], z_future[:, -1, :])
                    # elif training_cfg['dyn_loss_type'] == 'all':
                    #     dynamics_loss = mse_loss(z_hat_future, z_future)
                    # # print('dynamics_loss ', dynamics_loss)
                    # total_dynamics_loss += dynamics_loss.item()
                    ######## sanity check finished ########
                else:
                    history_imgs = batch['history_imgs'].to(device='cuda')  # (batch_size, h+1, 3, 96, 96)
                    img_t = history_imgs[:, -1, ...]
                    images.append(img_t.cpu())
                    img_t = self.dynamics_model_normalizer['image'].normalize(img_t)
                    z_t = self.dynamics_model.encode_obs(img_t)
                    
                    ######## sanity check ########
                    # mse_loss = nn.MSELoss()
                    # history_actions = batch['history_actions'].to(device='cuda')  # (batch_size, h, action_dim)
                    # future_actions = batch['future_actions'].to(device='cuda')  # (batch_size, h, action_dim)
                    # future_imgs = batch['future_imgs'].to(device='cuda')  # (batch_size, h, 3, 96, 96)
                    # history_imgs = self.dynamics_model_normalizer['image'](history_imgs)
                    # history_actions = self.dynamics_model_normalizer['action'](history_actions)
                    # future_actions = self.dynamics_model_normalizer['action'](future_actions)
                    # future_imgs = self.dynamics_model_normalizer['image'](future_imgs)
                    # z_history, z_hat_future, _ = self.dynamics_model(history_imgs, history_actions, future_actions)
                    # z_future = self.dynamics_model.z_future(future_imgs)
                    # dynamics_loss = mse_loss(z_hat_future, z_future)
                    # print('dynamics_loss ', dynamics_loss)
                    # total_dynamics_loss += dynamics_loss.item()
                    ######## sanity check finished ########

                train_latents.append(z_t)
        
        avg_dynamics_loss = total_dynamics_loss / len(train_loader)
        print('avg_dynamics_loss is ', avg_dynamics_loss)
        train_latents = torch.cat(train_latents, dim=0)  # Combine all latents into a single tensor
        print('train_latents shape ', train_latents.size())
        images = torch.cat(images, dim=0)  # Combine all images into a single tensor
        # exit()
        self.ood_module = KNN_torch(train_latents=train_latents, 
                                    images=images, 
                                    method=ood_method, 
                                    temp=temp, 
                                    cont_loss=hyper['cont_loss'],
                                    ).to('cuda')
        self.early_stop = early_stop
        self.early_stop_threshold = early_stop_threshold

        del demo_dataset

    def set_ood_quantification_h_step_robomimic(self, 
                               demo_dataset_config,
                               demo_data_file_path, 
                               demo_data_val_ratio,
                               dynamics_model_ckpt, 
                               ood_method='vanila_nn',
                               temp=0.1,
                               use_history=False,
                               early_stop=False,
                               early_stop_threshold=0.08):

        directory_path = os.path.dirname(dynamics_model_ckpt)
        dynamics_config_path = os.path.join(directory_path, 'config.yaml')

        if not os.path.exists(dynamics_config_path):
            raise FileNotFoundError(f"{dynamics_config_path} does not exist!")
        with open(dynamics_config_path, 'r') as f:
            dynamics_config = yaml.safe_load(f)

        hyper = dynamics_config.get('hyperparameters', {}).get('value', {})

        demo_dataset: BaseImageDataset
        demo_dataset_config.dataset_path = 'data/robomimic/datasets/square/mh/demo_v141_224_success_nn_computation.hdf5'
        demo_dataset = hydra.utils.instantiate(demo_dataset_config)

        train_loader = DataLoader(demo_dataset, batch_size=32, shuffle=False, num_workers=4)
        train_latents = []
        images = []

        normalizer_path = os.path.join(os.path.dirname(dynamics_model_ckpt), 'normalizer.pth')
        dynamics_model_normalizer = LinearNormalizer()
        dynamics_model_normalizer.load_state_dict(torch.load(normalizer_path))
        self.dynamics_model_normalizer = dynamics_model_normalizer.to('cuda') if torch.cuda.is_available() else dynamics_model_normalizer

        self.dynamics_model = WorldModel(
            **hyper
        ).to('cuda')
        checkpoint = torch.load(dynamics_model_ckpt, map_location='cuda')
        self.dynamics_model.load_state_dict(checkpoint['model_state_dict'])
        self.dynamics_model.eval()

        total_dynamics_loss = 0.0
        with torch.no_grad():
            for batch_idx, batch in enumerate(train_loader):
                # NOTE: for debugging
                # if batch_idx == 1: break
                obs = batch['obs']
                obs = dict_apply(obs, lambda x: x[:, -1:, ...])
                img_t = {}
                img_t['agentview'] = obs['agentview_image'].to(device='cuda')
                img_t['robot0_eye_in_hand'] = obs['robot0_eye_in_hand_image'].to(device='cuda')

                for key in img_t.keys():
                    img_t[key] = img_t[key].to(device='cuda')
                img_t_cpu = {key: tensor.cpu() for key, tensor in img_t.items()}
                images.append(img_t_cpu)

                for key in img_t.keys():
                    img_t[key] = self.dynamics_model_normalizer['image'].normalize(img_t[key])
                z_t = self.dynamics_model.encode_obs(img_t)
                ######## sanity check ########
                # mse_loss = nn.MSELoss()
                # actions = batch['actions_seq'].to(device='cuda')      # (batch_size, h, action_dim)
                # # print('actions range before normalization ', actions.min(), actions.max())
                # actions = self.dynamics_model_normalizer['action'].normalize(actions)
                # # print('img_t range after normalization ', img_t.min(), img_t.max())
                # # print('actions range after normalization ', actions.min(), actions.max())
                # z_t, z_hat_future, _, _ = self.dynamics_model(img_t, actions)
                # img_future = batch['img_future_seq'].to(device='cuda')   # (batch_size, h, 3, 96, 96)
                # img_future = self.dynamics_model_normalizer['image'].normalize(img_future)
                # z_future = self.dynamics_model.z_future(img_future)
                # if training_cfg['dyn_loss_type'] == 'last':
                #     dynamics_loss = mse_loss(z_hat_future[:, -1, :], z_future[:, -1, :])
                # elif training_cfg['dyn_loss_type'] == 'all':
                #     dynamics_loss = mse_loss(z_hat_future, z_future)
                # # print('dynamics_loss ', dynamics_loss)
                # total_dynamics_loss += dynamics_loss.item()
                ######## sanity check finished ########

                train_latents.append(z_t.squeeze(1))
        
        # avg_dynamics_loss = total_dynamics_loss / len(train_loader)
        # print('avg_dynamics_loss is ', avg_dynamics_loss)
        train_latents = torch.cat(train_latents, dim=0)  # Combine all latents into a single tensor
        print('train_latents shape ', train_latents.size())
        images = {
            key: torch.cat([d[key] for d in images], dim=0)
            for key in images[0].keys()
        }
        for key in images.keys():
            print('images shape ', key, images[key].size())
        # exit()
        self.ood_module = KNN_torch(train_latents=train_latents, 
                                    images=images, 
                                    method=ood_method, 
                                    temp=temp, 
                                    cont_loss=hyper['cont_loss'],
                                    ).to('cuda')
        self.early_stop = early_stop
        self.early_stop_threshold = early_stop_threshold

        del demo_dataset

    def initialize_robomimic_planner(self, 
                                     demo_dataset_config,
                                     dynamics_model_ckpt,
                                     decoder_path,
                                     value_func_path,
                                     action_step,
                                     output_dir,
                                     method,
                                     guidance_start_timestep,
                                     guidance_scale,
                                     threshold):
        self.planner = Planner(demo_dataset_config, dynamics_model_ckpt, decoder_path, value_func_path, action_step, output_dir)
        if 'classifier_guidance' in method:
            self.guidance_start_timestep = guidance_start_timestep
            self.guidance_scale = guidance_scale
        self.planner.set_policy_action_normalizer(self.normalizer['action'])
        self.opt_method = method
        self.threshold = threshold
        self.ood_score_array = []

    def initialize_pusht_planner(self, 
                                     demo_dataset_config,
                                     dynamics_model_ckpt,
                                     decoder_path,
                                     value_func_path,
                                     action_step,
                                     output_dir,
                                     method,
                                     guidance_start_timestep,
                                     guidance_scale,
                                     threshold):
        self.planner = PushTPlanner(demo_dataset_config, dynamics_model_ckpt, decoder_path, value_func_path, action_step, output_dir)
        if 'classifier_guidance' in method:
            self.guidance_start_timestep = guidance_start_timestep
            self.guidance_scale = guidance_scale
        self.planner.set_policy_action_normalizer(self.normalizer['action'])
        self.opt_method = method
        self.threshold = threshold
        
    # ========= inference  ============
    def conditional_sample(self, 
            condition_data, condition_mask,
            local_cond=None, global_cond=None,
            generator=None,
            classifier_guidance=False,
            current_obs=None,
            frame_assembled_success=False,
            # keyword arguments to scheduler.step
            **kwargs
            ):
        model = self.model
        scheduler = self.noise_scheduler

        trajectory = torch.randn(
            size=condition_data.shape, 
            dtype=condition_data.dtype,
            device=condition_data.device,
            generator=generator)
    
        # set step values
        scheduler.set_timesteps(self.num_inference_steps)

        if classifier_guidance:
            current_cost = -1 * self.planner.compute_current_reward(current_obs)
            current_cost = current_cost.item()
            print('current_cost ', current_cost)
            # current_cost = 0
        
        for t in scheduler.timesteps:
            # 1. apply conditioning
            trajectory[condition_mask] = condition_data[condition_mask]

            # 2. predict model output
            model_output = model(trajectory, t, 
                local_cond=local_cond, global_cond=global_cond)

            if classifier_guidance and t < self.guidance_start_timestep and current_cost > self.threshold and not frame_assembled_success:
                grad = self.planner.cond_fn(trajectory, current_obs)
                # # artificial gradient sanity check
                # grad = torch.ones_like(trajectory)
                guidance_scale = self.guidance_scale
                grad_scale = guidance_scale * (1 - scheduler.alphas_cumprod[t]).sqrt()
                # print('grad_scale ', grad_scale)
                # if current_cost <= 50:
                #     grad_scale *= 0.001
                # if current_cost > 40:
                #     grad_scale = grad_scale * 0.5
                # if current_cost > 50:
                #     grad_scale = grad_scale * 0.5
                # if current_cost > 60:
                #     grad_scale = grad_scale * 0.5
                # if current_cost > 70:
                #     grad_scale = grad_scale * 0.5
                    
                model_output = model_output - grad_scale * grad

            # 3. compute previous image: x_t -> x_t-1
            trajectory = scheduler.step(
                model_output, t, trajectory, 
                generator=generator,
                **kwargs
                ).prev_sample

        # finally make sure conditioning is enforced
        trajectory[condition_mask] = condition_data[condition_mask]        

        return trajectory

    # ========= inference  ============
    def guided_conditional_sample(self, 
            condition_data, condition_mask,
            local_cond=None, global_cond=None,
            generator=None,
            classifier_guidance=False,
            current_obs=None,
            frame_assembled_success=False,
            # keyword arguments to scheduler.step
            **kwargs
            ):
        # print('variant2')
        model = self.model
        scheduler = self.noise_scheduler

        trajectory = torch.randn(
            size=condition_data.shape, 
            dtype=condition_data.dtype,
            device=condition_data.device,
            generator=generator)
    
        # set step values
        scheduler.set_timesteps(self.num_inference_steps)

        if classifier_guidance:
            current_cost = -1 * self.planner.compute_current_reward(current_obs)
            current_cost = current_cost.item()
            # self.ood_score_array.append(current_cost)
            print('current_cost ', current_cost)
            # current_cost = 0
        
        for t in scheduler.timesteps:
            # 1. apply conditioning
            trajectory[condition_mask] = condition_data[condition_mask]
            trajectory = trajectory.detach().requires_grad_()

            # 2. predict model output
            model_output = model(trajectory, t, 
                local_cond=local_cond, global_cond=global_cond)

            if classifier_guidance and t < self.guidance_start_timestep and current_cost > self.threshold:
                trajectory0 = scheduler.step(model_output, t, trajectory).pred_original_sample
                loss = self.planner.compute_loss(trajectory0, current_obs)
                cond_grad = -torch.autograd.grad(loss, trajectory)[0]
                guidance_scale = self.guidance_scale
                grad_scale = guidance_scale * (1 - scheduler.alphas_cumprod[t]).sqrt()
                # if frame_assembled_success:
                #     grad_scale *= 0.1
                print('cond_grad norm ', cond_grad.norm().item())
                print('grad_scale ', grad_scale)
                trajectory = trajectory.detach() + grad_scale * cond_grad

            # 3. compute previous image: x_t -> x_t-1
            trajectory = scheduler.step(
                model_output, t, trajectory, 
                generator=generator,
                **kwargs
                ).prev_sample

        # finally make sure conditioning is enforced
        trajectory[condition_mask] = condition_data[condition_mask]        

        return trajectory

    def predict_action(self, obs_dict: Dict[str, torch.Tensor], avoid_ood=False) -> Dict[str, torch.Tensor]:
        """
        obs_dict: must include "obs" key
        result: must include "action" key
        """
        assert 'past_action' not in obs_dict # not implemented yet
        # normalize input
        nobs = self.normalizer.normalize(obs_dict)
        value = next(iter(nobs.values()))
        B, To = value.shape[:2]
        T = self.horizon
        Da = self.action_dim
        Do = self.obs_feature_dim
        To = self.n_obs_steps

        # build input
        device = self.device
        dtype = self.dtype

        # handle different ways of passing observation
        local_cond = None
        global_cond = None
        if self.obs_as_global_cond:
            # condition through global feature
            this_nobs = dict_apply(nobs, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:]))
            nobs_features = self.obs_encoder(this_nobs)
            # reshape back to B, Do
            global_cond = nobs_features.reshape(B, -1)
            # empty data for action
            cond_data = torch.zeros(size=(B, T, Da), device=device, dtype=dtype)
            cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
        else:
            # condition through impainting
            this_nobs = dict_apply(nobs, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:]))
            nobs_features = self.obs_encoder(this_nobs)
            # reshape back to B, To, Do
            nobs_features = nobs_features.reshape(B, To, -1)
            cond_data = torch.zeros(size=(B, T, Da+Do), device=device, dtype=dtype)
            cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
            cond_data[:,:To,Da:] = nobs_features
            cond_mask[:,:To,Da:] = True

        # run sampling
        with torch.no_grad():
            nsample = self.guided_conditional_sample(
                cond_data, 
                cond_mask,
                local_cond=local_cond,
                global_cond=global_cond,
                current_obs=dict_apply(obs_dict, lambda x: x[:, -1:, ...]),
                **self.kwargs)
        # unnormalize prediction
        naction_pred = nsample[...,:Da]
        action_pred = self.normalizer['action'].unnormalize(naction_pred)

        # get action
        start = To - 1
        end = start + self.n_action_steps
        action = action_pred[:,start:end]
        
        result = {
            'action': action,
            'action_pred': action_pred
        }
        return result

    def predict_action_gd(self, obs_dict: Dict[str, torch.Tensor], avoid_ood=False, frame_assembled_success=False) -> Dict[str, torch.Tensor]:
        """
        obs_dict: must include "obs" key
        result: must include "action" key
        """
        assert 'past_action' not in obs_dict # not implemented yet
        # normalize input
        nobs = self.normalizer.normalize(obs_dict)
        value = next(iter(nobs.values()))
        B, To = value.shape[:2]
        T = self.horizon
        Da = self.action_dim
        Do = self.obs_feature_dim
        To = self.n_obs_steps

        # build input
        device = self.device
        dtype = self.dtype

        # handle different ways of passing observation
        local_cond = None
        global_cond = None
        if self.obs_as_global_cond:
            # condition through global feature
            this_nobs = dict_apply(nobs, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:]))
            nobs_features = self.obs_encoder(this_nobs)
            # reshape back to B, Do
            global_cond = nobs_features.reshape(B, -1)
            # empty data for action
            cond_data = torch.zeros(size=(B, T, Da), device=device, dtype=dtype)
            cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
        else:
            # condition through impainting
            this_nobs = dict_apply(nobs, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:]))
            nobs_features = self.obs_encoder(this_nobs)
            # reshape back to B, To, Do
            nobs_features = nobs_features.reshape(B, To, -1)
            cond_data = torch.zeros(size=(B, T, Da+Do), device=device, dtype=dtype)
            cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
            cond_data[:,:To,Da:] = nobs_features
            cond_mask[:,:To,Da:] = True

        # run sampling
        with torch.no_grad():
            nsample = self.conditional_sample(
                cond_data, 
                cond_mask,
                local_cond=local_cond,
                global_cond=global_cond,
                current_obs=dict_apply(obs_dict, lambda x: x[:, -1:, ...]),
                **self.kwargs)
        # unnormalize prediction
        naction_pred = nsample[...,:Da]
        action_pred = self.normalizer['action'].unnormalize(naction_pred)

        # get action
        start = To - 1
        end = start + self.n_action_steps
        action = action_pred[:,start:end]

        action = self.planner.optimize_gd(action, dict_apply(obs_dict, lambda x: x[:, -1:, ...]), frame_assembled_success=frame_assembled_success)
        result = {
            'action': action,
            'action_pred': action_pred
        }
        return result
    
    def predict_action_pusht(self, obs_dict: Dict[str, torch.Tensor], avoid_ood=False, h_step=False, optim_lr=1e-4, num_iters=20, use_embed=1, weight_decay=0.0, input_type='state', use_history=False, info=None) -> Dict[str, torch.Tensor]:
        """
        obs_dict: must include "obs" key
        result: must include "action" key
        """
        assert 'past_action' not in obs_dict # not implemented yet
        # normalize input
        nobs = self.normalizer.normalize(obs_dict)
        value = next(iter(nobs.values()))
        B, To = value.shape[:2]
        T = self.horizon
        Da = self.action_dim
        Do = self.obs_feature_dim
        To = self.n_obs_steps

        # build input
        device = self.device
        dtype = self.dtype

        # handle different ways of passing observation
        local_cond = None
        global_cond = None
        if self.obs_as_global_cond:
            # condition through global feature
            this_nobs = dict_apply(nobs, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:]))
            nobs_features = self.obs_encoder(this_nobs)
            # reshape back to B, Do
            global_cond = nobs_features.reshape(B, -1)
            # empty data for action
            cond_data = torch.zeros(size=(B, T, Da), device=device, dtype=dtype)
            cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
        else:
            # condition through impainting
            this_nobs = dict_apply(nobs, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:]))
            nobs_features = self.obs_encoder(this_nobs)
            # reshape back to B, To, Do
            nobs_features = nobs_features.reshape(B, To, -1)
            cond_data = torch.zeros(size=(B, T, Da+Do), device=device, dtype=dtype)
            cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
            cond_data[:,:To,Da:] = nobs_features
            cond_mask[:,:To,Da:] = True

        # run sampling
        with torch.no_grad():
            nsample = self.conditional_sample(
                cond_data, 
                cond_mask,
                local_cond=local_cond,
                global_cond=global_cond,
                **self.kwargs)
        # unnormalize prediction
        naction_pred = nsample[...,:Da]

        # get action
        start = To - 1
        end = start + self.n_action_steps
        action = naction_pred[:,start:end]

        def print_grad(grad, name):
            print(f"Gradient of {name}: {grad}")

        if avoid_ood:
            # normalize action with normalizer from trained dynamics model
            action_unormalized = self.normalizer['action'].unnormalize(action)
            action_copy = self.dynamics_model_normalizer['action'].normalize(action_unormalized).detach().clone()
            action_copy.requires_grad = True

            optimizer = optim.Adam([action_copy], lr=optim_lr, weight_decay=weight_decay)
            # optimizer = torch.optim.SGD([action_copy], lr=optim_lr, momentum=0.9, weight_decay=weight_decay)

            # normalize action with normalizer from trained dynamics model
            initial_obs_dict = dict_apply(obs_dict, lambda x: x[:, -1, ...])
            current_img = initial_obs_dict['image'].detach().cpu().numpy()

            initial_state = self.dynamics_model_normalizer['image'].normalize(initial_obs_dict['image'])
            initial_state = initial_state.detach()

            # if info is not None:
            #     print('info length ', len(info['image']))
            # num_iters = 20
            self.info = info
            if not avoid_ood:
                num_iters = 0
            total_loss_array = []

            if avoid_ood and use_history:
                if info is not None and info['image'].shape[0] >= 9:
                    np_info_dict = dict_apply(info, lambda x: np.expand_dims(x, axis=0))
                    info_dict = dict_apply(np_info_dict, lambda x: torch.from_numpy(x).to(device=device, dtype=dtype))
                    history_obs = info_dict['image'][:, -9:, ...]
                    history_action = info_dict['action'][:, -8:, ...]
                    history_obs = self.dynamics_model_normalizer['image'].normalize(history_obs)
                    history_action = self.dynamics_model_normalizer['action'].normalize(history_action)
                    history_obs = history_obs.detach()
                    history_action = history_action.detach()
                    
            for optim_iter in range(num_iters):
                optimizer.zero_grad()

                if not use_history:
                    _, predicted_latent_states, _, _ = self.dynamics_model(initial_state, action_copy)
                else:
                    if info is None or info['image'].shape[0] < 9:
                        break
                    _, predicted_latent_states, _ = self.dynamics_model(history_obs, history_action, action_copy)

                if use_embed:
                    embedded_predicted_latent_states = self.dynamics_model.dynamics_model.state_emb(predicted_latent_states[:, -1, :])
                    ood_score_last = self.ood_module(embedded_predicted_latent_states, current_img)
                    total_ood_score = ood_score_last.sum()
                else:
                    predicted_latent_states = predicted_latent_states.squeeze(0)
                    ood_score_last = self.ood_module(predicted_latent_states[-1, :].unsqueeze(0), current_img)
                    total_ood_score = ood_score_last.sum()

                if self.early_stop and \
                        len(total_loss_array) > 1 and \
                        total_loss_array[-1] < self.early_stop_threshold and \
                        total_loss_array[-2] < self.early_stop_threshold:
                    break

                # Register a hook on the final loss
                print(f"Total OOD score at step {optim_iter}: {total_ood_score}")

                # apply boundary penalty to updated action
                boundary_penalty_loss = boundary_penalty(action_copy)
                total_loss = total_ood_score + 0.1 * boundary_penalty_loss         
                total_loss.backward()
                assert (action_copy.requires_grad and action_copy.grad is not None)
                # print(f"Gradients of action_copy at step {optim_iter}: {action_copy.grad}")
                # print(f"Gradients of action_copy after backward: {action_copy.grad}")

                # torch.nn.utils.clip_grad_value_([action_copy], clip_value=1.0)

                optimizer.step()
                total_loss_array.append(total_ood_score.item())

                # if len(total_loss_array) > 2 and \
                #         abs(total_loss_array[-1] - total_loss_array[-2]) < 0.001:
                #     action_copy.data += torch.normal(mean=0, std=0.3, size=action_copy.shape).to(action_copy.device)

            action_pred = self.normalizer['action'].unnormalize(naction_pred)
            action_copy = self.dynamics_model_normalizer['action'].unnormalize(action_copy)

            result = {
                'action': action_copy,
                'action_pred': action_pred
            }
            return result
        else:
            action_pred = self.normalizer['action'].unnormalize(naction_pred)
            action = self.normalizer['action'].unnormalize(action)
            result = {
                'action': action,
                'action_pred': action_pred
            }
            return result

    def predict_action_shooting(self, obs_dict: Dict[str, torch.Tensor], state, eval_id, timestep, num_samples) -> Dict[str, torch.Tensor]:
        """
        obs_dict: must include "obs" key
        result: must include "action" key
        """
        assert 'past_action' not in obs_dict # not implemented yet
        # normalize input
        nobs = self.normalizer.normalize(obs_dict)
        value = next(iter(nobs.values()))
        B, To = value.shape[:2]
        T = self.horizon
        Da = self.action_dim
        Do = self.obs_feature_dim
        To = self.n_obs_steps

        # build input
        device = self.device
        dtype = self.dtype

        # handle different ways of passing observation
        local_cond = None
        global_cond = None
        if self.obs_as_global_cond:
            # condition through global feature
            this_nobs = dict_apply(nobs, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:]))
            nobs_features = self.obs_encoder(this_nobs)
            # reshape back to B, Do
            global_cond = nobs_features.reshape(B, -1)
            # empty data for action
            cond_data = torch.zeros(size=(B, T, Da), device=device, dtype=dtype)
            cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
        else:
            # condition through impainting
            this_nobs = dict_apply(nobs, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:]))
            nobs_features = self.obs_encoder(this_nobs)
            # reshape back to B, To, Do
            nobs_features = nobs_features.reshape(B, To, -1)
            cond_data = torch.zeros(size=(B, T, Da+Do), device=device, dtype=dtype)
            cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
            cond_data[:,:To,Da:] = nobs_features
            cond_mask[:,:To,Da:] = True

        # run sampling
        # print('cond_data ', cond_data.shape)
        # print('cond_mask ', cond_mask.shape)
        # print('global_cond ', global_cond.shape)

        # current_cost = -1 * self.planner.compute_current_reward(dict_apply(obs_dict, lambda x: x[:, -1:, ...]))
        # current_cost = current_cost.item()
        # print('current_cost ', current_cost)
        # if current_cost > 55:
        bs = num_samples
        # else:
        #     bs = 1
        cond_data = cond_data.expand(bs, -1, -1)
        cond_mask = cond_mask.expand(bs, -1, -1)
        global_cond = global_cond.expand(bs, -1)
        nsample = self.conditional_sample(
            cond_data, 
            cond_mask,
            local_cond=local_cond,
            global_cond=global_cond,
            current_obs=dict_apply(obs_dict, lambda x: x[:, -1:, ...]),
            **self.kwargs)
        # print('nsample ', nsample.shape)
        # print('nsample ', nsample[:2])
        # unnormalize prediction
        naction_pred = nsample[...,:Da]
        # print('naction_pred ', naction_pred.shape)
        action_pred = self.normalizer['action'].unnormalize(naction_pred)

        # get action
        start = To - 1
        end = start + self.n_action_steps
        print('start ', start)
        print('end ', end)
        # action = action_pred[:,start:end]
        print('action samples ', action_pred.shape)
        # if current_cost > 55:
        action, best_index = self.planner.plan_shooting(dict_apply(obs_dict, lambda x: x[:, -1:, ...]), state, action_pred[:,start:], eval_id=eval_id, timestep=timestep)
        print('action exe', action.shape)
        # else:
        #     action = action_pred[:,start:end]
        #     best_index = 0
        result = {
            'action': action,
            'action_pred': action_pred[best_index].unsqueeze(0)
        }
        return result
    
    def predict_action_classifier_guidance(self, obs_dict: Dict[str, torch.Tensor], avoid_ood=False, frame_assembled_success=False) -> Dict[str, torch.Tensor]:
        """
        obs_dict: must include "obs" key
        result: must include "action" key
        """
        assert 'past_action' not in obs_dict # not implemented yet
        # normalize input
        nobs = self.normalizer.normalize(obs_dict)
        value = next(iter(nobs.values()))
        B, To = value.shape[:2]
        T = self.horizon
        Da = self.action_dim
        Do = self.obs_feature_dim
        To = self.n_obs_steps

        # build input
        device = self.device
        dtype = self.dtype

        # handle different ways of passing observation
        local_cond = None
        global_cond = None
        if self.obs_as_global_cond:
            # condition through global feature
            this_nobs = dict_apply(nobs, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:]))
            nobs_features = self.obs_encoder(this_nobs)
            # reshape back to B, Do
            global_cond = nobs_features.reshape(B, -1)
            # empty data for action
            cond_data = torch.zeros(size=(B, T, Da), device=device, dtype=dtype)
            cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
        else:
            # condition through impainting
            this_nobs = dict_apply(nobs, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:]))
            nobs_features = self.obs_encoder(this_nobs)
            # reshape back to B, To, Do
            nobs_features = nobs_features.reshape(B, To, -1)
            cond_data = torch.zeros(size=(B, T, Da+Do), device=device, dtype=dtype)
            cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
            cond_data[:,:To,Da:] = nobs_features
            cond_mask[:,:To,Da:] = True

        # run sampling
        # with torch.no_grad():
        if 'variant1' in self.opt_method:
            nsample = self.conditional_sample(
                cond_data, 
                cond_mask,
                local_cond=local_cond,
                global_cond=global_cond,
                classifier_guidance=True,
                current_obs=dict_apply(obs_dict, lambda x: x[:, -1:, ...]),
                frame_assembled_success=frame_assembled_success,
                **self.kwargs)
        elif 'variant2' in self.opt_method:
            nsample = self.guided_conditional_sample(
                cond_data, 
                cond_mask,
                local_cond=local_cond,
                global_cond=global_cond,
                classifier_guidance=True,
                current_obs=dict_apply(obs_dict, lambda x: x[:, -1:, ...]),
                frame_assembled_success=frame_assembled_success,
                **self.kwargs)
        # unnormalize prediction
        naction_pred = nsample[...,:Da]
        action_pred = self.normalizer['action'].unnormalize(naction_pred)

        # get action
        start = To - 1
        end = start + self.n_action_steps
        action = action_pred[:,start:end]
        
        result = {
            'action': action,
            'action_pred': action_pred
        }
        return result
    
    def predict_action_robomimic(self, obs_dict: Dict[str, torch.Tensor], avoid_ood=False, optim_lr=1e-2, num_iters=20):
        """
        obs_dict: must include "obs" key
        result: must include "action" key
        """
        assert 'past_action' not in obs_dict # not implemented yet
        # normalize input
        nobs = self.normalizer.normalize(obs_dict)
        value = next(iter(nobs.values()))
        B, To = value.shape[:2]
        T = self.horizon
        Da = self.action_dim
        Do = self.obs_feature_dim
        To = self.n_obs_steps

        # build input
        device = self.device
        dtype = self.dtype

        # handle different ways of passing observation
        local_cond = None
        global_cond = None
        if self.obs_as_global_cond:
            # condition through global feature
            this_nobs = dict_apply(nobs, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:]))
            nobs_features = self.obs_encoder(this_nobs)
            # reshape back to B, Do
            global_cond = nobs_features.reshape(B, -1)
            # empty data for action
            cond_data = torch.zeros(size=(B, T, Da), device=device, dtype=dtype)
            cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
        else:
            # condition through impainting
            this_nobs = dict_apply(nobs, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:]))
            nobs_features = self.obs_encoder(this_nobs)
            # reshape back to B, To, Do
            nobs_features = nobs_features.reshape(B, To, -1)
            cond_data = torch.zeros(size=(B, T, Da+Do), device=device, dtype=dtype)
            cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
            cond_data[:,:To,Da:] = nobs_features
            cond_mask[:,:To,Da:] = True

        # run sampling
        with torch.no_grad():
            nsample = self.conditional_sample(
                cond_data, 
                cond_mask,
                local_cond=local_cond,
                global_cond=global_cond,
                **self.kwargs)
        # unnormalize prediction
        naction_pred = nsample[...,:Da]

        # get action
        start = To - 1
        end = start + self.n_action_steps
        action = naction_pred[:,start:end]

        def print_grad(grad, name):
            print(f"Gradient of {name}: {grad}")

        if avoid_ood:
            # normalize action with normalizer from trained dynamics model
            action_unormalized = self.normalizer['action'].unnormalize(action)
            action_copy = self.dynamics_model_normalizer['action'].normalize(action_unormalized).detach().clone()
            action_copy.requires_grad = True

            optimizer = optim.Adam([action_copy], lr=optim_lr)
            # optimizer = torch.optim.SGD([action_copy], lr=optim_lr, momentum=0.9, weight_decay=weight_decay)

            # normalize action with normalizer from trained dynamics model
            obs_dict = dict_apply(obs_dict, lambda x: x[:, -1:, ...])
            initial_obs_dict = {}
            initial_obs_dict['agentview'] = obs_dict['agentview_image']
            initial_obs_dict['robot0_eye_in_hand'] = obs_dict['robot0_eye_in_hand_image']

            current_img = {'agentview': initial_obs_dict['agentview'].squeeze(1).detach().cpu().numpy(),
                           'robot0_eye_in_hand': initial_obs_dict['robot0_eye_in_hand'].squeeze(1).detach().cpu().numpy()}
            for key in current_img.keys():
                current_img[key] = current_img[key].squeeze(0)
                current_img[key] = (current_img[key] * 255).astype(np.uint8)    # Generate a random color image
                current_img[key] = np.transpose(current_img[key], (1, 2, 0))
                if current_img[key].shape[2] == 3:
                    current_img[key] = cv2.cvtColor(current_img[key], cv2.COLOR_RGB2BGR)

            initial_state = {}
            for key in initial_obs_dict.keys():
                initial_state[key] = self.dynamics_model_normalizer['image'].normalize(initial_obs_dict[key])
                initial_state[key] = initial_state[key].detach()

            if not avoid_ood:
                num_iters = 0
            total_loss_array = []

            for optim_iter in range(num_iters):
                optimizer.zero_grad()
                dummy_history_action = torch.empty((1, 0, 10)).to(device=device, dtype=dtype)
                predicted_latent_states, _, _ = self.dynamics_model(o_history=initial_state, a_history=dummy_history_action, a_future=action_copy)

                predicted_latent_states = predicted_latent_states.squeeze(0)
                ood_score_last = self.ood_module(predicted_latent_states[-1, :].unsqueeze(0), current_img)
                total_ood_score = ood_score_last.sum()

                if self.early_stop and \
                        len(total_loss_array) > 1 and \
                        total_loss_array[-1] < self.early_stop_threshold and \
                        total_loss_array[-2] < self.early_stop_threshold:
                    break

                # Register a hook on the final loss
                print(f"Total OOD score at step {optim_iter}: {total_ood_score}")

                # apply boundary penalty to updated action
                boundary_penalty_loss = boundary_penalty(action_copy)
                total_loss = total_ood_score + 0.1 * boundary_penalty_loss         
                total_loss.backward()
                assert (action_copy.requires_grad and action_copy.grad is not None)
                # print(f"Gradients of action_copy at step {optim_iter}: {action_copy.grad}")
                # print(f"Gradients of action_copy after backward: {action_copy.grad}")

                # torch.nn.utils.clip_grad_value_([action_copy], clip_value=1.0)

                optimizer.step()
                total_loss_array.append(total_ood_score.item())

                # if len(total_loss_array) > 2 and \
                #         abs(total_loss_array[-1] - total_loss_array[-2]) < 0.001:
                #     action_copy.data += torch.normal(mean=0, std=0.3, size=action_copy.shape).to(action_copy.device)
            action_pred = self.normalizer['action'].unnormalize(naction_pred)
            original_action = self.normalizer['action'].unnormalize(action)
            # action_copy[0, :, -1] = action[0, :, -1]
            action_copy = self.dynamics_model_normalizer['action'].unnormalize(action_copy)
            # action_copy[0, :, -1] = original_action[0, :, -1]

            result = {
                'action': action_copy,
                'action_pred': action_pred
            }
            return result
        else:
            action_pred = self.normalizer['action'].unnormalize(naction_pred)
            action = self.normalizer['action'].unnormalize(action)
            result = {
                'action': action,
                'action_pred': action_pred
            }
            return result

    def plan_action_robomimic(self, obs_dict: Dict[str, torch.Tensor], avoid_ood=False, optim_lr=1e-2, num_iters=20):
        """
        obs_dict: must include "obs" key
        result: must include "action" key
        """
        assert 'past_action' not in obs_dict # not implemented yet
        # normalize input
        nobs = self.normalizer.normalize(obs_dict)
        value = next(iter(nobs.values()))
        B, To = value.shape[:2]
        T = self.horizon
        Da = self.action_dim
        Do = self.obs_feature_dim
        To = self.n_obs_steps

        # build input
        device = self.device
        dtype = self.dtype

        # handle different ways of passing observation
        local_cond = None
        global_cond = None
        if self.obs_as_global_cond:
            # condition through global feature
            this_nobs = dict_apply(nobs, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:]))
            nobs_features = self.obs_encoder(this_nobs)
            # reshape back to B, Do
            global_cond = nobs_features.reshape(B, -1)
            # empty data for action
            cond_data = torch.zeros(size=(B, T, Da), device=device, dtype=dtype)
            cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
        else:
            # condition through impainting
            this_nobs = dict_apply(nobs, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:]))
            nobs_features = self.obs_encoder(this_nobs)
            # reshape back to B, To, Do
            nobs_features = nobs_features.reshape(B, To, -1)
            cond_data = torch.zeros(size=(B, T, Da+Do), device=device, dtype=dtype)
            cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
            cond_data[:,:To,Da:] = nobs_features
            cond_mask[:,:To,Da:] = True

        # run sampling
        with torch.no_grad():
            nsample = self.conditional_sample(
                cond_data, 
                cond_mask,
                local_cond=local_cond,
                global_cond=global_cond,
                **self.kwargs)
        # unnormalize prediction
        naction_pred = nsample[...,:Da]

        # get action
        start = To - 1
        end = start + self.n_action_steps
        action = naction_pred[:,start:end]

        def print_grad(grad, name):
            print(f"Gradient of {name}: {grad}")

        if avoid_ood:
            # normalize action with normalizer from trained dynamics model
            action_unormalized = self.normalizer['action'].unnormalize(action)
            action_copy = self.dynamics_model_normalizer['action'].normalize(action_unormalized).detach().clone()
            bs = 16

            # normalize action with normalizer from trained dynamics model
            obs_dict = dict_apply(obs_dict, lambda x: x[:, -1:, ...])
            initial_obs_dict = {}
            initial_obs_dict['agentview'] = obs_dict['agentview_image']
            initial_obs_dict['robot0_eye_in_hand'] = obs_dict['robot0_eye_in_hand_image']
            
            current_img = {'agentview': initial_obs_dict['agentview'].squeeze(1).detach().cpu().numpy(),
                           'robot0_eye_in_hand': initial_obs_dict['robot0_eye_in_hand'].squeeze(1).detach().cpu().numpy()}
            for key in current_img.keys():
                current_img[key] = current_img[key].squeeze(0)
                current_img[key] = (current_img[key] * 255).astype(np.uint8)    # Generate a random color image
                current_img[key] = np.transpose(current_img[key], (1, 2, 0))
                if current_img[key].shape[2] == 3:
                    current_img[key] = cv2.cvtColor(current_img[key], cv2.COLOR_RGB2BGR)

            initial_state = {}
            for key in initial_obs_dict.keys():
                initial_state[key] = self.dynamics_model_normalizer['image'].normalize(initial_obs_dict[key])
                initial_state[key] = initial_state[key].detach()
                initial_state[key] = initial_state[key].expand(bs, -1, -1, -1, -1)

            mean = action_copy
            # std = 0.2 * torch.ones_like(action_copy)

            std = torch.Tensor([
                    [0.000,  0.000,  0.000,  0.000,  0.000,  0.000,  0.000,  0.000,  0.000, 
                    0.000],
                    [0.0065, 0.0059, 0.0042, 0.0147, 0.0124, 0.0046, 0.0124, 0.0146, 0.0041,
                    0.0000],
                    [0.0126, 0.0114, 0.0058, 0.0287, 0.0243, 0.0089, 0.0244, 0.0285, 0.0081,
                    0.0000],
                    [0.0183, 0.0167, 0.0093, 0.0423, 0.0359, 0.0131, 0.0360, 0.0421, 0.0119,
                    0.0000],
                    [0.0236, 0.0218, 0.0110, 0.0557, 0.0472, 0.0171, 0.0473, 0.0554, 0.0157,
                    0.0000],
                    [0.0287, 0.0267, 0.0141, 0.0689, 0.0583, 0.0211, 0.0585, 0.0686, 0.0193,
                    0.0000],
                    [0.0335, 0.0314, 0.0158, 0.0818, 0.0692, 0.0249, 0.0694, 0.0815, 0.0229,
                    0.0000],
                    [0.0381, 0.0361, 0.0186, 0.0946, 0.0799, 0.0285, 0.0801, 0.0942, 0.0263,
                    0.0000]]
                ) + \
                torch.Tensor([
                    [0.0065, 0.0059, 0.0042, 0.0147, 0.0124, 0.0046, 0.0124, 0.0146, 0.0041,
                    0.0000]]
                )
            std = 1.2 * std
            std = std.to(device=device, dtype=dtype)

            mixture_coef = 0.05
            num_samples = 512
            num_elites = 64
            # num_samples = 64
            # num_elites = 16
            num_pi_trajs = int(mixture_coef * num_samples)
            
            for iteration in range(15):
                ##########################
                dummy_history_action = torch.empty((1, 0, 10)).to(device=device, dtype=dtype)
                initial_state_bs_one = {}
                for key in initial_state.keys():
                    initial_state_bs_one[key] = initial_state[key][0:1]
                with torch.no_grad():
                    predicted_latent_states, _, _ = self.dynamics_model(o_history=initial_state_bs_one, a_history=dummy_history_action, a_future=mean)
                    predicted_latent_states = predicted_latent_states[:, -1, ...]
                    ood_score_of_current_mean = self.ood_module.compute_ood_score(predicted_latent_states, current_img=current_img)
                    ood_score_of_current_mean = ood_score_of_current_mean[0].item()
                    print('{} th iteration, value of current mean action is {}'.format(iteration, ood_score_of_current_mean))
                    if ood_score_of_current_mean < 0.15:
                        break
                ##########################   

                actions = torch.clamp(mean + std * \
                    torch.randn(num_samples, 8, 10, device=std.device), -1, 1)

                value = torch.zeros((num_samples, 1), device=device, dtype=dtype)
                # Compute elite actions
                for batch in range(num_samples // bs):
                    action_batch = actions[batch*bs:(batch+1)*bs, :, :]
                    # print('action_batch shape ', action_batch.shape)
                    dummy_history_action = torch.empty((bs, 0, 10)).to(device=device, dtype=dtype)
                    with torch.no_grad():
                        predicted_latent_states, _, _ = self.dynamics_model(o_history=initial_state, a_history=dummy_history_action, a_future=action_batch)
                    # print('predicted_latent_states shape ', predicted_latent_states.shape)
                    predicted_latent_states = predicted_latent_states[:, -1, ...]
                    value_batch = -1 * self.ood_module.compute_ood_score(predicted_latent_states, current_img=current_img)
                    # print('value_batch shape ', value_batch.shape)
                    value[batch*bs:(batch+1)*bs, :] = value_batch

                elite_idxs = torch.topk(value.squeeze(1), num_elites, dim=0).indices
                elite_value, elite_actions = value[elite_idxs], actions[elite_idxs, :, :]
                # Update parameters
                max_value = elite_value.max(0)[0]
                score = torch.exp(0.5*(elite_value - max_value))
                score /= score.sum(0)
                _mean = torch.sum(score.view(num_elites, 1, 1) * elite_actions, dim=0) / (score.sum(0) + 1e-9)
                _std = torch.sqrt(torch.sum(score.view(num_elites, 1, 1) * (elite_actions - _mean.unsqueeze(0)) ** 2, dim=0) / (score.sum(0) + 1e-9))
                # max_std = max(0.2 - 0.02 * iteration, 0.05)
                # _std = _std.clamp_(0, max_std)
                mean, std = 0.01 * mean + (1 - 0.01) * _mean, _std
       
            # Outputs
            score = score.squeeze(1).cpu().numpy()
            elite_action = elite_actions[np.random.choice(np.arange(score.shape[0]), p=score), :, :]
            mean, std = elite_action, _std
            action_copy = mean.unsqueeze(0)

            action_pred = self.normalizer['action'].unnormalize(naction_pred)
            original_action = self.normalizer['action'].unnormalize(action)
            # action_copy[0, :, -1] = action[0, :, -1]
            action_copy = self.dynamics_model_normalizer['action'].unnormalize(action_copy)
            action_copy[0, :, -1] = original_action[0, :, -1]
            result = {
                'action': action_copy,
                'action_pred': action_pred
            }
            return result
        else:
            action_pred = self.normalizer['action'].unnormalize(naction_pred)
            action = self.normalizer['action'].unnormalize(action)
            result = {
                'action': action,
                'action_pred': action_pred
            }
            return result

    def check_dynamics_loss(self, use_history):
        np_info_dict = dict_apply(self.info, lambda x: np.expand_dims(x, axis=0))
        info_dict = dict_apply(np_info_dict, lambda x: torch.from_numpy(x).to(device='cuda', dtype=self.dtype))

        images = info_dict['image']
        actions = info_dict['action']
        print('images shape ', images.shape)
        print('actions shape ', actions.shape)
        mse_loss = nn.MSELoss()
        total_dynamics_loss = 0.0
        for i in range(9, len(images[0]) - 9):
            if use_history:
                history_actions = actions[:,i-8:i, ...]
                history_imgs = images[:,i-8:i+1, ...]
                future_actions = actions[:,i:i+8, ...]
                future_imgs = images[:,i+1:i+9, ...]
                history_imgs = self.dynamics_model_normalizer['image'].normalize(history_imgs)
                history_actions = self.dynamics_model_normalizer['action'].normalize(history_actions)
                future_actions = self.dynamics_model_normalizer['action'].normalize(future_actions)
                future_imgs = self.dynamics_model_normalizer['image'].normalize(future_imgs)
                z_history, z_hat_future, _ = self.dynamics_model(history_imgs, history_actions, future_actions)
                z_future = self.dynamics_model.z_future(future_imgs)
            else:
                img_t = images[:,i, ...]
                # print('img_t range before normalization ', img_t.min(), img_t.max())
                img_t = self.dynamics_model_normalizer['image'].normalize(img_t)
                actions_t = actions[:,i:i+8, ...]
                # print('actions_t range before normalization ', actions_t.min(), actions_t.max())
                actions_t = self.dynamics_model_normalizer['action'].normalize(actions_t)
                # print('img_t range after normalization ', img_t.min(), img_t.max())
                # print('actions_t range after normalization ', actions_t.min(), actions_t.max())
                future_imgs = images[:,i+1:i+9, ...]
                future_imgs = self.dynamics_model_normalizer['image'].normalize(future_imgs)
                z_t, z_hat_future, _, _ = self.dynamics_model(img_t, actions_t)
                z_future = self.dynamics_model.z_future(future_imgs)

            dynamics_loss = mse_loss(z_hat_future, z_future)
            total_dynamics_loss += dynamics_loss.item()
            print('dynamics_loss ', dynamics_loss)

        avg_dynamics_loss = total_dynamics_loss / (len(images[0]) - 18)
        print('avg_dynamics_loss is ', avg_dynamics_loss)

    def save_nn_video(self, save_dir):
        self.ood_module.save_nn_list_as_video(save_dir)
        self.ood_module.save_current_image_list_as_video(save_dir)
        
    # ========= training  ============
    def set_normalizer(self, normalizer: LinearNormalizer):
        self.normalizer.load_state_dict(normalizer.state_dict())

    def compute_loss(self, batch):
        # normalize input
        assert 'valid_mask' not in batch
        nobs = self.normalizer.normalize(batch['obs'])
        nactions = self.normalizer['action'].normalize(batch['action'])
        batch_size = nactions.shape[0]
        horizon = nactions.shape[1]

        # handle different ways of passing observation
        local_cond = None
        global_cond = None
        trajectory = nactions
        cond_data = trajectory
        if self.obs_as_global_cond:
            # reshape B, T, ... to B*T
            this_nobs = dict_apply(nobs, 
                lambda x: x[:,:self.n_obs_steps,...].reshape(-1,*x.shape[2:]))
            nobs_features = self.obs_encoder(this_nobs)
            # reshape back to B, Do
            global_cond = nobs_features.reshape(batch_size, -1)
        else:
            # reshape B, T, ... to B*T
            this_nobs = dict_apply(nobs, lambda x: x.reshape(-1, *x.shape[2:]))
            nobs_features = self.obs_encoder(this_nobs)
            # reshape back to B, T, Do
            nobs_features = nobs_features.reshape(batch_size, horizon, -1)
            cond_data = torch.cat([nactions, nobs_features], dim=-1)
            trajectory = cond_data.detach()

        # generate impainting mask
        condition_mask = self.mask_generator(trajectory.shape)

        # Sample noise that we'll add to the images
        noise = torch.randn(trajectory.shape, device=trajectory.device)
        bsz = trajectory.shape[0]
        # Sample a random timestep for each image
        timesteps = torch.randint(
            0, self.noise_scheduler.config.num_train_timesteps, 
            (bsz,), device=trajectory.device
        ).long()
        # Add noise to the clean images according to the noise magnitude at each timestep
        # (this is the forward diffusion process)
        noisy_trajectory = self.noise_scheduler.add_noise(
            trajectory, noise, timesteps)
        
        # compute loss mask
        loss_mask = ~condition_mask

        # apply conditioning
        noisy_trajectory[condition_mask] = cond_data[condition_mask]
        
        # Predict the noise residual
        pred = self.model(noisy_trajectory, timesteps, 
            local_cond=local_cond, global_cond=global_cond)

        pred_type = self.noise_scheduler.config.prediction_type 
        if pred_type == 'epsilon':
            target = noise
        elif pred_type == 'sample':
            target = trajectory
        else:
            raise ValueError(f"Unsupported prediction type {pred_type}")

        loss = F.mse_loss(pred, target, reduction='none')
        loss = loss * loss_mask.type(loss.dtype)
        loss = reduce(loss, 'b ... -> b (...)', 'mean')
        loss = loss.mean()
        return loss
