import torch, os, imageio, argparse
import shutil
from torchvision.transforms import v2
from einops import rearrange
import lightning as pl
import pandas as pd
from diffsynth import WanVideoPipeline, ModelManager, load_state_dict
from peft import LoraConfig, inject_adapter_in_model
import torchvision
from PIL import Image
import numpy as np
import json
import datetime
from pathlib import Path
from tqdm import tqdm
class WristConditionDataset(torch.utils.data.Dataset):
    """New dataset class loading wrist_rgb condition data from .tensors.pth files"""
    def __init__(self, tensors_dir, steps_per_epoch, file_list=None, return_filename: bool = False):
        self.tensors_dir = Path(tensors_dir)
        self.steps_per_epoch = steps_per_epoch
        self.return_filename = return_filename
        
        # Discover all .tensors.pth files or use specific list
        if file_list is not None:
            self.tensor_files = [Path(p) for p in file_list]
        else:
            self.tensor_files = list(self.tensors_dir.glob("*.tensors.pth"))
        print(f"Found {len(self.tensor_files)} tensor files in {tensors_dir}")
        assert len(self.tensor_files) > 0, f"No .tensors.pth files found in {tensors_dir}"
        
        # Validate data structure
        self._validate_data_structure()
    
    def _detach_all(self, obj):
        if torch.is_tensor(obj):
            return obj.detach().requires_grad_(False)
        if isinstance(obj, dict):
            return {k: self._detach_all(v) for k, v in obj.items()}
        if isinstance(obj, (list, tuple)):
            t = [self._detach_all(v) for v in obj]
            return type(obj)(t) if not isinstance(obj, list) else t
        return obj
        
    def _validate_data_structure(self):
        """Validate integrity of data structure"""
        sample_data = torch.load(self.tensor_files[0], weights_only=True, map_location="cpu")
        required_keys = ["latents", "prompt_emb", "image_emb", "ext_frame_feats"]
        missing_keys = [key for key in required_keys if key not in sample_data]
        if missing_keys:
            raise ValueError(f"Missing required keys in tensor data: {missing_keys}")
            
        # Validate image_emb structure
        image_emb = sample_data["image_emb"]
        required_img_keys = ["clip_feature", "y_wrist16", "control_latents"]
        missing_img_keys = [key for key in required_img_keys if key not in image_emb]
        if missing_img_keys:
            raise ValueError(f"Missing required keys in image_emb: {missing_img_keys}")
            
        print("Data structure validation passed")
        print(f" - latents shape: {tuple(sample_data['latents'].shape)}")
        print(f" - ext_frame_feats shape: {tuple(sample_data['ext_frame_feats'].shape)}")
        print(f" - prompt_emb.context shape: {tuple(sample_data['prompt_emb']['context'].shape)}")

    def __getitem__(self, index):
        # Randomly pick a tensor file
        data_id = torch.randint(0, len(self.tensor_files), (1,))[0]
        data_id = (data_id + index) % len(self.tensor_files)  # For fixed seed
        
        tensor_path = self.tensor_files[data_id]
        data = torch.load(tensor_path, weights_only=True, map_location="cpu")
        
        # Ensure no tensor carries grad history into DataLoader collate
        data = self._detach_all(data)
        
        if self.return_filename:
            return data, tensor_path.name
        return data

    def __len__(self):
        return self.steps_per_epoch



class LightningModelForTrain(pl.LightningModule):
    def __init__(
        self,
        dit_path,
        learning_rate=1e-5,
        lora_rank=4, lora_alpha=4, train_architecture="lora", lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming",
        use_gradient_checkpointing=True, use_gradient_checkpointing_offload=False,
        pretrained_lora_path=None,
        max_frames_for_ext=200,  # for temporal embedding capacity
        ext_feat_dim=1280,
        text_dim=4096,
        ext_proj_hidden=2048,
        dropout=0.1,
        # validation/inference configs
        vae_path=None,
        tiled=False,
        tile_size_height=34,
        tile_size_width=34,
        tile_stride_height=18,
        tile_stride_width=16,
        num_inference_steps=50,
        use_first_frame_guide=False,
    ):
        super().__init__()
        model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")

        # Load base models
        model_manager.load_models(
            ["XXX/models--Wan-AI--Wan2.1-I2V-14B-480P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"],
            torch_dtype=torch.float32,
        )

        model_manager.load_models(
            [
                "XXX/models--Wan-AI--Wan2.1-T2V-1.3B/snapshots/37ec512624d61f7aa208f7ea8140a131f93afc9a/diffusion_pytorch_model.safetensors",
                "XXX/models--Wan-AI--Wan2.1-T2V-1.3B/snapshots/37ec512624d61f7aa208f7ea8140a131f93afc9a/models_t5_umt5-xxl-enc-bf16.pth",
                "XXX/models--Wan-AI--Wan2.1-T2V-1.3B/snapshots/37ec512624d61f7aa208f7ea8140a131f93afc9a/Wan2.1_VAE.pth",
            ],
            torch_dtype=torch.bfloat16,
        )
        checkpoint_path = dit_path
        # Fine-tuned checkpoint
        if checkpoint_path:
            checkpoint_file = Path(checkpoint_path) / "checkpoint" / "mp_rank_00_model_states.pt"
            if not checkpoint_file.exists():
                raise FileNotFoundError(f"model file not exist: {checkpoint_file}")
            try:
                state_dict = torch.load(str(checkpoint_file), map_location="cpu")
                dit_model = model_manager.fetch_model("wan_video_dit")
                if dit_model is not None:
                    dit_model.load_state_dict(state_dict, strict=False)
            except Exception as e:
                import traceback
                traceback.print_exc()
                raise e

        # pipe = WanVideoPipeline.from_model_manager(, torch_dtype=torch.bfloat16, device=device)
        # pipe.enable_vram_management(num_persistent_param_in_dit=60 * 10**9)
        # if os.path.isfile(dit_path):
        #     model_manager.load_models([dit_path])
        # else:
        #     dit_path = dit_path.split(",")
        #     model_manager.load_models([dit_path])
        # # load VAE for decoding during validation
        # if vae_path is not None:
        #     model_manager.load_models([vae_path])
        
        self.pipe = WanVideoPipeline.from_model_manager(model_manager,torch_dtype=torch.bfloat16)
        self.pipe.enable_vram_management(num_persistent_param_in_dit=60 * 10**9)
        self.pipe.scheduler.set_timesteps(1000, training=True)
        self.training_num_timesteps = 1000
        
        # Initialize learnable embeddings and projection for pseudo text tokens
        self.max_frames_for_ext = max_frames_for_ext
        self.ext_feat_dim = ext_feat_dim
        self.text_dim = text_dim
        self.temporal_embed = torch.nn.Embedding(self.max_frames_for_ext, self.ext_feat_dim)
        self.view_embed = torch.nn.Embedding(2, self.ext_feat_dim)
        self.ext_proj = torch.nn.Sequential(
            torch.nn.Linear(self.ext_feat_dim, ext_proj_hidden),
            torch.nn.GELU(),
            torch.nn.Dropout(dropout),
            torch.nn.Linear(ext_proj_hidden, self.text_dim),
        )
        
        # Freeze base, then expand to 48 channels BEFORE LoRA
        self.freeze_parameters()
        self.ensure_in_channels_32_before_lora()
        if train_architecture == "lora":
            self.add_lora_to_model(
                self.pipe.denoising_model(),
                lora_rank=lora_rank,
                lora_alpha=lora_alpha,
                lora_target_modules=lora_target_modules,
                init_lora_weights=init_lora_weights,
                pretrained_lora_path=pretrained_lora_path,
            )
        else:
            ...
            # Due to previous issues, some new parameters (e.g., ext_proj) were not saved; train them separately
            self.pipe.denoising_model().requires_grad_(True)
        # Lora pretrained lora weights
        if pretrained_lora_path is not None:

            if ',' in pretrained_lora_path:
                pretrained_lora_path = pretrained_lora_path.split(",")
                state_dict = load_state_dict(pretrained_lora_path[1])
                assert 'sign' in state_dict
                dit_state_dict = load_state_dict(pretrained_lora_path[0])
                loaded_keys=["denoising_model","ext_proj","temporal_embed","view_embed"]
                final_state_dict = {key:{} for key in loaded_keys}
                for name,param in state_dict.items():
                    if '.' in name and name.split(".")[0] in loaded_keys:
                        final_state_dict[name.split(".")[0]][name[len(name.split(".")[0])+1:]] = param
                final_state_dict['denoising_model'].update(dit_state_dict)
                missing_keys, unexpected_keys = self.pipe.denoising_model().load_state_dict(final_state_dict['denoising_model'], strict=True)
                all_keys = [i for i, _ in self.pipe.denoising_model().named_parameters()]
                num_updated_keys = len(all_keys) - len(missing_keys)
                num_unexpected_keys = len(unexpected_keys)
                self.ext_proj.load_state_dict(final_state_dict['ext_proj'])
                self.temporal_embed.load_state_dict(final_state_dict['temporal_embed'])
                self.view_embed.load_state_dict(final_state_dict['view_embed'])
                print(f"{num_updated_keys} parameters are loaded from {pretrained_lora_path}. {num_unexpected_keys} parameters are unexpected.")
            else:
                state_dict = load_state_dict(pretrained_lora_path)
                if 'sign' in state_dict:
                    loaded_keys=["denoising_model","ext_proj","temporal_embed","view_embed"]
                    final_state_dict = {key:{} for key in loaded_keys}
                    print(state_dict.keys())
                    for name,param in state_dict.items():
                        if '.' in name and name.split(".")[0] in loaded_keys:
                            final_state_dict[name.split(".")[0]][name[len(name.split(".")[0])+1:]] = param
                    missing_keys, unexpected_keys = self.pipe.denoising_model().load_state_dict(final_state_dict['denoising_model'], strict=True)
                    all_keys = [i for i, _ in self.pipe.denoising_model().named_parameters()]
                    num_updated_keys = len(all_keys) - len(missing_keys)
                    num_unexpected_keys = len(unexpected_keys)
                    self.ext_proj.load_state_dict(final_state_dict['ext_proj'])
                    self.temporal_embed.load_state_dict(final_state_dict['temporal_embed'])
                    self.view_embed.load_state_dict(final_state_dict['view_embed'])
                    print(f"{num_updated_keys} parameters are loaded from {pretrained_lora_path}. {num_unexpected_keys} parameters are unexpected.")
                else:
                    missing_keys, unexpected_keys = self.pipe.denoising_model().load_state_dict(state_dict, strict=False)
                    all_keys = [i for i, _ in self.pipe.denoising_model().named_parameters()]
                    num_updated_keys = len(all_keys) - len(missing_keys)
                    num_unexpected_keys = len(unexpected_keys)
                    print(f"{num_updated_keys} parameters are loaded from {pretrained_lora_path}. {num_unexpected_keys} parameters are unexpected.")
        # Re-enable gradient on the expanded input layer specifically, after LoRA injection
        self.enable_patch_embedding_grad()
        
        self.learning_rate = learning_rate
        self.use_gradient_checkpointing = use_gradient_checkpointing
        self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
        
        # validation/inference config
        self.val_num_inference_steps = num_inference_steps
        self.val_tiled = tiled
        self.val_tile_size = (tile_size_height, tile_size_width)
        self.val_tile_stride = (tile_stride_height, tile_stride_width)
        # validation sample counter (reset each validation epoch)
        self.val_sample_counter = 0
        self.use_first_frame_guide = use_first_frame_guide
    
    def freeze_parameters(self):
        # Freeze parameters
        self.pipe.requires_grad_(False)
        self.pipe.eval()
        self.pipe.denoising_model().train()
    
    def ensure_in_channels_32_before_lora(self,model=None):
        """Ensure model supports 32 input channels (x:16 + y:32) before LoRA injection."""
        if model is None:
            model = self.pipe.denoising_model()
        if hasattr(model, 'patch_embedding'):
            if hasattr(model.patch_embedding, 'module'):
                patch_emb = model.patch_embedding.module
            else:
                patch_emb = model.patch_embedding
            current_in_dim = patch_emb.weight.shape[1]
            if current_in_dim != 32:
                print(f"[INFO] Expanding patch_embedding input channels from {current_in_dim} to 32 before LoRA")
                original_weight = patch_emb.weight.data
                out_dims = list(original_weight.shape)
                out_channels = out_dims[0]
                in_channels = out_dims[1]
                remaining_shape = out_dims[2:]
                new_weight = torch.zeros((out_channels, 32, *remaining_shape), dtype=original_weight.dtype, device=original_weight.device)
                new_weight[:, :in_channels] = original_weight
                torch.nn.init.normal_(new_weight[:, in_channels:], mean=0.0, std=0.02)
                patch_emb.weight.data = new_weight
            else:
                print(f"[INFO] patch_embedding input channels already 32")
        else:
            print("[WARN] denoising model has no patch_embedding; please verify architecture")
    
    def enable_patch_embedding_grad(self):
        model = self.pipe.denoising_model()
        if hasattr(model, 'patch_embedding'):
            if hasattr(model.patch_embedding.module, 'weight'):
                model.patch_embedding.module.weight.requires_grad_(True)
                print("[INFO] Enabled gradient for patch_embedding.weight")
    
    def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", pretrained_lora_path=None, state_dict_converter=None):
        # Add LoRA to UNet
        self.lora_alpha = lora_alpha
        if init_lora_weights == "kaiming":
            init_lora_weights = True
            
        lora_config = LoraConfig(
            r=lora_rank,
            lora_alpha=lora_alpha,
            init_lora_weights=init_lora_weights,
            target_modules=lora_target_modules.split(","),
        )
        model = inject_adapter_in_model(lora_config, model)
        for param in model.parameters():
            # Upcast LoRA parameters into fp32
            if param.requires_grad:
                param.data = param.to(torch.float32)
                
        # Lora pretrained lora weights
        if pretrained_lora_path is not None:
            state_dict = load_state_dict(pretrained_lora_path)
            if state_dict_converter is not None:
                state_dict = state_dict_converter(state_dict)
            missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
            all_keys = [i for i, _ in model.named_parameters()]
            num_updated_keys = len(all_keys) - len(missing_keys)
            num_unexpected_keys = len(unexpected_keys)
            print(f"{num_updated_keys} parameters are loaded from {pretrained_lora_path}. {num_unexpected_keys} parameters are unexpected.")

    def training_step(self, batch, batch_idx):
        # Extract data
        latents = batch["latents"].squeeze(0).squeeze(1).to(self.device) # 1, 16, 24, 90, 160
        prompt_emb = batch["prompt_emb"]
        prompt_emb["context"] = prompt_emb["context"].to(self.device) # torch.Size([1, 512, 4096])
        image_emb = batch["image_emb"]
        ext_frame_feats = batch["ext_frame_feats"].to(self.device)  # (2, T, 1280)

        # Deterministic first-frame guidance controlled by flag
        # if self.use_first_frame_guide:
        #     clip_feature = image_emb["clip_feature"].to(self.device)
        #     y_wrist16 = image_emb["y_wrist16"].to(self.device)
        # else:
        #     clip_feature = torch.zeros_like(image_emb["clip_feature"]).to(self.device)
        #     y_wrist16 = torch.zeros_like(image_emb["y_wrist16"]).to(self.device)

        # Assemble y = [control_latents, y_wrist16] along channel dim
        control_latents = image_emb["control_latents"].squeeze(1).to(self.device)
        # y = torch.cat([control_latents.squeeze(0), y_wrist16.squeeze(0)], dim=1)  # (B, 32, F, H, W)
        # Pure y: torch.Size([1, 20, 21, 60, 104])
        # Build pseudo text tokens from ext features using learnable embeddings
        # ext_frame_feats can be (2, T, 1280) or (B, 2, T, 1280) after collate
        ext = ext_frame_feats
        if ext.dim() == 3:
            ext = ext.unsqueeze(0)
        B, V, T, C = ext.shape
        assert V == 2 and C == self.ext_feat_dim, f"ext_frame_feats shape expected (*,2,T,{self.ext_feat_dim}), got {tuple(ext.shape)}"
        assert T <= self.max_frames_for_ext, f"T={T} exceeds max_frames_for_ext={self.max_frames_for_ext}"
        temporal_ids = torch.arange(T, device=self.device)
        view_ids = torch.arange(V, device=self.device)
        temporal_pe = self.temporal_embed(temporal_ids).view(1, 1, T, C)
        view_pe = self.view_embed(view_ids).view(1, V, 1, C)
        enhanced_feats = ext + temporal_pe + view_pe                       # (B, 2, T, 1280)
        proj_in = enhanced_feats.reshape(B * V * T, C)
        proj_out = self.ext_proj(proj_in).reshape(B, V * T, self.text_dim)  # (B, 2T, 4096)

        # Concatenate to text context on token dimension
        text_context = prompt_emb["context"].squeeze(1)  # (B, L, 4096)
        if proj_out.dtype != text_context.dtype:
            proj_out = proj_out.to(text_context.dtype)
        context = torch.cat([text_context[:,:(512-proj_out.shape[1])], proj_out], dim=1)                       # (B, L+2T, 4096)
        prompt_emb["context"] = context

        # Final image_emb to model
        final_image_emb = {
            # "clip_feature": clip_feature.squeeze(0),
            # "y": y,
        }
        # print(clip_feature.shape,y.shape,context.shape,text_context.shape,proj_out.shape) 
        # torch.Size([1, 1, 257, 1280]) torch.Size([1, 2, 16, 24, 90, 160]) torch.Size([1, 698, 4096])
        # Loss
        self.pipe.device = self.device
        noise = torch.randn_like(latents)
        # ensure we sample a valid index within current scheduler timesteps
        num_steps = len(self.pipe.scheduler.timesteps)
        timestep_idx = torch.randint(0, num_steps, (1,), device=self.pipe.scheduler.timesteps.device)
        timestep = self.pipe.scheduler.timesteps[timestep_idx].to(dtype=self.pipe.torch_dtype, device=self.device)
        extra_input = self.pipe.prepare_extra_input(latents)
        noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep)
        training_target = self.pipe.scheduler.training_target(latents, noise, timestep)
        noisy_latents = torch.cat([noisy_latents,control_latents.squeeze(0)],dim=1)
        noise_pred = self.pipe.denoising_model()( 
            noisy_latents, timestep=timestep, **prompt_emb, **extra_input, **final_image_emb,
            use_gradient_checkpointing=self.use_gradient_checkpointing,
            use_gradient_checkpointing_offload=self.use_gradient_checkpointing_offload,
        )
        loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
        loss = loss * self.pipe.scheduler.training_weight(timestep)

        # Logs (English only)
        self.log("train_loss", loss, prog_bar=True, on_step=True, on_epoch=True, sync_dist=True)
        self.log("context_length", context.shape[1], prog_bar=True)
        return loss

    def configure_optimizers(self):
        # Collect trainable params: LoRA params (in denoising model), expanded input layer, and new embedding/projection modules
        trainable = list(filter(lambda p: p.requires_grad, self.pipe.denoising_model().parameters()))
        trainable += list(self.temporal_embed.parameters())
        # self.temporal_embed.requires_grad_(True)
        trainable += list(self.view_embed.parameters())
        # self.view_embed.requires_grad_(True)
        trainable += list(self.ext_proj.parameters())
        optimizer = torch.optim.AdamW(trainable, lr=self.learning_rate)
        return optimizer

    def on_save_checkpoint(self, checkpoint):
        checkpoint.clear()
        # Collect trainable (LoRA and expanded) parameter names
        trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.pipe.denoising_model().named_parameters()))
        trainable_param_names = set([named_param[0] for named_param in trainable_param_names])
        state_dict = self.pipe.denoising_model().state_dict()
        lora_state_dict = {}
        for name, param in state_dict.items():
            if name in trainable_param_names:
                lora_state_dict[name] = param
        tmp_state_dict = {"denoising_model":lora_state_dict,"ext_proj":self.ext_proj.state_dict(),"temporal_embed":self.temporal_embed.state_dict(),"view_embed":self.view_embed.state_dict()}
        final_state_dict = {}
        for name, param in tmp_state_dict.items():
            for subname,ts in param.items():
                final_state_dict[f"{name}.{subname}"] = ts   
        final_state_dict["sign"] = torch.zeros([1]) 
        from safetensors.torch import save_file
        timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        save_dir = os.path.join(self.trainer.checkpoint_callback.dirpath, "lora_weights")
        os.makedirs(save_dir, exist_ok=True)
        save_path = os.path.join(save_dir, f"lora_{timestamp}.safetensors")
        save_file(final_state_dict, save_path)
        print(f"Saved LoRA weights to {save_path}")
        
        # Update checkpoint with trainable params only
        checkpoint.update(lora_state_dict)


    def on_train_start(self):
        # Reset scheduler to training timesteps in case validation or other steps changed it
        self.pipe.scheduler.set_timesteps(self.training_num_timesteps, training=True)

    def on_train_epoch_start(self):
        # Ensure each epoch begins with correct training timesteps
        self.pipe.scheduler.set_timesteps(self.training_num_timesteps, training=True)

    def on_validation_epoch_start(self):
        # Reset validation sample counter at the start of each validation epoch
        self.val_sample_counter = 0

    @torch.no_grad()
    def validation_step(self, batch, batch_idx):
        # Generate video without first-frame guidance; provide ext/control like training
        self.pipe.device = self.device
        # Shapes
        latents_shape = batch["latents"].squeeze(0).shape  # (B,16,F,H,W)
        B, Cx, F, H, W = latents_shape
        assert Cx == 16, f"Expected x channels=16, got {Cx}"
        
        # Prompt context + pseudo tokens from ext
        prompt_emb = batch["prompt_emb"]
        prompt_emb["context"] = prompt_emb["context"].to(self.device)
        ext_frame_feats = batch["ext_frame_feats"].to(self.device)
        ext = ext_frame_feats
        if ext.dim() == 3:
            ext = ext.unsqueeze(0)
        Bx, V, T, C = ext.shape
        assert V == 2 and C == self.ext_feat_dim
        temporal_ids = torch.arange(T, device=self.device)
        view_ids = torch.arange(V, device=self.device)
        temporal_pe = self.temporal_embed(temporal_ids).view(1, 1, T, C)
        view_pe = self.view_embed(view_ids).view(1, V, 1, C)
        enhanced_feats = ext + temporal_pe + view_pe
        proj_in = enhanced_feats.reshape(Bx * V * T, C)
        proj_out = self.ext_proj(proj_in).reshape(Bx, V * T, self.text_dim)
        text_context = prompt_emb["context"].squeeze(1)
        if proj_out.dtype != text_context.dtype:
            proj_out = proj_out.to(text_context.dtype)
        context = torch.cat([text_context[:,:(512-proj_out.shape[1])], proj_out], dim=1)
        prompt_emb["context"] = context

        # Classifier-Free Guidance settings for validation
        cfg_scale = 5.0
        negative_prompt_text = "low quality, distorted, ugly, bad anatomy"
        # Load text encoder to encode negative prompt
        self.pipe.load_models_to_device(["text_encoder"])  # mirror wan_video.__call__ behavior
        prompt_emb_nega = self.pipe.encode_prompt(negative_prompt_text, positive=False)
        
        # Image conditions: controlled by flag
        image_emb = batch["image_emb"]
        control_latents = image_emb["control_latents"].to(self.device)
        # if self.use_first_frame_guide:
        #     y_wrist16 = image_emb["y_wrist16"].to(self.device)
        #     clip_feature = image_emb["clip_feature"].to(self.device)
        # else:
        #     y_wrist16 = torch.zeros_like(image_emb["y_wrist16"]).to(self.device)
        #     clip_feature = torch.zeros_like(image_emb["clip_feature"]).to(self.device)
        # y = torch.cat([control_latents.squeeze(0), y_wrist16.squeeze(0)], dim=1)
        final_image_emb = {}
        
        # Denoising
        self.pipe.scheduler.set_timesteps(self.val_num_inference_steps, denoising_strength=1.0, shift=5.0)
        latents = torch.randn((B, 16, F, H, W), device=self.device, dtype=self.pipe.torch_dtype)
        latents = torch.cat([latents,control_latents.squeeze(0)],dim=1)
        extra_input = self.pipe.prepare_extra_input(latents)
        for step_id, timestep in tqdm(enumerate(self.pipe.scheduler.timesteps)):
            timestep = timestep.unsqueeze(0).to(dtype=self.pipe.torch_dtype, device=self.device)
            # Positive branch
            noise_pred_posi = self.pipe.denoising_model()( 
                latents, timestep=timestep, **prompt_emb, **extra_input, **final_image_emb,
                use_gradient_checkpointing=False,
                use_gradient_checkpointing_offload=False,
            )
            # Negative branch
            noise_pred_nega = self.pipe.denoising_model()( 
                latents, timestep=timestep, **prompt_emb_nega, **extra_input, **final_image_emb,
                use_gradient_checkpointing=False,
                use_gradient_checkpointing_offload=False,
            )
            # CFG combine
            noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)

            latents[:,:16] = self.pipe.scheduler.step(noise_pred, self.pipe.scheduler.timesteps[step_id], latents[:,:16])
        
        # Decode
        self.pipe.load_models_to_device(['vae'])
        frames_tensor = self.pipe.decode_video(latents[:,:16], tiled=self.val_tiled, tile_size=self.val_tile_size, tile_stride=self.val_tile_stride)
        self.pipe.load_models_to_device([])
        frames = self.pipe.tensor2video(frames_tensor[0])
        # Build concatenated video with original condition and wrist_rgb videos (top-to-bottom)
        if "meta" not in batch or "paths" not in batch["meta"]:
            raise ValueError("'meta.paths' not found in validation batch; cannot build concatenated video.")
        paths = batch["meta"]["paths"]
        required_video_keys = ["condition", "wrist_rgb"]
        for k in required_video_keys:
            if k not in paths:
                raise ValueError(f"Missing '{k}' in meta.paths; cannot build concatenated video.")

        condition_path = paths["condition"]
        wrist_path = paths["wrist_rgb"]

        # Helper to read up to T frames from a video
        def read_video_frames(video_path, max_frames):
            if isinstance(video_path,list):
                video_path = video_path[0]
            reader = imageio.get_reader(video_path)
            frames_list = []
            try:
                for idx, fr in enumerate(reader):
                    frames_list.append(fr)
                    if len(frames_list) >= max_frames:
                        break
            finally:
                reader.close()
            return frames_list

        T = len(frames)
        gen_frames_pil = frames  # list of PIL.Image
        # Read original videos
        cond_frames_np = read_video_frames(condition_path, T)
        wrist_frames_np = read_video_frames(wrist_path, T)
        if len(cond_frames_np) == 0 or len(wrist_frames_np) == 0:
            raise ValueError("Failed to read frames from condition or wrist_rgb videos during validation.")
        # Align by the minimum available frame count to avoid index errors
        T_aligned = min(T, len(cond_frames_np), len(wrist_frames_np))

        # Target size from generated frames
        target_w, target_h = gen_frames_pil[0].size

        concatenated_frames = []
        for i in range(T_aligned):
            gen_im = gen_frames_pil[i].resize((target_w, target_h), Image.BILINEAR)
            cond_im = Image.fromarray(cond_frames_np[i]).resize((target_w, target_h), Image.BILINEAR)
            wrist_im = Image.fromarray(wrist_frames_np[i]).resize((target_w, target_h), Image.BILINEAR)
            # Vertical stack: condition (top), wrist_rgb (middle), generated (bottom)
            stacked = np.concatenate([
                np.array(cond_im),
                np.array(wrist_im),
                np.array(gen_im)
            ], axis=0)
            concatenated_frames.append(stacked)

        # Save concatenated video to disk with a global counter
        save_dir = os.path.join(self.trainer.default_root_dir, "val_videos")
        os.makedirs(save_dir, exist_ok=True)
        # Get current process rank (useful for distributed training)
        if hasattr(self.trainer, "strategy") and hasattr(self.trainer.strategy, "local_rank"):
            rank = self.trainer.strategy.local_rank
        elif hasattr(self.trainer, "global_rank"):
            rank = self.trainer.global_rank
        else:
            # Single-GPU / non-distributed fallback
            rank = 0
        save_path = os.path.join(save_dir, f"epoch{self.current_epoch:03d}_sample{self.val_sample_counter:02d}_rank{rank:02d}.mp4")
        imageio.mimsave(save_path, concatenated_frames, fps=8)
        # Increment counter after successful save
        self.val_sample_counter += 1
        
        # Log to wandb (if available)
        try:
            from pytorch_lightning.loggers import WandbLogger
            loggers = self.trainer.loggers if hasattr(self.trainer, 'loggers') else ([self.logger] if self.logger is not None else [])
            for lg in loggers:
                if isinstance(lg, WandbLogger):
                    import wandb
                    lg.experiment.log({
                        f"val_video/epoch_{self.current_epoch:03d}_sample_{batch_idx:02d}": wandb.Video(save_path, fps=8, format="mp4")
                    }, step=self.global_step)
        except Exception as e:
            print(f"[WARN] wandb video log failed: {e}")
        
        # Also log a scalar to ensure validation loop is visible
        self.log("val_generated", float(batch_idx), prog_bar=False)
        
        # Restore training timesteps after validation step
        self.pipe.scheduler.set_timesteps(self.training_num_timesteps, training=True)

    # 删除独立推理保存接口，避免在训练脚本内暴露推理导出入口


# 已移除推理文件列表加载函数


def parse_args():
    parser = argparse.ArgumentParser(description="Simple example of a training script.")
    parser.add_argument(
        "--task",
        type=str,
        default="train",
        required=True,
        choices=["train"],
        help="Task. `train`.",
    )
    parser.add_argument(
        "--dataset_path",
        type=str,
        default=None,
        required=True,
        help="The path of the Dataset. For train task, this should be the directory containing .tensors.pth files.",
    )
    parser.add_argument(
        "--output_path",
        type=str,
        default="./",
        help="Path to save the model.",
    )
    parser.add_argument(
        "--text_encoder_path",
        type=str,
        default=None,
        help="Path of text encoder.",
    )
    parser.add_argument(
        "--image_encoder_path",
        type=str,
        default=None,
        help="Path of image encoder.",
    )
    parser.add_argument(
        "--vae_path",
        type=str,
        default=None,
        help="Path of VAE.",
    )
    parser.add_argument(
        "--dit_path",
        type=str,
        default=None,
        help="Path of DiT.",
    )
    parser.add_argument(
        "--tiled",
        default=False,
        action="store_true",
        help="Whether enable tile encode in VAE. This option can reduce VRAM required.",
    )
    parser.add_argument(
        "--tile_size_height",
        type=int,
        default=34,
        help="Tile size (height) in VAE.",
    )
    parser.add_argument(
        "--tile_size_width",
        type=int,
        default=34,
        help="Tile size (width) in VAE.",
    )
    parser.add_argument(
        "--tile_stride_height",
        type=int,
        default=18,
        help="Tile stride (height) in VAE.",
    )
    parser.add_argument(
        "--tile_stride_width",
        type=int,
        default=16,
        help="Tile stride (width) in VAE.",
    )
    parser.add_argument(
        "--steps_per_epoch",
        type=int,
        default=2000,
        help="Number of steps per epoch.",
    )
    parser.add_argument(
        "--num_frames",
        type=int,
        default=17,
        help="Number of frames.",
    )
    parser.add_argument(
        "--height",
        type=int,
        default=480,
        help="Image height.",
    )
    parser.add_argument(
        "--width",
        type=int,
        default=832,
        help="Image width.",
    )
    parser.add_argument(
        "--dataloader_num_workers",
        type=int,
        default=1,
        help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
    )
    parser.add_argument(
        "--learning_rate",
        type=float,
        default=1e-5,
        help="Learning rate.",
    )
    parser.add_argument(
        "--accumulate_grad_batches",
        type=int,
        default=1,
        help="The number of batches in gradient accumulation.",
    )
    parser.add_argument(
        "--max_epochs",
        type=int,
        default=1,
        help="Number of epochs.",
    )
    parser.add_argument(
        "--lora_target_modules",
        type=str,
        default="q,k,v,o,ffn.0,ffn.2",
        help="Layers with LoRA modules.",
    )
    parser.add_argument(
        "--init_lora_weights",
        type=str,
        default="kaiming",
        choices=["gaussian", "kaiming"],
        help="The initializing method of LoRA weight.",
    )
    parser.add_argument(
        "--training_strategy",
        type=str,
        default="auto",
        choices=["auto", "deepspeed_stage_1", "deepspeed_stage_2", "deepspeed_stage_3"],
        help="Training strategy",
    )
    parser.add_argument(
        "--lora_rank",
        type=int,
        default=4,
        help="The dimension of the LoRA update matrices.",
    )
    parser.add_argument(
        "--lora_alpha",
        type=float,
        default=4.0,
        help="The weight of the LoRA update matrices.",
    )
    parser.add_argument(
        "--use_gradient_checkpointing",
        default=False,
        action="store_true",
        help="Whether to use gradient checkpointing.",
    )
    parser.add_argument(
        "--use_gradient_checkpointing_offload",
        default=False,
        action="store_true",
        help="Whether to use gradient checkpointing offload.",
    )
    parser.add_argument(
        "--train_architecture",
        type=str,
        default="lora",
        choices=["lora", "full"],
        help="Model structure to train. LoRA training or full training.",
    )
    parser.add_argument(
        "--pretrained_lora_path",
        type=str,
        default=None,
        help="Pretrained LoRA path. Required if the training is resumed.",
    )
    parser.add_argument(
        "--use_swanlab",
        default=False,
        action="store_true",
        help="Whether to use SwanLab logger.",
    )
    parser.add_argument(
        "--swanlab_mode",
        default=None,
        help="SwanLab mode (cloud or local).",
    )
    parser.add_argument(
        "--checkpoint_every_n_steps",
        type=int,
        default=2000,
        help="Save a checkpoint every N training steps (in addition to end of epoch).",
    )
    # validation / logging
    parser.add_argument(
        "--use_wandb",
        default=True,
        action="store_true",
        help="Use WandB logger (offline by default).",
    )
    parser.add_argument(
        "--val_num_samples",
        type=int,
        default=3,
        help="Number of validation samples (videos) per epoch.",
    )
    parser.add_argument(
        "--num_inference_steps",
        type=int,
        default=25,
        help="Number of denoising steps for validation generation.",
    )
    parser.add_argument(
        "--use-first-frame-guide",
        default=False,
        action="store_true",
        help="Use first-frame guidance in both training and validation/inference.",
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        default=1,
        help="Batch size.",
    )
    # 已移除独立推理相关参数（inference_only/inference_jsonl/inference_num）
    args = parser.parse_args()
    return args


# def data_process(args):
#     dataset = TextVideoDataset(
#         args.dataset_path,
#         os.path.join(args.dataset_path, "metadata/dense_prompt.jsonl"),
#         max_num_frames=args.num_frames,
#         frame_interval=1,
#         num_frames=args.num_frames,
#         height=args.height,
#         width=args.width,
#         is_i2v=args.image_encoder_path is not None,
#     )
#     dataloader = torch.utils.data.DataLoader(
#         dataset,
#         shuffle=False,
#         batch_size=1,
#         num_workers=args.dataloader_num_workers
#     )
#     model = LightningModelForDataProcess(
#         text_encoder_path=args.text_encoder_path,
#         image_encoder_path=args.image_encoder_path,
#         vae_path=args.vae_path,
#         tiled=args.tiled,
#         tile_size=(args.tile_size_height, args.tile_size_width),
#         tile_stride=(args.tile_stride_height, args.tile_stride_width),
#     )
#     trainer = pl.Trainer(
#         accelerator="gpu",
#         devices="auto",
#         default_root_dir=args.output_path,
#     )
#     trainer.test(model, dataloader)
    
    
def train(args):
    # 使用新的数据集类
    dataset = WristConditionDataset(
        args.dataset_path,  # 这里应该是tensors目录路径
        steps_per_epoch=args.steps_per_epoch,
    )
    # select validation files deterministically
    all_files = sorted([str(p) for p in dataset.tensor_files])
    val_files = all_files[:max(1, min(args.val_num_samples, len(all_files)))]
    val_dataset = WristConditionDataset(
        args.dataset_path,
        steps_per_epoch=len(val_files),
        file_list=val_files,
    )
    dataloader = torch.utils.data.DataLoader(
        dataset,
        shuffle=True,
        batch_size=args.batch_size,
        num_workers=args.dataloader_num_workers
    )
    val_dataloader = torch.utils.data.DataLoader(
        val_dataset,
        shuffle=True,
        batch_size=1,
        num_workers=0,
    )
    model = LightningModelForTrain(
        dit_path=args.dit_path,
        learning_rate=args.learning_rate,
        train_architecture=args.train_architecture,
        lora_rank=args.lora_rank,
        lora_alpha=args.lora_alpha,
        lora_target_modules=args.lora_target_modules,
        init_lora_weights=args.init_lora_weights,
        use_gradient_checkpointing=args.use_gradient_checkpointing,
        use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
        pretrained_lora_path=args.pretrained_lora_path,
        vae_path=args.vae_path,
        tiled=args.tiled,
        tile_size_height=args.tile_size_height,
        tile_size_width=args.tile_size_width,
        tile_stride_height=args.tile_stride_height,
        tile_stride_width=args.tile_stride_width,
        num_inference_steps=args.num_inference_steps,
        use_first_frame_guide=args.use_first_frame_guide,
    )
    # Loggers
    loggers = []
    if args.use_swanlab:
        from swanlab.integration.pytorch_lightning import SwanLabLogger
        swanlab_config = {"UPPERFRAMEWORK": "DiffSynth-Studio"}
        swanlab_config.update(vars(args))
        swanlab_logger = SwanLabLogger(
            project="wan", 
            name="wan",
            config=swanlab_config,
            mode=args.swanlab_mode,
            logdir=os.path.join(args.output_path, "swanlog"),
        )
        loggers.append(swanlab_logger)
    if args.use_wandb:
        os.environ.setdefault("WANDB_MODE", "offline")
        try:
            from pytorch_lightning.loggers import WandbLogger
            wandb_logger = WandbLogger(project="wan", name="wan_train", save_dir=os.path.join(args.output_path, "wandb"))
            loggers.append(wandb_logger)
        except Exception as e:
            print(f"[WARN] Failed to init WandB logger: {e}")

    callbacks = []
    checkpoint_callback = pl.pytorch.callbacks.ModelCheckpoint(
        dirpath=os.path.join(args.output_path, "checkpoints"),
        filename="wan-{epoch:02d}-{train_loss:.4f}",
        save_top_k=-1,
        every_n_train_steps=args.checkpoint_every_n_steps,
        save_on_train_epoch_end=True,
    )
    callbacks.append(checkpoint_callback)
    
    # trainer = pl.Trainer(
    #     max_epochs=args.max_epochs,
    #     accelerator="gpu",
    #     devices="auto",
    #     precision="bf16",
    #     strategy=args.training_strategy,
    #     default_root_dir=args.output_path,
    #     accumulate_grad_batches=args.accumulate_grad_batches,
    #     callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)],
    #     logger=logger,
    # )
    
    # 训练模式
        trainer = pl.Trainer(
            max_epochs=args.max_epochs,
            accelerator="gpu",
            devices="auto",
            precision="bf16",
            strategy=args.training_strategy,
            default_root_dir=args.output_path,
            accumulate_grad_batches=args.accumulate_grad_batches,
            callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)],
            logger=loggers if len(loggers) > 0 else None,
            check_val_every_n_epoch=5,
            log_every_n_steps=50,
            num_sanity_val_steps=1,
        )
        trainer.fit(model=model, train_dataloaders=dataloader, val_dataloaders=val_dataloader)


if __name__ == '__main__':
    args = parse_args()
    train(args)

"""
Usage example:

Train wrist_rgb conditional generation model:

python train_wan_t2v_vggt_condition.py \
    --task train \
    --dataset_path XXX/condition_dataset_out_tensors/wrist_rgb \
    --output_path ./outputs \
    --dit_path /path/to/your/dit_model.pth \
    --learning_rate 1e-5 \
    --max_epochs 10 \
    --steps_per_epoch 1000 \
    --lora_rank 4 \
    --lora_alpha 4 \
    --use_gradient_checkpointing \
    --accumulate_grad_batches 4

Notes:
1. --dataset_path should point to the directory containing .tensors.pth files
2. The model expands input channels to 48 (16 for x + 32 for y) before LoRA, and re-enables grad on that layer after LoRA
3. 80% probability to drop first-frame guidance online during training
4. ext1/ext2 features are converted to pseudo text tokens and concatenated to the text context
"""
