import os
import wandb
import numpy as np
import torch
import collections
import pathlib
import tqdm
import h5py
import math
import dill
import wandb.sdk.data_types.video as wv
from diffusion_policy.gym_util.async_vector_env import AsyncVectorEnv
from diffusion_policy.gym_util.sync_vector_env import SyncVectorEnv
from diffusion_policy.gym_util.multistep_wrapper import MultiStepWrapper
from diffusion_policy.gym_util.video_recording_wrapper import VideoRecordingWrapper, VideoRecorder
from diffusion_policy.model.common.rotation_transformer import RotationTransformer

from diffusion_policy.policy.base_image_policy import BaseImagePolicy
from diffusion_policy.common.pytorch_util import dict_apply
from diffusion_policy.env_runner.base_image_runner import BaseImageRunner
from diffusion_policy.env.robomimic.robomimic_image_wrapper import RobomimicImageWrapper
import robomimic.utils.file_utils as FileUtils
import robomimic.utils.env_utils as EnvUtils
import robomimic.utils.obs_utils as ObsUtils



from diffusion_policy.model.common.rotation_transformer import RotationTransformer

rotation_rep = 'rotation_6d'  # ignored when abs_action=False
rotation_transformer = RotationTransformer(
    from_rep='axis_angle', to_rep=rotation_rep)

def convert_actions(raw_actions, abs_action):
    actions = raw_actions
    if not abs_action:
        is_dual_arm = False
        if raw_actions.shape[-1] == 14:
            # dual arm
            raw_actions = raw_actions.reshape(-1, 2, 7)
            is_dual_arm = True

        pos = raw_actions[..., :3]
        rot = raw_actions[..., 3:6]
        gripper = raw_actions[..., 6:]
        rot = rotation_transformer.forward(rot)
        raw_actions = np.concatenate([
            pos, rot, gripper
        ], axis=-1).astype(np.float32)

        if is_dual_arm:
            raw_actions = raw_actions.reshape(-1, 20)
        actions = raw_actions
    return actions



def create_env(env_meta, shape_meta, enable_render=True):
    modality_mapping = collections.defaultdict(list)
    for key, attr in shape_meta['obs'].items():
        modality_mapping[attr.get('type', 'low_dim')].append(key)
    ObsUtils.initialize_obs_modality_mapping_from_dict(modality_mapping)

    env = EnvUtils.create_env_from_metadata(
        env_meta=env_meta,
        render=False,
        render_offscreen=enable_render,
        use_image_obs=enable_render,
        # use_depth_obs=enable_render,   #  for testing in my4060
    )
    # print('test depth_map visual')
    # breakpoint()
    return env


class RobomimicImageRunner(BaseImageRunner):
    """
    Robomimic envs already enforces number of steps.
    """

    def __init__(self,
                 output_dir,
                 dataset_path,
                 shape_meta: dict,
                 n_train=10,
                 n_train_vis=3,
                 train_start_idx=0,
                 n_test=22,
                 n_test_vis=6,
                 test_start_seed=10000,
                 max_steps=400,
                 n_obs_steps=2,
                 n_action_steps=8,
                 render_obs_key='agentview_image',
                 fps=10,
                 crf=22,
                 past_action=False,
                 abs_action=False,
                 tqdm_interval_sec=5.0,
                 n_envs=None
                 ):
        super().__init__(output_dir)

        if n_envs is None:
            n_envs = n_train + n_test

        # assert n_obs_steps <= n_action_steps
        dataset_path = os.path.expanduser(dataset_path)
        robosuite_fps = 20
        steps_per_render = max(robosuite_fps // fps, 1)

        # read from dataset
        env_meta = FileUtils.get_env_metadata_from_dataset(
            dataset_path)
        # breakpoint()

        #{'env_name': 'ToolHang', 'type': 1, 'env_kwargs':
        # {'has_renderer': False, 'has_offscreen_renderer': True,
        # 'ignore_done': True, 'use_object_obs': True,
        # 'use_camera_obs': True, 'control_freq': 20,
        # 'controller_configs':
        # {'type': 'OSC_POSE', 'input_max': 1, 'input_min': -1,
        # 'output_max': [0.05, 0.05, 0.05, 0.5, 0.5, 0.5],
        # 'output_min': [-0.05, -0.05, -0.05, -0.5, -0.5, -0.5],
        # 'kp': 150, 'damping': 1, 'impedance_mode': 'fixed',
        # 'kp_limits': [0, 300], 'damping_limits': [0, 10],
        # 'position_limits': None, 'orientation_limits': None,
        # 'uncouple_pos_ori': True,
        # 'control_delta': True,
        # 'interpolation': None,
        # 'ramp_ratio': 0.2},
        # 'robots': ['Panda'],
        # 'camera_depths': False,
        # 'camera_heights': 240,
        # 'camera_widths': 240,
        # 'reward_shaping': False,
        # 'camera_names': ['sideview', 'robot0_eye_in_hand'], 'render_gpu_device_id': 0}}


        # disable object state observation
        env_meta['env_kwargs']['use_object_obs'] = False

        rotation_transformer = None
        if abs_action:
            env_meta['env_kwargs']['controller_configs']['control_delta'] = False
            rotation_transformer = RotationTransformer('axis_angle', 'rotation_6d')

        def env_fn():
            robomimic_env = create_env(
                env_meta=env_meta,
                shape_meta=shape_meta
            )
            # Robosuite's hard reset causes excessive memory consumption.
            # Disabled to run more envs.
            # https://github.com/ARISE-Initiative/robosuite/blob/92abf5595eddb3a845cd1093703e5a3ccd01e77e/robosuite/environments/base.py#L247-L248
            robomimic_env.env.hard_reset = False
            return MultiStepWrapper(
                VideoRecordingWrapper(
                    RobomimicImageWrapper(
                        env=robomimic_env,
                        shape_meta=shape_meta,
                        init_state=None,
                        render_obs_key=render_obs_key
                    ),
                    video_recoder=VideoRecorder.create_h264(
                        fps=fps,
                        codec='h264',
                        input_pix_fmt='rgb24',
                        crf=crf,
                        thread_type='FRAME',
                        thread_count=1
                    ),
                    file_path=None,
                    steps_per_render=steps_per_render
                ),
                n_obs_steps=n_obs_steps,
                n_action_steps=n_action_steps,
                max_episode_steps=max_steps
            )

        # For each process the OpenGL context can only be initialized once
        # Since AsyncVectorEnv uses fork to create worker process,
        # a separate env_fn that does not create OpenGL context (enable_render=False)
        # is needed to initialize spaces.
        def dummy_env_fn():
            robomimic_env = create_env(
                env_meta=env_meta,
                shape_meta=shape_meta,
                enable_render=False
            )
            return MultiStepWrapper(
                VideoRecordingWrapper(
                    RobomimicImageWrapper(
                        env=robomimic_env,
                        shape_meta=shape_meta,
                        init_state=None,
                        render_obs_key=render_obs_key
                    ),
                    video_recoder=VideoRecorder.create_h264(
                        fps=fps,
                        codec='h264',
                        input_pix_fmt='rgb24',
                        crf=crf,
                        thread_type='FRAME',
                        thread_count=1
                    ),
                    file_path=None,
                    steps_per_render=steps_per_render
                ),
                n_obs_steps=n_obs_steps,
                n_action_steps=n_action_steps,
                max_episode_steps=max_steps
            )

        env_fns = [env_fn] * n_envs
        env_seeds = list()
        env_prefixs = list()
        env_init_fn_dills = list()

        # train
        with h5py.File(dataset_path, 'r') as f:
            for i in range(n_train):
                train_idx = train_start_idx + i
                enable_render = i < n_train_vis
                init_state = f[f'data/demo_{train_idx}/states'][0]
                mask=torch.nonzero(torch.from_numpy(init_state),as_tuple=True)[0].numpy()
                print(init_state[mask])
                # breakpoint()
                def init_fn(env, init_state=init_state,
                            enable_render=enable_render):
                    # setup rendering
                    # video_wrapper
                    assert isinstance(env.env, VideoRecordingWrapper)
                    env.env.video_recoder.stop()
                    env.env.file_path = None
                    if enable_render:
                        filename = pathlib.Path(output_dir).joinpath(
                            'media', wv.util.generate_id() + ".mp4")
                        filename.parent.mkdir(parents=False, exist_ok=True)
                        filename = str(filename)
                        env.env.file_path = filename

                    # switch to init_state reset
                    assert isinstance(env.env.env, RobomimicImageWrapper)
                    env.env.env.init_state = init_state

                env_seeds.append(train_idx)
                env_prefixs.append('train/')
                env_init_fn_dills.append(dill.dumps(init_fn))

        # test
        for i in range(n_test):
            seed = test_start_seed + i
            enable_render = i < n_test_vis

            def init_fn(env, seed=seed,
                        enable_render=enable_render):
                # setup rendering
                # video_wrapper
                assert isinstance(env.env, VideoRecordingWrapper)
                env.env.video_recoder.stop()
                env.env.file_path = None
                if enable_render:
                    filename = pathlib.Path(output_dir).joinpath(
                        'media', wv.util.generate_id() + ".mp4")
                    filename.parent.mkdir(parents=False, exist_ok=True)
                    filename = str(filename)
                    env.env.file_path = filename

                # switch to seed reset
                assert isinstance(env.env.env, RobomimicImageWrapper)
                env.env.env.init_state = None
                env.seed(seed)

            env_seeds.append(seed)
            env_prefixs.append('test/')
            env_init_fn_dills.append(dill.dumps(init_fn))

        env = AsyncVectorEnv(env_fns, dummy_env_fn=dummy_env_fn)
        # env = SyncVectorEnv(env_fns)

        self.env_meta = env_meta
        self.env = env
        self.env_fns = env_fns
        self.env_seeds = env_seeds
        self.env_prefixs = env_prefixs
        self.env_init_fn_dills = env_init_fn_dills
        self.fps = fps
        self.crf = crf
        self.n_obs_steps = n_obs_steps
        self.n_action_steps = n_action_steps
        self.past_action = past_action
        self.max_steps = max_steps
        self.rotation_transformer = rotation_transformer
        self.abs_action = abs_action
        self.tqdm_interval_sec = tqdm_interval_sec

    def run(self, policy: BaseImagePolicy):
        device = policy.device
        dtype = policy.dtype
        env = self.env

        # plan for rollout
        n_envs = len(self.env_fns)
        n_inits = len(self.env_init_fn_dills)
        n_chunks = math.ceil(n_inits / n_envs)

        # allocate data
        all_video_paths = [None] * n_inits
        all_rewards = [None] * n_inits

        for chunk_idx in range(n_chunks):
            start = chunk_idx * n_envs
            end = min(n_inits, start + n_envs)
            this_global_slice = slice(start, end)
            this_n_active_envs = end - start
            this_local_slice = slice(0, this_n_active_envs)

            this_init_fns = self.env_init_fn_dills[this_global_slice]
            n_diff = n_envs - len(this_init_fns)
            if n_diff > 0:
                this_init_fns.extend([self.env_init_fn_dills[0]] * n_diff)
            assert len(this_init_fns) == n_envs

            # init envs
            env.call_each('run_dill_function',
                          args_list=[(x,) for x in this_init_fns])

            # start rollout
            obs = env.reset()
            past_action = None
            policy.reset()

            env_name = self.env_meta['env_name']
            pbar = tqdm.tqdm(total=self.max_steps, desc=f"Eval {env_name}Image {chunk_idx + 1}/{n_chunks}",
                             leave=False, mininterval=self.tqdm_interval_sec)

            done = False
            while not done:
                # create obs dict
                np_obs_dict = dict(obs)
                if self.past_action and (past_action is not None):
                    # TODO: not tested
                    np_obs_dict['past_action'] = past_action[
                                                 :, -(self.n_obs_steps - 1):].astype(np.float32)

                # device transfer
                obs_dict = dict_apply(np_obs_dict,
                                      lambda x: torch.from_numpy(x).to(
                                          device=device))


                # run policy
                # obs_dict['robot0_eye_in_hand_image'].shape
                # torch.Size([14, 2, 3, 84, 84])

                input_obs_dict = {}
                for key in obs_dict.keys():
                    input_obs_dict[key] = obs_dict[key][:, -policy.n_obs_steps:]

                with torch.no_grad():
                    action_dict = policy.predict_action(input_obs_dict)

                # device_transfer
                np_action_dict = dict_apply(action_dict,
                                            lambda x: x.detach().to('cpu').numpy())

                action = np_action_dict['action']
                if not np.all(np.isfinite(action)):
                    print(action)
                    raise RuntimeError("Nan or Inf action")

                # step env
                env_action = action
                if self.abs_action:
                    env_action = self.undo_transform_action(action)

                obs, reward, done, info = env.step(env_action)



                done = np.all(done)
                past_action = action

                # update pbar
                pbar.update(action.shape[1])
            pbar.close()

            # collect data for this round
            all_video_paths[this_global_slice] = env.render()[this_local_slice]
            all_rewards[this_global_slice] = env.call('get_attr', 'reward')[this_local_slice]
        # clear out video buffer
        _ = env.reset()

        # log
        max_rewards = collections.defaultdict(list)
        log_data = dict()
        # results reported in the paper are generated using the commented out line below
        # which will only report and average metrics from first n_envs initial condition and seeds
        # fortunately this won't invalidate our conclusion since
        # 1. This bug only affects the variance of metrics, not their mean
        # 2. All baseline methods are evaluated using the same code
        # to completely reproduce reported numbers, uncomment this line:
        # for i in range(len(self.env_fns)):
        # and comment out this line
        for i in range(n_inits):
            seed = self.env_seeds[i]
            prefix = self.env_prefixs[i]
            max_reward = np.max(all_rewards[i])
            max_rewards[prefix].append(max_reward)
            log_data[prefix + f'sim_max_reward_{seed}'] = max_reward

            # visualize sim
            video_path = all_video_paths[i]
            if video_path is not None:
                sim_video = wandb.Video(video_path)
                log_data[prefix + f'sim_video_{seed}'] = sim_video

        # log aggregate metrics
        for prefix, value in max_rewards.items():
            name = prefix + 'mean_score'
            value = np.mean(value)
            log_data[name] = value

        return log_data
    def onestep_run(self, policy: BaseImagePolicy):
        device = policy.device
        dtype = policy.dtype
        env = self.env

        # plan for rollout
        n_envs = len(self.env_fns)
        n_inits = len(self.env_init_fn_dills)
        n_chunks = math.ceil(n_inits / n_envs)

        # allocate data
        all_video_paths = [None] * n_inits
        all_rewards = [None] * n_inits

        for chunk_idx in range(n_chunks):
            start = chunk_idx * n_envs
            end = min(n_inits, start + n_envs)
            this_global_slice = slice(start, end)
            this_n_active_envs = end - start
            this_local_slice = slice(0, this_n_active_envs)

            this_init_fns = self.env_init_fn_dills[this_global_slice]
            n_diff = n_envs - len(this_init_fns)
            if n_diff > 0:
                this_init_fns.extend([self.env_init_fn_dills[0]] * n_diff)
            assert len(this_init_fns) == n_envs

            # init envs
            env.call_each('run_dill_function',
                          args_list=[(x,) for x in this_init_fns])

            # start rollout
            obs = env.reset()
            past_action = None
            policy.reset()

            env_name = self.env_meta['env_name']
            pbar = tqdm.tqdm(total=self.max_steps, desc=f"Eval {env_name}Image {chunk_idx + 1}/{n_chunks}",
                             leave=False, mininterval=self.tqdm_interval_sec)

            done = False
            while not done:
                # create obs dict
                np_obs_dict = dict(obs)
                if self.past_action and (past_action is not None):
                    # TODO: not tested
                    np_obs_dict['past_action'] = past_action[
                                                 :, -(self.n_obs_steps - 1):].astype(np.float32)

                # device transfer
                obs_dict = dict_apply(np_obs_dict,
                                      lambda x: torch.from_numpy(x).to(
                                          device=device))

                # run policy
                # obs_dict['robot0_eye_in_hand_image'].shape
                # torch.Size([14, 2, 3, 84, 84])

                with torch.no_grad():
                    action_dict = policy.predict_action(obs_dict)

                # device_transfer
                np_action_dict = dict_apply(action_dict,
                                            lambda x: x.detach().to('cpu').numpy())

                action = np_action_dict['action']
                if not np.all(np.isfinite(action)):
                    print(action)
                    raise RuntimeError("Nan or Inf action")

                # step env
                # env_action = action
                env_action = np_action_dict['action_pred'][:,1:2]

                if self.abs_action:
                    env_action = self.undo_transform_action(env_action)

                obs, reward, done, info = env.step(env_action)
                done = np.all(done)
                past_action = action

                # update pbar
                pbar.update(action.shape[1])
            pbar.close()

            # collect data for this round
            all_video_paths[this_global_slice] = env.render()[this_local_slice]
            all_rewards[this_global_slice] = env.call('get_attr', 'reward')[this_local_slice]
        # clear out video buffer
        _ = env.reset()

        # log
        max_rewards = collections.defaultdict(list)
        log_data = dict()
        # results reported in the paper are generated using the commented out line below
        # which will only report and average metrics from first n_envs initial condition and seeds
        # fortunately this won't invalidate our conclusion since
        # 1. This bug only affects the variance of metrics, not their mean
        # 2. All baseline methods are evaluated using the same code
        # to completely reproduce reported numbers, uncomment this line:
        # for i in range(len(self.env_fns)):
        # and comment out this line
        for i in range(n_inits):
            seed = self.env_seeds[i]
            prefix = self.env_prefixs[i]
            max_reward = np.max(all_rewards[i])
            max_rewards[prefix].append(max_reward)
            log_data[prefix + f'sim_max_reward_{seed}'] = max_reward

            # visualize sim
            video_path = all_video_paths[i]
            if video_path is not None:
                sim_video = wandb.Video(video_path)
                log_data[prefix + f'sim_video_{seed}'] = sim_video

        # log aggregate metrics
        for prefix, value in max_rewards.items():
            name = prefix + 'mean_score'
            value = np.mean(value)
            log_data[name] = value

        return log_data

    # def aug_run_before(self, s_policy,policy: BaseImagePolicy):
    #     device = policy.device
    #     dtype = policy.dtype
    #     env = self.env
    #
    #     # plan for rollout
    #     n_envs = len(self.env_fns)
    #     n_inits = len(self.env_init_fn_dills)
    #     n_chunks = math.ceil(n_inits / n_envs)
    #
    #     # allocate data
    #     all_video_paths = [None] * n_inits
    #     all_rewards = [None] * n_inits
    #     ori_mse_list = []
    #     guide_mse_list = []
    #     guide_ori_mse_list = []
    #     save_data = dict()
    #     save_data_list = [dict() for _ in range(n_chunks)]
    #     for chunk_idx in range(n_chunks):
    #         save_data = save_data_list[chunk_idx]
    #         start = chunk_idx * n_envs
    #         end = min(n_inits, start + n_envs)
    #         this_global_slice = slice(start, end)
    #         this_n_active_envs = end - start
    #         this_local_slice = slice(0, this_n_active_envs)
    #
    #         this_init_fns = self.env_init_fn_dills[this_global_slice]
    #         n_diff = n_envs - len(this_init_fns)
    #         if n_diff > 0:
    #             this_init_fns.extend([self.env_init_fn_dills[0]] * n_diff)
    #         assert len(this_init_fns) == n_envs
    #
    #         # init envs
    #         env.call_each('run_dill_function',
    #                       args_list=[(x,) for x in this_init_fns])
    #
    #         # start rollout
    #         obs = env.reset()
    #         # obs.keys()
    #         # odict_keys(['robot0_eef_pos', 'robot0_eef_quat', 'robot0_eye_in_hand_image', 'robot0_gripper_qpos', 'robot1_eef_pos',
    #         #             'robot1_eef_quat', 'robot1_eye_in_hand_image', 'robot1_gripper_qpos', 'shouldercamera0_image',
    #         #             'shouldercamera1_image'])
    #         # obs['robot0_eef_pos'].shape
    #         # (14, 2, 3)
    #         # obs['robot0_eef_pos'][0]
    #         # array([[0.00288391, -0.3695618, 0.99539196],
    #         #        [0.00288391, -0.3695618, 0.99539196]], dtype=float32)
    #         # (
    #         # obs['robot0_eef_pos'][1]
    #         # array([[0.02061568, -0.36231545, 1.0037795],
    #         #        [0.02061568, -0.36231545, 1.0037795]], dtype=float32)
    #         # save_data['robot0_eef_pos'].shape
    #         # (14, 1,2, 3)
    #
    #         past_action = None
    #         policy.reset()
    #
    #         env_name = self.env_meta['env_name']
    #         pbar = tqdm.tqdm(total=self.max_steps, desc=f"Eval {env_name}Image {chunk_idx + 1}/{n_chunks}",
    #                          leave=False, mininterval=self.tqdm_interval_sec)
    #
    #         done = False
    #
    #         env_number = obs['robot0_eef_pos'].shape[0]
    #
    #         num_tik = 0
    #         while not done:
    #             # create obs dict
    #             num_tik += 1
    #
    #             np_obs_dict = dict(obs)
    #             if self.past_action and (past_action is not None):
    #                 # TODO: not tested
    #                 np_obs_dict['past_action'] = past_action[
    #                                              :, -(self.n_obs_steps - 1):].astype(np.float32)
    #
    #             # device transfer
    #             obs_dict = dict_apply(np_obs_dict,
    #                                   lambda x: torch.from_numpy(x).to(
    #                                       device=device))
    #
    #             # robot0_eef_pos
    #             # shape: torch.Size([14, 2, 3])
    #             # robot0_eef_quat
    #             # shape: torch.Size([14, 2, 4])
    #             # robot0_eye_in_hand_image
    #             # shape: torch.Size([14, 2, 3, 84, 84])
    #             # robot0_gripper_qpos
    #             # shape: torch.Size([14, 2, 2])
    #             # robot1_eef_pos
    #             # shape: torch.Size([14, 2, 3])
    #             # robot1_eef_quat
    #             # shape: torch.Size([14, 2, 4])
    #             # robot1_eye_in_hand_image
    #             # shape: torch.Size([14, 2, 3, 84, 84])”
    #             # robot1_gripper_qpos
    #             # shape: torch.Size([14, 2, 2])
    #             # shouldercamera0_image
    #             # shape: torch.Size([14, 2, 3, 84, 84])
    #             # shouldercamera1_image
    #             # shape: torch.Size([14, 2, 3, 84, 84])
    #
    #             if 'obs' not in save_data:
    #                 save_data['obs'] = dict()
    #                 for key in obs_dict.keys():
    #                     save_data['obs'][key] = obs_dict[key].unsqueeze(1).cpu()
    #             else:
    #                 # for key in save_data['obs'].keys():
    #                 #     print(key)
    #                 for key in obs_dict.keys():
    #                     save_data['obs'][key] = torch.cat((save_data['obs'][key], obs_dict[key].unsqueeze(1).cpu()), dim=1)
    #             # for key in obs_dict.keys():
    #             #     print(key,'shape:',obs_dict[key].shape)
    #             # save_data['obs']['robot0_eye_in_hand_image'].shape
    #             # torch.Size([14, 2, 3, 84, 84])
    #
    #             # RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [14, 2, 3, 76, 76]
    #
    #             # run policy
    #             # with torch.no_grad():
    #             #     action_dict = policy.predict_action(obs_dict)
    #             # device_transfer
    #             # key_list=list(obs_dict.keys())
    #             # runner_obs_step=obs_dict[key_list[0]].shape[1]
    #
    #             # if runner_obs_step>policy.n_obs_steps:
    #             #     obs_dict_policy=dict()
    #             #     for key in obs_dict.keys():
    #             #         obs_dict_policy[key]=obs_dict[key][:,-policy.n_obs_steps:]
    #             # else:
    #             #     obs_dict_policy=obs_dict
    #             with torch.no_grad():
    #                 s_action_dict = s_policy.predict_action(obs_dict)
    #                 action_dict = policy.predict_action_mse_guide(obs_dict,s_action_dict['action'])
    #                 # ori_mse = torch.mean((-s_action_dict['action']+t_ori_action_dict['action'])**2)
    #                 # ori_mse_list.append(ori_mse.item())
    #                 # guide_mse = torch.mean((-s_action_dict['action']+action_dict['action'])**2)
    #                 # guide_mse_list.append(guide_mse.item())
    #                 # guide_ori_mse = torch.mean((-t_ori_action_dict['action'] + action_dict['action']) ** 2)
    #                 # guide_ori_mse_list.append(guide_ori_mse.item())
    #
    #
    #
    #             # step env
    #
    #             if 'action' not in save_data:
    #                 save_data['action'] = action_dict['action_pred'][:,:s_policy.horizon].unsqueeze(1).cpu()
    #             else:
    #                 save_data['action'] = torch.cat((save_data['action'], action_dict['action_pred'][:,:s_policy.horizon].unsqueeze(1).cpu()), dim=1)
    #
    #
    #             np_action_dict = dict_apply(action_dict,
    #                                         lambda x: x.detach().to('cpu').numpy())
    #
    #             action = np_action_dict['action']
    #             if not np.all(np.isfinite(action)):
    #                 print(action)
    #                 raise RuntimeError("Nan or Inf action")
    #
    #             env_action = np_action_dict['action_pred'][:,1:2]
    #             if self.abs_action:
    #                 env_action = self.undo_transform_action(env_action)
    #                 # env_action_pred = self.undo_transform_action(np_action_dict['action_pred'])
    #                 # env_action_pred = torch.from_numpy(env_action_pred).to(device=device)
    #             # action.shape
    #             # (14, 8, 20)
    #             # env_action.shape
    #             # (14, 8, 14)
    #
    #             obs, reward, done, info = env.step(env_action)
    #             # for test
    #             # if num_tik == 2:
    #             #     reward = torch.ones_like(torch.from_numpy(reward))
    #             #     save_data['reward'] = torch.cat((save_data['reward'], reward.unsqueeze(1)), dim=1)
    #             #
    #             #     break
    #             # Todo 这个break之后的代码不会直接跳出训练，而是继续执行，因为还有好几个并行环境，14，8就是4个14，2合成的。。。。 循环要再往外一层才行，先睡了
    #             # torch.Size([14, 8])
    #             # save_data['reward'][0]
    #             # tensor([0., 1., 0., 1., 0., 1., 0., 1.], dtype=torch.float64)
    #             if torch.any(torch.from_numpy(reward)==1):
    #                 print('这是50')
    #             if 'reward' not in save_data:
    #                 save_data['reward'] = torch.from_numpy(reward).unsqueeze(1)
    #             else:
    #                 save_data['reward'] = torch.cat((save_data['reward'], torch.from_numpy(reward).unsqueeze(1)), dim=1)
    #                 #  14,1
    #             # for test
    #
    #             # obs['robot0_eef_pos'][0]
    #             # array([[0.01355493, -0.37373725, 0.9203873],
    #             #        [0.01482035, -0.37595108, 0.92129064]], dtype=float32)
    #             # obs['robot0_eef_pos'][1]
    #             # array([[0.14259064, -0.55685365, 1.1557218],
    #             #        [0.13804738, -0.5758146, 1.1942073]], dtype=float32)
    #
    #             # test for env.step()
    #             # obs = env.reset()
    #             # obs, reward, done, info = env.step(env_action[:,:2])
    #             # obs['robot0_eef_pos'][0]
    #             # array([[-0.00370188, -0.37130702, 0.98732823],
    #             #        [0.00369851, -0.3710973, 0.96389383]], dtype=float32)
    #             # obs['robot0_eef_pos'][1]
    #             # array([[0.0297156, -0.37939224, 1.0014757],
    #             #        [0.04958566, -0.41708672, 1.0024276]], dtype=float32)
    #             # obs = env.reset()
    #             # obs, reward, done, info = env.step(env_action[:,:1])
    #             # obs['robot0_eef_pos'][0]
    #             # array([[0.00288391, -0.3695618, 0.99539196],
    #             #        [-0.00334014, -0.37123954, 0.98799133]], dtype=float32)
    #             # obs['robot0_eef_pos'][1]
    #             # array([[0.02061568, -0.36231545, 1.0037795],
    #             #        [0.02904901, -0.37806112, 1.0016063]], dtype=float32)
    #
    #             # reward.shape
    #             # (14,)
    #             # done.shape
    #             # (14,)
    #
    #             done = np.all(done)
    #             past_action = action
    #
    #             # update pbar
    #             pbar.update(env_action.shape[1])
    #         pbar.close()
    #
    #         # collect data for this round
    #         # all_video_paths[this_global_slice] = env.render()[this_local_slice]
    #         all_rewards[this_global_slice] = env.call('get_attr', 'reward')[this_local_slice]
    #         save_data_list[chunk_idx] = save_data
    #     # clear out video buffer
    #     _ = env.reset()
    #     # print(f"Ori MSE: {np.mean(ori_mse_list)}")
    #     # print(f"Guide MSE: {np.mean(guide_mse_list)}")
    #     # print(f"Guide Ori MSE: {np.mean(guide_ori_mse_list)}")
    #     # log
    #     max_rewards = collections.defaultdict(list)
    #     log_data = dict()
    #     # results reported in the paper are generated using the commented out line below
    #     # which will only report and average metrics from first n_envs initial condition and seeds
    #     # fortunately this won't invalidate our conclusion since
    #     # 1. This bug only affects the variance of metrics, not their mean
    #     # 2. All baseline methods are evaluated using the same code
    #     # to completely reproduce reported numbers, uncomment this line:
    #     # for i in range(len(self.env_fns)):
    #     # and comment out this line
    #     for i in range(n_inits):
    #         seed = self.env_seeds[i]
    #         prefix = self.env_prefixs[i]
    #         max_reward = np.max(all_rewards[i])
    #         max_rewards[prefix].append(max_reward)
    #         log_data[prefix + f'sim_max_reward_{seed}'] = max_reward
    #
    #         # visualize sim
    #         video_path = all_video_paths[i]
    #         if video_path is not None:
    #             sim_video = wandb.Video(video_path)
    #             log_data[prefix + f'sim_video_{seed}'] = sim_video
    #
    #     # log aggregate metrics
    #     for prefix, value in max_rewards.items():
    #         name = prefix + 'mean_score'
    #         value = np.mean(value)
    #         log_data[name] = value
    #     # 一部分展平以后加到replay_buffer中，一部分加到self.data中
    #
    #     # if 'reward' not in save_data:
    #     #     save_data['reward'] = torch.from_numpy(reward).unsqueeze(1)
    #     # else:
    #     #     save_data['reward'] = torch.cat((save_data['reward'], torch.from_numpy(reward).unsqueeze(1)), dim=1)
    #
    #     traj_list = []
    #     # save_data['obs']['robot0_eye_in_hand_image'].shape
    #     # torch.Size([14, 8, 2, 3, 84, 84])
    #     for j in range(n_chunks):
    #         # save_data_list = [dict() for _ in range(n_chunks)]
    #         save_data = save_data_list[j]
    #         for i in range(env_number):
    #             reward_tensor = save_data['reward'][i]  # traj_len
    #             if torch.any(reward_tensor == 1):
    #                 position = (reward_tensor == 1).nonzero(as_tuple=True)[0][0]
    #                 traj_dict = {'action': save_data['action'][i, :position + self.n_action_steps+8].numpy(), 'obs': {}}
    #                 for key in save_data['obs'].keys():
    #                     traj_dict['obs'][key] = save_data['obs'][key][i, :position + self.n_action_steps+8].numpy()
    #                 traj_list.append(traj_dict)
    #     return log_data, traj_list

    def dagger_run(self, policy: BaseImagePolicy,t_model,epoch_num):
        device = policy.device
        dtype = policy.dtype
        env = self.env

        # plan for rollout
        # num_env_fns=len(self.env_fns)
        # aug_env_fns=self.env_fns[:num_env_fns//2]
        n_envs = len(self.env_fns)
        n_inits = len(self.env_init_fn_dills)
        # n_chunks = math.ceil(n_inits / n_envs)
        # if n_envs==28:
        #     n_chunks=1
        # else:
        #     n_chunks=28//n_envs
        n_chunks = math.ceil(n_inits / n_envs)
        # allocate data
        all_video_paths = [None] * n_inits
        all_rewards = [None] * n_inits
        ori_mse_list = []
        guide_mse_list = []
        guide_ori_mse_list = []
        save_data = dict()
        save_data_list = [dict() for _ in range(n_chunks)]

        for chunk_idx in range(n_chunks):
            save_data = save_data_list[chunk_idx]
            start = chunk_idx * n_envs
            end = min(n_inits, start + n_envs)
            this_global_slice = slice(start, end)
            this_n_active_envs = end - start
            this_local_slice = slice(0, this_n_active_envs)

            this_init_fns = self.env_init_fn_dills[this_global_slice]
            n_diff = n_envs - len(this_init_fns)
            if n_diff > 0:
                this_init_fns.extend([self.env_init_fn_dills[0]] * n_diff)
            assert len(this_init_fns) == n_envs

            # init envs
            env.call_each('run_dill_function',
                          args_list=[(x,) for x in this_init_fns])

            # start rollout
            obs = env.reset()
            # obs.keys()
            # odict_keys(['robot0_eef_pos', 'robot0_eef_quat', 'robot0_eye_in_hand_image', 'robot0_gripper_qpos', 'robot1_eef_pos',
            #             'robot1_eef_quat', 'robot1_eye_in_hand_image', 'robot1_gripper_qpos', 'shouldercamera0_image',
            #             'shouldercamera1_image'])
            # obs['robot0_eef_pos'].shape
            # (14, 2, 3)
            # obs['robot0_eef_pos'][0]
            # array([[0.00288391, -0.3695618, 0.99539196],
            #        [0.00288391, -0.3695618, 0.99539196]], dtype=float32)
            # (
            # obs['robot0_eef_pos'][1]
            # array([[0.02061568, -0.36231545, 1.0037795],
            #        [0.02061568, -0.36231545, 1.0037795]], dtype=float32)
            # save_data['robot0_eef_pos'].shape
            # (14, 1,2, 3)

            past_action = None
            policy.reset()

            env_name = self.env_meta['env_name']
            pbar = tqdm.tqdm(total=self.max_steps, desc=f"Eval {env_name}Image {chunk_idx + 1}/{n_chunks}",
                             leave=False, mininterval=self.tqdm_interval_sec)

            done = False

            env_number = obs['robot0_eef_pos'].shape[0]

            num_tik = 0
            while not done:
                # create obs dict
                num_tik += 1

                np_obs_dict = dict(obs)
                if self.past_action and (past_action is not None):
                    # TODO: not tested
                    np_obs_dict['past_action'] = past_action[
                                                 :, -(self.n_obs_steps - 1):].astype(np.float32)

                # device transfer
                obs_dict = dict_apply(np_obs_dict,
                                      lambda x: torch.from_numpy(x).to(
                                          device=device))

                # robot0_eef_pos
                # shape: torch.Size([14, 2, 3])
                # robot0_eef_quat
                # shape: torch.Size([14, 2, 4])
                # robot0_eye_in_hand_image
                # shape: torch.Size([14, 2, 3, 84, 84])
                # robot0_gripper_qpos
                # shape: torch.Size([14, 2, 2])
                # robot1_eef_pos
                # shape: torch.Size([14, 2, 3])
                # robot1_eef_quat
                # shape: torch.Size([14, 2, 4])
                # robot1_eye_in_hand_image
                # shape: torch.Size([14, 2, 3, 84, 84])”
                # robot1_gripper_qpos
                # shape: torch.Size([14, 2, 2])
                # shouldercamera0_image
                # shape: torch.Size([14, 2, 3, 84, 84])
                # shouldercamera1_image
                # shape: torch.Size([14, 2, 3, 84, 84])

                # for key in obs_dict.keys():
                #     print(key,'shape:',obs_dict[key].shape)
                # save_data['obs']['robot0_eye_in_hand_image'].shape
                # torch.Size([14, 2, 3, 84, 84])

                # RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [14, 2, 3, 76, 76]

                # run policy
                # with torch.no_grad():
                #     action_dict = policy.predict_action(obs_dict)
                # device_transfer
                # key_list=list(obs_dict.keys())
                # runner_obs_step=obs_dict[key_list[0]].shape[1]

                # if runner_obs_step>policy.n_obs_steps:
                #     obs_dict_policy=dict()
                #     for key in obs_dict.keys():
                #         obs_dict_policy[key]=obs_dict[key][:,-policy.n_obs_steps:]
                # else:
                #     obs_dict_policy=obs_dict
                with torch.no_grad():

                    input_obs_dict_t={}
                    for key in obs_dict.keys():
                        input_obs_dict_t[key]=obs_dict[key][:,-t_model.n_obs_steps:]


                    input_obs_dict_s={}
                    for key in obs_dict.keys():
                        input_obs_dict_s[key]=obs_dict[key][:,-policy.n_obs_steps:]


                    t_action_dict = t_model.predict_action(input_obs_dict_t)
                    action_dict_s = policy.predict_action(input_obs_dict_s)
                    # ori_mse = torch.mean((-s_action_dict['action']+t_ori_action_dict['action'])**2)
                    # ori_mse_list.append(ori_mse.item())
                    # guide_mse = torch.mean((-s_action_dict['action']+action_dict['action'])**2)
                    # guide_mse_list.append(guide_mse.item())
                    # guide_ori_mse = torch.mean((-t_ori_action_dict['action'] + action_dict['action']) ** 2)
                    # guide_ori_mse_list.append(guide_ori_mse.item())




                action_s = action_dict_s['action']
                action_t = t_action_dict['action']

                # env_action_t = np_action_dict_t['action_pred'][:,1:2]
                expert_scale=0.9**epoch_num   # https://arxiv.org/pdf/1011.0686 Figure 4 set best beta=0.5
                action=expert_scale*action_t+(1-expert_scale)*action_s


                if 'action' not in save_data:
                    save_data['action'] = action.cpu()
                else:
                    save_data['action'] = torch.cat((save_data['action'], action.cpu()), dim=1)  # [14, action_steps, 20]

                action=action.detach().cpu().numpy()

                if not np.all(np.isfinite(action)):
                    print(action_t)
                    raise RuntimeError("Nan or Inf action")




                if self.abs_action:
                    action = self.undo_transform_action(action)
                    # env_action_pred = self.undo_transform_action(np_action_dict['action_pred'])
                    # env_action_pred = torch.from_numpy(env_action_pred).to(device=device)
                # action.shape
                # (14, 8, 20)
                # env_action.shape
                # (14, 8, 14)

                obs, reward, done, info = env.step(action)
                # for test
                if num_tik == 10:
                #     reward = torch.ones_like(torch.from_numpy(reward))
                #     save_data['reward'] = torch.cat((save_data['reward'], reward.unsqueeze(1)), dim=1)
                    pass
                    # break


                np_obs_dict = dict(obs)
                # device transfer
                obs_dict = dict_apply(np_obs_dict,
                                      lambda x: torch.from_numpy(x).to(
                                          device=device))

                if 'obs' not in save_data:
                    save_data['obs'] = dict()
                    for key in obs_dict.keys():
                        save_data['obs'][key] = obs_dict[key][:,self.n_obs_steps-self.n_action_steps-1:self.n_obs_steps-1].cpu() # [14, 1, 3, 84, 84])
                else:
                    # for key in save_data['obs'].keys():
                    #     print(key)
                    for key in obs_dict.keys():
                        save_data['obs'][key] = torch.cat((save_data['obs'][key], obs_dict[key][:,self.n_obs_steps-self.n_action_steps-1:self.n_obs_steps-1].cpu()), dim=1)

                # Todo 这个break之后的代码不会直接跳出训练，而是继续执行，因为还有好几个并行环境，14，8就是4个14，2合成的。。。。 循环要再往外一层才行，先睡了
                # torch.Size([14, 8])
                # save_data['reward'][0]
                # tensor([0., 1., 0., 1., 0., 1., 0., 1.], dtype=torch.float64)
                # if torch.any(torch.from_numpy(reward)==1):
                #     print('这是50')
                if 'reward' not in save_data:
                    save_data['reward'] = torch.from_numpy(reward).unsqueeze(1).repeat(1,self.n_action_steps)
                else:
                    save_data['reward'] = torch.cat((save_data['reward'], torch.from_numpy(reward).unsqueeze(1).repeat(1,self.n_action_steps)), dim=1)
                    #  14,1
                # for test

                # obs['robot0_eef_pos'][0]
                # array([[0.01355493, -0.37373725, 0.9203873],
                #        [0.01482035, -0.37595108, 0.92129064]], dtype=float32)
                # obs['robot0_eef_pos'][1]
                # array([[0.14259064, -0.55685365, 1.1557218],
                #        [0.13804738, -0.5758146, 1.1942073]], dtype=float32)

                # test for env.step()
                # obs = env.reset()
                # obs, reward, done, info = env.step(env_action[:,:2])
                # obs['robot0_eef_pos'][0]
                # array([[-0.00370188, -0.37130702, 0.98732823],
                #        [0.00369851, -0.3710973, 0.96389383]], dtype=float32)
                # obs['robot0_eef_pos'][1]
                # array([[0.0297156, -0.37939224, 1.0014757],
                #        [0.04958566, -0.41708672, 1.0024276]], dtype=float32)
                # obs = env.reset()
                # obs, reward, done, info = env.step(env_action[:,:1])
                # obs['robot0_eef_pos'][0]
                # array([[0.00288391, -0.3695618, 0.99539196],
                #        [-0.00334014, -0.37123954, 0.98799133]], dtype=float32)
                # obs['robot0_eef_pos'][1]
                # array([[0.02061568, -0.36231545, 1.0037795],
                #        [0.02904901, -0.37806112, 1.0016063]], dtype=float32)

                # reward.shape
                # (14,)
                # done.shape
                # (14,)

                done = np.all(done)
                past_action = action

                # update pbar
                pbar.update(action.shape[1])
            pbar.close()

            # collect data for this round
            # all_video_paths[this_global_slice] = env.render()[this_local_slice]
            all_rewards[this_global_slice] = env.call('get_attr', 'reward')[this_local_slice]

            save_data_list[chunk_idx] = save_data
        # clear out video buffer
        _ = env.reset()
        # print(f"Ori MSE: {np.mean(ori_mse_list)}")
        # print(f"Guide MSE: {np.mean(guide_mse_list)}")
        # print(f"Guide Ori MSE: {np.mean(guide_ori_mse_list)}")
        # log
        max_rewards = collections.defaultdict(list)
        log_data = dict()
        # results reported in the paper are generated using the commented out line below
        # which will only report and average metrics from first n_envs initial condition and seeds
        # fortunately this won't invalidate our conclusion since
        # 1. This bug only affects the variance of metrics, not their mean
        # 2. All baseline methods are evaluated using the same code
        # to completely reproduce reported numbers, uncomment this line:
        # for i in range(len(self.env_fns)):
        # and comment out this line
        # for i in range(n_inits):
        #     seed = self.env_seeds[i]
        #     prefix = self.env_prefixs[i]
        #     max_reward = np.max(all_rewards[i])
        #     max_rewards[prefix].append(max_reward)
        #     log_data[prefix + f'sim_max_reward_{seed}'] = max_reward
        #
        #     # visualize sim
        #     video_path = all_video_paths[i]
        #     if video_path is not None:
        #         sim_video = wandb.Video(video_path)
        #         log_data[prefix + f'sim_video_{seed}'] = sim_video
        #
        # # log aggregate metrics
        # for prefix, value in max_rewards.items():
        #     name = prefix + 'mean_score'
        #     value = np.mean(value)
        #     log_data[name] = value
        # 一部分展平以后加到replay_buffer中，一部分加到self.data中

        # if 'reward' not in save_data:
        #     save_data['reward'] = torch.from_numpy(reward).unsqueeze(1)
        # else:
        #     save_data['reward'] = torch.cat((save_data['reward'], torch.from_numpy(reward).unsqueeze(1)), dim=1)

        traj_list = []
        # save_data['obs']['robot0_eye_in_hand_image'].shape
        # torch.Size([14, 8, 2, 3, 84, 84])
        # for j in range(n_chunks):
        #     # save_data_list = [dict() for _ in range(n_chunks)]
        #     save_data = save_data_list[j]
        #     for i in range(env_number):
        #         # reward_tensor = save_data['reward'][i]  # traj_len
        #         # if torch.any(reward_tensor == 1):
        #         #     position = (reward_tensor == 1).nonzero(as_tuple=True)[0][0]
        #         traj_dict = {'actions': save_data['action'][i ], 'obs': {}}   # traj,3,84,84
        #         for key in save_data['obs'].keys():
        #             traj_dict['obs'][key] = save_data['obs'][key][i] # traj,20
        #         traj_list.append(traj_dict)


        for j in range(n_chunks):
            # save_data_list = [dict() for _ in range(n_chunks)]
            save_data = save_data_list[j]
            for i in range(env_number):
                reward_tensor = save_data['reward'][i]  # traj_len
                if torch.any(reward_tensor == 1):
                    position = (reward_tensor == 1).nonzero(as_tuple=True)[0][0]
                    traj_dict = {'actions': save_data['action'][i, :position +self.n_action_steps+8].numpy(), 'obs': {}}   # traj,3,84,84
                    for key in save_data['obs'].keys():
                        traj_dict['obs'][key] = save_data['obs'][key][i, :position + self.n_action_steps+8].numpy() # traj,20

                    traj_list.append(traj_dict)

        return log_data, traj_list
    def augwog_run(self, policy: BaseImagePolicy):
        device = policy.device
        dtype = policy.dtype
        env = self.env

        # plan for rollout
        # num_env_fns=len(self.env_fns)
        # aug_env_fns=self.env_fns[:num_env_fns//2]
        n_envs = len(self.env_fns)
        n_inits = len(self.env_init_fn_dills)
        # n_chunks = math.ceil(n_inits / n_envs)
        # if n_envs==28:
        #     n_chunks=1
        # else:
        #     n_chunks=28//n_envs
        n_chunks = math.ceil(n_inits / n_envs)

        # allocate data
        all_video_paths = [None] * n_inits
        all_rewards = [None] * n_inits
        ori_mse_list = []
        guide_mse_list = []
        guide_ori_mse_list = []
        save_data = dict()
        save_data_list = [dict() for _ in range(n_chunks)]

        for chunk_idx in range(n_chunks):
            save_data = save_data_list[chunk_idx]
            start = chunk_idx * n_envs
            end = min(n_inits, start + n_envs)
            this_global_slice = slice(start, end)
            this_n_active_envs = end - start
            this_local_slice = slice(0, this_n_active_envs)

            this_init_fns = self.env_init_fn_dills[this_global_slice]
            n_diff = n_envs - len(this_init_fns)
            if n_diff > 0:
                this_init_fns.extend([self.env_init_fn_dills[0]] * n_diff)
            assert len(this_init_fns) == n_envs

            # init envs
            env.call_each('run_dill_function',
                          args_list=[(x,) for x in this_init_fns])

            # start rollout
            obs = env.reset()
            # obs.keys()
            # odict_keys(['robot0_eef_pos', 'robot0_eef_quat', 'robot0_eye_in_hand_image', 'robot0_gripper_qpos', 'robot1_eef_pos',
            #             'robot1_eef_quat', 'robot1_eye_in_hand_image', 'robot1_gripper_qpos', 'shouldercamera0_image',
            #             'shouldercamera1_image'])
            # obs['robot0_eef_pos'].shape
            # (14, 2, 3)
            # obs['robot0_eef_pos'][0]
            # array([[0.00288391, -0.3695618, 0.99539196],
            #        [0.00288391, -0.3695618, 0.99539196]], dtype=float32)
            # (
            # obs['robot0_eef_pos'][1]
            # array([[0.02061568, -0.36231545, 1.0037795],
            #        [0.02061568, -0.36231545, 1.0037795]], dtype=float32)
            # save_data['robot0_eef_pos'].shape
            # (14, 1,2, 3)

            past_action = None
            policy.reset()

            env_name = self.env_meta['env_name']
            pbar = tqdm.tqdm(total=self.max_steps, desc=f"Eval {env_name}Image {chunk_idx + 1}/{n_chunks}",
                             leave=False, mininterval=self.tqdm_interval_sec)

            done = False

            env_number = obs['robot0_eef_pos'].shape[0]

            num_tik = 0
            while not done:
                # create obs dict
                num_tik += 1

                np_obs_dict = dict(obs)
                if self.past_action and (past_action is not None):
                    # TODO: not tested
                    np_obs_dict['past_action'] = past_action[
                                                 :, -(self.n_obs_steps - 1):].astype(np.float32)

                # device transfer
                obs_dict = dict_apply(np_obs_dict,
                                      lambda x: torch.from_numpy(x).to(
                                          device=device))

                # robot0_eef_pos
                # shape: torch.Size([14, 2, 3])
                # robot0_eef_quat
                # shape: torch.Size([14, 2, 4])
                # robot0_eye_in_hand_image
                # shape: torch.Size([14, 2, 3, 84, 84])
                # robot0_gripper_qpos
                # shape: torch.Size([14, 2, 2])
                # robot1_eef_pos
                # shape: torch.Size([14, 2, 3])
                # robot1_eef_quat
                # shape: torch.Size([14, 2, 4])
                # robot1_eye_in_hand_image
                # shape: torch.Size([14, 2, 3, 84, 84])”
                # robot1_gripper_qpos
                # shape: torch.Size([14, 2, 2])
                # shouldercamera0_image
                # shape: torch.Size([14, 2, 3, 84, 84])
                # shouldercamera1_image
                # shape: torch.Size([14, 2, 3, 84, 84])

                # for key in obs_dict.keys():
                #     print(key,'shape:',obs_dict[key].shape)
                # save_data['obs']['robot0_eye_in_hand_image'].shape
                # torch.Size([14, 2, 3, 84, 84])

                # RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [14, 2, 3, 76, 76]

                # run policy
                # with torch.no_grad():
                #     action_dict = policy.predict_action(obs_dict)
                # device_transfer
                # key_list=list(obs_dict.keys())
                # runner_obs_step=obs_dict[key_list[0]].shape[1]

                # if runner_obs_step>policy.n_obs_steps:
                #     obs_dict_policy=dict()
                #     for key in obs_dict.keys():
                #         obs_dict_policy[key]=obs_dict[key][:,-policy.n_obs_steps:]
                # else:
                #     obs_dict_policy=obs_dict

                with torch.no_grad():

                    input_obs_dict={}
                    for key in obs_dict.keys():
                        input_obs_dict[key]=obs_dict[key][:,-policy.n_obs_steps:]

                    action_dict = policy.predict_action(input_obs_dict)
                    # ori_mse = torch.mean((-s_action_dict['action']+t_ori_action_dict['action'])**2)
                    # ori_mse_list.append(ori_mse.item())
                    # guide_mse = torch.mean((-s_action_dict['action']+action_dict['action'])**2)
                    # guide_mse_list.append(guide_mse.item())
                    # guide_ori_mse = torch.mean((-t_ori_action_dict['action'] + action_dict['action']) ** 2)
                    # guide_ori_mse_list.append(guide_ori_mse.item())

                    # for key,value in action_dict.items():
                    #     print(key,value.shape)


                # step env


                if 'action' not in save_data:
                    save_data['action'] = action_dict['action'].cpu()
                else:
                    save_data['action'] = torch.cat((save_data['action'], action_dict['action'].cpu()), dim=1)  # [14, action_step, 20]


                # self.n_obs_steps = n_obs_steps
                # self.n_action_steps = n_action_steps
                np_action_dict = dict_apply(action_dict,
                                            lambda x: x.detach().to('cpu').numpy())


                action = np_action_dict['action']



                if not np.all(np.isfinite(action)):
                    print(action)
                    raise RuntimeError("Nan or Inf action")

                # env_action = np_action_dict['action_pred'][:,1:2]
                if self.abs_action:
                    action = self.undo_transform_action(action)
                    # env_action_pred = self.undo_transform_action(np_action_dict['action_pred'])
                    # env_action_pred = torch.from_numpy(env_action_pred).to(device=device)
                # action.shape
                # (14, 8, 20)
                # env_action.shape
                # (14, 8, 14)

                obs, reward, done, info = env.step(action)



                np_obs_dict = dict(obs)
                # device transfer
                obs_dict = dict_apply(np_obs_dict,
                                      lambda x: torch.from_numpy(x).to(
                                          device=device))

                if 'obs' not in save_data:
                    save_data['obs'] = dict()
                    for key in obs_dict.keys():
                        # print(obs_dict[key].shape)
                        # if obs_dict[key].shape[2]==2:
                        #     print(obs_dict[key][0][:,0],num_tik)
                            # breakpoint()
                        save_data['obs'][key] = obs_dict[key][:,self.n_obs_steps-self.n_action_steps-1:self.n_obs_steps-1].cpu() # [14, 1, 3, 84, 84])
                else:
                    # for key in save_data['obs'].keys():
                    #     print(key)
                    for key in obs_dict.keys():
                        # print(obs_dict[key].shape)
                        # torch.Size([11, 9, 3, 84, 84])
                        # if obs_dict[key].shape[2]==2:
                        #     print(obs_dict[key][0][:,0],num_tik)

                        save_data['obs'][key] = torch.cat((save_data['obs'][key], obs_dict[key][:,self.n_obs_steps-self.n_action_steps-1:self.n_obs_steps-1].cpu()), dim=1)

                # breakpoint()
                # for test
                if num_tik == 10:
                #     reward = torch.ones_like(torch.from_numpy(reward))
                #     save_data['reward'] = torch.cat((save_data['reward'], reward.unsqueeze(1)), dim=1)
                    pass
                    # break
                # Todo 这个break之后的代码不会直接跳出训练，而是继续执行，因为还有好几个并行环境，14，8就是4个14，2合成的。。。。 循环要再往外一层才行，先睡了
                # torch.Size([14, 8])
                # save_data['reward'][0]
                # tensor([0., 1., 0., 1., 0., 1., 0., 1.], dtype=torch.float64)
                # if torch.any(torch.from_numpy(reward)==1):
                #     print('这是50')
                if 'reward' not in save_data:
                    save_data['reward'] = torch.from_numpy(reward).unsqueeze(1).repeat(1,self.n_action_steps)
                else:
                    save_data['reward'] = torch.cat((save_data['reward'], torch.from_numpy(reward).unsqueeze(1).repeat(1,self.n_action_steps)), dim=1)
                    #  14,1
                # for test

                # obs['robot0_eef_pos'][0]
                # array([[0.01355493, -0.37373725, 0.9203873],
                #        [0.01482035, -0.37595108, 0.92129064]], dtype=float32)
                # obs['robot0_eef_pos'][1]
                # array([[0.14259064, -0.55685365, 1.1557218],
                #        [0.13804738, -0.5758146, 1.1942073]], dtype=float32)

                # test for env.step()
                # obs = env.reset()
                # obs, reward, done, info = env.step(env_action[:,:2])
                # obs['robot0_eef_pos'][0]
                # array([[-0.00370188, -0.37130702, 0.98732823],
                #        [0.00369851, -0.3710973, 0.96389383]], dtype=float32)
                # obs['robot0_eef_pos'][1]
                # array([[0.0297156, -0.37939224, 1.0014757],
                #        [0.04958566, -0.41708672, 1.0024276]], dtype=float32)
                # obs = env.reset()
                # obs, reward, done, info = env.step(env_action[:,:1])
                # obs['robot0_eef_pos'][0]
                # array([[0.00288391, -0.3695618, 0.99539196],
                #        [-0.00334014, -0.37123954, 0.98799133]], dtype=float32)
                # obs['robot0_eef_pos'][1]
                # array([[0.02061568, -0.36231545, 1.0037795],
                #        [0.02904901, -0.37806112, 1.0016063]], dtype=float32)

                # reward.shape
                # (14,)
                # done.shape
                # (14,)

                done = np.all(done)
                past_action = action

                # update pbar
                pbar.update(action.shape[1])
            pbar.close()

            # collect data for this round
            # all_video_paths[this_global_slice] = env.render()[this_local_slice]
            all_rewards[this_global_slice] = env.call('get_attr', 'reward')[this_local_slice]
            save_data_list[chunk_idx] = save_data
        # clear out video buffer
        _ = env.reset()
        # print(f"Ori MSE: {np.mean(ori_mse_list)}")
        # print(f"Guide MSE: {np.mean(guide_mse_list)}")
        # print(f"Guide Ori MSE: {np.mean(guide_ori_mse_list)}")
        # log
        max_rewards = collections.defaultdict(list)
        log_data = dict()
        # results reported in the paper are generated using the commented out line below
        # which will only report and average metrics from first n_envs initial condition and seeds
        # fortunately this won't invalidate our conclusion since
        # 1. This bug only affects the variance of metrics, not their mean
        # 2. All baseline methods are evaluated using the same code
        # to completely reproduce reported numbers, uncomment this line:
        # for i in range(len(self.env_fns)):
        # and comment out this line
        # for i in range(n_inits):
        #     seed = self.env_seeds[i]
        #     prefix = self.env_prefixs[i]
        #     max_reward = np.max(all_rewards[i])
        #     max_rewards[prefix].append(max_reward)
        #     log_data[prefix + f'sim_max_reward_{seed}'] = max_reward
        #
        #     # visualize sim
        #     video_path = all_video_paths[i]
        #     if video_path is not None:
        #         sim_video = wandb.Video(video_path)
        #         log_data[prefix + f'sim_video_{seed}'] = sim_video
        #
        # # log aggregate metrics
        # for prefix, value in max_rewards.items():
        #     name = prefix + 'mean_score'
        #     value = np.mean(value)
        #     log_data[name] = value
        # 一部分展平以后加到replay_buffer中，一部分加到self.data中

        # if 'reward' not in save_data:
        #     save_data['reward'] = torch.from_numpy(reward).unsqueeze(1)
        # else:
        #     save_data['reward'] = torch.cat((save_data['reward'], torch.from_numpy(reward).unsqueeze(1)), dim=1)

        traj_list = []
        # save_data['obs']['robot0_eye_in_hand_image'].shape
        # torch.Size([14, 8, 2, 3, 84, 84])
        for j in range(n_chunks):
            # save_data_list = [dict() for _ in range(n_chunks)]
            save_data = save_data_list[j]
            for i in range(env_number):
                reward_tensor = save_data['reward'][i]  # traj_len
                if torch.any(reward_tensor == 1):
                    position = (reward_tensor == 1).nonzero(as_tuple=True)[0][0]
                    traj_dict = {'actions': save_data['action'][i, :position +self.n_action_steps+8].numpy(), 'obs': {}}   # traj,3,84,84
                    for key in save_data['obs'].keys():
                        traj_dict['obs'][key] = save_data['obs'][key][i, :position + self.n_action_steps+8].numpy() # traj,20
                    traj_list.append(traj_dict)
        return log_data, traj_list
    def aug_run(self, s_policy,policy: BaseImagePolicy,current_epoch):
        device = policy.device
        dtype = policy.dtype
        env = self.env

        # plan for rollout
        # num_env_fns=len(self.env_fns)
        # aug_env_fns=self.env_fns[:num_env_fns//2]
        n_envs = len(self.env_fns)
        n_inits = len(self.env_init_fn_dills)
        # n_chunks = math.ceil(n_inits / n_envs)
        # if n_envs==28:
        #     n_chunks=1
        # else:
        #     n_chunks=28//n_envs
        n_chunks = math.ceil(n_inits / n_envs)

        # allocate data
        all_video_paths = [None] * n_inits
        all_rewards = [None] * n_inits
        ori_mse_list = []
        guide_mse_list = []
        guide_ori_mse_list = []
        save_data = dict()
        save_data_list = [dict() for _ in range(n_chunks)]

        for chunk_idx in range(n_chunks):
            save_data = save_data_list[chunk_idx]
            start = chunk_idx * n_envs
            end = min(n_inits, start + n_envs)
            this_global_slice = slice(start, end)
            this_n_active_envs = end - start
            this_local_slice = slice(0, this_n_active_envs)

            this_init_fns = self.env_init_fn_dills[this_global_slice]
            n_diff = n_envs - len(this_init_fns)
            if n_diff > 0:
                this_init_fns.extend([self.env_init_fn_dills[0]] * n_diff)
            assert len(this_init_fns) == n_envs

            # init envs
            env.call_each('run_dill_function',
                          args_list=[(x,) for x in this_init_fns])

            # start rollout
            obs = env.reset()
            # obs.keys()
            # odict_keys(['robot0_eef_pos', 'robot0_eef_quat', 'robot0_eye_in_hand_image', 'robot0_gripper_qpos', 'robot1_eef_pos',
            #             'robot1_eef_quat', 'robot1_eye_in_hand_image', 'robot1_gripper_qpos', 'shouldercamera0_image',
            #             'shouldercamera1_image'])
            # obs['robot0_eef_pos'].shape
            # (14, 2, 3)
            # obs['robot0_eef_pos'][0]
            # array([[0.00288391, -0.3695618, 0.99539196],
            #        [0.00288391, -0.3695618, 0.99539196]], dtype=float32)
            # (
            # obs['robot0_eef_pos'][1]
            # array([[0.02061568, -0.36231545, 1.0037795],
            #        [0.02061568, -0.36231545, 1.0037795]], dtype=float32)
            # save_data['robot0_eef_pos'].shape
            # (14, 1,2, 3)

            past_action = None
            policy.reset()

            env_name = self.env_meta['env_name']
            pbar = tqdm.tqdm(total=self.max_steps, desc=f"Eval {env_name}Image {chunk_idx + 1}/{n_chunks}",
                             leave=False, mininterval=self.tqdm_interval_sec)

            done = False

            env_number = obs['robot0_eef_pos'].shape[0]

            num_tik = 0
            while not done:
                # create obs dict
                num_tik += 1

                np_obs_dict = dict(obs)
                if self.past_action and (past_action is not None):
                    # TODO: not tested
                    np_obs_dict['past_action'] = past_action[
                                                 :, -(self.n_obs_steps - 1):].astype(np.float32)

                # device transfer
                obs_dict = dict_apply(np_obs_dict,
                                      lambda x: torch.from_numpy(x).to(
                                          device=device))

                # robot0_eef_pos
                # shape: torch.Size([14, 2, 3])
                # robot0_eef_quat
                # shape: torch.Size([14, 2, 4])
                # robot0_eye_in_hand_image
                # shape: torch.Size([14, 2, 3, 84, 84])
                # robot0_gripper_qpos
                # shape: torch.Size([14, 2, 2])
                # robot1_eef_pos
                # shape: torch.Size([14, 2, 3])
                # robot1_eef_quat
                # shape: torch.Size([14, 2, 4])
                # robot1_eye_in_hand_image
                # shape: torch.Size([14, 2, 3, 84, 84])”
                # robot1_gripper_qpos
                # shape: torch.Size([14, 2, 2])
                # shouldercamera0_image
                # shape: torch.Size([14, 2, 3, 84, 84])
                # shouldercamera1_image
                # shape: torch.Size([14, 2, 3, 84, 84])

                # for key in obs_dict.keys():
                #     print(key,'shape:',obs_dict[key].shape)
                # save_data['obs']['robot0_eye_in_hand_image'].shape
                # torch.Size([14, 2, 3, 84, 84])

                # RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [14, 2, 3, 76, 76]

                # run policy
                # with torch.no_grad():
                #     action_dict = policy.predict_action(obs_dict)
                # device_transfer
                # key_list=list(obs_dict.keys())
                # runner_obs_step=obs_dict[key_list[0]].shape[1]

                # if runner_obs_step>policy.n_obs_steps:
                #     obs_dict_policy=dict()
                #     for key in obs_dict.keys():
                #         obs_dict_policy[key]=obs_dict[key][:,-policy.n_obs_steps:]
                # else:
                #     obs_dict_policy=obs_dict

                with torch.no_grad():

                    input_obs_dict_s={}
                    for key in obs_dict.keys():
                        input_obs_dict_s[key]=obs_dict[key][:,-s_policy.n_obs_steps:]

                    input_obs_dict_t={}
                    for key in obs_dict.keys():
                        input_obs_dict_t[key]=obs_dict[key][:,-policy.n_obs_steps:]

                    s_action_dict = s_policy.predict_action(input_obs_dict_s)



                    # s_action_dict['action']=s_action_dict['action'].cpu().numpy()

                    # s_action_dict['action']=convert_actions(s_action_dict['action'],s_policy.abs_action)
                    #
                    # s_action_dict['action']=torch.from_numpy(s_action_dict['action']).to(device=device)



                    action_dict = policy.predict_action_mse_guide(input_obs_dict_t,s_action_dict['action'],current_epoch)
                    # ori_mse = torch.mean((-s_action_dict['action']+t_ori_action_dict['action'])**2)
                    # ori_mse_list.append(ori_mse.item())
                    # guide_mse = torch.mean((-s_action_dict['action']+action_dict['action'])**2)
                    # guide_mse_list.append(guide_mse.item())
                    # guide_ori_mse = torch.mean((-t_ori_action_dict['action'] + action_dict['action']) ** 2)
                    # guide_ori_mse_list.append(guide_ori_mse.item())
                    # for key,value in action_dict.items():
                    #     print(key,value.shape)


                # step env


                if 'action' not in save_data:
                    save_data['action'] = action_dict['action'].cpu()
                else:
                    save_data['action'] = torch.cat((save_data['action'], action_dict['action'].cpu()), dim=1)  # [14, action_step, 20]

                # self.n_obs_steps = n_obs_steps
                # self.n_action_steps = n_action_steps
                np_action_dict = dict_apply(action_dict,
                                            lambda x: x.detach().to('cpu').numpy())


                action = np_action_dict['action']



                if not np.all(np.isfinite(action)):
                    print(action)
                    raise RuntimeError("Nan or Inf action")

                # env_action = np_action_dict['action_pred'][:,1:2]
                if self.abs_action:
                    action = self.undo_transform_action(action)
                    # env_action_pred = self.undo_transform_action(np_action_dict['action_pred'])
                    # env_action_pred = torch.from_numpy(env_action_pred).to(device=device)
                # action.shape
                # (14, 8, 20)
                # env_action.shape
                # (14, 8, 14)
                # print(action.shape,'actiopn')

                obs, reward, done, info = env.step(action)



                np_obs_dict = dict(obs)
                # device transfer
                obs_dict = dict_apply(np_obs_dict,
                                      lambda x: torch.from_numpy(x).to(
                                          device=device))



                if 'obs' not in save_data:
                    save_data['obs'] = dict()
                    for key in obs_dict.keys():
                        # print(obs_dict[key].shape)
                        # if obs_dict[key].shape[2]==4:
                            # print(obs_dict[key][0][:,0],num_tik)
                            # breakpoint()
                        save_data['obs'][key] = obs_dict[key][:,self.n_obs_steps-self.n_action_steps-1:self.n_obs_steps-1].cpu() # [14, 1, 3, 84, 84])
                else:
                    # for key in save_data['obs'].keys():
                    #     print(key)
                    for key in obs_dict.keys():
                        # print(obs_dict[key].shape)
                        # torch.Size([11, 9, 3, 84, 84])
                        # if obs_dict[key].shape[2]==2:
                        #     print(obs_dict[key][0][:,0],num_tik)

                        save_data['obs'][key] = torch.cat((save_data['obs'][key], obs_dict[key][:,self.n_obs_steps-self.n_action_steps-1:self.n_obs_steps-1].cpu()), dim=1)

                # for test
                if num_tik == 10:
                #     reward = torch.ones_like(torch.from_numpy(reward))
                #     save_data['reward'] = torch.cat((save_data['reward'], reward.unsqueeze(1)), dim=1)
                    pass
                    # break
                # Todo 这个break之后的代码不会直接跳出训练，而是继续执行，因为还有好几个并行环境，14，8就是4个14，2合成的。。。。 循环要再往外一层才行，先睡了
                # torch.Size([14, 8])
                # save_data['reward'][0]
                # tensor([0., 1., 0., 1., 0., 1., 0., 1.], dtype=torch.float64)
                # if torch.any(torch.from_numpy(reward)==1):
                #     print('这是50')
                if 'reward' not in save_data:
                    save_data['reward'] = torch.from_numpy(reward).unsqueeze(1).repeat(1,self.n_action_steps)
                else:
                    save_data['reward'] = torch.cat((save_data['reward'], torch.from_numpy(reward).unsqueeze(1).repeat(1,self.n_action_steps)), dim=1)
                    #  14,1
                # for test

                # obs['robot0_eef_pos'][0]
                # array([[0.01355493, -0.37373725, 0.9203873],
                #        [0.01482035, -0.37595108, 0.92129064]], dtype=float32)
                # obs['robot0_eef_pos'][1]
                # array([[0.14259064, -0.55685365, 1.1557218],
                #        [0.13804738, -0.5758146, 1.1942073]], dtype=float32)

                # test for env.step()
                # obs = env.reset()
                # obs, reward, done, info = env.step(env_action[:,:2])
                # obs['robot0_eef_pos'][0]
                # array([[-0.00370188, -0.37130702, 0.98732823],
                #        [0.00369851, -0.3710973, 0.96389383]], dtype=float32)
                # obs['robot0_eef_pos'][1]
                # array([[0.0297156, -0.37939224, 1.0014757],
                #        [0.04958566, -0.41708672, 1.0024276]], dtype=float32)
                # obs = env.reset()
                # obs, reward, done, info = env.step(env_action[:,:1])
                # obs['robot0_eef_pos'][0]
                # array([[0.00288391, -0.3695618, 0.99539196],
                #        [-0.00334014, -0.37123954, 0.98799133]], dtype=float32)
                # obs['robot0_eef_pos'][1]
                # array([[0.02061568, -0.36231545, 1.0037795],
                #        [0.02904901, -0.37806112, 1.0016063]], dtype=float32)

                # reward.shape
                # (14,)
                # done.shape
                # (14,)

                done = np.all(done)
                past_action = action

                # update pbar
                pbar.update(action.shape[1])
            pbar.close()

            # collect data for this round
            # all_video_paths[this_global_slice] = env.render()[this_local_slice]
            all_rewards[this_global_slice] = env.call('get_attr', 'reward')[this_local_slice]
            save_data_list[chunk_idx] = save_data

        # clear out video buffer

        _ = env.reset()
        # print(f"Ori MSE: {np.mean(ori_mse_list)}")
        # print(f"Guide MSE: {np.mean(guide_mse_list)}")
        # print(f"Guide Ori MSE: {np.mean(guide_ori_mse_list)}")
        # log
        max_rewards = collections.defaultdict(list)
        log_data = dict()
        # results reported in the paper are generated using the commented out line below
        # which will only report and average metrics from first n_envs initial condition and seeds
        # fortunately this won't invalidate our conclusion since
        # 1. This bug only affects the variance of metrics, not their mean
        # 2. All baseline methods are evaluated using the same code
        # to completely reproduce reported numbers, uncomment this line:
        # for i in range(len(self.env_fns)):
        # and comment out this line
        # for i in range(n_inits):
        #     seed = self.env_seeds[i]
        #     prefix = self.env_prefixs[i]
        #     max_reward = np.max(all_rewards[i])
        #     max_rewards[prefix].append(max_reward)
        #     log_data[prefix + f'sim_max_reward_{seed}'] = max_reward
        #
        #     # visualize sim
        #     video_path = all_video_paths[i]
        #     if video_path is not None:
        #         sim_video = wandb.Video(video_path)
        #         log_data[prefix + f'sim_video_{seed}'] = sim_video
        #
        # # log aggregate metrics
        # for prefix, value in max_rewards.items():
        #     name = prefix + 'mean_score'
        #     value = np.mean(value)
        #     log_data[name] = value
        # 一部分展平以后加到replay_buffer中，一部分加到self.data中

        # if 'reward' not in save_data:
        #     save_data['reward'] = torch.from_numpy(reward).unsqueeze(1)
        # else:
        #     save_data['reward'] = torch.cat((save_data['reward'], torch.from_numpy(reward).unsqueeze(1)), dim=1)

        traj_list = []
        # save_data['obs']['robot0_eye_in_hand_image'].shape
        # torch.Size([14, 8, 2, 3, 84, 84])
        for j in range(n_chunks):
            # save_data_list = [dict() for _ in range(n_chunks)]
            save_data = save_data_list[j]
            for i in range(env_number):
                reward_tensor = save_data['reward'][i]  # traj_len
                if torch.any(reward_tensor == 1):
                    position = (reward_tensor == 1).nonzero(as_tuple=True)[0][0]
                    traj_dict = {'actions': save_data['action'][i, :position +self.n_action_steps+8].numpy(), 'obs': {}}   # traj,3,84,84
                    for key in save_data['obs'].keys():
                        traj_dict['obs'][key] = save_data['obs'][key][i, :position + self.n_action_steps+8].numpy() # traj,20
                    traj_list.append(traj_dict)


        return log_data, traj_list

    def undo_transform_action(self, action):
        raw_shape = action.shape
        if raw_shape[-1] == 20:
            # dual arm
            action = action.reshape(-1, 2, 10)

        d_rot = action.shape[-1] - 4
        pos = action[..., :3]
        rot = action[..., 3:3 + d_rot]
        gripper = action[..., [-1]]
        rot = rotation_transformer.inverse(rot)
        uaction = np.concatenate([
            pos, rot, gripper
        ], axis=-1)

        if raw_shape[-1] == 20:
            # dual arm
            uaction = uaction.reshape(*raw_shape[:-1], 14)

        return uaction
