
from openpi.training import config as config_pi
from openpi.policies import policy_config
from openpi_client import image_tools
# from openpi.shared import download

import numpy as np


from accelerate import Accelerator
import torch
from diffusers import StableVideoDiffusionPipeline
from omegaconf import OmegaConf
import numpy as np
# import cv2
import torch
import torch.nn.functional as F
import torch.nn as nn
import einops
from accelerate import Accelerator
import datetime
import os
from accelerate.logging import get_logger
from tqdm.auto import tqdm
import wandb
import imageio
from video_models.pipeline import MaskStableVideoDiffusionPipeline
import json
from decord import VideoReader, cpu
import swanlab
import mediapy
import sys
from scipy.spatial.transform import Rotation as R

def get_tf_mat(i, dh):
    a = dh[i][0]
    d = dh[i][1]
    alpha = dh[i][2]
    theta = dh[i][3]
    q = theta

    return np.array([[np.cos(q), -np.sin(q), 0, a],
                     [np.sin(q) * np.cos(alpha), np.cos(q) * np.cos(alpha), -np.sin(alpha), -np.sin(alpha) * d],
                     [np.sin(q) * np.sin(alpha), np.cos(q) * np.sin(alpha), np.cos(alpha), np.cos(alpha) * d],
                     [0, 0, 0, 1]])


def get_fk_solution(joint_angles):
    dh_params = [[0, 0.333, 0, joint_angles[0]],
                 [0, 0, -np.pi/2, joint_angles[1]],
                 [0, 0.316, np.pi/2, joint_angles[2]],
                 [0.0825, 0, np.pi/2, joint_angles[3]],
                 [-0.0825, 0.384, -np.pi/2, joint_angles[4]],
                 [0, 0, np.pi/2, joint_angles[5]],
                 [0.088, 0, np.pi/2, joint_angles[6]],
                 [0, 0.107, 0, 0],
                 [0, 0, 0, -np.pi/4],
                 [0.0, 0.1034, 0, 0]]

    T = np.eye(4)
    for i in range(7 + 1):
        T = T @ get_tf_mat(i, dh_params)
    return T
    

class agent():
    def __init__(self):
          
        args = Args()
        self.args = args
        self.accelerator = Accelerator()
        self.device = self.accelerator.device

        # load policy
        if 'pi05' in args.policy_type:
            config = config_pi.get_config("pi05_droid")
            checkpoint_dir = '/cephfs/shared/llm/openpi/openpi-assets-preview/checkpoints/pi05_droid' 
        elif 'pi0fast' in args.policy_type:
            config = config_pi.get_config("pi0fast_droid")
            checkpoint_dir = '/cephfs/shared/llm/openpi/openpi-assets/checkpoints/pi0fast_droid'
        elif 'pi0' in args.policy_type:
            config = config_pi.get_config("pi0_droid")
            checkpoint_dir = '/cephfs/shared/llm/openpi/openpi-assets/checkpoints/pi0_droid'
        else:
            raise ValueError(f"Unknown policy type: {args.policy_type}")
        if args.pi_model_path is not None:
            checkpoint_dir = args.pi_model_path
        self.policy = policy_config.create_trained_policy(config, checkpoint_dir)
        print('load policy success from:', args.pi_model_path)

        # load world model  
        sys.path.append('./')
        from exp31_droid_framecond_s9 import FuseSVD
        self.model = FuseSVD(args)
        self.model.load_state_dict(torch.load(args.val_model_path))
        self.model.to(self.accelerator.device)
        self.model.eval()
        print("load world model success")
        with open(f"{args.data_stat_path}", 'r') as f:
            data_stat = json.load(f)
            self.state_p01 = np.array(data_stat['state_01'])[None,:]
            self.state_p99 = np.array(data_stat['state_99'])[None,:]
        
        # load dynamics model
        if args.dynamics_model_path is not None:
            from output_dynamics.train2 import Dynamics
            # from output_dynamics.train3 import Dynamics
            self.dynamics_model = Dynamics(action_dim=7, action_num=15, hidden_size=512).to(self.device)
            self.dynamics_model.load_state_dict(torch.load(args.dynamics_model_path, map_location=self.device))


        # report cuda memory usage
        if torch.cuda.is_available():
            print(f"CUDA memory allocated: {torch.cuda.memory_allocated() / (1024 ** 2):.2f} MB")
            print(f"CUDA memory reserved: {torch.cuda.memory_reserved() / (1024 ** 2):.2f} MB")         

    def normalize_bound(
        self,
        data: np.ndarray,
        data_min: np.ndarray,
        data_max: np.ndarray,
        clip_min: float = -1,
        clip_max: float = 1,
        eps: float = 1e-8,
    ) -> np.ndarray:
        ndata = 2 * (data - data_min) / (data_max - data_min + eps) - 1
        return np.clip(ndata, clip_min, clip_max)

    def get_traj_info(self, id, start_idx=0, steps=8):
        val_dataset_dir = self.args.val_dataset_dir
        args = self.args
        skip = args.skip_step
        num_frames = steps
        annotation_path = f"{val_dataset_dir}/annotation/val/{id}.json"
        with open(annotation_path) as f:
            anno = json.load(f)
            try:
                length = len(anno['action'])
            except:
                length = anno["video_length"]
        frames_ids = np.arange(start_idx, start_idx + num_frames * skip, skip)
        print(frames_ids)
        max_ids = np.ones_like(frames_ids) * (length - 1)
        frames_ids = np.min([frames_ids, max_ids], axis=0).astype(int)
        print("frames_ids", frames_ids)
        

        action = anno['states']
        action = np.array(action)
        action = action[frames_ids]

        joint_pos = anno['joints']
        joint_pos = np.array(joint_pos)
        joint_pos = joint_pos[frames_ids]

        # for i in range(action.shape[0]):
        #     if i < 15:
        #         action[i,1] += i*0.1/15
        #     else:
        #         action[i,1] += 0.1

        # action = self.normalize_bound(action, self.state_p01, self.state_p99)

        # get videos
        video_dict =[]
        video_latent = []
        for id in range(len(anno['videos'])):
            video_path = anno['videos'][id]['video_path']
            video_path = f"{val_dataset_dir}/{video_path}"

            # load videos and actions
            vr = VideoReader(video_path, ctx=cpu(0), num_threads=2)
            try:
                true_video = vr.get_batch(range(length)).asnumpy()
            except:
                true_video = vr.get_batch(range(length)).numpy()
            true_video = true_video[frames_ids]
            video_dict.append(true_video)

        
            # encode video
            device = self.device
            true_video = torch.from_numpy(true_video).float().to(device)
            x = true_video.permute(0,3,1,2).to(device) / 255.0*2-1
            # resize to 192*320
            # x = F.interpolate(x, size=(args.height, args.width), mode='bilinear', align_corners=False)

            vae = self.model.pipeline.vae
            with torch.no_grad():
                batch_size = 32
                latents = []
                for i in range(0, len(x), batch_size):
                    batch = x[i:i+batch_size]
                    latent = vae.encode(batch).latent_dist.sample().mul_(vae.config.scaling_factor)
                    # x = vae.encode(x).latent_dist.sample().mul_(vae.config.scaling_factor).cpu()
                    latents.append(latent)
                x = torch.cat(latents, dim=0)
                # x = x [:, :4, 2:18, :40]  # take first 4 channels
                # assert x.shape[1:] == (4, 16, 40), f"Expected shape (4, 16, 40), got {x.shape[1:]}"
            video_latent.append(x)

        
        return action,joint_pos, video_dict, video_latent

    def forward_wm(self, action_cond, video_latent_true, video_latent_cond, his_cond=None, text=None):
        # action_input, video_latent_true, video_latent_first, his_cond=his_cond_input,text=text_i
        args = self.args

        # prepare input
        # image_cond = torch.zeros((1, 4, 72, 40), dtype=torch.float32,device=self.device) # (72,40) (64,64)
        # image_cond[:,:4,:24] = video_latent_cond[0]
        # image_cond[:,:4,24:48] = video_latent_cond[1]
        # image_cond[:,:4,48:72] = video_latent_cond[2]
        # image_cond[:,4:8,:24] = his_cond[0]
        # image_cond[:,4:8,24:48] = his_cond[1]
        # image_cond[:,4:8,48:72] = his_cond[2]

        image_cond = video_latent_cond

        # action should be normed
        action_cond = self.normalize_bound(
            action_cond, self.state_p01, self.state_p99, clip_min=-1, clip_max=1
        )
        action_cond = torch.tensor(action_cond).unsqueeze(0).to(self.device).float() # (1, 14)
        assert image_cond.shape[1:] == (4, 72, 40)
        assert action_cond.shape[1:] == (args.num_frames+args.num_history, args.action_dim)


        # predict future frames
        with torch.no_grad():
            bsz = action_cond.shape[0]
            if text is not None:
                text_token = self.model.action_encoder(action_cond, text, self.model.tokenizer, self.model.text_encoder)
            else:
                text_token = self.model.action_encoder(action_cond)           
            pipeline = self.model.pipeline
            
            _, latents = MaskStableVideoDiffusionPipeline.__call__(
                pipeline,
                image=image_cond,
                text=text_token,
                width=320,
                height=int(72*8),
                num_frames=args.num_frames,
                history=his_cond,
                num_inference_steps=args.num_inference_steps,
                decode_chunk_size=args.decode_chunk_size,
                max_guidance_scale=args.guidance_scale,
                fps=args.fps,
                motion_bucket_id=args.motion_bucket_id,
                mask=None,
                output_type='latent',
                return_dict=False,
                frame_independent=True,
            )
        latents = einops.rearrange(latents, 'b f c (m h) (n w) -> (b m n) f c h w', m=3,n=1) # (B, 8, 4, 32,32)


        # decode ground truth video
        true_video = torch.stack(video_latent_true, dim=0) # (bsz, 8,32,32)
        decoded_video = []
        bsz,frame_num = true_video.shape[:2]
        true_video = true_video.flatten(0,1)
        decode_kwargs = {}
        for i in range(0,true_video.shape[0],args.decode_chunk_size):
            chunk = true_video[i:i+args.decode_chunk_size]/pipeline.vae.config.scaling_factor
            decode_kwargs["num_frames"] = chunk.shape[0]
            decoded_video.append(pipeline.vae.decode(chunk, **decode_kwargs).sample)
        true_video = torch.cat(decoded_video,dim=0)
        true_video = true_video.reshape(bsz,frame_num,*true_video.shape[1:])
        true_video = ((true_video / 2.0 + 0.5).clamp(0, 1)*255)
        true_video = true_video.to(torch.float32).detach().cpu().numpy().transpose(0,1,3,4,2).astype(np.uint8) #(2,16,256,256,3)

        # decode predicted video
        decoded_video = []
        bsz,frame_num = latents.shape[:2]
        x = latents.flatten(0,1)
        decode_kwargs = {}
        for i in range(0,x.shape[0],args.decode_chunk_size):
            chunk = x[i:i+args.decode_chunk_size]/pipeline.vae.config.scaling_factor
            decode_kwargs["num_frames"] = chunk.shape[0]
            decoded_video.append(pipeline.vae.decode(chunk, **decode_kwargs).sample)
        videos = torch.cat(decoded_video,dim=0)
        videos = videos.reshape(bsz,frame_num,*videos.shape[1:])
        videos = ((videos / 2.0 + 0.5).clamp(0, 1)*255)
        videos = videos.to(torch.float32).detach().cpu().numpy().transpose(0,1,3,4,2).astype(np.uint8)

        # concatenate true videos and video
        videos_cat = np.concatenate([true_video,videos],axis=-3) # (3, 8, 256, 256, 3)
        videos_cat = np.concatenate([video for video in videos_cat],axis=-2).astype(np.uint8) 

        return videos_cat, true_video, videos, latents  # np.uint8:(3, 8, 128, 256, 3) or (3, 8, 192, 320, 3)
   
    def forward_policy(self, videos, state, joints, text):
        
        # inference policy
        image1 = videos[1]
        image2 = videos[2]
        image1 = torch.from_numpy(image1).to(torch.uint8)  # convert to torch tensor
        image2 = torch.from_numpy(image2).to(torch.uint8)  # convert to torch tensor
        assert image1.shape == (192, 320, 3), "Image 1 shape should be (192, 320, 3), got {}".format(image1.shape)
        image1 = torch.nn.functional.interpolate(image1.permute(2, 0, 1).unsqueeze(0).float(), size=(180, 320), mode='bilinear', align_corners=False).squeeze(0).permute(1, 2, 0).to(torch.uint8)
        image2 = torch.nn.functional.interpolate(image2.permute(2, 0, 1).unsqueeze(0).float(), size=(180, 320), mode='bilinear', align_corners=False).squeeze(0).permute(1, 2, 0).to(torch.uint8)
        image1 = image1.numpy()  # convert back to numpy array
        image2 = image2.numpy()  # convert back to numpy array
        example = {
            "observation/exterior_image_1_left": image_tools.resize_with_pad(image1, 224, 224),
            "observation/wrist_image_left": image_tools.resize_with_pad(image2, 224, 224),
            "observation/joint_position": joints[:7],
            "observation/gripper_position": joints[-1:],
            "prompt": text,
        }

        action_chunk = self.policy.infer(example)["actions"] #(10,8) velocity
        # print("policy input", joints, text)
        # print("policy action", action_chunk)


        # policy output joint velocity and gripper position
        joint_vel = action_chunk[:,:7] # (10, 7)
        gripper_pos = action_chunk[:,7:] # (10, 1)
        current_joint = joints[:7][None,:]  # (1,7)
        currrent_state = state[:8][None,:]  # (1,7)
        # import pdb; pdb.set_trace() 
        
        
        if 'pi05' in self.args.policy_type:
            # idx = [0,2,4,6,8]
            # delta_t = 1/7.5
            idx = [0,3,6,9,12]
            delta_t = 1/5
        else:
            idx = [0,1,2,3,4,5,6,7,8,9]
            delta_t = 1/12
        
        if self.args.dynamics_model_path is not None:
            if 'pi05' in self.args.policy_type:
                idx = [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14]  # for dynamics model, we need more steps
            else:
                idx = [0,1,2,3,4,5,6,7,8,9,9,9,9,9,9]
        

        joint_vel = joint_vel[idx]  # (10, 7)
        gripper_pos = gripper_pos[idx]  # (10, 1)

        gripper_max = 0.95
        if 'block' in text:
            gripper_max = 0.65
        if 'towel' in text:
            gripper_max = 0.9
        if 'tape' in text:
            gripper_max = 0.9
        if 'marker' in text:
            gripper_max = 0.9
        if 'glove' in text:
            gripper_max = 0.9
        if 'sponge' in text:
            gripper_max = 0.4
        gripper_pos = np.clip(gripper_pos, 0, gripper_max)  # 0.67 for block grasp 

        # gripper_pos = np.clip(gripper_pos, 0, 0.7)  # 0.67 for block grasp 


        # calculate future joint positions
        joint_pos=None
        state_fk = None
        if self.args.dynamics_model_path is None: # directly add: j_{t+1} = j_{t} +jv_{t} * dt
            joint_pos_future = []
            joint_future = current_joint
            for i in range(joint_vel.shape[0]):
                # current_joint = current_joint + joint_vel[i:i+1,:7]/7.5
                joint_future = joint_future + (joint_vel[i:i+1,:7])* delta_t
                joint_pos_future.append(joint_future)
            # print("joint_pos_future", len(joint_pos_future), joint_pos_future[0].shape)
            # import pdb; pdb.set_trace()
            joint_pos = np.array(joint_pos_future)  
            if joint_pos.ndim == 3:
                joint_pos = joint_pos[:,0,:]
            assert joint_pos.shape[1] == 7, f"Expected joint_pos shape (8, 7), got {joint_pos.shape}" # (8,7)


        else: # dynamics model
            print(current_joint.shape, joint_vel.shape, currrent_state.shape)
            joint_pos = self.dynamics_model(current_joint, joint_vel,None, training=False) # train2
        
        # fk
        policy_in_out= {}
        state_fk = []
        joint_pos = np.concatenate([current_joint, joint_pos], axis=0)[:15]  # (15, 7)
        joint_vel = joint_vel  # (15, 7)
        for i in range(joint_pos.shape[0]):
            current_state_fk = get_fk_solution(joint_pos[i,:7])
            xyz = current_state_fk[:3, 3]
            rotation_matrix = current_state_fk[:3, :3]
            r = R.from_matrix(rotation_matrix)
            euler = r.as_euler('xyz') 
            state_fk.append(np.concatenate([xyz, euler, gripper_pos[i]], axis=0))
        state_fk = np.array(state_fk) # (15,7)

        policy_in_out = {
            'joint_pos': joint_pos,  # (15, 7)
            'joint_vel': joint_vel,  # (15, 7)
            'state_fk': state_fk,  # (15, 7)
        }


        state_fk_skip = state_fk[::3]
        joint_pos_skip = joint_pos[::3]
        joint_pos_skip = np.concatenate([joint_pos_skip, state_fk_skip[:,-1:]], axis=-1)  # (8, 8) add gripper
        print("joint_pos", joint_pos_skip.shape, "state_fk", state_fk_skip.shape)

        return policy_in_out, joint_pos_skip, state_fk_skip



class Args:
    pretrained_model_path = "/cephfs/shared/llm/stable-video-diffusion-img2vid" #  #"/cephfs/shared/llm/stable-video-diffusion-img2vid" # "/cephfs/cjyyj/code/video_robot_svd/output/svd/train_2025-05-08T01-09-10/checkpoint-320000" #"/cephfs/shared/llm/stable-video-diffusion-img2vid"
    clip_model_path = "/cephfs/shared/llm/clip-vit-base-patch32"


    debug = False
    output_dir = "output_unit_test/yay_robot_4img_cond"
    project_name = "unit_test_svd"
    tag = 'action_cond'
    action_dim = 7

    load_path = None
    learning_rate= 5e-6
    gradient_accumulation_steps = 1
    mixed_precision = 'fp16'
    train_batch_size = 12
    shuffle = True

    num_train_epochs = 100
    max_train_steps = 500000
    checkpointing_steps = 20000
    validation_steps = 2000
    max_grad_norm = 1.0

    dataset_names = 'droid_svd_v2'
    dataset_dir="/cephfs/shared/droid_hf" 
    data_root_path = "/cephfs/shared/droid_hf" #'/cephfs/shared/droid_hf/opensource_robotdata'
    annotation_name='annotation' # annotation dirname under dataset_dir path
    # dataset='xhand'
    # prob=[1.0]
    tie_weight=True
    normalize=True
    pre_encode=True
    num_workers=4
    video_size = [256,256]
    
    only_one_clip = True
    get_action = True
    

    clip_img_size = 224
    use_img_cond =  False
    motion_bucket_id = 127
    fps = 7
    guidance_scale = 2.0 #7.5 #7.5 #7.5 #3.0
    num_inference_steps = 30
    decode_chunk_size = 7
    width = 320
    height = 192
    validation_num= 32
    video_num= 3
    num_frames= 5
    num_history = 6
    # sequence_length = 8
    data_json_path = 'video_dataset/xhand_mix_seq16_train.json' #'video_dataset/xhand_skip2_seq16_train.json'
    val_data_json_path = "video_dataset/xhand_mix_seq16_val.json" #'video_dataset/xhand_skip2_seq16_val.json'
    

    data_stat_path = 'exp_cfg/droid_svd_v2/stat.json' # 'exp_cfg/droid_svd_v2/stat.json'
    val_model_path = '/cephfs/cjyyj/code/video_evaluation/output2/exp33_210_s11/checkpoint-10000.pt'#'/cephfs/cjyyj/code/video_evaluation/output2/exp33_210/checkpoint-20000.pt'#'/cephfs/cjyyj/code/video_evaluation/output2/exp31_droid_cond9_eef_text_0804_accu4/checkpoint-10000.pt'#'/cephfs/cjyyj/code/video_evaluation/output2/exp31_210_post/checkpoint-90000.pt'#'/cephfs/cjyyj/code/video_evaluation/output2/exp31_droid_cond9_eef_text_0804_accu4/checkpoint-5000.pt'#'/cephfs/cjyyj/code/video_evaluation/output2/exp31_droid_cond9_eef_text_0804/checkpoint-40000.pt' #'/cephfs/cjyyj/code/video_evaluation/output2/exp31_210_post/checkpoint-90000.pt'
    pred_step = 4
    skip_step = 1

    val_dataset_dir = '/cephfs/shared/droid_hf/droid_real0724_2/droid_pi0'
    val_id = [215713, 215713,215713]  # [215713, 215802, 215647,] 
    start_idx = [0,0,0]
    instruction = ['pick up the blue block and place in white plate','pick up the blue block and place in white plate','pick up the blue block and place in white plate'] 

    # val_dataset_dir = '/cephfs/shared/droid_hf/droid_real0728/droid_pi05'
    # val_id = [160821,161002,161348]
    # start_idx = [0,0,0]
    # instruction = ['fold the blue towel', 'fold the red towel', 'fold the yellow towel']

    # val_dataset_dir = '/cephfs/shared/droid_hf/droid_real0803/droid_pi05'
    # val_id = [111445,112356,112558, 112640,113729,121541] # [111445,112356,112558]
    # start_idx = [16,8,8,8,8,8]
    # instruction = ['fold the towel', 'fold the towel', 'fold the towel','fold the towel', 'fold the towel', 'fold the towel']


    # val_dataset_dir = '/cephfs/shared/droid_hf/droid_real0803/droid_pi05'
    # val_id = [111445,112356,112558, 112640,113729,121541] # [111445,112356,112558]
    # start_idx = [16,8,8,8,8,8]
    # instruction = ['fold the towel', 'fold the towel', 'fold the towel','fold the towel', 'fold the towel', 'fold the towel']

    # val_dataset_dir = '/cephfs/shared/droid_hf/droid_real0804/droid_pi05'
    # val_id = ['094037','094117','094146']
    # start_idx = [0,0,0]
    # instruction = ["pick up the sponge and place in box","pick up the sponge and place in box","pick up the sponge and place in box"]
    # instruction = ["pick up the blue glove and place in box","pick up the blue glove and place in box","pick up the blue glove and place in box"]

    # val_dataset_dir = '/cephfs/shared/droid_hf/droid_real0804_3/droid_pi05'
    # val_id = [210338,210509,211038] # [111445,112356,112558]
    # start_idx = [8,8,8]
    # instruction = ['pick up the blue block and place in white plate','pick up the blue block and place in white plate','pick up the blue block and place in white plate']

    # val_dataset_dir = '/cephfs/shared/droid_hf/droid_real0804_3/droid_pi05'
    # val_id = [152741,152905,153315] # [111445,112356,112558]
    # start_idx = [8,8,8]
    # instruction = ['fold the towel','fold the towel', 'fold the towel']

    # val_dataset_dir = '/cephfs/shared/droid_hf/droid_real0804_3/droid_pi05'
    # val_id = [174031, 173838,173955] # [111445,112356,112558]
    # start_idx = [8,8,8]
    # instruction = ["pick the sponge and place in box","pick the sponge and place in box","pick the sponge and place in box"]

    # # sponge
    # val_dataset_dir = '/cephfs/shared/droid_hf/droid_real0804_3/droid_pi05'
    # val_id = [173838,173955,174031,174143,174230,174307,174346,174529,174717,174807]
    # start_idx = [0,0,0,0,0,0,0,0,0,0]
    # instruction = ["pick the sponge and place in box","pick the sponge and place in box","pick the sponge and place in box",
    #                "pick the sponge and place in box","pick the sponge and place in box","pick the sponge and place in box",
    #                "pick the sponge and place in box","pick the sponge and place in box","pick the sponge and place in box",
    #                "pick the sponge and place in box"]

    # # sponge
    # val_dataset_dir = '/cephfs/shared/droid_hf/droid_real0804_3/droid_pi05'
    # val_id = [174717,174717,174717,174717,174717,174717,174717,174717,174717,174717]
    # start_idx = [0,1,3,4,2,0,1,3,4,2]
    # instruction = ["pick the green sponge and place in box","pick the sponge and place in box","pick the sponge and place in box",
    #                "pick the sponge and place in box","pick the sponge and place in box","pick the sponge and place in box",
    #                "pick the sponge and place in box","pick the sponge and place in box","pick the sponge and place in box",
    #                "pick the sponge and place in box"]

    # # glove
    # val_dataset_dir = '/cephfs/shared/droid_hf/droid_real0804_3/droid_pi05'
    # val_id = [181116,181140,181218,181253,181529, 192801,192844,193042,193153,195050]
    # start_idx = [0,1,3,4,2,0,1,3,4,2]
    # instruction = ["pick the blue glove and place in box","pick the blue glove and place in box","pick the blue glove and place in box",
    #                 "pick the blue glove and place in box","pick the blue glove and place in box","pick the blue glove and place in box",
    #                 "pick the blue glove and place in box","pick the blue glove and place in box","pick the blue glove and place in box",
    #                 "pick the blue glove and place in box"]

    val_dataset_dir = '/cephfs/shared/droid_hf/droid_real0812/droid_pi05'
    val_id = ['150724', '150914', '151030', '151207', '151333', '151444', '151645', '151758', '151919']
    # val_id = ['145413', '145513', '145630', '145902', '150052', '150405', '150528', '150724', '150914', '151030', '151207', '151333', '151444', '151645', '151758', '151919']
    start_idx = [0]*len(val_id)
    instruction = ["fold the towel from left side"]*len(val_id)

    val_dataset_dir = '/cephfs/shared/droid_hf/droid_real0803/droid_pi05'
    val_id = ['111115', '111221', '112356', '112558', '112640', '113458', '113619', '113729', '121541'] #[111445,112356,112558, 112640,113729,121541] # [111445,112356,112558]
    start_idx = [0]*len(val_id)
    instruction = ["fold the towel from left side"]*len(val_id)

    val_dataset_dir = '/cephfs/shared/droid_hf/droid_real0812/droid_pi05'
    val_id = ['145413', '145513', '145630', '145902', '150052', '150405', '150528', '150724', '150914', '151030']
    start_idx = [0]*len(val_id)
    instruction = ["fold the towel from left"]*len(val_id)

    val_dataset_dir = '/cephfs/shared/droid_hf/droid_real0812/droid_pi05'
    val_id = ['150724', '150914', '151030', '151207', '151333', '151444', '151645', '151758', '151919']
    start_idx = [0]*len(val_id)
    instruction = ["fold the towel from right side"]*len(val_id)

    val_dataset_dir = '/cephfs/shared/droid_hf/droid_real0811/droid_pi05'
    val_id = ['151811', '151847', '161434', '161521', '161616', '161730', '162156']
    start_idx = [0]*len(val_id)
    instruction = ["fold the towel from left side"]*len(val_id)

    # val_dataset_dir = '/cephfs/shared/droid_hf/droid_real0812/droid_pi05'
    # val_id = ['150724', '150724']
    # start_idx = [0]*len(val_id)
    # instruction = ["fold the towel from left side", "fold the towel from right side"]

    task_name = 'towel_ft'

    dynamics_model_path = '/cephfs/cjyyj/code/video_evaluation/output_dynamics/model2_15_9.pth' #None #'/cephfs/cjyyj/code/video_evaluation/output_dynamics/model2_epoch_9.pth' #None #'output_dynamics/model3_epoch_6.pth' #'output_dynamics/model_epoch_1.pth'
    pi_model_path = '/cephfs/cjyyj/code/openpi/checkpoints/pi05_droid_finetune/droid_0817/5000'
    policy_type = 'pi05' # 'pi05' # 'pi0' # 'pi0fast'
    text_cond = True

    # CUDA_VISIBLE_DEVICES=1 python exp33_evel_pi05_real.py

        
if __name__ == "__main__":

    Agent = agent()
    interact_num = 12
    pred_step = Agent.args.pred_step # important parameters
    num_history = Agent.args.num_history
    num_frames = Agent.args.num_frames


    for val_id_i, text_i, start_idx_i in zip(Args.val_id, Args.instruction, Args.start_idx):
        
        # get initial state, groud truth actions
        id = val_id_i
        eef_gt, joint_pos_gt, video_dict, video_latents = Agent.get_traj_info(val_id_i, start_idx=start_idx_i, steps=int(pred_step*interact_num+8))
        print("At the episode beginnig, the state is:", "eef pose 0", eef_gt[0], "joint 0", joint_pos_gt[0])

        # t=0 for each episode
        predict_latents = None
        
        video_to_save = []
        info_to_save = []

        his_cond = []
        his_joint = []
        his_eef = []
        first_latent = torch.cat([v[0] for v in video_latents], dim=1).unsqueeze(0)  # (1, 4, 72, 40)
        assert first_latent.shape == (1, 4, 72, 40), f"Expected first_latent shape (1, 4, 72, 40), got {first_latent.shape}"
        for i in range(Agent.args.num_history*4):
            his_cond.append(first_latent)  # (1, 4, 72, 40)
            his_joint.append(joint_pos_gt[0:1])  # (1, 7)
            his_eef.append(eef_gt[0:1])  # (1, 7)
        # his_cond = [v[0] for v in video_latents]
        # his_action = action_gt[0:1]  # (4, 7)



        # interact loop
        for i in range(interact_num):
            video_latent_true = [v[int(i*pred_step):int(i*pred_step+num_frames)] for v in video_latents]
            
            # prepare input for policy
            joint_first = his_joint[-1][0] # (1, 8)
            state_first = his_eef[-1][0] # (1, 8)
            if i==0:
                video_first = [v[0] for v in video_dict]
            else:
                video_first = [v[-1] for v in video_dict_pred]
            assert joint_first.shape == (8,), f"Expected joint_first shape (8,), got {joint_first.shape}"
            assert state_first.shape == (7,), f"Expected state_first shape (7,), got {state_first.shape}"
            
            # forward policy
            print("################ policy forward ####################")
            policy_in_out, joint_pos, state_pos= Agent.forward_policy(video_first, state_first, joint_first, text=text_i)
            print("policy output eef pose", state_pos) # output xyz and gripper for debug

            
            # forward world model
            print("################ world model forward ################")
            print(f'traj_id:{val_id_i}, interact step: {i}!!!!!!!!')

            history_idx = [-15,-15,-15,-9,-6,-3]
            history_idx = [-15,-15,-8,-6,-4,-2]
            # history_idx = [-1,-1,-1,-1,-1,-1]  # use the last frame for all history frames ablation
            his_eef_input = np.concatenate([his_eef[idx] for idx in history_idx], axis=0)  # (4, 7)
            action_input = np.concatenate([his_eef_input, state_pos], axis=0)
            his_cond_input = torch.cat([his_cond[idx] for idx in history_idx], dim=0).unsqueeze(0)
            video_latent_first = his_cond[-1]  # (1, 4, 72, 40)

            assert video_latent_first.shape == (1, 4, 72, 40), f"Expected video_latent_first shape (1, 4, 72, 40), got {video_latent_first.shape}"
            assert action_input.shape == (11, 7), f"Expected action_input shape (11, 7), got {action_input.shape}"
            assert his_cond_input.shape == (1, 6, 4, 72, 40), f"Expected his_cond_input shape (1, 6, 72, 40), got {his_cond_input.shape}"


            videos_cat, true_videos, video_dict_pred, predict_latents = Agent.forward_wm(action_input, video_latent_true, video_latent_first, his_cond=his_cond_input,text=text_i if Agent.args.text_cond else None)
            
            his_joint.append(joint_pos[pred_step-1:pred_step])  # (1, 8)
            his_eef.append(state_pos[pred_step-1:pred_step])
            his_cond.append(torch.cat([v[pred_step-1] for v in predict_latents], dim=1).unsqueeze(0))  # (1, 4, 72, 40)

            video_to_save.append(videos_cat[:pred_step-1]) # last frame is the first frame of next step, so we only save the first pred_step-1 frames
            info_to_save.append(policy_in_out)  # save policy output info
                
        video = np.concatenate(video_to_save, axis=0)
        
        # save rollout video and info with parameters
        task_name = Args.task_name
        wm_id = Agent.args.val_model_path.split('/')[-1].split('.')[0]
        text_id = text_i.replace(' ', '_').replace(',', '').replace('.', '').replace('\'', '').replace('\"', '')[:30]
        num_inference_steps = Agent.args.num_inference_steps
        guidance_scale = Agent.args.guidance_scale
        # path for save predicted videos 
        videos_dir = Args.val_model_path.split('/')[:-1]
        videos_dir = '/'.join(videos_dir)
        uuid = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        filename_video = f"{videos_dir}/{task_name}/video/time_{uuid}_traj_{val_id_i}_{start_idx_i}_{pred_step}_{wm_id}_{guidance_scale}_{num_inference_steps}_{text_id}.mp4"
        os.makedirs(os.path.dirname(filename_video), exist_ok=True)
        mediapy.write_video(filename_video, video, fps=5)
        print(f"Saving video to {filename_video}")
        print("##########################################################################")

        # save info 
        # gather all the dict 
        info = {'success': 1, 'start_idx': 0, 'end_idx': video.shape[0]-1, 'instructions':text_i}
        for key in info_to_save[0].keys():
            info[key] = []
            for i in range(len(info_to_save)):
                info[key]+=info_to_save[i][key].tolist()[:int(pred_step*3)]
        
        # save to json
        filename_info = f"{videos_dir}/{task_name}/info/time_{uuid}_traj_{val_id_i}_{start_idx_i}_{pred_step}_{wm_id}_{guidance_scale}_{num_inference_steps}_{text_id}.json"
        os.makedirs(os.path.dirname(filename_info), exist_ok=True)
        with open(filename_info, 'w') as f:
            json.dump(info, f, indent=4)
        
        
        
