
from openpi.training import config as config_pi
from openpi.policies import policy_config
from openpi_client import image_tools
# from openpi.shared import download

import numpy as np


from accelerate import Accelerator
import torch
from diffusers import StableVideoDiffusionPipeline
from omegaconf import OmegaConf
import numpy as np
# import cv2
import torch
import torch.nn.functional as F
import torch.nn as nn
import einops
from accelerate import Accelerator
import datetime
import os
from accelerate.logging import get_logger
from tqdm.auto import tqdm
import wandb
import imageio
from video_models.pipeline import MaskStableVideoDiffusionPipeline
import json
from decord import VideoReader, cpu
import swanlab
import mediapy
import sys
from scipy.spatial.transform import Rotation as R

def get_tf_mat(i, dh):
    a = dh[i][0]
    d = dh[i][1]
    alpha = dh[i][2]
    theta = dh[i][3]
    q = theta

    return np.array([[np.cos(q), -np.sin(q), 0, a],
                     [np.sin(q) * np.cos(alpha), np.cos(q) * np.cos(alpha), -np.sin(alpha), -np.sin(alpha) * d],
                     [np.sin(q) * np.sin(alpha), np.cos(q) * np.sin(alpha), np.cos(alpha), np.cos(alpha) * d],
                     [0, 0, 0, 1]])


def get_fk_solution(joint_angles):
    dh_params = [[0, 0.333, 0, joint_angles[0]],
                 [0, 0, -np.pi/2, joint_angles[1]],
                 [0, 0.316, np.pi/2, joint_angles[2]],
                 [0.0825, 0, np.pi/2, joint_angles[3]],
                 [-0.0825, 0.384, -np.pi/2, joint_angles[4]],
                 [0, 0, np.pi/2, joint_angles[5]],
                 [0.088, 0, np.pi/2, joint_angles[6]],
                 [0, 0.107, 0, 0],
                 [0, 0, 0, -np.pi/4],
                 [0.0, 0.1034, 0, 0]]

    T = np.eye(4)
    for i in range(7 + 1):
        T = T @ get_tf_mat(i, dh_params)
    return T
    

class agent():
    def __init__(self):
          
        args = Args()
        self.args = args
        self.accelerator = Accelerator()
        self.device = self.accelerator.device

        # load policy
        if 'pi05' in args.policy_type:
            config = config_pi.get_config("pi05_droid")
            checkpoint_dir = '/cephfs/shared/llm/openpi/openpi-assets-preview/checkpoints/pi05_droid' 
        elif 'pi0fast' in args.policy_type:
            config = config_pi.get_config("pi0fast_droid")
            checkpoint_dir = '/cephfs/shared/llm/openpi/openpi-assets/checkpoints/pi0fast_droid'
        elif 'pi0' in args.policy_type:
            config = config_pi.get_config("pi0_droid")
            checkpoint_dir = '/cephfs/shared/llm/openpi/openpi-assets/checkpoints/pi0_droid'
        else:
            raise ValueError(f"Unknown policy type: {args.policy_type}")
        self.policy = policy_config.create_trained_policy(config, checkpoint_dir)

        # load world model  
        sys.path.append('./')
        from exp31_droid_framecond_s9 import FuseSVD
        self.model = FuseSVD(args)
        self.model.load_state_dict(torch.load(args.val_model_path))
        self.model.to(self.accelerator.device)
        self.model.eval()
        print("load world model success")
        with open(f"{args.data_stat_path}", 'r') as f:
            data_stat = json.load(f)
            self.state_p01 = np.array(data_stat['state_01'])[None,:]
            self.state_p99 = np.array(data_stat['state_99'])[None,:]
        
        # load dynamics model
        if args.dynamics_model_path is not None:
            from output_dynamics.train2 import Dynamics
            # from output_dynamics.train3 import Dynamics
            self.dynamics_model = Dynamics(action_dim=7, action_num=15, hidden_size=512).to(self.device)
            self.dynamics_model.load_state_dict(torch.load(args.dynamics_model_path, map_location=self.device))


        # report cuda memory usage
        if torch.cuda.is_available():
            print(f"CUDA memory allocated: {torch.cuda.memory_allocated() / (1024 ** 2):.2f} MB")
            print(f"CUDA memory reserved: {torch.cuda.memory_reserved() / (1024 ** 2):.2f} MB")         

    def normalize_bound(
        self,
        data: np.ndarray,
        data_min: np.ndarray,
        data_max: np.ndarray,
        clip_min: float = -1,
        clip_max: float = 1,
        eps: float = 1e-8,
    ) -> np.ndarray:
        ndata = 2 * (data - data_min) / (data_max - data_min + eps) - 1
        return np.clip(ndata, clip_min, clip_max)

    def get_traj_info(self, id, start_idx=0, steps=8,delta=0.0):
        val_dataset_dir = self.args.val_dataset_dir
        args = self.args
        skip = args.skip_step
        num_frames = steps
        annotation_path = f"{val_dataset_dir}/annotation/val/{id}.json"
        with open(annotation_path) as f:
            anno = json.load(f)
            try:
                length = len(anno['action'])
            except:
                length = anno["video_length"]
        frames_ids = np.arange(start_idx, start_idx + num_frames * skip, skip)
        print(frames_ids)
        max_ids = np.ones_like(frames_ids) * (length - 1)
        frames_ids = np.min([frames_ids, max_ids], axis=0).astype(int)
        print("frames_ids", frames_ids)

        instruction = anno['texts'][0]
        
        action = anno['states']
        action = np.array(action)
        action = action[frames_ids]

        joint_pos = anno['joints']
        joint_pos = np.array(joint_pos)
        joint_pos = joint_pos[frames_ids]

        # delta = 0.1
        for i in range(action.shape[0]):
            if i < 15:
                action[i,1] += i*delta/15
            else:
                action[i,1] += delta

        # action = self.normalize_bound(action, self.state_p01, self.state_p99)

        # get videos
        video_dict =[]
        video_latent = []
        for id in range(len(anno['videos'])):
            video_path = anno['videos'][id]['video_path']
            video_path = f"{val_dataset_dir}/{video_path}"

            # load videos and actions
            vr = VideoReader(video_path, ctx=cpu(0), num_threads=2)
            try:
                true_video = vr.get_batch(range(length)).asnumpy()
            except:
                true_video = vr.get_batch(range(length)).numpy()
            true_video = true_video[frames_ids]
            video_dict.append(true_video)

        
            # encode video
            device = self.device
            true_video = torch.from_numpy(true_video).float().to(device)
            x = true_video.permute(0,3,1,2).to(device) / 255.0*2-1
            vae = self.model.pipeline.vae
            with torch.no_grad():
                batch_size = 32
                latents = []
                for i in range(0, len(x), batch_size):
                    batch = x[i:i+batch_size]
                    latent = vae.encode(batch).latent_dist.sample().mul_(vae.config.scaling_factor)
                    latents.append(latent)
                x = torch.cat(latents, dim=0)
    
            video_latent.append(x)

        
        return action,joint_pos, video_dict, video_latent, instruction

    def forward_wm(self, action_cond, video_latent_true, video_latent_cond, his_cond=None, text=None):
        # action_input, video_latent_true, video_latent_first, his_cond=his_cond_input,text=text_i
        args = self.args
        image_cond = video_latent_cond

        # action should be normed
        action_cond = self.normalize_bound(
            action_cond, self.state_p01, self.state_p99, clip_min=-1, clip_max=1
        )
        action_cond = torch.tensor(action_cond).unsqueeze(0).to(self.device).float() # (1, 14)
        assert image_cond.shape[1:] == (4, 72, 40)
        assert action_cond.shape[1:] == (args.num_frames+args.num_history, args.action_dim)


        # predict future frames
        with torch.no_grad():
            bsz = action_cond.shape[0]
            if text is not None:
                text_token = self.model.action_encoder(action_cond, text, self.model.tokenizer, self.model.text_encoder)
            else:
                text_token = self.model.action_encoder(action_cond)           
            pipeline = self.model.pipeline
            
            _, latents = MaskStableVideoDiffusionPipeline.__call__(
                pipeline,
                image=image_cond,
                text=text_token,
                width=320,
                height=int(72*8),
                num_frames=args.num_frames,
                history=his_cond,
                num_inference_steps=args.num_inference_steps,
                decode_chunk_size=args.decode_chunk_size,
                max_guidance_scale=args.guidance_scale,
                fps=args.fps,
                motion_bucket_id=args.motion_bucket_id,
                mask=None,
                output_type='latent',
                return_dict=False,
                frame_independent=True,
            )
        latents = einops.rearrange(latents, 'b f c (m h) (n w) -> (b m n) f c h w', m=3,n=1) # (B, 8, 4, 32,32)


        # decode ground truth video
        true_video = torch.stack(video_latent_true, dim=0) # (bsz, 8,32,32)
        decoded_video = []
        bsz,frame_num = true_video.shape[:2]
        true_video = true_video.flatten(0,1)
        decode_kwargs = {}
        for i in range(0,true_video.shape[0],args.decode_chunk_size):
            chunk = true_video[i:i+args.decode_chunk_size]/pipeline.vae.config.scaling_factor
            decode_kwargs["num_frames"] = chunk.shape[0]
            decoded_video.append(pipeline.vae.decode(chunk, **decode_kwargs).sample)
        true_video = torch.cat(decoded_video,dim=0)
        true_video = true_video.reshape(bsz,frame_num,*true_video.shape[1:])
        true_video = ((true_video / 2.0 + 0.5).clamp(0, 1)*255)
        true_video = true_video.to(torch.float32).detach().cpu().numpy().transpose(0,1,3,4,2).astype(np.uint8) #(2,16,256,256,3)

        # decode predicted video
        decoded_video = []
        bsz,frame_num = latents.shape[:2]
        x = latents.flatten(0,1)
        decode_kwargs = {}
        for i in range(0,x.shape[0],args.decode_chunk_size):
            chunk = x[i:i+args.decode_chunk_size]/pipeline.vae.config.scaling_factor
            decode_kwargs["num_frames"] = chunk.shape[0]
            decoded_video.append(pipeline.vae.decode(chunk, **decode_kwargs).sample)
        videos = torch.cat(decoded_video,dim=0)
        videos = videos.reshape(bsz,frame_num,*videos.shape[1:])
        videos = ((videos / 2.0 + 0.5).clamp(0, 1)*255)
        videos = videos.to(torch.float32).detach().cpu().numpy().transpose(0,1,3,4,2).astype(np.uint8)

        # concatenate true videos and video
        videos_cat = np.concatenate([true_video,videos],axis=-3) # (3, 8, 256, 256, 3)
        videos_cat = np.concatenate([video for video in videos_cat],axis=-2).astype(np.uint8) 

        return videos_cat, true_video, videos, latents  # np.uint8:(3, 8, 128, 256, 3) or (3, 8, 192, 320, 3)


class Args:
    pretrained_model_path = "/cephfs/shared/llm/stable-video-diffusion-img2vid" #  #"/cephfs/shared/llm/stable-video-diffusion-img2vid" # "/cephfs/cjyyj/code/video_robot_svd/output/svd/train_2025-05-08T01-09-10/checkpoint-320000" #"/cephfs/shared/llm/stable-video-diffusion-img2vid"
    clip_model_path = "/cephfs/shared/llm/clip-vit-base-patch32"


    debug = False
    output_dir = "output_unit_test/yay_robot_4img_cond"
    project_name = "unit_test_svd"
    tag = 'action_cond'
    action_dim = 7

    load_path = None
    learning_rate= 5e-6
    gradient_accumulation_steps = 1
    mixed_precision = 'fp16'
    train_batch_size = 12
    shuffle = True

    num_train_epochs = 100
    max_train_steps = 500000
    checkpointing_steps = 20000
    validation_steps = 2000
    max_grad_norm = 1.0

    dataset_names = 'droid_svd_v2'
    dataset_dir="/cephfs/shared/droid_hf" 
    data_root_path = "/cephfs/shared/droid_hf" #'/cephfs/shared/droid_hf/opensource_robotdata'
    annotation_name='annotation' # annotation dirname under dataset_dir path
    # dataset='xhand'
    # prob=[1.0]
    tie_weight=True
    normalize=True
    pre_encode=True
    num_workers=4
    video_size = [256,256]
    
    only_one_clip = True
    get_action = True
    

    clip_img_size = 224
    use_img_cond =  False
    motion_bucket_id = 127
    fps = 7
    guidance_scale = 2.0 #7.5 #7.5 #7.5 #3.0 #
    num_inference_steps = 50
    decode_chunk_size = 7
    width = 320
    height = 192
    validation_num= 32
    video_num= 3
    num_frames= 5
    num_history = 6
    # sequence_length = 8
    data_json_path = 'video_dataset/xhand_mix_seq16_train.json' #'video_dataset/xhand_skip2_seq16_train.json'
    val_data_json_path = "video_dataset/xhand_mix_seq16_val.json" #'video_dataset/xhand_skip2_seq16_val.json'
    data_stat_path = 'exp_cfg/droid_svd_v2/stat.json' # 'exp_cfg/droid_svd_v2/stat.json'
    val_model_path = '/cephfs/cjyyj/code/video_evaluation/output2/exp33_210_s11/checkpoint-10000.pt' #'/cephfs/cjyyj/code/video_evaluation/output2/exp33_210_s11_2/checkpoint-20000.pt' #
    pred_step = 5
    skip_step = 1

    val_dataset_dir = '/cephfs/shared/droid_hf/droid_svd_v2'
    deltas = [0.0]
    deltas = [0.0]
    val_id = ['18599']*len(deltas)
    start_idx = [14]*len(deltas)
    instruction = [""]*len(val_id)

    interact_num = 8


    task_name = f'replay_demo'
    gripper_max = 1.0


    dynamics_model_path = '/cephfs/cjyyj/code/video_evaluation/output_dynamics/model2_15_9.pth' #None #'/cephfs/cjyyj/code/video_evaluation/output_dynamics/model2_epoch_9.pth' #None #'output_dynamics/model3_epoch_6.pth' #'output_dynamics/model_epoch_1.pth'
    policy_type = 'pi05' # 'pi05' # 'pi0' # 'pi0fast'

    text_cond = True

    # CUDA_VISIBLE_DEVICES=0 python exp33_replay_demo.py

        
if __name__ == "__main__":

    Agent = agent()
    interact_num = Agent.args.interact_num
    pred_step = Agent.args.pred_step # important parameters
    num_history = Agent.args.num_history
    num_frames = Agent.args.num_frames


    # for val_id_i, text_i, start_idx_i in zip(Args.val_id, Args.instruction, Args.start_idx):
    for id in range(len(Args.val_id)):
        # get initial state, groud truth actions
        # id = val_id_i
        val_id_i = Args.val_id[id]
        start_idx_i = Args.start_idx[id]
        delta = Args.deltas[id]
        try:
            eef_gt, joint_pos_gt, video_dict, video_latents, instruction = Agent.get_traj_info(val_id_i, start_idx=start_idx_i, steps=int(pred_step*interact_num+8),delta=delta)
        except:
            print(f"Error in loading data for traj id {val_id_i}, continue to next")
            continue
        text_i = instruction
        print("text_i:",instruction, "eef pose 0", eef_gt[0], "joint 0", joint_pos_gt[0])

        # t=0 for each episode
        predict_latents = None
        
        video_to_save = []
        info_to_save = []

        his_cond = []
        his_joint = []
        his_eef = []
        first_latent = torch.cat([v[0] for v in video_latents], dim=1).unsqueeze(0)  # (1, 4, 72, 40)
        assert first_latent.shape == (1, 4, 72, 40), f"Expected first_latent shape (1, 4, 72, 40), got {first_latent.shape}"
        for i in range(Agent.args.num_history*4):
            his_cond.append(first_latent)  # (1, 4, 72, 40)
            his_joint.append(joint_pos_gt[0:1])  # (1, 7)
            his_eef.append(eef_gt[0:1])  # (1, 7)
        # his_cond = [v[0] for v in video_latents]
        # his_action = action_gt[0:1]  # (4, 7)



        # interact loop
        for i in range(interact_num):
            start_id = int(i*(pred_step-1))
            end_id = start_id + pred_step
            video_latent_true = [v[start_id:end_id] for v in video_latents]
            
            # prepare input for policy
            joint_first = his_joint[-1][0] # (1, 8)
            state_first = his_eef[-1][0] # (1, 8)
            if i==0:
                video_first = [v[0] for v in video_dict]
            else:
                video_first = [v[-1] for v in video_dict_pred]
            assert joint_first.shape == (8,), f"Expected joint_first shape (8,), got {joint_first.shape}"
            assert state_first.shape == (7,), f"Expected state_first shape (7,), got {state_first.shape}"
            
            # forward policy
            print("################ policy forward ####################")
            # policy_in_out, joint_pos, state_pos= Agent.forward_policy(video_first, state_first, joint_first, text=text_i, time_step=i)

            # state_pos = eef_gt[int(i*pred_step):int(i*pred_step+pred_step)]  # use gt state for each step
            
            state_pos = eef_gt[start_id:end_id]  # (pred_step, 7)
            # [0.6571552157402039, -0.23330219089984894, 0.14893116056919098, 3.024440050125122, -0.10950495302677155, -0.5373510122299194, 0.0]
            # [0.6629241108894348, -0.23121032118797302, 0.1249917671084404, 3.033341884613037, -0.07552512735128403, -0.5513433814048767, 0.0]
            # [0.6683378219604492, -0.23391945660114288, 0.10137390345335007, 3.064735174179077, -0.05264944210648537, -0.5547294616699219, 0.0]
            # [0.6717594265937805, -0.24730165302753448, 0.08032714575529099, 3.097148895263672, -0.04734605923295021, -0.550068199634552, 0.0]
            # [0.6743124723434448, -0.264811635017395, 0.06130369007587433, 3.118835210800171, -0.05084453523159027, -0.5484786629676819, 0.0]
            # [0.6779722571372986, -0.27722054719924927, 0.04469338804483414, 3.1374242305755615, -0.05760820955038071, -0.5545875430107117, 0.0]
            # [0.6827353239059448, -0.27674779295921326, 0.030309824272990227, -3.1302576065063477, -0.058622077107429504, -0.5516106486320496, 0.0]
            # [0.6880980134010315, -0.2685723602771759, 0.019759373739361763, -3.1161246299743652, -0.052770618349313736, -0.5417410731315613, 0.0]
            # [0.691747784614563, -0.262683629989624, 0.012840778566896915, -3.112227201461792, -0.04243526607751846, -0.5334827303886414, 0.08370043337345123]
            
            # if i == 0:
            #     state_pos = np.array([[0.6571552157402039, -0.23330219089984894, 0.14893116056919098, 3.024440050125122, -0.10950495302677155, -0.5373510122299194, 0.0],
            #                             [0.6571552157402039, -0.21330219089984894, 0.14893116056919098, 3.024440050125122, -0.10950495302677155, -0.5373510122299194, 0.0],
            #                             [0.6571552157402039, -0.19330219089984894, 0.14893116056919098, 3.024440050125122, -0.10950495302677155, -0.5373510122299194, 0.0],
            #                             [0.6571552157402039, -0.17330219089984894, 0.14893116056919098, 3.024440050125122, -0.10950495302677155, -0.5373510122299194, 0.0],
            #                             [0.6571552157402039, -0.15330219089984894, 0.14893116056919098, 3.024440050125122, -0.10950495302677155, -0.5373510122299194, 0.0],
            #                             ])
            # elif i == 1:
            #     state_pos = np.array([[0.6571552157402039, -0.15330219089984894, 0.14893116056919098, 3.024440050125122, -0.10950495302677155, -0.5373510122299194, 0.0],
            #                             [0.6571552157402039, -0.13330219089984894, 0.14893116056919098, 3.024440050125122, -0.10950495302677155, -0.5373510122299194, 0.0],
            #                             [0.6571552157402039, -0.11330219089984894, 0.14893116056919098, 3.024440050125122, -0.10950495302677155, -0.5373510122299194, 0.0],
            #                             [0.6571552157402039, -0.09330219089984894, 0.14893116056919098, 3.024440050125122, -0.10950495302677155, -0.5373510122299194, 0.0],
            #                             [0.6571552157402039, -0.07330219089984894, 0.14893116056919098, 3.024440050125122, -0.10950495302677155, -0.5373510122299194, 0.0],
            #                             ])
            # elif i == 2:
            #     state_pos = np.array([[0.6571552157402039, -0.07330219089984894, 0.16893116056919098, 3.024440050125122, -0.10950495302677155, -0.5373510122299194, 0.0],
            #                             [0.6571552157402039, -0.07330219089984894, 0.18893116056919098, 3.024440050125122, -0.10950495302677155, -0.5373510122299194, 0.0],
            #                             [0.6571552157402039, -0.07330219089984894, 0.20893116056919098, 3.024440050125122, -0.10950495302677155, -0.5373510122299194, 0.0],
            #                             [0.6571552157402039, -0.07330219089984894, 0.22893116056919098, 3.024440050125122, -0.10950495302677155, -0.5373510122299194, 0.0],
            #                             [0.6571552157402039, -0.07330219089984894, 0.24893116056919098, 3.024440050125122, -0.10950495302677155, -0.5373510122299194, 0.0],
            #                             ])
            # elif i == 3:
            #     state_pos = np.array([[0.6571552157402039, -0.07330219089984894, 0.24893116056919098, 3.024440050125122, -0.10950495302677155, -0.5373510122299194, 0.0],
            #                             [0.6571552157402039, -0.07330219089984894, 0.26893116056919098, 3.024440050125122, -0.10950495302677155, -0.5373510122299194, 0.0],
            #                             [0.6571552157402039, -0.07330219089984894, 0.28893116056919098, 3.024440050125122, -0.10950495302677155, -0.5373510122299194, 0.0],
            #                             [0.6571552157402039, -0.07330219089984894, 0.30893116056919098, 3.024440050125122, -0.10950495302677155, -0.5373510122299194, 0.0],
            #                             [0.6571552157402039, -0.07330219089984894, 0.32893116056919098, 3.024440050125122, -0.10950495302677155, -0.5373510122299194, 0.0],
            #                             ])
            # elif i == 4:
            #     state_pos = np.array([[0.6571552157402039, -0.07330219089984894, 0.32893116056919098, 3.024440050125122, -0.10950495302677155, -0.5373510122299194, 0.0],
            #                             [0.6571552157402039, -0.07330219089984894, 0.32893116056919098, 3.024440050125122, -0.10950495302677155, -0.5373510122299194, 0.2],
            #                             [0.6571552157402039, -0.07330219089984894, 0.32893116056919098, 3.024440050125122, -0.10950495302677155, -0.5373510122299194, 0.4],
            #                             [0.6571552157402039, -0.07330219089984894, 0.32893116056919098, 3.024440050125122, -0.10950495302677155, -0.5373510122299194, 0.6],
            #                             [0.6571552157402039, -0.07330219089984894, 0.32893116056919098, 3.024440050125122, -0.10950495302677155, -0.5373510122299194, 0.9],
            #                             ])
            # elif i == 5:
            #     state_pos = np.array([[0.6571552157402039, -0.07330219089984894, 0.32893116056919098, 3.024440050125122, -0.10950495302677155, -0.5373510122299194, 0.9],
            #                             [0.6571552157402039, -0.07330219089984894, 0.30893116056919098, 3.024440050125122, -0.10950495302677155, -0.5373510122299194, 0.9],
            #                             [0.6571552157402039, -0.07330219089984894, 0.28893116056919098, 3.024440050125122, -0.10950495302677155, -0.5373510122299194, 0.9],
            #                             [0.6571552157402039, -0.07330219089984894, 0.26893116056919098, 3.024440050125122, -0.10950495302677155, -0.5373510122299194, 0.9],
            #                             [0.6571552157402039, -0.07330219089984894, 0.24893116056919098, 3.024440050125122, -0.10950495302677155, -0.5373510122299194, 0.9],
            #                             ])

            print("policy output eef pose", state_pos) # output xyz and gripper for debug

            
            # forward world model
            print("################ world model forward ################")
            print(f'traj_id:{val_id_i}, interact step: {i}!!!!!!!!')

            history_idx = [-15,-15,-15,-9,-6,-3]
            history_idx = [-15,-15,-8,-6,-4,-2]
            history_idx = [-10,-10,-8,-6,-4,-2]
            his_eef_input = np.concatenate([his_eef[idx] for idx in history_idx], axis=0)  # (4, 7)
            action_input = np.concatenate([his_eef_input, state_pos], axis=0)
            his_cond_input = torch.cat([his_cond[idx] for idx in history_idx], dim=0).unsqueeze(0)
            video_latent_first = his_cond[-1]  # (1, 4, 72, 40)
            assert video_latent_first.shape == (1, 4, 72, 40), f"Expected video_latent_first shape (1, 4, 72, 40), got {video_latent_first.shape}"
            assert action_input.shape == (11, 7), f"Expected action_input shape (11, 7), got {action_input.shape}"
            assert his_cond_input.shape == (1, 6, 4, 72, 40), f"Expected his_cond_input shape (1, 6, 72, 40), got {his_cond_input.shape}"


            videos_cat, true_videos, video_dict_pred, predict_latents = Agent.forward_wm(action_input, video_latent_true, video_latent_first, his_cond=his_cond_input,text=text_i if Agent.args.text_cond else None)
            
            # his_joint.append(joint_pos[pred_step-1:pred_step])  # (1, 8)
            his_eef.append(state_pos[pred_step-1:pred_step])
            his_cond.append(torch.cat([v[pred_step-1] for v in predict_latents], dim=1).unsqueeze(0))  # (1, 4, 72, 40)
            if i == interact_num - 1:
                video_to_save.append(videos_cat)  # save all frames for the last step
            else:
                video_to_save.append(videos_cat[:pred_step-1]) # last frame is the first frame of next step, so we only save the first pred_step-1 frames

            # num = int(3*(pred_step-1))
            # policy_in_out = {key: value[:num] for key, value in policy_in_out.items()}  # only save the first pred_step-1 frames
            # info_to_save.append(policy_in_out)  # save policy output info
                
        video = np.concatenate(video_to_save, axis=0)
        
        # save rollout video and info with parameters
        task_name = Args.task_name
        wm_id = Agent.args.val_model_path.split('/')[-1].split('.')[0]
        text_id = text_i.replace(' ', '_').replace(',', '').replace('.', '').replace('\'', '').replace('\"', '')[:30]
        num_inference_steps = Agent.args.num_inference_steps
        guidance_scale = Agent.args.guidance_scale
        # path for save predicted videos 
        videos_dir = Args.val_model_path.split('/')[:-1]
        videos_dir = '/'.join(videos_dir)
        uuid = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        filename_video = f"{videos_dir}/{task_name}/video/time_{uuid}_traj_{val_id_i}_{start_idx_i}_{pred_step}_{wm_id}_{guidance_scale}_{num_inference_steps}_{text_id}.mp4"
        os.makedirs(os.path.dirname(filename_video), exist_ok=True)
        mediapy.write_video(filename_video, video, fps=5)
        print(f"Saving video to {filename_video}")
        print("##########################################################################")

        # save info 
        # gather all the dict 
        # info = {'success': 1, 'start_idx': 0, 'end_idx': video.shape[0]-1, 'instructions':text_i}
        # for key in info_to_save[0].keys():
        #     info[key] = []
        #     for i in range(len(info_to_save)):
        #         info[key]+=info_to_save[i][key].tolist()[:int(pred_step*3)]
        
        # save to json
        # filename_info = f"{videos_dir}/{task_name}/info/time_{uuid}_traj_{val_id_i}_{start_idx_i}_{pred_step}_{wm_id}_{guidance_scale}_{num_inference_steps}_{text_id}.json"
        # os.makedirs(os.path.dirname(filename_info), exist_ok=True)
        # with open(filename_info, 'w') as f:
        #     json.dump(info, f, indent=4)
        
        
        
