from diffusers import StableVideoDiffusionPipeline
from omegaconf import OmegaConf
import numpy as np
# import cv2
import torch
import torch.nn.functional as F
import torch.nn as nn
import einops
from accelerate import Accelerator
import datetime
import os
from accelerate.logging import get_logger
from tqdm.auto import tqdm
import wandb
import imageio
from video_models.pipeline import MaskStableVideoDiffusionPipeline
import json
from decord import VideoReader, cpu
import swanlab
import mediapy

def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
    """
    grid_size: int of the grid height and width
    return:
    pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
    """
    grid_h = np.arange(grid_size, dtype=np.float32)
    grid_w = np.arange(grid_size, dtype=np.float32)
    grid = np.meshgrid(grid_w, grid_h)  # here w goes first
    grid = np.stack(grid, axis=0)

    grid = grid.reshape([2, 1, grid_size, grid_size])
    pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
    if cls_token and extra_tokens > 0:
        pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
    return pos_embed


def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
    assert embed_dim % 2 == 0

    # use half of dimensions to encode grid_h
    emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])  # (H*W, D/2)
    emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])  # (H*W, D/2)

    emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
    return emb


def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
    """
    embed_dim: output dimension for each position
    pos: a list of positions to be encoded: size (M,)
    out: (M, D)
    """
    assert embed_dim % 2 == 0
    omega = np.arange(embed_dim // 2, dtype=np.float64)
    omega /= embed_dim / 2.
    omega = 1. / 10000**omega  # (D/2,)

    pos = pos.reshape(-1)  # (M,)
    out = np.einsum('m,d->md', pos, omega)  # (M, D/2), outer product

    emb_sin = np.sin(out) # (M, D/2)
    emb_cos = np.cos(out) # (M, D/2)

    emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)
    return emb

class Action_encoder2(nn.Module):
    def __init__(self, action_dim, action_num, hidden_size, text_cond=True, frame_independent=True):
        super().__init__()
        self.action_dim = action_dim
        self.action_num = action_num
        self.hidden_size = hidden_size
        self.text_cond = text_cond

        # j0, jt, jv0-15
        input_dim = int(action_dim) if frame_independent else int(action_dim*action_num)
        self.action_encode = nn.Sequential(
            nn.Linear(input_dim, 1024),
            nn.SiLU(),
            nn.Linear(1024, 1024),
            nn.SiLU(),
            nn.Linear(1024, 1024)
        )
        # kaiming initialization
        nn.init.kaiming_normal_(self.action_encode[0].weight, mode='fan_in', nonlinearity='relu')
        nn.init.kaiming_normal_(self.action_encode[2].weight, mode='fan_in', nonlinearity='relu')

    def forward(self, action,  texts=None, text_tokinizer=None, text_encoder=None, frame_independent=True,):
        # action: (B, action_num, action_dim)
        B,T,D = action.shape
        if not frame_independent:
            action = einops.rearrange(action, 'b t d -> b 1 (t d)')
        action = self.action_encode(action)

        if texts is not None and self.text_cond:
            # with 50% probability, add text condition
            with torch.no_grad():
                inputs = text_tokinizer(texts, padding='max_length', return_tensors="pt", truncation=True).to(text_encoder.device)
                outputs = text_encoder(**inputs)
                hidden_text = outputs.text_embeds # (B, 512)
                hidden_text = einops.repeat(hidden_text, 'b c -> b 1 (n c)', n=2) # (B, 1, 1024)
            
            action = action + hidden_text # (B, T, hidden_size)
        return action # (B, 1, hidden_size) or (B, T, hidden_size) if frame_independent


class FuseSVD(nn.Module):
    def __init__(self, args, device='cuda'):
        super(FuseSVD, self).__init__()

        self.args = args
        self.pipeline = StableVideoDiffusionPipeline.from_pretrained(args.pretrained_model_path)
        self.unet = self.pipeline.unet
        self.vae = self.pipeline.vae
        self.image_encoder = self.pipeline.image_encoder
        self.scheduler = self.pipeline.scheduler

        # update self.unet.in_conv for 3 images condition
        # self.conv_in = nn.Conv2d(
        #     24,
        #     320,
        #     kernel_size=3,
        #     padding=1,
        # )
        # self.conv_in.weight.data[:,:8] = self.unet.conv_in.weight.data[:,:8]
        # self.conv_in.weight.data[:,8:] = 0.0
        # self.conv_in.bias.data = self.unet.conv_in.bias.data
        # del self.unet.conv_in
        # self.unet.conv_in = self.conv_in

        # self.unet.conv_in.requires_grad_(True)


        from transformers import AutoTokenizer, CLIPTextModelWithProjection
        self.text_encoder = CLIPTextModelWithProjection.from_pretrained(args.clip_model_path)
        self.tokenizer = AutoTokenizer.from_pretrained(args.clip_model_path,use_fast=False)
        self.text_encoder.requires_grad_(False)
        self.vae.requires_grad_(False)
        self.image_encoder.requires_grad_(False)

        self.unet.requires_grad_(True)
        self.unet.enable_gradient_checkpointing()

        # self.vision_encoder = VideoEncoder(hidden_size=1024)

        self.action_encoder = Action_encoder2(action_dim=args.action_dim, action_num=int(args.num_history+args.num_frames), hidden_size=1024, text_cond=args.text_cond, frame_independent=args.frame_independent)

        # count parameters num in each part
        num_params = sum(p.numel() for p in self.unet.parameters())
        print(f"Number of parameters in the unet: {num_params/1000000:.2f}M")
        num_params = sum(p.numel() for p in self.vae.parameters())
        print(f"Number of parameters in the vae: {num_params/1000000:.2f}M")
        num_params = sum(p.numel() for p in self.image_encoder.parameters())
        print(f"Number of parameters in the image_encoder: {num_params/1000000:.2f}M")
        num_params = sum(p.numel() for p in self.text_encoder.parameters())
        print(f"Number of parameters in the text_encoder: {num_params/1000000:.2f}M")
        num_params = sum(p.numel() for p in self.action_encoder.parameters())
        print(f"Number of parameters in the vision_encoder: {num_params/1000000:.2f}M")
    

    def forward(self, batch, get_new_img=False):
        latents = batch['latent'] # (B, 16, 4, 32, 32)
        texts = batch['text']
        dtype = self.unet.dtype
        device = self.unet.device
        P_mean=0.7
        P_std=1.6
        noise_aug_strength = 0.0

        num_history  = self.args.num_history

        latents = latents.to(device) #[B, num_history+num_frames]
        image_latent = latents[:,num_history:(num_history+1)] # (B, 1, 4, 32, 32)

        bsz,num_frames = latents.shape[:2]
        image_latent = image_latent[:,0] # (B, 4, 32, 32)
        sigma = torch.rand([bsz, 1, 1, 1], device=device) * 0.2
        c_in = 1 / (sigma**2 + 1) ** 0.5
        image_latent = c_in*(image_latent + torch.randn_like(image_latent) * sigma)

        condition_latent = einops.repeat(image_latent, 'b c h w -> b f c h w', f=num_frames) # (8, 16,12, 32,32)
        if self.args.his_cond_zero:
            condition_latent[:, :num_history] = 0.0 # (B, num_history+num_frames, 4, 32, 32)


        # condition
        action = batch['action'] # (B, f, 7)
        action = action.to(device)
        encoder_hidden_states = self.action_encoder(action, texts, self.tokenizer, self.text_encoder, frame_independent=self.args.frame_independent) # (B, f, 1024)
        uncond_hidden_states = torch.zeros_like(encoder_hidden_states)
        text_mask = (torch.rand(encoder_hidden_states.shape[0], device=device)>0.05).unsqueeze(1).unsqueeze(2)
        encoder_hidden_states = encoder_hidden_states*text_mask+uncond_hidden_states*(~text_mask)


        # add noise to latents
        rnd_normal = torch.randn([bsz, 1, 1, 1, 1], device=device)
        sigma = (rnd_normal * P_std + P_mean).exp()
        c_skip = 1 / (sigma**2 + 1)
        c_out =  -sigma / (sigma**2 + 1) ** 0.5
        c_in = 1 / (sigma**2 + 1) ** 0.5
        c_noise = (sigma.log() / 4).reshape([bsz])
        loss_weight = (sigma ** 2 + 1) / sigma ** 2

        noisy_latents = (latents + torch.randn_like(latents) * sigma)

        # add noise to history
        sigma_h = torch.randn([bsz, num_history, 1, 1, 1], device=device) * 0.3
        history = latents[:,:num_history] # (B, num_history, 4, 32, 32)
        noisy_history = 1/(sigma_h**2+1)**0.5 *(history + sigma_h * torch.randn_like(history)) # (B, num_history, 4, 32, 32)
        input_latents = torch.cat([noisy_history, c_in*noisy_latents[:,num_history:]], dim=1) # (B, num_history+num_frames, 4, 32, 32)

        # input_latents = torch.cat([latents[:,:num_history], c_in*noisy_latents[:,num_history:]], dim=1) # (B, num_history+num_frames, 4, 32, 32)


        input_latents = torch.cat([input_latents, condition_latent/self.vae.config.scaling_factor], dim=2)
        motion_bucket_id = self.args.motion_bucket_id
        fps = self.args.fps
        added_time_ids = self.pipeline._get_add_time_ids(fps, motion_bucket_id, noise_aug_strength, encoder_hidden_states.dtype, bsz, 1, False)
        added_time_ids = added_time_ids.to(device)

        # caculate loss
        loss = 0
        model_pred = self.unet(input_latents, c_noise, encoder_hidden_states=encoder_hidden_states, added_time_ids=added_time_ids,frame_independent=self.args.frame_independent).sample
        predict_x0 = c_out * model_pred + c_skip * noisy_latents 
        loss += ((predict_x0[:,num_history:] - latents[:,num_history:])**2 * loss_weight).mean()
        return loss, torch.tensor(0.0, device=device,dtype=dtype) # return loss and L2_distance




def main():
    
    args = Args()
    logger = get_logger(__name__, log_level="INFO")
    swanlab.sync_wandb()
    accelerator = Accelerator(
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        mixed_precision=args.mixed_precision,
        log_with='wandb',
        project_dir=args.output_dir
    )
    if accelerator.is_main_process:
        now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
        tag = args.tag
        run_name = f"train_{now}_{tag}"
        accelerator.init_trackers(args.project_name,config={}, init_kwargs={"wandb":{"name":run_name}})
    os.makedirs(args.output_dir, exist_ok=True)

    model = FuseSVD(args)
    if args.ckpt_path is not None:
        print(f"Loading checkpoint from {args.ckpt_path}!!!!!!")
        state_dict = torch.load(args.ckpt_path, map_location='cpu')
        model.load_state_dict(state_dict, strict=True)
    model.to(accelerator.device)
    model.train()

    optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate)

    from video_dataset.dataset_droid_exp33 import Dataset_mix
    train_dataset = Dataset_mix(args,mode='train')
    val_dataset = Dataset_mix(args,mode='val')

    # DataLoaders creation:
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset, 
        batch_size=args.train_batch_size,
        shuffle=args.shuffle
    )

    
    val_dataloader = torch.utils.data.DataLoader(
        val_dataset, 
        batch_size=args.train_batch_size,
        shuffle=args.shuffle
    )

    # Prepare everything with our `accelerator`.
    model, optimizer, train_dataloader, val_dataloader = accelerator.prepare(
        model, optimizer, train_dataloader, val_dataloader
    )
    import math

    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
    num_train_epochs = math.ceil(args.max_train_steps * args.gradient_accumulation_steps*total_batch_size / len(train_dataloader))
    logger.info("***** Running training *****")
    logger.info(f"  Num examples = {len(train_dataset)}")
    logger.info(f"  Num Epochs = {args.num_train_epochs}")
    logger.info(f"  Instantaneous batch size per device = {args.train_batch_size}")
    logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
    logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
    logger.info(f"  Total optimization steps = {args.max_train_steps}")
    logger.info(f"  checkpointing_steps = {args.checkpointing_steps}")
    logger.info(f"  validation_steps = {args.validation_steps}")

    global_step = 0
    forward_step=0
    first_epoch = 0
    train_loss = 0.0
    train_loss_l2 = 0.0
    # Only show the progress bar once on each machine.
    progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
    progress_bar.set_description("Steps")

    for epoch in range(num_train_epochs):
        for step, batch in enumerate(train_dataloader):
            with accelerator.accumulate(model):
                with accelerator.autocast():
                    loss_gen, loss_l2 = model(batch)
                # loss = loss_gen + L2_dis * 0.001
                avg_loss = accelerator.gather(loss_gen.repeat(args.train_batch_size)).mean()
                train_loss += avg_loss.item()/ args.gradient_accumulation_steps
                avg_loss_l2 = accelerator.gather(loss_l2.repeat(args.train_batch_size)).mean()
                train_loss_l2 += avg_loss_l2.item()/ args.gradient_accumulation_steps

                accelerator.backward(loss_gen + loss_l2 * 0.00)
                params_to_clip = model.parameters()
                if accelerator.sync_gradients:
                    accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)

                optimizer.step()
                optimizer.zero_grad()
                forward_step += 1
            
            if accelerator.sync_gradients:
                progress_bar.update(1)
                global_step += 1
                if global_step %100 == 0:
                    progress_bar.set_postfix({"loss": train_loss})
                    accelerator.log({"train_loss": train_loss/100}, step=global_step)
                    accelerator.log({"L2_distance": train_loss_l2/100}, step=global_step)
                    train_loss = 0.0
                    train_loss_l2 = 0.0


                if global_step % args.checkpointing_steps == 0 and accelerator.is_main_process:
                    save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}.pt")
                    # unwrap_model = accelerator.unwrap_model(model)
                    torch.save(accelerator.unwrap_model(model).state_dict(), save_path)
                    logger.info(f"Saved checkpoint to {save_path}")

                if global_step % args.validation_steps == 5 and accelerator.is_main_process:
                    model.eval()
                    with accelerator.autocast():
                        for id in range(args.video_num):
                            validate_video_generation(model, val_dataset, args,global_step, args.output_dir, id, accelerator)
                    model.train()
                # print(global_step,accelerator.local_process_index, accelerator.num_processes)
                # accelerator.wait_for_everyone()



def main_val():
    args = Args()
    accelerator = Accelerator()
    model = FuseSVD(args)
    # load form val_model_path
    print("load from val_model_path",args.val_model_path)
    model.load_state_dict(torch.load(args.val_model_path))
    model.to(accelerator.device)
    model.eval()
    validate_video_generation(model, None, args, 0, 'output', 0, accelerator, load_from_dataset=False)
    
            

def validate_video_generation(model, val_dataset, args, train_steps, videos_dir, id, accelerator, load_from_dataset=True):
    device = accelerator.device
    pipeline = model.module.pipeline if accelerator.num_processes > 1 else model.pipeline
    videos_row = args.video_num if not args.debug else 1
    videos_col = 2
    mask_frame_num = 1

    if load_from_dataset:
        batch_id = list(range(0,len(val_dataset),int(len(val_dataset)/videos_row/videos_col)))
        batch_id = batch_id[int(id*(videos_col)):int((id+1)*(videos_col))]
        batch_list = [val_dataset.__getitem__(id) for id in batch_id]
        true_video = torch.cat([t['latent'].unsqueeze(0) for i,t in enumerate(batch_list)],dim=0).to(device, non_blocking=True)
        # cond1 = torch.cat([t['latent_cond1'].unsqueeze(0) for i,t in enumerate(batch_list)],dim=0).to(device, non_blocking=True)
        # cond2 = torch.cat([t['latent_cond2'].unsqueeze(0) for i,t in enumerate(batch_list)],dim=0).to(device, non_blocking=True)
        # cond3 = torch.cat([t['latent_cond3'].unsqueeze(0) for i,t in enumerate(batch_list)],dim=0).to(device, non_blocking=True)
        # cond4 = torch.cat([t['latent_cond4'].unsqueeze(0) for i,t in enumerate(batch_list)],dim=0).to(device, non_blocking=True)
        text = [t['text'] for i,t in enumerate(batch_list)]
        # actions = [t['action'] for i,t in enumerate(batch_list)]
        actions = torch.cat([t['action'].unsqueeze(0) for i,t in enumerate(batch_list)],dim=0).to(device, non_blocking=True)
        print("actions",actions.shape)
        true_history, true_video = true_video[:,:args.num_history], true_video[:,args.num_history:] # (2, 4, 4, 32, 32), (2, 5, 4, 32, 32)
        image = true_video[:,0] #(8,4,32,32)

    
    else:
        val_dataset_dir = args.val_dataset_dir
        val_id = args.val_id.split('+')
        val_id = [int(id) for id in val_id]

        actions = []
        true_videos = []

        with open(f"{args.data_stat_path}", 'r') as f:
            data_stat = json.load(f)
            state_p01 = np.array(data_stat['state_01'])[None,:]
            state_p99 = np.array(data_stat['state_99'])[None,:]
        def normalize_bound(
            data: np.ndarray,
            data_min: np.ndarray,
            data_max: np.ndarray,
            clip_min: float = -1,
            clip_max: float = 1,
            eps: float = 1e-8,
        ) -> np.ndarray:
            ndata = 2 * (data - data_min) / (data_max - data_min + eps) - 1
            return np.clip(ndata, clip_min, clip_max)
       

        for id in val_id:
            annotation_path = f"{val_dataset_dir}/annotation/val/{id}.json"
            # annotation_path = f"{input_dir}/annotation/train/{id}.json"
            with open(annotation_path) as f:
                anno = json.load(f)
                try:
                    length = len(anno['action'])
                except:
                    length = anno["video_length"]
            video_path = anno['videos'][0]['video_path']
            video_path = f"{val_dataset_dir}/{video_path}"
            action = anno['actions']

            # load videos and actions
            vr = VideoReader(video_path, ctx=cpu(0), num_threads=2)
            try:
                true_video = vr.get_batch(range(length)).asnumpy()
            except:
                true_video = vr.get_batch(range(length)).numpy()
            skip = args.skip_step
            start_idx = args.start_idx
            end_idx = start_idx + int(args.num_frames*skip)
            # print(true_video.shape, start_idx, end_idx)
            if true_video.shape[0] < end_idx:
                true_video = torch.concat([true_video, true_video[-1].unsqueeze(0).repeat(end_idx-true_video.size(0),1,1,1)], dim=0)
                action = np.concatenate([action, action[-1].unsqueeze(0).repeat(end_idx-len(action),1)], axis=0)
            true_video = true_video[start_idx:end_idx]
            true_video = true_video[::skip]
            action = np.array(action)
            action = action[start_idx:end_idx]
            action = action[::skip]

            if 'xhand' in args.val_dataset_dir:
                for step in range(action.shape[0]):
                    action[step,3:7] = -action[step,3:7] if action[step,3] < 0 else action[step,3:7]
            action = normalize_bound(action, state_p01, state_p99)

            true_videos.append(true_video)
            actions.append(action)

        true_video = torch.tensor(true_videos).to(device).float()
        actions = torch.tensor(actions).to(device).float()
        print("action after encode",actions.shape) # (8, 16, 19)
        print("true_video before encode",true_video.shape) # (8, 16, 256, 256, 3)
        # import pdb; pdb.set_trace()

        # encode video
        vae = pipeline.vae

        bsz, frame_num = true_video.shape[:2]
        true_video = true_video.flatten(0,1)
        x = true_video.permute(0,3,1,2).to(device) / 255.0*2-1
        
        # x = torch.nn.functional.interpolate(x, size=(128,320), mode='bilinear', align_corners=False).to(device)   # (frame, h, w, c) -> (frame, c, h, w)
        
        with torch.no_grad():
            batch_size = 32
            latents = []
            for i in range(0, len(x), batch_size):
                batch = x[i:i+batch_size]
                latent = vae.encode(batch).latent_dist.sample().mul_(vae.config.scaling_factor)
                # x = vae.encode(x).latent_dist.sample().mul_(vae.config.scaling_factor).cpu()
                latents.append(latent)
            x = torch.cat(latents, dim=0)
            x = x.reshape(bsz, frame_num, *x.shape[1:])
        true_video = x #(8,16,4,32,32)
        image = true_video[:,0] #(8,4,32,32)
    
    print("image",image.shape, 'action', actions.shape)
    assert image.shape[1:] == (4, 72, 40)
    assert actions.shape[1:] == (int(args.num_frames+args.num_history), args.action_dim)

    # actions[:] = actions[5:6]   # 
    # start generate
    with torch.no_grad():
        # image = model(batch, get_new_img=True) #(8,4,32,32)
    
        # bsz = len(text)
        bsz = actions.shape[0]
        # text_token = torch.zeros(bsz,1,1024).to(device)
        text_token = model.module.action_encoder(actions, text, model.module.tokenizer, model.module.text_encoder, args.frame_independent) if accelerator.num_processes > 1 else model.action_encoder(actions, text, model.tokenizer, model.text_encoder,args.frame_independent) # (8, 1, 1024)
        print("text_token",text_token.shape)

        # import pdb; pdb.set_trace()
        _, latents = MaskStableVideoDiffusionPipeline.__call__(
            pipeline,
            image=image,
            text=text_token,
            width=320,
            height=576,
            num_frames=args.num_frames,
            history=true_history,
            num_inference_steps=args.num_inference_steps,
            decode_chunk_size=args.decode_chunk_size,
            max_guidance_scale=args.guidance_scale,
            fps=args.fps,
            motion_bucket_id=args.motion_bucket_id,
            mask=None,
            output_type='latent',
            return_dict=False,
            frame_independent=args.frame_independent,
            his_cond_zero=args.his_cond_zero,
        )
        # print("videos_num",len(videos))
    
    latents = einops.rearrange(latents, 'b f c (m h) (n w) -> (b m n) f c h w', m=3,n=1) # (B, 8, 4, 32,32)

    true_video = torch.cat([true_history, true_video], dim=1) # (B, 8, 4, 32,32)
    true_video = einops.rearrange(true_video, 'b f c (m h) (n w) -> (b m n) f c h w', m=3,n=1) # (B, 8, 4, 32,32)
    
    if true_video.shape[2] != 3:
        # decode latent
        decoded_video = []
        bsz,frame_num = true_video.shape[:2]
        true_video = true_video.flatten(0,1)
        decode_kwargs = {}
        for i in range(0,true_video.shape[0],args.decode_chunk_size):
            chunk = true_video[i:i+args.decode_chunk_size]/pipeline.vae.config.scaling_factor
            decode_kwargs["num_frames"] = chunk.shape[0]
            decoded_video.append(pipeline.vae.decode(chunk, **decode_kwargs).sample)
        true_video = torch.cat(decoded_video,dim=0)
        true_video = true_video.reshape(bsz,frame_num,*true_video.shape[1:])
        
        
        decoded_video = []
        bsz,frame_num = latents.shape[:2]
        latents = latents.flatten(0,1)
        decode_kwargs = {}
        for i in range(0,latents.shape[0],args.decode_chunk_size):
            chunk = latents[i:i+args.decode_chunk_size]/pipeline.vae.config.scaling_factor
            decode_kwargs["num_frames"] = chunk.shape[0]
            decoded_video.append(pipeline.vae.decode(chunk, **decode_kwargs).sample)
        videos = torch.cat(decoded_video,dim=0)
        videos = videos.reshape(bsz,frame_num,*videos.shape[1:])

    # import pdb; pdb.set_trace()
    true_video = ((true_video / 2.0 + 0.5).clamp(0, 1)*255)
    true_video = true_video.to(torch.float32).detach().cpu().numpy().transpose(0,1,3,4,2).astype(np.uint8) #(2,16,256,256,3)

    videos = ((videos / 2.0 + 0.5).clamp(0, 1)*255)
    videos = videos.to(torch.float32).detach().cpu().numpy().transpose(0,1,3,4,2).astype(np.uint8) #(2,16,256,256,3)
    videos = np.concatenate([true_video[:, :args.num_history],videos],axis=1) #(2,16,512,256,3)

    videos = np.concatenate([true_video,videos],axis=-3) #(2,16,512,256,3)
    videos = np.concatenate([video for video in videos],axis=-2).astype(np.uint8) # (16,512,256*batch,3)
    
    os.makedirs(f"{videos_dir}/samples", exist_ok=True)
    filename = f"{videos_dir}/samples/train_steps_{train_steps}_{id}.mp4"
    mediapy.write_video(filename, videos, fps=2) # 8 is the quality, 1 is the lowest, 10 is the highest
    # writer = imageio.get_writer(filename, fps=4) # fps
    # for frame in videos:
    #     writer.append_data(frame)
    # writer.close()
    name = videos_dir.split('/')[-1]
    if load_from_dataset:
        wandb.log({f"{name}_train_steps_{train_steps}": wandb.Video(filename, fps=4, format="mp4")})
    return 


class Args:
    pretrained_model_path = "/cephfs/shared/llm/stable-video-diffusion-img2vid" #  #"/cephfs/shared/llm/stable-video-diffusion-img2vid" # "/cephfs/cjyyj/code/video_robot_svd/output/svd/train_2025-05-08T01-09-10/checkpoint-320000" #"/cephfs/shared/llm/stable-video-diffusion-img2vid"
    clip_model_path = "/cephfs/shared/llm/clip-vit-base-patch32"

    ckpt_path = '/cephfs/cjyyj/code/video_evaluation/output2/exp33_ablation_noframecross_notext/checkpoint-150000.pt' #None #'/cephfs/cjyyj/code/video_evaluation/output2/exp33_ablation_noframecross/checkpoint-110000.pt' #'/cephfs/cjyyj/code/video_evaluation/output2/exp33_210_s11/checkpoint-10000.pt' #'/cephfs/cjyyj/code/video_evaluation/output2/exp33_droid_s11_skip8/checkpoint-20000.pt' #'/cephfs/cjyyj/code/video_evaluation/output2/exp31_droid_cond9_eef_text_post/checkpoint-90000.pt'##'/cephfs/cjyyj/code/video_evaluation/output2/exp33_droid_s11/checkpoint-20000.pt'#'/cephfs/cjyyj/code/video_evaluation/output2/exp31_droid_cond9_eef_text_post/checkpoint-90000.pt'#'/cephfs/cjyyj/code/video_evaluation/output2/exp31_210_post/checkpoint-70000.pt'# None #'/cephfs/cjyyj/code/video_evaluation/output2/exp31_droid_cond9_eef/checkpoint-80000.pt' # None #'/cephfs/cjyyj/code/video_evaluation/output2/exp26_droid/checkpoint-80000.pt' #None #'/cephfs/cjyyj/code/video_evaluation/output2/exp21_droid_joint2_his_a9_post_accu4_12/checkpoint-30000.pt'#'/cephfs/cjyyj/code/video_evaluation/output_unit_test/exp15_droid_joint2_his_a9/checkpoint-140000.pt'


    debug = False
    tag = 'exp33_ablation_noframecross_notext2'
    output_dir = f"output2/{tag}"
    project_name = "droid-memory"
    action_dim = 7

    load_path = None
    learning_rate= 1e-5 #1e-5 #1e-5 #5e-6
    gradient_accumulation_steps = 4
    mixed_precision = 'fp16'
    train_batch_size = 3
    shuffle = True

    num_train_epochs = 100
    max_train_steps = 500000
    checkpointing_steps = 5000
    validation_steps = 2500
    max_grad_norm = 1.0

    dataset_names = 'droid_svd_v2'
    dataset_cfgs = 'droid_svd_v2'
    prob=[1.0]

    dataset_dir= "/cephfs/shared/droid_hf"  # # "/cephfs/shared/droid_hf" #'/localssd/gyj' #  #'/cephfs/shared/droid_hf/opensource_robotdata' #"/localssd/gyj/opensource_robotdata" 
    data_root_path = "/cephfs/shared/droid_hf"  #'/localssd/gyj' #"/cephfs/shared/droid_hf" #'/localssd/gyj' #"/cephfs/shared/droid_hf"  #'/cephfs/shared/droid_hf/opensource_robotdata' #"/localssd/gyj/opensource_robotdata" #'/cephfs/shared/droid_hf/opensource_robotdata'
    annotation_name='annotation_all_skip1' # annotation dirname under dataset_dir path
    # dataset='xhand'
    
    tie_weight=True
    normalize=True
    pre_encode=True
    num_workers=4
    video_size = [256,256]
    
    only_one_clip = True
    get_action = True
    

    clip_img_size = 224
    use_img_cond =  False
    motion_bucket_id = 127
    fps = 7
    guidance_scale = 7.5 #3.0
    num_inference_steps = 30
    decode_chunk_size = 7
    width = 320
    height = 160
    validation_num= 32
    video_num= 12
    num_frames= 5
    num_actions = 5
    num_history = 6 #4
    skip_his = 8
    # sequence_length = 8
    data_stat_path = 'video_dataset/xhand_stat.json'
    data_json_path = 'video_dataset/xhand_mix_seq16_train.json' #'video_dataset/xhand_skip2_seq16_train.json'
    val_data_json_path = "video_dataset/xhand_mix_seq16_val.json" #'video_dataset/xhand_skip2_seq16_val.json'
    


    val_model_path = 'output_unit_test/test_xhand_action_cond3_cfg/checkpoint-10000.pt'
    val_dataset_dir = '/localssd/gyj/opensource_robotdata/xhand_1024_v2'
    val_id = '50+150+200+250+300+350+400+450'
    skip_step = 2
    start_idx = 10

    frame_independent = False #!!!!!
    his_cond_zero = False

    text_cond = False

    


if __name__ == "__main__":

    main()
    # main_val()

    # CUDA_VISIBLE_DEVICES=0,1 WANDB_MODE=offline accelerate launch --main_process_port 29503 exp33_ablation_noframecross.py
    # CUDA_VISIBLE_DEVICES=0 accelerate launch --main_process_port 29506 unit_test2.py

    # args = Args()
    # from video_dataset.dataset_droid_exp33 import Dataset_mix
    # dataset = Dataset_mix(args,mode='val')
    # from torch.utils.data import DataLoader
    # dataloader = DataLoader(dataset, batch_size=3, shuffle=True, num_workers=2)
    # model = FuseSVD(args).to('cuda')
    # # print model parameter num
    # num_params = sum(p.numel() for p in model.parameters())
    # print(f"Number of parameters in the model: {num_params/1000000:.2f}M")
    # optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=5e-6)
    # total_elements = sum(p.numel() for group in optimizer.param_groups for p in group['params'])
    # print(f"Total number of learnable parameters: {total_elements}")
    # model.train()
    

    # for batch in dataloader:
    #     print(batch['latent'].shape)
    #     print(batch['text'])
    #     print(batch['action'].shape)

    #     loss,_ = model(batch)
    #     loss.backward()
    #     optimizer.step()
    #     optimizer.zero_grad()
    #     print(loss.item())





    # device = 'cuda'
    # video_encoder = VideoEncoder(hidden_size=1024).to(device)
    # # count the parameters of the model
    # num_params = sum(p.numel() for p in video_encoder.parameters())
    # print(f"Number of parameters in the model: {num_params/1000000:.2f}M")
    # vae_latent = torch.randn(8, 1, 4, 32, 32).to(device)
    # clip_latent = torch.randn(8, 20, 512).to(device)
    # image_latent = video_encoder(vae_latent, clip_latent)
    # print(image_latent.shape)  # (8, 1, 4, 32, 32)


    # pos_emb = get_2d_sincos_pos_embed(1024, 16)
    # print(pos_emb.shape)  # (256, 1024)
    # clip_emb = get_1d_sincos_pos_embed_from_grid(1024, np.arange(20))
    # print(clip_emb.shape)  # (20, 512)
