
from openpi.training import config as config_pi
from openpi.policies import policy_config
from openpi_client import image_tools
# from openpi.shared import download

import numpy as np


from accelerate import Accelerator
import torch
from diffusers import StableVideoDiffusionPipeline
from omegaconf import OmegaConf
import numpy as np
# import cv2
import torch
import torch.nn.functional as F
import torch.nn as nn
import einops
from accelerate import Accelerator
import datetime
import os
from accelerate.logging import get_logger
from tqdm.auto import tqdm
import wandb
import imageio
from video_models.pipeline import MaskStableVideoDiffusionPipeline
import json
from decord import VideoReader, cpu
import swanlab
import mediapy
import sys
from scipy.spatial.transform import Rotation as R

def get_tf_mat(i, dh):
    a = dh[i][0]
    d = dh[i][1]
    alpha = dh[i][2]
    theta = dh[i][3]
    q = theta

    return np.array([[np.cos(q), -np.sin(q), 0, a],
                     [np.sin(q) * np.cos(alpha), np.cos(q) * np.cos(alpha), -np.sin(alpha), -np.sin(alpha) * d],
                     [np.sin(q) * np.sin(alpha), np.cos(q) * np.sin(alpha), np.cos(alpha), np.cos(alpha) * d],
                     [0, 0, 0, 1]])


def get_fk_solution(joint_angles):
    dh_params = [[0, 0.333, 0, joint_angles[0]],
                 [0, 0, -np.pi/2, joint_angles[1]],
                 [0, 0.316, np.pi/2, joint_angles[2]],
                 [0.0825, 0, np.pi/2, joint_angles[3]],
                 [-0.0825, 0.384, -np.pi/2, joint_angles[4]],
                 [0, 0, np.pi/2, joint_angles[5]],
                 [0.088, 0, np.pi/2, joint_angles[6]],
                 [0, 0.107, 0, 0],
                 [0, 0, 0, -np.pi/4],
                 [0.0, 0.1034, 0, 0]]

    T = np.eye(4)
    for i in range(7 + 1):
        T = T @ get_tf_mat(i, dh_params)
    return T
    

class agent():
    def __init__(self):
          
        args = Args()
        self.args = args
        self.accelerator = Accelerator()
        self.device = self.accelerator.device

        # load policy
        if 'pi05' in args.policy_type:
            config = config_pi.get_config("pi05_droid")
            checkpoint_dir = '/cephfs/shared/llm/openpi/openpi-assets-preview/checkpoints/pi05_droid' 
        elif 'pi0fast' in args.policy_type:
            config = config_pi.get_config("pi0fast_droid")
            checkpoint_dir = '/cephfs/shared/llm/openpi/openpi-assets/checkpoints/pi0fast_droid'
        elif 'pi0' in args.policy_type:
            config = config_pi.get_config("pi0_droid")
            checkpoint_dir = '/cephfs/shared/llm/openpi/openpi-assets/checkpoints/pi0_droid'
        else:
            raise ValueError(f"Unknown policy type: {args.policy_type}")
        self.policy = policy_config.create_trained_policy(config, checkpoint_dir)

        # load world model  
        sys.path.append('./')
        from exp31_droid_framecond_s9 import FuseSVD
        self.model = FuseSVD(args)
        self.model.load_state_dict(torch.load(args.val_model_path))
        self.model.to(self.accelerator.device)
        self.model.eval()
        print("load world model success")
        with open(f"{args.data_stat_path}", 'r') as f:
            data_stat = json.load(f)
            self.state_p01 = np.array(data_stat['state_01'])[None,:]
            self.state_p99 = np.array(data_stat['state_99'])[None,:]
        
        # load dynamics model
        if args.dynamics_model_path is not None:
            from output_dynamics.train2 import Dynamics
            # from output_dynamics.train3 import Dynamics
            self.dynamics_model = Dynamics(action_dim=7, action_num=15, hidden_size=512).to(self.device)
            self.dynamics_model.load_state_dict(torch.load(args.dynamics_model_path, map_location=self.device))


        # report cuda memory usage
        if torch.cuda.is_available():
            print(f"CUDA memory allocated: {torch.cuda.memory_allocated() / (1024 ** 2):.2f} MB")
            print(f"CUDA memory reserved: {torch.cuda.memory_reserved() / (1024 ** 2):.2f} MB")         

    def normalize_bound(
        self,
        data: np.ndarray,
        data_min: np.ndarray,
        data_max: np.ndarray,
        clip_min: float = -1,
        clip_max: float = 1,
        eps: float = 1e-8,
    ) -> np.ndarray:
        ndata = 2 * (data - data_min) / (data_max - data_min + eps) - 1
        return np.clip(ndata, clip_min, clip_max)

    def get_traj_info(self, id, start_idx=0, steps=8):
        val_dataset_dir = self.args.val_dataset_dir
        args = self.args
        skip = args.skip_step
        num_frames = steps
        annotation_path = f"{val_dataset_dir}/annotation/val/{id}.json"
        with open(annotation_path) as f:
            anno = json.load(f)
            try:
                length = len(anno['action'])
            except:
                length = anno["video_length"]
        frames_ids = np.arange(start_idx, start_idx + num_frames * skip, skip)
        print(frames_ids)
        max_ids = np.ones_like(frames_ids) * (length - 1)
        frames_ids = np.min([frames_ids, max_ids], axis=0).astype(int)
        print("frames_ids", frames_ids)
        

        action = anno['states']
        action = np.array(action)
        action = action[frames_ids]

        joint_pos = anno['joints']
        joint_pos = np.array(joint_pos)
        joint_pos = joint_pos[frames_ids]

        # for i in range(action.shape[0]):
        #     if i < 15:
        #         action[i,1] += i*0.1/15
        #     else:
        #         action[i,1] += 0.1

        # action = self.normalize_bound(action, self.state_p01, self.state_p99)

        # get videos
        video_dict =[]
        video_latent = []
        for id in range(len(anno['videos'])):
            video_path = anno['videos'][id]['video_path']
            video_path = f"{val_dataset_dir}/{video_path}"

            # load videos and actions
            vr = VideoReader(video_path, ctx=cpu(0), num_threads=2)
            try:
                true_video = vr.get_batch(range(length)).asnumpy()
            except:
                true_video = vr.get_batch(range(length)).numpy()
            true_video = true_video[frames_ids]
            video_dict.append(true_video)

        
            # encode video
            device = self.device
            true_video = torch.from_numpy(true_video).float().to(device)
            x = true_video.permute(0,3,1,2).to(device) / 255.0*2-1
            # resize to 192*320
            # x = F.interpolate(x, size=(args.height, args.width), mode='bilinear', align_corners=False)

            vae = self.model.pipeline.vae
            with torch.no_grad():
                batch_size = 32
                latents = []
                for i in range(0, len(x), batch_size):
                    batch = x[i:i+batch_size]
                    latent = vae.encode(batch).latent_dist.sample().mul_(vae.config.scaling_factor)
                    # x = vae.encode(x).latent_dist.sample().mul_(vae.config.scaling_factor).cpu()
                    latents.append(latent)
                x = torch.cat(latents, dim=0)
                # x = x [:, :4, 2:18, :40]  # take first 4 channels
                # assert x.shape[1:] == (4, 16, 40), f"Expected shape (4, 16, 40), got {x.shape[1:]}"
            video_latent.append(x)

        
        return action,joint_pos, video_dict, video_latent

    def forward_wm(self, action_cond, video_latent_true, video_latent_cond, his_cond=None, text=None):
        # action_input, video_latent_true, video_latent_first, his_cond=his_cond_input,text=text_i
        args = self.args

        # prepare input
        # image_cond = torch.zeros((1, 4, 72, 40), dtype=torch.float32,device=self.device) # (72,40) (64,64)
        # image_cond[:,:4,:24] = video_latent_cond[0]
        # image_cond[:,:4,24:48] = video_latent_cond[1]
        # image_cond[:,:4,48:72] = video_latent_cond[2]
        # image_cond[:,4:8,:24] = his_cond[0]
        # image_cond[:,4:8,24:48] = his_cond[1]
        # image_cond[:,4:8,48:72] = his_cond[2]

        image_cond = video_latent_cond

        # action should be normed
        action_cond = self.normalize_bound(
            action_cond, self.state_p01, self.state_p99, clip_min=-1, clip_max=1
        )
        action_cond = torch.tensor(action_cond).unsqueeze(0).to(self.device).float() # (1, 14)
        assert image_cond.shape[1:] == (4, 72, 40)
        assert action_cond.shape[1:] == (args.num_frames+args.num_history, args.action_dim)


        # predict future frames
        with torch.no_grad():
            bsz = action_cond.shape[0]
            if text is not None:
                text_token = self.model.action_encoder(action_cond, text, self.model.tokenizer, self.model.text_encoder)
            else:
                text_token = self.model.action_encoder(action_cond)           
            pipeline = self.model.pipeline
            
            _, latents = MaskStableVideoDiffusionPipeline.__call__(
                pipeline,
                image=image_cond,
                text=text_token,
                width=320,
                height=int(72*8),
                num_frames=args.num_frames,
                history=his_cond,
                num_inference_steps=args.num_inference_steps,
                decode_chunk_size=args.decode_chunk_size,
                max_guidance_scale=args.guidance_scale,
                fps=args.fps,
                motion_bucket_id=args.motion_bucket_id,
                mask=None,
                output_type='latent',
                return_dict=False,
                frame_independent=True,
            )
        latents = einops.rearrange(latents, 'b f c (m h) (n w) -> (b m n) f c h w', m=3,n=1) # (B, 8, 4, 32,32)


        # decode ground truth video
        true_video = torch.stack(video_latent_true, dim=0) # (bsz, 8,32,32)
        decoded_video = []
        bsz,frame_num = true_video.shape[:2]
        true_video = true_video.flatten(0,1)
        decode_kwargs = {}
        for i in range(0,true_video.shape[0],args.decode_chunk_size):
            chunk = true_video[i:i+args.decode_chunk_size]/pipeline.vae.config.scaling_factor
            decode_kwargs["num_frames"] = chunk.shape[0]
            decoded_video.append(pipeline.vae.decode(chunk, **decode_kwargs).sample)
        true_video = torch.cat(decoded_video,dim=0)
        true_video = true_video.reshape(bsz,frame_num,*true_video.shape[1:])
        true_video = ((true_video / 2.0 + 0.5).clamp(0, 1)*255)
        true_video = true_video.to(torch.float32).detach().cpu().numpy().transpose(0,1,3,4,2).astype(np.uint8) #(2,16,256,256,3)

        # decode predicted video
        decoded_video = []
        bsz,frame_num = latents.shape[:2]
        x = latents.flatten(0,1)
        decode_kwargs = {}
        for i in range(0,x.shape[0],args.decode_chunk_size):
            chunk = x[i:i+args.decode_chunk_size]/pipeline.vae.config.scaling_factor
            decode_kwargs["num_frames"] = chunk.shape[0]
            decoded_video.append(pipeline.vae.decode(chunk, **decode_kwargs).sample)
        videos = torch.cat(decoded_video,dim=0)
        videos = videos.reshape(bsz,frame_num,*videos.shape[1:])
        videos = ((videos / 2.0 + 0.5).clamp(0, 1)*255)
        videos = videos.to(torch.float32).detach().cpu().numpy().transpose(0,1,3,4,2).astype(np.uint8)

        # concatenate true videos and video
        videos_cat = np.concatenate([true_video,videos],axis=-3) # (3, 8, 256, 256, 3)
        videos_cat = np.concatenate([video for video in videos_cat],axis=-2).astype(np.uint8) 

        return videos_cat, true_video, videos, latents  # np.uint8:(3, 8, 128, 256, 3) or (3, 8, 192, 320, 3)
   
    def forward_policy(self, videos, state, joints, text, time_step=1):
        
        # random adjust the initial position
        # inference policy
        image1 = videos[1]
        image2 = videos[2]
        image1 = torch.from_numpy(image1).to(torch.uint8)  # convert to torch tensor
        image2 = torch.from_numpy(image2).to(torch.uint8)  # convert to torch tensor
        assert image1.shape == (192, 320, 3), "Image 1 shape should be (192, 320, 3), got {}".format(image1.shape)
        image1 = torch.nn.functional.interpolate(image1.permute(2, 0, 1).unsqueeze(0).float(), size=(180, 320), mode='bilinear', align_corners=False).squeeze(0).permute(1, 2, 0).to(torch.uint8)
        image2 = torch.nn.functional.interpolate(image2.permute(2, 0, 1).unsqueeze(0).float(), size=(180, 320), mode='bilinear', align_corners=False).squeeze(0).permute(1, 2, 0).to(torch.uint8)
        image1 = image1.numpy()  # convert back to numpy array
        image2 = image2.numpy()  # convert back to numpy array
        example = {
            "observation/exterior_image_1_left": image_tools.resize_with_pad(image1, 224, 224),
            "observation/wrist_image_left": image_tools.resize_with_pad(image2, 224, 224),
            "observation/joint_position": joints[:7],
            "observation/gripper_position": joints[-1:],
            "prompt": text,
        }
        action_chunk = self.policy.infer(example)["actions"] #(10,8) velocity
        


        # policy output joint velocity and gripper position
        joint_vel = action_chunk[:,:7] # (10, 7)
        gripper_pos = action_chunk[:,7:] # (10, 1)
        current_joint = joints[:7][None,:]  # (1,7)
        current_gripper = joints[-1:][None,:]  # (1,1)
        currrent_state = state[:8][None,:]  # (1,7)
        # import pdb; pdb.set_trace() 
        
        
        if 'pi05' in self.args.policy_type:
            # idx = [0,2,4,6,8]
            # delta_t = 1/7.5
            idx = [0,3,6,9,12]
            delta_t = 1/5
        else:
            idx = [0,1,2,3,4,5,6,7,8,9]
            delta_t = 1/12
        
        if self.args.dynamics_model_path is not None:
            if 'pi05' in self.args.policy_type:
                idx = [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14]  # for dynamics model, we need more steps
            else:
                idx = [0,1,2,3,4,5,6,7,8,9,9,9,9,9,9]
        

        joint_vel = joint_vel[idx]  # (10, 7)
        gripper_pos = gripper_pos[idx]  # (10, 1)

        gripper_max = self.args.gripper_max
        if 'block' in text:
            gripper_max = self.args.gripper_max_block
        if 'towel' in text:
            gripper_max = 0.9
        if 'sponge' in text:
            gripper_max = self.args.gripper_max_sponge
        if 'toy' in text:
            gripper_max = 0.4
        if 'bat' in text:
            gripper_max = 0.7
        if 'marker' in text:
            gripper_max = 0.7
        if 'tape' in text:
            gripper_max = 0.5


        gripper_pos = np.clip(gripper_pos, 0, gripper_max)  # 0.67 for block grasp 


        # calculate future joint positions
        joint_pos=None
        state_fk = None
        if self.args.dynamics_model_path is None: # directly add: j_{t+1} = j_{t} +jv_{t} * dt
            joint_pos_future = []
            joint_future = current_joint
            for i in range(joint_vel.shape[0]):
                # current_joint = current_joint + joint_vel[i:i+1,:7]/7.5
                joint_future = joint_future + (joint_vel[i:i+1,:7])* delta_t
                joint_pos_future.append(joint_future)
            # print("joint_pos_future", len(joint_pos_future), joint_pos_future[0].shape)
            # import pdb; pdb.set_trace()
            joint_pos = np.array(joint_pos_future)  
            if joint_pos.ndim == 3:
                joint_pos = joint_pos[:,0,:]
            assert joint_pos.shape[1] == 7, f"Expected joint_pos shape (8, 7), got {joint_pos.shape}" # (8,7)


        else: # dynamics model
            print(current_joint.shape, joint_vel.shape, currrent_state.shape)
            joint_pos = self.dynamics_model(current_joint, joint_vel,None, training=False) # train2
        
        # fk
        policy_in_out= {}
        state_fk = []
        joint_pos = np.concatenate([current_joint, joint_pos], axis=0)[:15]  # (15, 7)
        gripper_pos = np.concatenate([current_gripper, gripper_pos], axis=0)[:15]  # (15, 1)

        joint_vel = joint_vel  # (15, 7)
        for i in range(joint_pos.shape[0]):
            current_state_fk = get_fk_solution(joint_pos[i,:7])
            xyz = current_state_fk[:3, 3]
            rotation_matrix = current_state_fk[:3, :3]
            r = R.from_matrix(rotation_matrix)
            euler = r.as_euler('xyz') 
            state_fk.append(np.concatenate([xyz, euler, gripper_pos[i]], axis=0))
        state_fk = np.array(state_fk) # (15,7)

        policy_in_out = {
            'joint_pos': joint_pos,  # (15, 7)
            'joint_vel': joint_vel,  # (15, 7)
            'state_fk': state_fk,  # (15, 7)
        }

        skip = self.args.policy_skip_step

        state_fk_skip = state_fk[::skip][:5]
        joint_pos_skip = joint_pos[::skip][:5]
        joint_pos_skip = np.concatenate([joint_pos_skip, state_fk_skip[:,-1:]], axis=-1)  # (8, 8) add gripper
        print("joint_pos", joint_pos_skip.shape, "state_fk", state_fk_skip.shape)

        return policy_in_out, joint_pos_skip, state_fk_skip



class Args:
    pretrained_model_path = "/cephfs/shared/llm/stable-video-diffusion-img2vid" #  #"/cephfs/shared/llm/stable-video-diffusion-img2vid" # "/cephfs/cjyyj/code/video_robot_svd/output/svd/train_2025-05-08T01-09-10/checkpoint-320000" #"/cephfs/shared/llm/stable-video-diffusion-img2vid"
    clip_model_path = "/cephfs/shared/llm/clip-vit-base-patch32"


    debug = False
    output_dir = "output_unit_test/yay_robot_4img_cond"
    project_name = "unit_test_svd"
    tag = 'action_cond'
    action_dim = 7

    load_path = None
    learning_rate= 5e-6
    gradient_accumulation_steps = 1
    mixed_precision = 'fp16'
    train_batch_size = 12
    shuffle = True

    num_train_epochs = 100
    max_train_steps = 500000
    checkpointing_steps = 20000
    validation_steps = 2000
    max_grad_norm = 1.0

    dataset_names = 'droid_svd_v2'
    dataset_dir="/cephfs/shared/droid_hf" 
    data_root_path = "/cephfs/shared/droid_hf" #'/cephfs/shared/droid_hf/opensource_robotdata'
    annotation_name='annotation' # annotation dirname under dataset_dir path
    # dataset='xhand'
    # prob=[1.0]
    tie_weight=True
    normalize=True
    pre_encode=True
    num_workers=4
    video_size = [256,256]
    
    only_one_clip = True
    get_action = True
    

    clip_img_size = 224
    use_img_cond =  False
    motion_bucket_id = 127
    fps = 7
    guidance_scale = 2.0 #2.0 #7.5 #7.5 #7.5 #3.0
    num_inference_steps = 50
    decode_chunk_size = 7
    width = 320
    height = 192
    validation_num= 32
    video_num= 3
    num_frames= 5
    num_history = 6
    # sequence_length = 8
    data_json_path = 'video_dataset/xhand_mix_seq16_train.json' #'video_dataset/xhand_skip2_seq16_train.json'
    val_data_json_path = "video_dataset/xhand_mix_seq16_val.json" #'video_dataset/xhand_skip2_seq16_val.json'
    

    data_stat_path = 'exp_cfg/droid_svd_v2/stat.json' # 'exp_cfg/droid_svd_v2/stat.json'
    val_model_path = '/cephfs/cjyyj/code/video_evaluation/output2/exp33_210_s11/checkpoint-10000.pt'#'/cephfs/cjyyj/code/video_evaluation/output2/exp33_210/checkpoint-20000.pt'#'/cephfs/cjyyj/code/video_evaluation/output2/exp31_droid_cond9_eef_text_0804_accu4/checkpoint-10000.pt'#'/cephfs/cjyyj/code/video_evaluation/output2/exp31_210_post/checkpoint-90000.pt'#'/cephfs/cjyyj/code/video_evaluation/output2/exp31_droid_cond9_eef_text_0804_accu4/checkpoint-5000.pt'#'/cephfs/cjyyj/code/video_evaluation/output2/exp31_droid_cond9_eef_text_0804/checkpoint-40000.pt' #'/cephfs/cjyyj/code/video_evaluation/output2/exp31_210_post/checkpoint-90000.pt'
    pred_step = 5
    skip_step = 1

    

    # val_dataset_dir = '/cephfs/shared/droid_hf/droid_real_all/droid_real0724_2/droid_pi0'
    # val_id = [215713, 215713, 215713]  # [215713, 215802, 215647,] 
    # start_idx = [0,0,0]
    # instruction = ['pick up the blue block and place in white plate', 'pick up the blue block and place in white plate', 'pick up the blue block and place in white plate'] 

    val_dataset_dir = '/cephfs/shared/droid_hf/droid_real0913/droid_pi05'
    val_id = [221217,221602,221924,222414,222532,222842,223201,223301]
    start_idx = [0]*len(val_id)
    instruction = ['pick up the blue block and place in white plate', 'pick up the blue block and place in white plate', 'pick up the green block and place in white plate',
                     'pick up the green block and place in white plate', 'pick up the blue block and place in white plate', 'pick up the blue block and place in white plate',
                     'pick up the blue block and place in white plate', 'pick up the blue block and place in white plate']

    val_dataset_dir = '/cephfs/shared/droid_hf/droid_real0914/droid_pi05'
    val_id = [203038,203715,203803,203837,204021,204112,204202,204331,204437,204502]
    start_idx = [0]*len(val_id)
    instruction = ['pick up the blue block and place in white plate', 'pick up the blue block and place in white plate', 'pick up the green block and place in white plate',
                     'pick up the green block and place in white plate', 'pick up the blue block and place in white plate', 'pick up the green block and place in white plate',
                     'pick up the green block and place in white plate', 'pick up the green block and place in white plate', 'pick up the red block and place in white plate',
                     'pick up the red block and place in white plate']

    val_dataset_dir = '/cephfs/shared/droid_hf/droid_real0914_2/droid_pi05'
    val_id = ['000018', '000044', '000120', '000228', '000255', '000336', '000403', '000427', '000453', '000643', '000739', '000803', '000833', '000902', '235555', '235713', '235826', '235933']
    start_idx = [0]*len(val_id)
    instruction = ['fold the towel']*len(val_id)
    
    
    # val_dataset_dir = '/cephfs/shared/droid_hf/droid_real0913/droid_pi05'
    # val_id = [224640,224723,224832,225213,225306,234949]
    # start_idx = [0,0,0,0,0,0]
    # instruction = ['pick up the sponge and place in the drawer', 'pick up the sponge and place in the drawer', 'pick up the sponge and place in the drawer','pick up the sponge and place in the drawer', 'pick up the sponge and place in the drawer', 'pick up the sponge and place in the drawer']

    val_dataset_dir = '/cephfs/shared/droid_hf/droid_real0913/droid_pi05'
    val_id = [224640,224723,224832,225306,234949]
    start_idx = [0]*len(val_id)
    instruction = ['pick up the sponge and place in the drawer', 'pick up the sponge and place in the drawer', 'pick up the sponge and place in the drawer', 'pick up the sponge and place in the drawer', 'pick up the sponge and place in the drawer']


    # val_dataset_dir = '/cephfs/shared/droid_hf/droid_real_all/droid_real0729/droid_pi05'
    # val_id = [155941,155941,155941] # [111445,112356,112558]
    # start_idx = [68, 58, 48,0]
    # instruction = ['pick up the yellow tape and place in drawer','pick up the yellow tape and place in drawer','pick up the yellow tape and place in drawer']


    # val_dataset_dir = '/cephfs/shared/droid_hf/droid_real0917/droid_pi05'
    # val_id = [161310,161310,152638,153618]
    # start_idx = [5,5,0,0]
    # instruction = ['pull one tissue out of the box','pull one white tissue out of the box','moving the towel from left to right','close the laptop']

    val_dataset_dir = '/cephfs/shared/droid_hf/droid_real0918/droid_pi05'
    val_id = ['134750', '134908', '135009', '135048', '135205', '135334', '135425', '135525', '135623', '135749', '135849', '135931']
    start_idx = [0]*len(val_id)
    instruction = ['moving the towel from left to right', 'moving the towel from right to left', 'moving the towel from left to right','moving the towel from left to right','moving the towel from left to right']

    # val_dataset_dir = '/cephfs/shared/droid_hf/droid_real0918/droid_pi05'
    # val_id = ['135334', '135425', '135525', '135623']
    # start_idx = [0]*len(val_id)
    # instruction = ['pull one tissue out of the box']*len(val_id)

    # val_dataset_dir = '/cephfs/shared/droid_hf/droid_real0918/droid_pi05'
    # val_id = ['135749', '135849', '135931']
    # start_idx = [0]*len(val_id)
    # instruction = ['close the laptop']*len(val_id)


    glove_pos = 'top right'
    gripper_max = 1.0
    gripper_max_block = 0.8
    gripper_max_towel = 0.9
    gripper_max_sponge = 0.5
    policy_skip_step = 2
    task_name = f'replay_demo_0917_cfg{int(guidance_scale*10)}_{policy_skip_step}_{gripper_max}'


    dynamics_model_path = '/cephfs/cjyyj/code/video_evaluation/output_dynamics/model2_15_9.pth' #None #'/cephfs/cjyyj/code/video_evaluation/output_dynamics/model2_epoch_9.pth' #None #'output_dynamics/model3_epoch_6.pth' #'output_dynamics/model_epoch_1.pth'
    policy_type = 'pi05' # 'pi05' # 'pi0' # 'pi0fast'

    text_cond = True

    # CUDA_VISIBLE_DEVICES=0 python exp33_generate_data_demo.py
    # CUDA_VISIBLE_DEVICES=1 python exp33_generate_data_demo.py
    # CUDA_VISIBLE_DEVICES=2 python exp33_generate_data_demo.py
    # CUDA_VISIBLE_DEVICES=3 python exp33_generate_data_demo.py
    # CUDA_VISIBLE_DEVICES=7 python exp33_generate_data_demo.py
        
if __name__ == "__main__":

    Agent = agent()
    interact_num = 14
    pred_step = Agent.args.pred_step # important parameters
    num_history = Agent.args.num_history
    num_frames = Agent.args.num_frames


    for val_id_i, text_i, start_idx_i in zip(Args.val_id, Args.instruction, Args.start_idx):
        
        # get initial state, groud truth actions
        id = val_id_i
        eef_gt, joint_pos_gt, video_dict, video_latents = Agent.get_traj_info(val_id_i, start_idx=start_idx_i, steps=int(pred_step*interact_num+8))
        print("At the episode beginnig, the state is:", "eef pose 0", eef_gt[0], "joint 0", joint_pos_gt[0])

        # t=0 for each episode
        predict_latents = None
        
        video_to_save = []
        info_to_save = []

        his_cond = []
        his_joint = []
        his_eef = []
        first_latent = torch.cat([v[0] for v in video_latents], dim=1).unsqueeze(0)  # (1, 4, 72, 40)
        assert first_latent.shape == (1, 4, 72, 40), f"Expected first_latent shape (1, 4, 72, 40), got {first_latent.shape}"
        for i in range(Agent.args.num_history*4):
            his_cond.append(first_latent)  # (1, 4, 72, 40)
            his_joint.append(joint_pos_gt[0:1])  # (1, 7)
            his_eef.append(eef_gt[0:1])  # (1, 7)
        # his_cond = [v[0] for v in video_latents]
        # his_action = action_gt[0:1]  # (4, 7)



        # interact loop
        for i in range(interact_num):
            video_latent_true = [v[int(i*pred_step):int(i*pred_step+num_frames)] for v in video_latents]
            
            # prepare input for policy
            joint_first = his_joint[-1][0] # (1, 8)
            state_first = his_eef[-1][0] # (1, 8)
            if i==0:
                video_first = [v[0] for v in video_dict]
            else:
                video_first = [v[-1] for v in video_dict_pred]
            assert joint_first.shape == (8,), f"Expected joint_first shape (8,), got {joint_first.shape}"
            assert state_first.shape == (7,), f"Expected state_first shape (7,), got {state_first.shape}"
            
            # forward policy
            print("################ policy forward ####################")
            policy_in_out, joint_pos, state_pos= Agent.forward_policy(video_first, state_first, joint_first, text=text_i, time_step=i)
            print("policy output eef pose", state_pos) # output xyz and gripper for debug

            
            # forward world model
            print("################ world model forward ################")
            print(f'traj_id:{val_id_i}, interact step: {i}!!!!!!!!')

            history_idx = [-15,-15,-12,-9,-6,-3]
            history_idx = [-15,-15,-8,-6,-4,-2]
            # history_idx = [-15,-15,-15,-3,-2,-1]
            history_idx = [0,0,0,-9,-6,-3]

            # history_idx = [-15,-15,-15,-12,-8,-4]
            his_eef_input = np.concatenate([his_eef[idx] for idx in history_idx], axis=0)  # (4, 7)
            action_input = np.concatenate([his_eef_input, state_pos], axis=0)
            his_cond_input = torch.cat([his_cond[idx] for idx in history_idx], dim=0).unsqueeze(0)
            video_latent_first = his_cond[-1]  # (1, 4, 72, 40)

            assert video_latent_first.shape == (1, 4, 72, 40), f"Expected video_latent_first shape (1, 4, 72, 40), got {video_latent_first.shape}"
            assert action_input.shape == (11, 7), f"Expected action_input shape (11, 7), got {action_input.shape}"
            assert his_cond_input.shape == (1, 6, 4, 72, 40), f"Expected his_cond_input shape (1, 6, 72, 40), got {his_cond_input.shape}"


            videos_cat, true_videos, video_dict_pred, predict_latents = Agent.forward_wm(action_input, video_latent_true, video_latent_first, his_cond=his_cond_input,text=text_i if Agent.args.text_cond else None)
            
            his_joint.append(joint_pos[pred_step-1:pred_step])  # (1, 8)
            his_eef.append(state_pos[pred_step-1:pred_step])
            his_cond.append(torch.cat([v[pred_step-1] for v in predict_latents], dim=1).unsqueeze(0))  # (1, 4, 72, 40)

            video_to_save.append(videos_cat[:pred_step-1]) # last frame is the first frame of next step, so we only save the first pred_step-1 frames
            # info_to_save.append(policy_in_out)  # save policy output info
            num = int(3*(pred_step-1))
            policy_in_out = {key: value[:num] for key, value in policy_in_out.items()}  # only save the first pred_step-1 frames
            info_to_save.append(policy_in_out)  # save policy output info
                
        video = np.concatenate(video_to_save, axis=0)

        
        # save rollout video and info with parameters
        task_name = Args.task_name
        wm_id = Agent.args.val_model_path.split('/')[-1].split('.')[0]
        text_id = text_i.replace(' ', '_').replace(',', '').replace('.', '').replace('\'', '').replace('\"', '')[:30]
        num_inference_steps = Agent.args.num_inference_steps
        guidance_scale = Agent.args.guidance_scale
        # path for save predicted videos 
        videos_dir = Args.val_model_path.split('/')[:-1]
        videos_dir = '/'.join(videos_dir)
        uuid = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        filename_video = f"{videos_dir}/{task_name}/video/time_{uuid}_traj_{val_id_i}_{start_idx_i}_{pred_step}_{wm_id}_{guidance_scale}_{num_inference_steps}_{text_id}.mp4"
        os.makedirs(os.path.dirname(filename_video), exist_ok=True)
        mediapy.write_video(filename_video, video, fps=5)
        print(f"Saving video to {filename_video}")
        print("##########################################################################")

        # save info 
        # gather all the dict 
        info = {'success': 1, 'start_idx': 0, 'end_idx': video.shape[0]-1, 'instructions':text_i}
        for key in info_to_save[0].keys():
            info[key] = []
            for i in range(len(info_to_save)):
                info[key]+=info_to_save[i][key].tolist()[:int(pred_step*3)]
        
        # save to json
        filename_info = f"{videos_dir}/{task_name}/info/time_{uuid}_traj_{val_id_i}_{start_idx_i}_{pred_step}_{wm_id}_{guidance_scale}_{num_inference_steps}_{text_id}.json"
        os.makedirs(os.path.dirname(filename_info), exist_ok=True)
        with open(filename_info, 'w') as f:
            json.dump(info, f, indent=4)
        
        
        
