import torch, os, imageio, argparse
from torchvision.transforms import v2
from einops import rearrange
import lightning as pl
import pandas as pd
from diffsynth import WanVideoPipeline, ModelManager
from peft import LoraConfig, inject_adapter_in_model
import torchvision
from PIL import Image
import numpy as np
import pickle
from random import choice
import random
from pytorch_lightning.callbacks import ModelCheckpoint
import cv2
def get_frame_count_opencv(video_path):
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        raise ValueError("Fail to open video")
    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    cap.release()
    return frame_count
import numpy as np
import json

import torch
import torch.nn as nn
import torch.nn.functional as F

class PointMapEncoder(nn.Module):
    def __init__(self, input_dim=3, latent_dim=128):
        super(PointMapEncoder, self).__init__()
        
        # Define layers for the encoder
        self.conv1 = nn.Conv1d(input_dim, 64, kernel_size=1)
        self.conv2 = nn.Conv1d(64, 128, kernel_size=1)
        self.conv3 = nn.Conv1d(128, 256, kernel_size=1)
        self.fc1 = nn.Linear(256, latent_dim)
        
        # Batch normalization layers
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(256)
    
    def forward(self, x):
        # Forward pass through the network
        x = self.bn1(F.relu(self.conv1(x)))
        x = self.bn2(F.relu(self.conv2(x)))
        x = self.bn3(F.relu(self.conv3(x)))
        
        # Global feature pooling
        x = torch.max(x, dim=2)[0]  
        
        # Fully connected layer
        x = self.fc1(x)
        
        return x

from scipy.spatial.transform import Rotation as R
def imu_to_root_pose(gyro_data, accl_data, dt):
    rotation = R.from_euler('xyz', [0, 0, 0], degrees=False)
    for i in range(len(gyro_data)):
        angular_rate = gyro_data[i]
        delta_rotation = R.from_rotvec(angular_rate * dt)
        rotation = delta_rotation * rotation
        
        accl = accl_data[i]
        gravity = accl / np.linalg.norm(accl)
        pitch = np.arctan2(-gravity[0], np.sqrt(gravity[1]**2 + gravity[2]**2))
        roll = np.arctan2(gravity[1], gravity[2])
        correction = R.from_euler('xyz', [pitch, roll, 0])
        rotation = correction * rotation
    
    return rotation.as_rotvec()
from transforms3d.axangles import mat2axangle
def extrinsics_to_root_pose(extrinsics):
    """
    Convert camera extrinsics to SMPL-X root_pose parameters.
    
    Args:
        extrinsics (np.ndarray): Camera extrinsics array of shape (120, 4, 4)
        
    Returns:
        np.ndarray: SMPL-X root_pose parameters of shape (120, 3)
    """
    num_frames = extrinsics.shape[0]
    root_poses = np.zeros((num_frames, 3))
    
    # Coordinate system transformation matrix:
    # Converts from OpenCV camera coordinates (X=right, Y=down, Z=forward)
    # to SMPL-X coordinates (X=right, Y=up, Z=backward)
    R_adjust = np.array([
        [1,  0,  0],
        [0, -1,  0],
        [0,  0, -1]
    ], dtype=np.float32)
    
    for i in range(num_frames):
        # 1. Extract camera rotation matrix
        R_cam = extrinsics[i, :3, :3]
        
        # 2. Convert to world-to-camera to camera-to-world
        R_cam_world = R_cam.T
        
        # 3. Apply coordinate system adjustment
        R_root = R_adjust @ R_cam_world
        
        # 4. Convert rotation matrix to axis-angle
        axis, angle = mat2axangle(R_root)
        rotation_vector = axis * angle
        
        # 5. Ensure numerical stability
        rotation_vector = np.nan_to_num(rotation_vector)
        root_poses[i] = rotation_vector
        
    return root_poses
class TextVideoDataset(torch.utils.data.Dataset):
    def __init__(self, base_path="", metadata_path="", max_num_frames=81, frame_interval=4, num_frames=81, height=480, width=832):
        # metadata = pd.read_csv(metadata_path)
        # self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]]
        # self.text = metadata["text"].to_list()
        
        # split = 0
        # self.path = np.load("./finetrainers_textvideo/training/cogvideox/filepath/"+"egocentric_videos_path"+str(split)+'.npy')
        # self.text = np.load("./finetrainers_textvideo/training/cogvideox/filepath/"+"egocentric_captions"+str(split)+'.npy')
        # self.text = [self.text[i][0] for i in range(len(self.text))]  
        

        self.dataset_smpl_chard = "./egocentric/meshsave_charades/"
        self.dataset_videos_chard = "./third_person/CharadesEgo_v1/"
        videoexo_paths = [self.dataset_videos_chard+p for p in os.listdir(self.dataset_videos_chard) if "EGO" not in p]
        smpl_paths = ["./egocentric/meshsave_charades/"+p for p in os.listdir("./egocentric/meshsave_charades/")]
        self.videoexo_paths_true = []
        smpl_paths_keys = [p.split("/")[-1].replace('.npy', '') for p in smpl_paths]
        for p in videoexo_paths:
            if p.split("/")[-1] in smpl_paths_keys:
                self.videoexo_paths_true.append(p)

        self.dataset_smpl_ego = np.load("./DiffSynth-Studio/examples/wanvideo/my_list.npy")
        self.dataset_egoexo4d = "./newdataset/takes/"
        
        self.videos_egoexo4d_past = [self.dataset_egoexo4d+p for p in os.listdir(self.dataset_egoexo4d)]
        # self.videos_egoexo4d = []
        # for p in self.videos_egoexo4d_past:
        #     files = os.listdir(p+"/frame_aligned_videos")
        #     if len(files)<4:
        #         continue 
        #     flag_calib, flag_smpl = 0, 0
        #     for q in files:
        #         if "Calib" in q:
        #             flag_calib  = 1
        #         if "smpl" in q and "resize" not in q:
        #             flag_smpl = 1
        #         if flag_smpl==1 and flag_calib==1:
        #             self.videos_egoexo4d.append(p)
        #             break
        with open('./DiffSynth-Studio/task_name.pkl', 'rb') as f:
            self.text = pickle.load(f)

        # with open('./camera_pose.pkl', 'rb') as f:
        #     self.camera_parameter = pickle.load(f)

        self.max_num_frames = max_num_frames
        self.frame_interval = frame_interval
        self.num_frames = num_frames
        self.height = height
        self.width = width
            
        self.frame_process = v2.Compose([
            v2.CenterCrop(size=(height, width)),
            v2.Resize(size=(height, width), antialias=True),
            v2.ToTensor(),
            v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
        ])
        split = 0
        self.video_metadata_pose = np.load("./DiffSynth-Studio/examples/wanvideo/filepath_all.npy", allow_pickle=True).item()
        self.video_metadata = pd.read_csv('./egocentric/egovid/egovid-kinematic.csv')
        self.videos_key = os.listdir("./egocentric/egovid/poses/")
        # self.video_parameters = [p for p in range()]
        # self.videos_egovid5m = ["./egocentric/egovid/cropped_videos/chunk_1/"+p for p in os.listdir("./egocentric/egovid/cropped_videos/chunk_1")]
        # self.videos_egovid5m.extend(["./egocentric/egovid/cropped_videos/chunk_2"+p for p in os.listdir("./egocentric/egovid/cropped_videos/chunk_2")])
        # self.videos_egovid5m.extend(["./egocentric/egovid/cropped_videos/chunk_3"+p for p in os.listdir("./egocentric/egovid/cropped_videos/chunk_3")])
        # self.videos_egovid5m.extend(["./egocentric/egovid/cropped_videos/chunk_4"+p for p in os.listdir("./egocentric/egovid/cropped_videos/chunk_4")])
        # self.path_egocentriconly = np.load("./finetrainers_textvideo/training/cogvideox/filepath/"+"egocentric_videos_path"+str(split)+'.npy')
        # self.text_egocentriconly = np.load("./finetrainers_textvideo/training/cogvideox/filepath/"+"egocentric_captions"+str(split)+'.npy')
        # self.text_egocentriconly = [self.text_egocentriconly[i][0] for i in range(len(self.text_egocentriconly))]  
        
    def crop_and_resize(self, image):
        width, height = image.size
        scale = max(self.width / width, self.height / height)
        image = torchvision.transforms.functional.resize(
            image,
            (round(height*scale), round(width*scale)),
            interpolation=torchvision.transforms.InterpolationMode.BILINEAR
        )
        return image


    def load_frames_using_imageio(self, file_path, max_num_frames, start_frame_id, interval, num_frames, frame_process):
        try:
            reader = imageio.get_reader(file_path)
        except Exception as e:
            return None, None    
            
            
        if reader.count_frames() < max_num_frames or reader.count_frames() - 1 < start_frame_id + (num_frames - 1) * interval:
            reader.close()
            return None
        
        frames = []
        frame_select = []
        for frame_id in range(num_frames):
            frame = reader.get_data(start_frame_id + frame_id * interval)
            frame_select.append(int(start_frame_id + frame_id * interval))
            frame = Image.fromarray(frame)
            frame = self.crop_and_resize(frame)
            frame = frame_process(frame)
            frames.append(frame)
        reader.close()

        frames = torch.stack(frames, dim=0)
        frames = rearrange(frames, "T C H W -> C T H W")

        return frames, frame_select


    def load_video(self, videoego_path, frame_num=110, flag=0):
        frame_num = get_frame_count_opencv(videoego_path)
        
        if flag == 1:
            start_frame_id = torch.randint(31, frame_num - (self.num_frames - 1) * self.frame_interval, (1,))[0]
        elif flag == 2:
            frame_num = 120
            start_frame_id = torch.randint(0, frame_num - (self.num_frames - 1) * self.frame_interval, (1,))[0]
        else:
            start_frame_id = torch.randint(151, frame_num - (self.num_frames - 1) * self.frame_interval, (1,))[0]
        frames,frame_select = self.load_frames_using_imageio(videoego_path, frame_num, start_frame_id, self.frame_interval, self.num_frames, self.frame_process)
        return frames, frame_select
  
    
    def is_image(self, file_path):
        file_ext_name = file_path.split(".")[-1]
        if file_ext_name.lower() in ["jpg", "png", "webp"]:
            return True
        return False
    
    
    def load_image(self, file_path):
        frame = Image.open(file_path).convert("RGB")
        frame = self.crop_and_resize(frame)
        frame = self.frame_process(frame)
        frame = rearrange(frame, "C H W -> C 1 H W")
        return frame
    
    def __getitem__(self, data_id):
        if data_id<20000:
            data_id = data_id%len(self.dataset_smpl_ego)
            return self.__getitemego__(data_id)
        elif data_id<40000:
            data_id = data_id%len(self.videoexo_paths_true)
            return self.__getitemegochard__(data_id)
        # else:
        #     data_id = data_id%len(self.videos_egovid5m)
        #     return self.__getitemegovid5m__(data_id)
     
    def __getitemegovid5m__(self, data_id):
        videopath = self.videos_key[data_id]
        videopath = videopath+".mp4"
        videopath = "./egocentric/egovid"+self.video_metadata_pose[videopath][-1]
        try:
            videoego, frame_select = self.load_video(videopath, flag=2)
            camera_parameter = np.load(videopath.split("cropped_videos")[0]+"poses/"+videopath.split('/')[-1].replace('.mp4','')+"/fused_pose.npy")
            smplx_root_pose = extrinsics_to_root_pose(camera_parameter).unsqueeze(1)
            smplx_body_pose = np.zeros((len(frame_select), 1, 63))
            smplx_lhand_pose = torch.zeros((len(frame_select), 1, 45))
            smplx_rhand_pose = torch.zeros((len(frame_select), 1, 45))
            smplx_jaw_pose = torch.zeros((len(frame_select), 1, 3))
            data = {"smplx_root_pose": smplx_root_pose[:,np.newaxis,:][frame_select], "smplx_body_pose": smplx_body_pose, "smplx_lhand_pose": smplx_lhand_pose, "smplx_rhand_pose": smplx_rhand_pose, "smplx_jaw_pose": smplx_jaw_pose}
            data["text"] = "None"
            data["video"] = videoego
            data["path"] = path
            return data
        except Exception as e:
            index = random.randint(0,  len(self.videos_key)-1)
            data = self.__getitemegovid5m__(index)
            return data
     
            
    def __getitemegochard__(self, data_id):
        
        videopath = self.videoexo_paths_true[data_id]
        smplpath = videopath.replace(self.dataset_videos_chard, self.dataset_smpl_chard)+'.npy'
        path = videopath.replace(".mp4", "EGO.mp4")
        
        try:
        
            videoego, frame_select = self.load_video(path, flag=1)
            smpl_dict = np.load(smplpath, allow_pickle=True)
            smplx_root_pose = []
            smplx_body_pose = []
            smplx_lhand_pose = []
            smplx_rhand_pose = []
            smplx_jaw_pose = []
            for p in frame_select:
                out = smpl_dict.item()[p]
                smplx_root_pose.append(out["smplx_root_pose"])
                smplx_body_pose.append(out["smplx_body_pose"])
                smplx_lhand_pose.append(out["smplx_lhand_pose"])
                smplx_rhand_pose.append(out["smplx_rhand_pose"])
                smplx_jaw_pose.append(out["smplx_jaw_pose"])

            smplx_root_pose = torch.stack(smplx_root_pose)
            smplx_body_pose = torch.stack(smplx_body_pose)
            smplx_lhand_pose = torch.stack(smplx_lhand_pose)
            smplx_rhand_pose = torch.stack(smplx_rhand_pose)
            smplx_jaw_pose = torch.stack(smplx_jaw_pose)
            
            data = {"smplx_root_pose": smplx_root_pose, "smplx_body_pose": smplx_body_pose, "smplx_lhand_pose": smplx_lhand_pose, "smplx_rhand_pose": smplx_rhand_pose, "smplx_jaw_pose": smplx_jaw_pose}
            data["text"] = "None"
            data["video"] = videoego
            data["path"] = path
            return data
        except Exception as e:
            index = random.randint(0,  len(self.videoexo_paths_true)-1)
            data = self.__getitemegochard__(index)
            return data
        
    def __getitemego__(self, data_id):
        smpl_path = self.dataset_smpl_ego[data_id]
        # text = self.text[data_id]
        text = "None"
        path = smpl_path.replace(".npy", "").replace("./egocentric/meshsave/", self.dataset_egoexo4d)
        if "cam" in smpl_path:
            path = path.replace("cam", "/frame_aligned_videos/cam").split("cam")[0]
        elif "gp" in smpl_path:
            path = path.replace("gp", "/frame_aligned_videos/gp").split("gp")[0]

        files = [path+p for p in os.listdir(path) if "214" in p and "Calib" in p]    
        path = files[0] 
        try:
            videoego, frame_select = self.load_video(path)
            smpl_dict = np.load(smpl_path, allow_pickle=True)
            smplx_root_pose = []
            smplx_body_pose = []
            smplx_lhand_pose = []
            smplx_rhand_pose = []
            smplx_jaw_pose = []
            for p in frame_select:
                out = smpl_dict.item()[p]
                smplx_root_pose.append(out["smplx_root_pose"])
                smplx_body_pose.append(out["smplx_body_pose"])
                smplx_lhand_pose.append(out["smplx_lhand_pose"])
                smplx_rhand_pose.append(out["smplx_rhand_pose"])
                smplx_jaw_pose.append(out["smplx_jaw_pose"])

            smplx_root_pose = torch.stack(smplx_root_pose)
            smplx_body_pose = torch.stack(smplx_body_pose)
            smplx_lhand_pose = torch.stack(smplx_lhand_pose)
            smplx_rhand_pose = torch.stack(smplx_rhand_pose)
            smplx_jaw_pose = torch.stack(smplx_jaw_pose)
            
            data = {"smplx_root_pose": smplx_root_pose, "smplx_body_pose": smplx_body_pose, "smplx_lhand_pose": smplx_lhand_pose, "smplx_rhand_pose": smplx_rhand_pose, "smplx_jaw_pose": smplx_jaw_pose}
            data["text"] = text
            data["video"] = videoego
            data["path"] = path
            return data
        
        except Exception as e:
            index = random.randint(0,  len(self.dataset_smpl_ego)-1)
            data = self.__getitemego__(index)
            return data
            
    def __len__(self):
        return 40000#len(self.dataset_smpl_ego)+len(self.videoexo_paths_true)+len(self.videos_key)


class LightningModelForDataProcess(pl.LightningModule):
    def __init__(self, text_encoder_path, vae_path, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
        super().__init__()
        model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
        model_manager.load_models([text_encoder_path, vae_path])
        self.pipe = WanVideoPipeline.from_model_manager(model_manager)

        self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
        
    def test_step(self, batch, batch_idx):
        text, video, path = batch["text"][0], batch["video"], batch["path"][0]
        self.pipe.device = self.device
        if video is not None:
            prompt_emb = self.pipe.encode_prompt(text)
            latents = self.pipe.encode_video(video, **self.tiler_kwargs)[0]
            data = {"latents": latents, "prompt_emb": prompt_emb}
            torch.save(data, path + ".tensors.pth")



class TensorDataset(torch.utils.data.Dataset):
    def __init__(self, base_path, metadata_path, steps_per_epoch):
        metadata = pd.read_csv(metadata_path)
        self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]]
        print(len(self.path), "videos in metadata.")
        self.path = [i + ".tensors.pth" for i in self.path if os.path.exists(i + ".tensors.pth")]
        print(len(self.path), "tensors cached in metadata.")
        assert len(self.path) > 0
        
        self.steps_per_epoch = steps_per_epoch


    def __getitem__(self, index):
        data_id = torch.randint(0, len(self.path), (1,))[0]
        data_id = (data_id + index) % len(self.path) # For fixed seed.
        path = self.path[data_id]
        data = torch.load(path, weights_only=True, map_location="cpu")
        return data
    

    def __len__(self):
        return self.steps_per_epoch



class LightningModelForTrain(pl.LightningModule):
    def __init__(self, dit_path, text_encoder_path, vae_path, point_map_encoder_path, smpl_encoders_paths, tiled=False, tile_size=(34, 34), tile_stride=(18, 16), learning_rate=1e-5, lora_rank=4, lora_alpha=4, train_architecture="lora", lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", use_gradient_checkpointing=True):
        super().__init__()
        
        model_manager = ModelManager(torch_dtype=torch.bfloat16, device='cpu')
        model_manager.load_models([dit_path, text_encoder_path, vae_path, point_map_encoder_path])
        # Assuming smpl_encoders_paths is a list of paths
        for path in smpl_encoders_paths:
            model_manager.load_models([path])
        
        self.pipe = WanVideoPipeline.from_model_manager(model_manager)
        self.pipe.new_cameraencoder(device="cpu")
        
        self.pipe.scheduler.set_timesteps(1000, training=True)
        self.freeze_parameters()
        if train_architecture == "lora":
            self.add_lora_to_model(
                self.pipe.denoising_model(),
                lora_rank=lora_rank,
                lora_alpha=lora_alpha,
                lora_target_modules=lora_target_modules,
                init_lora_weights=init_lora_weights,
            )
            self.pipe.camera_encoder.requires_grad_(True)
            for name, param in self.pipe.denoising_model().named_parameters():
                if "blocks.29" in name or "blocks.28" in name:
                    param.requires_grad = True
        else:
            self.pipe.denoising_model().requires_grad_(True)
        
        self.learning_rate = learning_rate
        self.use_gradient_checkpointing = use_gradient_checkpointing

        self.tiler_kwargs = {"tiled": False, "tile_size": tile_size, "tile_stride": tile_stride}
        
    def freeze_parameters(self):
        self.pipe.requires_grad_(False)
        self.pipe.eval()
        self.pipe.denoising_model().train()
        self.pipe.camera_encoder.train()
        
    def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming"):
        self.lora_alpha = lora_alpha
        if init_lora_weights == "kaiming":
            init_lora_weights = True
            
        lora_config = LoraConfig(
            r=lora_rank,
            lora_alpha=lora_alpha,
            init_lora_weights=init_lora_weights,
            target_modules=lora_target_modules.split(","),
        )
        model = inject_adapter_in_model(lora_config, model)
        for param in model.parameters():
            if param.requires_grad:
                param.data = param.to(torch.float32)
    
    def training_step(self, batch, batch_idx):
        device = next(self.pipe.vae.model.parameters()).device
        
        # Encode text
        prompt_emb = self.pipe.encode_prompt(batch["text"])
        
        # Encode camera parameters
        camera = batch["parameter_dict"].to(device)
        camera_prompt_emb = self.pipe.encode_camera(camera)
        
        # Combine text and camera embeddings
        prompt_emb["context"] = [torch.cat([prompt_emb["context"][p], camera_prompt_emb[p]], dim=0) for p in range(len(prompt_emb["context"]))]
        
        # Prepare video input
        batch["video"] = torch.cat([batch["video"][:,:,:1,:,:].repeat(1,1,4,1,1), batch["video"]], dim=2)
        
        # Encode video latents
        latents = self.pipe.encode_video(batch["video"].to(device), device=device, **self.tiler_kwargs)
        
        # Point cloud input
        point_map_latents = self.pipe.encode_point_map(batch["point_map"].to(device))
        point_map_latents_noised = self.pipe.scheduler.add_noise(point_map_latents, torch.randn_like(point_map_latents))
        
        # Combine latents
        combined_latents = torch.cat([latents, point_map_latents_noised], dim=2)
        
        # SMPL input handling
        smpl_parts = ['head', 'body', 'feet', 'hands']
        smpl_latents = []
        for part in smpl_parts:
            part_latents = self.pipe.encode_smpl(batch["smpl"][part].to(device))
            smpl_latents.append(part_latents)
        smpl_combined_latents = torch.cat(smpl_latents, dim=2)
        
        # Further combine with SMPL latents
        combined_latents = torch.cat([combined_latents, smpl_combined_latents], dim=2)
        
        # Generate noise for diffused latents
        noise = torch.randn_like(combined_latents)
        timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,))
        timestep = self.pipe.scheduler.timesteps[timestep_id].to(self.device)
        
        # Prepare extra input
        extra_input = self.pipe.prepare_extra_input(combined_latents)
        noisy_latents = self.pipe.scheduler.add_noise(combined_latents, noise, timestep)
        
        # Train the model
        training_target = self.pipe.scheduler.training_target(combined_latents, noise, timestep)
        with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
            noise_pred = self.pipe.denoising_model()(
                noisy_latents, timestep=timestep, **prompt_emb, **extra_input,
                use_gradient_checkpointing=self.use_gradient_checkpointing
            )
            
            # Calculate loss
            loss_video = torch.nn.functional.mse_loss(noise_pred.float()[:,:,1:1+frame_len,:,:], training_target.float()[:,:,1:1+frame_len,:,:])
            loss_point_map = torch.nn.functional.mse_loss(noise_pred.float()[:,:,1+frame_len:,:,:], training_target.float()[:,:,1+frame_len:,:,:])
            
            total_loss = loss_video + loss_point_map
            total_loss = total_loss * self.pipe.scheduler.training_weight(timestep)

        # Log loss
        self.log("train_loss", total_loss, prog_bar=True)
        return total_loss

    def configure_optimizers(self):
        trainable_modules = filter(lambda p: p.requires_grad, self.pipe.denoising_model().parameters())
        optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate)
        return optimizer
    
    def on_save_checkpoint(self, checkpoint):
        checkpoint.clear()
        trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.pipe.denoising_model().named_parameters()))
        trainable_param_names = set([named_param[0] for named_param in trainable_param_names])
        state_dict = self.pipe.denoising_model().state_dict()
        lora_state_dict = {}
        for name, param in state_dict.items():
            if name in trainable_param_names:
                lora_state_dict[name] = param
        checkpoint.update(lora_state_dict)

def parse_args():
    parser = argparse.ArgumentParser(description="Simple example of a training script.")
    parser.add_argument(
        "--task",
        type=str,
        default="data_process",
        required=True,
        choices=["data_process", "train"],
        help="Task. `data_process` or `train`.",
    )
    parser.add_argument(
        "--dataset_path",
        type=str,
        default=None,
        required=True,
        help="The path of the Dataset.",
    )
    parser.add_argument(
        "--output_path",
        type=str,
        default="./",
        help="Path to save the model.",
    )
    parser.add_argument(
        "--text_encoder_path",
        type=str,
        default="./compare/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
        help="Path of text encoder.",
    )
    parser.add_argument(
        "--vae_path",
        type=str,
        default="./compare/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
        help="Path of VAE.",
    )
    parser.add_argument(
        "--dit_path",
        type=str,
        default="./compare/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
        help="Path of DiT.",
    )
    parser.add_argument(
        "--tiled",
        default=False,
        action="store_true",
        help="Whether enable tile encode in VAE. This option can reduce VRAM required.",
    )
    parser.add_argument(
        "--tile_size_height",
        type=int,
        default=34,
        help="Tile size (height) in VAE.",
    )
    parser.add_argument(
        "--tile_size_width",
        type=int,
        default=34,
        help="Tile size (width) in VAE.",
    )
    parser.add_argument(
        "--tile_stride_height",
        type=int,
        default=18,
        help="Tile stride (height) in VAE.",
    )
    parser.add_argument(
        "--tile_stride_width",
        type=int,
        default=16,
        help="Tile stride (width) in VAE.",
    )
    parser.add_argument(
        "--steps_per_epoch",
        type=int,
        default=500,
        help="Number of steps per epoch.",
    )
    parser.add_argument(
        "--num_frames",
        type=int,
        default=29,
        help="Number of frames.",
    )
    parser.add_argument(
        "--height",
        type=int,
        default=480,
        help="Image height.",
    )
    parser.add_argument(
        "--width",
        type=int,
        default=480,
        help="Image width.",
    )
    parser.add_argument(
        "--dataloader_num_workers",
        type=int,
        default=2,
        help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
    )
    parser.add_argument(
        "--learning_rate",
        type=float,
        default=9e-6,
        help="Learning rate.",
    )
    parser.add_argument(
        "--accumulate_grad_batches",
        type=int,
        default=1,
        help="The number of batches in gradient accumulation.",
    )
    parser.add_argument(
        "--max_epochs",
        type=int,
        default=1,
        help="Number of epochs.",
    )
    parser.add_argument(
        "--lora_target_modules",
        type=str,
        default="q,k,v,o,ffn.0,ffn.2",
        help="Layers with LoRA modules.",
    )
    parser.add_argument(
        "--init_lora_weights",
        type=str,
        default="kaiming",
        choices=["gaussian", "kaiming"],
        help="The initializing method of LoRA weight.",
    )
    parser.add_argument(
        "--training_strategy",
        type=str,
        default="auto",
        choices=["auto", "deepspeed_stage_1", "deepspeed_stage_2", "deepspeed_stage_3"],
        help="Training strategy",
    )
    parser.add_argument(
        "--lora_rank",
        type=int,
        default=128,
        help="The dimension of the LoRA update matrices.",
    )
    parser.add_argument(
        "--lora_alpha",
        type=float,
        default=4.0,
        help="The weight of the LoRA update matrices.",
    )
    parser.add_argument(
        "--use_gradient_checkpointing",
        default=False,
        action="store_true",
        help="Whether to use gradient checkpointing.",
    )
    parser.add_argument(
        "--train_architecture",
        type=str,
        default="lora",
        choices=["lora", "full"],
        help="Model structure to train. LoRA training or full training.",
    )
    parser.add_argument(
        "--use_swanlab",
        default=False,
        action="store_true",
        help="Whether to use SwanLab logger.",
    )
    parser.add_argument(
        "--swanlab_mode",
        default=None,
        help="SwanLab mode (cloud or local).",
    )
    args = parser.parse_args()
    return args


def data_process(args):
    dataset = TextVideoDataset(
        args.dataset_path,
        os.path.join(args.dataset_path, "metadata.csv"),
        max_num_frames=110,
        frame_interval=4,
        num_frames=args.num_frames,
        height=args.height,
        width=args.width
    )
    dataloader = torch.utils.data.DataLoader(
        dataset,
        shuffle=False,
        batch_size=1,
        num_workers=args.dataloader_num_workers
    )
    model = LightningModelForDataProcess(
        text_encoder_path=args.text_encoder_path,
        vae_path=args.vae_path,
        tiled=args.tiled,
        tile_size=(args.tile_size_height, args.tile_size_width),
        tile_stride=(args.tile_stride_height, args.tile_stride_width),
    )
    trainer = pl.Trainer(
        accelerator="gpu",
        devices="auto",
        default_root_dir=args.output_path,
    )
    trainer.test(model, dataloader)
    
    
def train(args):
    dataset = TextVideoDataset(
        args.dataset_path,
        os.path.join(args.dataset_path, "metadata.csv"),
        max_num_frames=110,
        frame_interval=3,
        num_frames=args.num_frames,
        height=args.height,
        width=args.width
    )
    
    dataloader = torch.utils.data.DataLoader(
        dataset,
        shuffle=True,
        batch_size=1,
        num_workers=args.dataloader_num_workers
    )
    
    model = LightningModelForTrain(
        dit_path=args.dit_path,
        text_encoder_path=args.text_encoder_path,
        vae_path=args.vae_path,
        learning_rate=args.learning_rate,
        train_architecture=args.train_architecture,
        lora_rank=args.lora_rank,
        lora_alpha=args.lora_alpha,
        lora_target_modules=args.lora_target_modules,
        init_lora_weights=args.init_lora_weights,
        use_gradient_checkpointing=args.use_gradient_checkpointing,
        tiled=args.tiled,
        tile_size=(args.tile_size_height, args.tile_size_width),
        tile_stride=(args.tile_stride_height, args.tile_stride_width),
    )
    
    
    output_path = './checkpoints/ours_thirdtofirst/'
    if len(os.listdir(output_path))>0:
        file_path = get_latest_file(output_path)
        dict_save = torch.load(file_path,map_location='cpu')
        model.load_state_dict(dict_save['state_dict'])
        
    
    if args.use_swanlab:
        from swanlab.integration.pytorch_lightning import SwanLabLogger
        swanlab_config = {"UPPERFRAMEWORK": "DiffSynth-Studio"}
        swanlab_config.update(vars(args))
        swanlab_logger = SwanLabLogger(
            project="wan", 
            name="wan",
            config=swanlab_config,
            mode=args.swanlab_mode,
            logdir=output_path,
        )
        logger = [swanlab_logger]
    else:
        logger = []
        
    checkpoint_callback = pl.pytorch.callbacks.ModelCheckpoint(
        dirpath=output_path,  # 保存路径
        filename='model-{epoch}-{step}',  # 文件名格式
        save_top_k=-1,  # 保存所有检查点
        every_n_train_steps=200,  # 每隔 100 个训练 step 保存一次
        save_last=True  # 保存最后一个检查点
    )    
        
        
    trainer = pl.Trainer(
        max_epochs=args.max_epochs,
        accelerator="gpu",
        devices="auto",
        strategy=args.training_strategy,
        default_root_dir=args.output_path,
        accumulate_grad_batches=args.accumulate_grad_batches,
        # callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)],
        callbacks=[checkpoint_callback],
        logger=logger,
    )
    trainer.fit(model, dataloader)


if __name__ == '__main__':
    args = parse_args()
    if args.task == "data_process":
        data_process(args)
    elif args.task == "train":
        train(args)



