#!/usr/bin/env python
# coding=utf-8
import imageio
import os
import argparse
import contextlib
import gc
import logging
import math
import os
import pickle
import random
import shutil
import sys
from collections import deque
from pathlib import Path
import torchvision
import accelerate
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.state import AcceleratorState
from accelerate.utils import ProjectConfiguration, set_seed
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel
from diffusers.utils.torch_utils import is_compiled_module
from diffusers.utils import check_min_version
from einops import rearrange
from omegaconf import OmegaConf
from packaging import version
from PIL import Image
from torch.utils.data import RandomSampler
from tqdm.auto import tqdm
from torch.utils.tensorboard import SummaryWriter

# 确保项目根目录在系统路径中
current_file_path = os.path.abspath(__file__)
project_roots = [os.path.dirname(current_file_path), os.path.dirname(os.path.dirname(current_file_path)), 
                os.path.dirname(os.path.dirname(os.path.dirname(current_file_path)))]
for project_root in project_roots:
    if project_root not in sys.path:
        sys.path.insert(0, project_root)

from MoRe4D.models.wan_vae import AutoencoderKLWan
from prompt_tuning.models.resnet import VAEEncoderadaptor, VAEDecoderadaptor
from prompt_tuning.datasets.flow_dataset import SceneFlowDataset

logger = get_logger(__name__)


class LossTracker:
    def __init__(self, window_size=1000):
        self.window_size = window_size
        self.loss_history = deque(maxlen=window_size)
        self.total_loss = 0.0
        self.total_samples = 0
        
    def update(self, loss_value, batch_size=1):
        if torch.is_tensor(loss_value):
            loss_value = loss_value.item()
        
        self.loss_history.append(loss_value)
        self.total_loss += loss_value * batch_size
        self.total_samples += batch_size
    
    def get_global_average(self):
        if self.total_samples == 0:
            return 0.0
        return self.total_loss / self.total_samples
    
    def get_window_average(self):
        if len(self.loss_history) == 0:
            return 0.0
        return sum(self.loss_history) / len(self.loss_history)
    
    def get_window_std(self):
        if len(self.loss_history) <= 1:
            return 0.0
        
        mean_loss = self.get_window_average()
        variance = sum((loss - mean_loss) ** 2 for loss in self.loss_history) / len(self.loss_history)
        return math.sqrt(variance)


def should_skip_batch(loss_value, loss_tracker, args):
    if torch.is_tensor(loss_value):
        loss_value = loss_value.item()
    
    if not torch.isfinite(torch.tensor(loss_value)):
        logger.warning(f"Skip batch: loss is not finite (loss={loss_value})")
        return True
    
    if len(loss_tracker.loss_history) < args.loss_skip_min_samples:
        return False
    
    window_mean = loss_tracker.get_window_average()
    window_std = loss_tracker.get_window_std()
    
    if window_std < 1e-6:
        threshold = window_mean * args.loss_skip_multiplier
    else:
        threshold = window_mean + args.loss_skip_std_multiplier * window_std
    
    if loss_value > threshold:
        logger.warning(
            f"Skip batch: loss={loss_value:.6f} > threshold={threshold:.6f} "
            f"(mean={window_mean:.6f}, std={window_std:.6f})"
        )
        return True
    
    if loss_value > args.loss_skip_absolute_threshold:
        logger.warning(f"Skip batch: loss={loss_value:.6f} > absolute_threshold={args.loss_skip_absolute_threshold}")
        return True
    
    return False


def filter_kwargs(cls, kwargs):
    import inspect
    sig = inspect.signature(cls.__init__)
    valid_params = set(sig.parameters.keys()) - {'self', 'cls'}
    filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params}
    return filtered_kwargs


def collate_fn(examples):
    sample = {}
    for key in examples[0].keys():
        sample[key] = torch.cat([example[key] for example in examples], dim=0)
    return sample


@torch.no_grad()
def log_validation(encoder_prompt, decoder_prompt, vae, args, accelerator, weight_dtype, step, is_final_validation=False):
    logger.info("Running validation... ")

    if not is_final_validation:
        encoder_prompt = accelerator.unwrap_model(encoder_prompt)
        decoder_prompt = accelerator.unwrap_model(decoder_prompt)
        vae = accelerator.unwrap_model(vae)
    else:
        encoder_prompt = VAEEncoderadaptor()
        decoder_prompt = VAEDecoderadaptor()
        vae = AutoencoderKLWan.from_pretrained(args.vae_model_path)
        encoder_state_dict = torch.load(os.path.join(args.output_dir, "encoder_prompt", "pytorch_model.bin"))
        decoder_state_dict = torch.load(os.path.join(args.output_dir, "decoder_prompt", "pytorch_model.bin"))
        encoder_prompt.load_state_dict(encoder_state_dict)
        decoder_prompt.load_state_dict(decoder_state_dict)

    videos = []
    projected_videos = []  
    inference_ctx = contextlib.nullcontext() if is_final_validation else torch.autocast("cuda")
    H_ori, W_ori = [720, 960]
    H, W = [368, 512]
    
    if W_ori / W > H_ori / H: 
        fx = 1
        fy = W_ori / H_ori / (W / H)
    else:
        fy = 1
        fx = H_ori / W_ori / (H / W)
        
    for i, data_path in enumerate(args.validation_sceneflow):
        with open(data_path, "rb") as f:
            data = pickle.load(f)
        targets = torch.from_numpy(data["coords"]).to(accelerator.device, weight_dtype)
        n_frames = min(49, (len(data["coords"])//4)*4+1)
        
        if not args.normalize_track:
            targets = targets - targets[0:1, :, :]
        targets = targets.reshape(targets.shape[0], 384, 512, 3)[:n_frames, :, :, :].permute(3, 0, 1, 2).unsqueeze(0)
        
        if args.normalize_track:
            targets = targets / targets.abs().amax(dim=(1, 2, 3, 4), keepdim=True)
        elif args.normalize_track_first_frame:
            frame0 = targets[:, :, :n_frames, :, :].to(dtype=weight_dtype)[0, :, 0, :, :].clone()  # [3, H, W]
                    
            # Compute max and min for each channel (x, y, z)
            max_vals = frame0.view(3, -1).max(dim=1)[0]  # [3]
            min_vals = frame0.view(3, -1).min(dim=1)[0] # [3]
            diff = (max_vals - min_vals).max().repeat(3)  # [3]
            
            # Avoid division by zero
            diff[diff == 0] = 1.0
            
            targets = targets / diff.view(3, 1, 1, 1) # Normalize the first frame
        elif args.normalize_track_z:
            frame0 = targets[:, :, :n_frames, :, :].to(dtype=weight_dtype)[0, :, 0, :, :].clone()  # [3, H, W]

            frame0[2,:,:][torch.isnan(frame0[2,:,:])] = 1.0
            frame0[2,:,:][frame0[2,:,:]==0] = 1.0
            frame0[2,:,:][torch.isinf(frame0[2,:,:])] = 1.0
           
            current_x_norm = frame0[2,:,:] / fx
            current_y_norm = frame0[2,:,:] / fy
            targets[:, 0:1, :, :, :] = targets[:, 0:1, :, :, :] / current_x_norm
            targets[:, 1:2, :, :, :] = targets[:, 1:2, :, :, :] / current_y_norm
            targets[:, 2:3, :, :, :] = targets[:, 2:3, :, :, :] / frame0[2:3, :, :]
            
        with inference_ctx:
            pseudo_video = encoder_prompt(targets)
            pseudo_video = pseudo_video * 2 - 1
            posterior = vae.encode(pseudo_video).latent_dist
            latents = posterior.sample()
            recon_video = vae.decode(latents).sample
            reconstructions = decoder_prompt(recon_video)
        # print("pseudo_video", pseudo_video.shape, "recon_video", recon_video.shape, "reconstructions", reconstructions.shape)
        videos.append(torch.cat([pseudo_video.cpu(), recon_video.cpu(), reconstructions.cpu()], dim=-2))
        
        from project_utils import project
        from torch_scatter import scatter
        import numpy as np
        from PIL import Image
        
        
            
        intrinsic = torch.Tensor([
            [fx, 0, 0.5],
            [0, fy, 0.5],
            [0, 0, 1]
        ]).to(accelerator.device)
        
        extrinsic = torch.Tensor([
            [1, 0, 0, 0],
            [0, 1, 0, 0],
            [0, 0, 1, 0],
            [0, 0, 0, 1]
        ]).to(accelerator.device)
        
        original_coords = torch.from_numpy(data["coords"]).to(accelerator.device)
        
        recon_coords = reconstructions[0].permute(1, 2, 3, 0).cpu().numpy()  # [T, H, W, 3]
        if args.normalize_track_first_frame:
            recon_coords = recon_coords * diff.view(1, 1, 1, 3).float().cpu().numpy() + frame0.permute(1, 2, 0).unsqueeze(0).float().cpu().numpy()
        elif args.normalize_track_z:
            recon_coords[:, :, :, 0] = recon_coords[:, :, :, 0] * current_x_norm.float().cpu().numpy()
            recon_coords[:, :, :, 1] = recon_coords[:, :, :, 1] * current_y_norm.float().cpu().numpy()
            recon_coords[:, :, :, 2] = recon_coords[:, :, :, 2] * frame0[2:3, :, :].float().cpu().numpy()
            recon_coords += frame0.permute(1, 2, 0).unsqueeze(0).float().cpu().numpy()
        frames = []
        for i in range(n_frames):
            world_points = torch.from_numpy(recon_coords[i]).reshape(-1, 3).to(accelerator.device)
            
            predicted_2D, depth_2D = project(world_points, extrinsic, intrinsic)
            mask = (predicted_2D[..., 0] >= 0) * (predicted_2D[..., 0] <= 1) * \
                   (predicted_2D[..., 1] >= 0) * (predicted_2D[..., 1] <= 1) * (depth_2D >= 0)
            
            color_pc = torch.from_numpy(data["colors"]).to(accelerator.device)[mask, :]
            depth_2D = depth_2D[mask]
            idx_pc = predicted_2D[mask, :]
            idx_xy = (idx_pc[:, 0]*W).floor().clamp(0, W-1) * H + (idx_pc[:, 1]*H).floor().clamp(0, H-1)
            
            unique_indices, inverse_indices = torch.unique(idx_xy, return_inverse=True)
            min_depth = torch.ones_like(unique_indices, dtype=depth_2D.dtype)*depth_2D.max()
            min_depth.index_reduce_(0, inverse_indices, depth_2D, 'amin')
            mask_depth = (depth_2D == min_depth[inverse_indices])
            
            color_pc = color_pc[mask_depth, :]
            idx_xy = idx_xy[mask_depth]
            
            color_image = scatter(color_pc, idx_xy.long(), dim=0, reduce="mean")
            if len(color_image) < H*W:
                color_image = torch.cat([color_image, torch.zeros((H*W-len(color_image), 3), device=accelerator.device)], dim=0)
            
            color_image = color_image.reshape(W, H, 3).transpose(0, 1)
            
            image_proj = color_image.cpu().numpy().astype(np.uint8)
            image_proj = Image.fromarray(image_proj)
            
            if fy > 1:
                H_true = H_ori / W_ori * W
                image_proj = image_proj.crop((0, (H-H_true)//2, W, H-(H-H_true)//2))
            elif fx > 1:
                W_true = W_ori / H_ori * H
                image_proj = image_proj.crop(((W-W_true)//2, 0, W-(W-W_true)//2, H))
            
            frames.append(np.array(image_proj))
        
        projected_video = np.stack(frames, axis=0)
        projected_videos.append(torch.from_numpy(projected_video).permute(0, 3, 1, 2).unsqueeze(0) / 255.0)  # [1, T, C, H, W]

    tracker_key = "test" if is_final_validation else "validation"
    
    os.makedirs(os.path.join(args.output_dir, "sample"), exist_ok=True)
    
    for i, video in enumerate(videos):
        video = ((video[0].permute( 1, 0, 2, 3) + 1) / 2).clamp(0, 1).to(torch.float32)
        video_path = os.path.join(args.output_dir, "sample", f"{tracker_key}_video_{step}_{i}.mp4")
        
        video_np = (video.permute(0, 2, 3, 1) * 255).to(torch.uint8).cpu().numpy()
        imageio.mimwrite(video_path, video_np, fps=8)
    
    for i, video in enumerate(projected_videos):
        video = video[0].to(torch.float32)  # [B, T, C, H, W]
        video_path = os.path.join(args.output_dir, "sample", f"{tracker_key}_projected_video_{step}_{i}.mp4")
        video_np = (video.permute(0, 2, 3, 1)* 255).to(torch.uint8).cpu().numpy()
        imageio.mimwrite(video_path, video_np, fps=8)
    
    for tracker in accelerator.trackers:
        # if tracker.name == "tensorboard":
        #     tracker.writer.add_video(f"{tracker_key}_ori_recon", videos, step)
        #     tracker.writer.add_video(f"{tracker_key}_projected", projected_videos.permute(0, 1, 3, 4, 2), step)
        # elif tracker.name == "wandb":
        #     import wandb
        #     tracker.log(
        #         {
        #             f"{tracker_key}: Original, VAE, Reconstruction": [
        #                 wandb.Video(videos)
        #             ],
        #             f"{tracker_key}: Projected": [
        #                 wandb.Video(projected_videos.permute(0, 1, 3, 4, 2))
        #             ]
        #         }
        #     )
        # else:
        #     logger.warn(f"Video logging not implemented for {tracker.name}")

        gc.collect()

    return videos, projected_videos


def parse_args(input_args=None):
    parser = argparse.ArgumentParser(description="WAN VAE fine-tuning script.")
    
    parser.add_argument("--vae_model_path", type=str, required=True, help="Path to pretrained VAE model.")
    parser.add_argument("--data_root", type=str, required=True, help="Root directory of scene flow data.")
    parser.add_argument("--video_column", type=str, required=True, help="Path to video list file.")
    parser.add_argument("--data_posfix", type=str, default="_render", help="Postfix for data files.")
    parser.add_argument("--output_dir", type=str, default="outputs/wan_vae_finetuned", help="Output directory.")
    
    parser.add_argument("--finetune_vae_decoder", action="store_true", help="Whether to finetune the VAE decoder.")
    parser.add_argument("--normalize_track", action="store_true", help="Whether to normalize the track coordinates.")
    parser.add_argument("--normalize_track_first_frame", action="store_true", help="Whether to normalize the track coordinates.")
    parser.add_argument("--normalize_track_z", action="store_true", help="Whether to normalize the track coordinates.")
    
    parser.add_argument("--num_frames", type=int, default=49, help="Number of frames to use.")
    parser.add_argument("--loss_skip_min_samples", type=int, default=100, help="Minimum samples before enabling loss skipping.")
    parser.add_argument("--loss_skip_std_multiplier", type=float, default=6.0, help="Skip batch if loss > mean + std_multiplier * std.")
    parser.add_argument("--loss_skip_multiplier", type=float, default=10.0, help="Skip batch if loss > multiplier * mean (when std is small).")
    parser.add_argument("--loss_skip_absolute_threshold", type=float, default=1e7, help="Absolute threshold for skipping batches.")
    parser.add_argument("--loss_tracker_window_size", type=int, default=1000, help="Window size for loss tracking.")
    
    parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility.")
    parser.add_argument("--train_batch_size", type=int, default=4, help="Batch size for training.")
    parser.add_argument("--num_train_epochs", type=int, default=100)
    parser.add_argument("--max_train_steps", type=int, default=None,
                        help="Total number of training steps to perform. If provided, overrides num_train_epochs.")
    parser.add_argument("--gradient_accumulation_steps", type=int, default=1,
                        help="Number of updates steps to accumulate before performing a backward/update pass.")
    parser.add_argument("--learning_rate", type=float, default=4.5e-6, help="Initial learning rate.")
    parser.add_argument("--scale_lr", action="store_true", default=False,
                        help="Scale learning rate by the number of GPUs, gradient accumulation steps, and batch size.")
    parser.add_argument("--lr_scheduler", type=str, default="constant",
                        choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
                        help="The scheduler type to use.")
    parser.add_argument("--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler.")
    parser.add_argument("--adam_beta1", type=float, default=0.9)
    parser.add_argument("--adam_beta2", type=float, default=0.999)
    parser.add_argument("--adam_weight_decay", type=float, default=1e-2)
    parser.add_argument("--adam_epsilon", type=float, default=1e-8)
    parser.add_argument("--max_grad_norm", type=float, default=1.0)
    parser.add_argument("--gradient_checkpointing", action="store_true", help="Enable gradient checkpointing.")
    
    parser.add_argument("--validation_steps", type=int, default=500,
                        help="Run validation every X steps.")
    parser.add_argument("--checkpointing_steps", type=int, default=500,
                        help="Save a checkpoint of the training state every X updates.")
    parser.add_argument("--checkpoints_total_limit", type=int, default=None,
                        help="Max number of checkpoints to store.")
    parser.add_argument("--resume_from_checkpoint", type=str, default=None,
                        help="Whether training should be resumed from a previous checkpoint.")
    parser.add_argument("--mixed_precision", type=str, default=None, choices=["no", "fp16", "bf16"],
                        help="Whether to use mixed precision. Choose among fp16, bf16, no.")
    parser.add_argument("--dataloader_num_workers", type=int, default=0,
                        help="Number of subprocesses to use for data loading.")
    parser.add_argument("--use_8bit_adam", action="store_true", help="Whether to use 8-bit Adam from bitsandbytes.")
    parser.add_argument("--report_to", type=str, default="tensorboard", 
                        help="The integration to report the results and logs to.")
    parser.add_argument("--validation_sceneflow", type=str, default=None, nargs="+",
                        help="Path to validation scene flow data.")
    parser.add_argument("--rec_loss", type=str, default="l1", choices=["l1", "l2"], help="Reconstruction loss type.")
    parser.add_argument("--kl_scale", type=float, default=1e-6, help="Scale factor for KL divergence loss.")
    parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
    parser.add_argument("--enable_xformers_memory_efficient_attention", action="store_true", help="Whether to use xformers.")
    parser.add_argument("--set_grads_to_none", action="store_true",
                        help="Save memory by setting grads to None instead of zero.")
    parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
    parser.add_argument("--pretrained_model_path", type=str, default=None, help="Path to pretrained encoder/decoder models.")
    
    if input_args is not None:
        args = parser.parse_args(input_args)
    else:
        args = parser.parse_args()

    env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
    if env_local_rank != -1 and env_local_rank != args.local_rank:
        args.local_rank = env_local_rank

    return args


def main(args):
    logging_dir = Path(args.output_dir, "logs")
    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)

    accelerator = Accelerator(
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        mixed_precision=args.mixed_precision,
        log_with=args.report_to,
        project_config=accelerator_project_config,
    )

    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    logger.info(accelerator.state, main_process_only=False)

    if args.seed is not None:
        set_seed(args.seed)

    if accelerator.is_main_process:
        if args.output_dir is not None:
            os.makedirs(args.output_dir, exist_ok=True)
            os.makedirs(os.path.join(args.output_dir, "encoder_prompt"), exist_ok=True)
            os.makedirs(os.path.join(args.output_dir, "decoder_prompt"), exist_ok=True)
            os.makedirs(os.path.join(args.output_dir, "vae"), exist_ok=True)

    encoder_prompt = VAEEncoderadaptor()
    decoder_prompt = VAEDecoderadaptor()
    
    vae = AutoencoderKLWan.from_pretrained(
        args.vae_model_path,
        additional_kwargs={
            "latent_channels": 16,
            "temporal_compression_ratio": 4,
            "spatial_compression_ratio": 8
        }
    )

    encoder_prompt.requires_grad_(True)
    encoder_prompt.train()
    
    decoder_prompt.requires_grad_(True)
    decoder_prompt.train()
    
    vae.model.encoder.requires_grad_(False)
    vae.model.encoder.eval()
    
    if args.finetune_vae_decoder:
        vae.model.decoder.requires_grad_(True)
        vae.model.decoder.train()
    else:
        vae.model.decoder.requires_grad_(False)
        vae.model.decoder.eval()

    def save_model_hook(models, weights, output_dir):
        if accelerator.is_main_process:
            i = len(weights) - 1

            while len(weights) > 0:
                weights.pop()
                model = models[i]

                if isinstance(model, AutoencoderKLWan):
                    sub_dir = "vae"
                    if args.finetune_vae_decoder:
                        os.makedirs(os.path.join(output_dir, sub_dir), exist_ok=True)
                        torch.save(model.state_dict(), os.path.join(output_dir, sub_dir, "pytorch_model.bin"))
                elif isinstance(model, VAEEncoderadaptor):
                    sub_dir = "encoder_prompt"
                    os.makedirs(os.path.join(output_dir, sub_dir), exist_ok=True)
                    torch.save(model.state_dict(), os.path.join(output_dir, sub_dir, "pytorch_model.bin"))
                elif isinstance(model, VAEDecoderadaptor):
                    sub_dir = "decoder_prompt"
                    os.makedirs(os.path.join(output_dir, sub_dir), exist_ok=True)
                    torch.save(model.state_dict(), os.path.join(output_dir, sub_dir, "pytorch_model.bin"))
                else:
                    raise ValueError(f"Model type {type(model)} not recognized!")
                i -= 1

    def load_model_hook(models, input_dir):
        while len(models) > 0:
            model = models.pop()
            
            if isinstance(model, AutoencoderKLWan):
                if os.path.exists(os.path.join(input_dir, "vae", "pytorch_model.bin")):
                    vae_state_dict = torch.load(os.path.join(input_dir, "vae", "pytorch_model.bin"))
                    model.load_state_dict(vae_state_dict)
                    del vae_state_dict
                else:
                    pass
            elif isinstance(model, VAEDecoderadaptor):
                if os.path.exists(os.path.join(input_dir, "decoder_prompt", "pytorch_model.bin")):
                    decoder_state_dict = torch.load(os.path.join(input_dir, "decoder_prompt", "pytorch_model.bin"))
                    model.load_state_dict(decoder_state_dict)
                    del decoder_state_dict
            elif isinstance(model, VAEEncoderadaptor):
                if os.path.exists(os.path.join(input_dir, "encoder_prompt", "pytorch_model.bin")):
                    encoder_state_dict = torch.load(os.path.join(input_dir, "encoder_prompt", "pytorch_model.bin"))
                    model.load_state_dict(encoder_state_dict)
                    del encoder_state_dict

    accelerator.register_save_state_pre_hook(save_model_hook)
    accelerator.register_load_state_pre_hook(load_model_hook)

    if args.gradient_checkpointing:
        if hasattr(vae, "enable_gradient_checkpointing"):
            vae.enable_gradient_checkpointing()
        else:
            logger.warning("VAE model does not support gradient checkpointing.")

    if args.enable_xformers_memory_efficient_attention:
        if hasattr(vae, "enable_xformers_memory_efficient_attention"):
            vae.enable_xformers_memory_efficient_attention()
        else:
            logger.warning("VAE model does not support xformers optimization.")

    if args.use_8bit_adam:
        try:
            import bitsandbytes as bnb
        except ImportError:
            raise ImportError("Please install bitsandbytes to use 8-bit Adam.")
        optimizer_cls = bnb.optim.AdamW8bit
    else:
        optimizer_cls = torch.optim.AdamW

    params_to_optimize = [
        {"params": encoder_prompt.parameters(), "lr": args.learning_rate},
        {"params": decoder_prompt.parameters(), "lr": args.learning_rate},
    ]
    if args.finetune_vae_decoder:
        params_to_optimize.append({"params": vae.model.decoder.parameters(), "lr": args.learning_rate * 0.1})

    total_params = sum(p.numel() for p in encoder_prompt.parameters()) + sum(p.numel() for p in decoder_prompt.parameters())
    total_params += sum(p.numel() for p in vae.model.parameters())
    trainable_params = sum(p.numel() for p in encoder_prompt.parameters() if p.requires_grad) + \
        sum(p.numel() for p in decoder_prompt.parameters() if p.requires_grad)
    if args.finetune_vae_decoder:
        trainable_params += sum(p.numel() for p in vae.model.decoder.parameters() if p.requires_grad)
    logger.info(f"Total parameters: {total_params:,}")
    logger.info(f"Trainable parameters: {trainable_params:,} ({100 * trainable_params / total_params:.2f}%)")
    optimizer = optimizer_cls(
        params_to_optimize,
        lr=args.learning_rate,
        betas=(args.adam_beta1, args.adam_beta2),
        weight_decay=args.adam_weight_decay,
        eps=args.adam_epsilon,
    )

    train_dataset = SceneFlowDataset(
        data_root=args.data_root,
        video_column=args.video_column,
        posfix=args.data_posfix,
        max_frames=args.num_frames,
    )

    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        shuffle=True,
        collate_fn=collate_fn,
        batch_size=args.train_batch_size,
        num_workers=args.dataloader_num_workers,
    )
    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
    if args.max_train_steps is None:
        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
        overrode_max_train_steps = True
    else:
        overrode_max_train_steps = False

    lr_scheduler = get_scheduler(
        args.lr_scheduler,
        optimizer=optimizer,
        num_warmup_steps=args.lr_warmup_steps,
        num_training_steps=args.max_train_steps,
    )

    encoder_prompt, decoder_prompt, vae, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
        encoder_prompt, decoder_prompt, vae, optimizer, train_dataloader, lr_scheduler
    )

    weight_dtype = torch.float32
    if accelerator.mixed_precision == "fp16":
        weight_dtype = torch.float16
    elif accelerator.mixed_precision == "bf16":
        weight_dtype = torch.bfloat16

    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
    if overrode_max_train_steps:
        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)

    if accelerator.is_main_process:
        tracker_config = dict(vars(args))
        if "validation_sceneflow" in tracker_config:
            tracker_config.pop("validation_sceneflow")
        accelerator.init_trackers("wan_vae_finetuning", config=tracker_config)
        writer = SummaryWriter(log_dir=logging_dir)

    def unwrap_model(model):
        model = accelerator.unwrap_model(model)
        model = model._orig_mod if is_compiled_module(model) else model
        return model

    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps

    logger.info("***** Running training *****")
    logger.info(f"  Num examples = {len(train_dataset)}")
    logger.info(f"  Num batches each epoch = {len(train_dataloader)}")
    logger.info(f"  Num Epochs = {args.num_train_epochs}")
    logger.info(f"  Instantaneous batch size per device = {args.train_batch_size}")
    logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
    logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
    logger.info(f"  Total optimization steps = {args.max_train_steps}")

    global_step = 0
    first_epoch = 0
    
    if args.resume_from_checkpoint:
        if args.resume_from_checkpoint != "latest":
            path = os.path.basename(args.resume_from_checkpoint)
        else:
            dirs = os.listdir(args.output_dir)
            dirs = [d for d in dirs if d.startswith("checkpoint")]
            dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
            path = dirs[-1] if len(dirs) > 0 else None

        if path is None:
            accelerator.print(f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.")
            args.resume_from_checkpoint = None
        else:
            accelerator.print(f"Resuming from checkpoint {path}")
            accelerator.load_state(os.path.join(args.output_dir, path))
            global_step = int(path.split("-")[1])
            first_epoch = global_step // num_update_steps_per_epoch

    progress_bar = tqdm(
        range(0, args.max_train_steps),
        initial=global_step,
        desc="Steps",
        disable=not accelerator.is_local_main_process,
    )

    loss_tracker = LossTracker(window_size=args.loss_tracker_window_size)
    skipped_batches = 0
    unwrapped_vae = accelerator.unwrap_model(vae).model
    for epoch in range(first_epoch, args.num_train_epochs):
        encoder_prompt.train()
        decoder_prompt.train()
        unwrapped_vae.encoder.eval()
        
        if args.finetune_vae_decoder:
            unwrapped_vae.decoder.train()
        else:
            unwrapped_vae.decoder.eval()
        H_ori, W_ori = [720, 960]
        H, W = [368, 512]
        
        if W_ori / W > H_ori / H:  
            fx = 1
            fy = W_ori / H_ori / (W / H)
        else:
            fy = 1
            fx = H_ori / W_ori / (H / W)
            
        for step, batch in enumerate(train_dataloader):
            try:
                if args.normalize_track:
                    targets = batch["coords_normalized"].to(dtype=weight_dtype)  # B, 3, T, H, W
                elif args.normalize_track_first_frame:
                    flow = batch["coords"][:, :, :args.num_frames, :, :].to(dtype=weight_dtype)    # B, 3, T, H, W
                    targets = []
                    for b in range(flow.size(0)):
                        # Get xyz coordinates of frame 0
                        frame0 = flow[b, :, 0, :, :]  # [3, H, W]
                        
                        # Compute max and min for each channel (x, y, z)
                        max_vals = frame0.view(3, -1).max(dim=1)[0]  # [3]
                        min_vals = frame0.view(3, -1).min(dim=1)[0] # [3]
                        diff = (max_vals - min_vals).max().repeat(3)  # [3]
                        
                        # Avoid division by zero
                        diff[diff == 0] = 1.0
                        
                        targets.append(batch["coords_delta"][b, :, :args.num_frames, :, :].to(dtype=weight_dtype)   / diff.view(3,1, 1, 1))  # Normalize the first frame

                    targets =  torch.stack(targets, dim=0)  # B, 3, T, H, W
                elif args.normalize_track_z:
                    flow = batch["coords"][:, :, :args.num_frames, :, :].to(dtype=weight_dtype)    # B, 3, T, H, W
                    targets = []
                    for b in range(flow.size(0)):
                        frame0 = flow[b, :, 0, :, :].clone()  # [3, H, W]
                        frame0[2,:,:][torch.isnan(frame0[2,:,:])] = 1.0
                        frame0[2,:,:][frame0[2,:,:]==0] = 1
                        frame0[2,:,:][torch.isinf(frame0[2,:,:])] = 1.0
                        current_x_norm = frame0[2,:,:] / fx
                        current_y_norm = frame0[2,:,:] / fy
                        temp = batch["coords_delta"][b, :, :args.num_frames, :, :].to(dtype=weight_dtype)
                        temp[0:1, :, :, :] = temp[0:1, :, :, :] / current_x_norm
                        temp[1:2, :, :, :] = temp[1:2, :, :, :] / current_y_norm
                        temp[2:3, :, :, :] = temp[2:3, :, :, :] / frame0[2:3, :, :]
                        targets.append(temp)  # Normalize the first frame
                    targets =  torch.stack(targets, dim=0)  # B, 3, T, H, W
                else:
                    targets = batch["coords"][:, :, :args.num_frames, :, :].to(dtype=weight_dtype)  # B, 3, T, H, W
                    targets = targets - targets[:, :, 0:1, :, :]
                    
                
                with accelerator.accumulate([encoder_prompt, decoder_prompt, vae]):
                    
                    pseudo_video = accelerator.unwrap_model(encoder_prompt)(targets)  # B, 3, T, H, W
                    
                    pseudo_video = pseudo_video * 2 - 1
                    with torch.no_grad():
                        posterior = accelerator.unwrap_model(vae).encode_memory_saver(pseudo_video).latent_dist
                        latents = posterior.sample()
                        
                        del pseudo_video
                        torch.cuda.empty_cache()
                    
                    if not args.finetune_vae_decoder:
                        with torch.no_grad():
                            recon_video = accelerator.unwrap_model(vae).decode_memory_saver(latents).sample
                    else:
                        recon_video = accelerator.unwrap_model(vae).decode_memory_saver(latents).sample
                    del latents
                    torch.cuda.empty_cache()
                    reconstructions = accelerator.unwrap_model(decoder_prompt)(recon_video)
                    
                    # 计算重构损失
                    if args.rec_loss == "l2":
                        rec_loss = F.mse_loss(reconstructions.float(), targets.float(), reduction="none")
                    elif args.rec_loss == "l1":
                        rec_loss = F.l1_loss(reconstructions.float(), targets.float(), reduction="none")
                    else:
                        raise ValueError(f"Invalid reconstruction loss type: {args.rec_loss}")
                    
                    nll_loss = torch.sum(rec_loss) / rec_loss.shape[0]
                    kl_loss = posterior.kl()
                    kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
                    
                    loss = nll_loss + args.kl_scale * kl_loss
                    
                    if should_skip_batch(loss, loss_tracker, args):
                        skipped_batches += 1
                        accelerator.backward(loss)
                        with torch.no_grad():
                            optimizer.zero_grad(set_to_none=args.set_grads_to_none)
                            
                            logger.info(f"Skip batch {step}: loss={loss.detach().mean().item()}")
                            intermediates = [recon_video, reconstructions]
                            for tensor in intermediates:
                                if isinstance(tensor, torch.Tensor):
                                    tensor.detach_()
                                del tensor
                            torch.cuda.empty_cache()
                            gc.collect()
                        continue
                    
                    loss_tracker.update(loss, batch_size=targets.size(0))
                    
                    logs = {
                        "loss": loss.detach().mean().item(),
                        "nll_loss": nll_loss.detach().mean().item(),
                        "kl_loss": kl_loss.detach().mean().item() * args.kl_scale,
                        "lr": lr_scheduler.get_last_lr()[0],
                        "window_avg_loss": loss_tracker.get_window_average(),
                    }
                    
                    accelerator.backward(loss)
                    
                    if accelerator.sync_gradients:
                        if args.finetune_vae_decoder:
                            params_to_clip = list(encoder_prompt.parameters()) + list(decoder_prompt.parameters()) + list(unwrapped_vae.decoder.parameters())
                        else:
                            params_to_clip = list(encoder_prompt.parameters()) + list(decoder_prompt.parameters())
                        
                        grad_norm = accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
                        
                        if torch.is_tensor(grad_norm):
                            grad_norm = grad_norm.item()
                        
                        if grad_norm is None or not math.isfinite(grad_norm):
                            logger.warning(f"Skip batch {step}: grad_norm={grad_norm}")
                            optimizer.zero_grad(set_to_none=args.set_grads_to_none)
                            continue
                    
                    optimizer.step()
                    lr_scheduler.step()
                    optimizer.zero_grad(set_to_none=args.set_grads_to_none)
                
                if accelerator.sync_gradients:
                    progress_bar.update(1)
                    global_step += 1
                    
                    accelerator.log({"train_loss": loss.detach().mean().item()}, step=global_step)
                    
                    if global_step % args.checkpointing_steps == 0:
                        if accelerator.is_main_process:
                            if args.checkpoints_total_limit is not None:
                                checkpoints = os.listdir(args.output_dir)
                                checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
                                checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
                                
                                if len(checkpoints) >= args.checkpoints_total_limit:
                                    num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
                                    removing_checkpoints = checkpoints[0:num_to_remove]
                                    
                                    logger.info(f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints")
                                    logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
                                    
                                    for removing_checkpoint in removing_checkpoints:
                                        removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
                                        shutil.rmtree(removing_checkpoint)
                            
                            save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
                            accelerator.save_state(save_path)
                            logger.info(f"Saved state to {save_path}")
                    
                    if args.validation_sceneflow and global_step % args.validation_steps == 0:
                        log_validation(
                            encoder_prompt,
                            decoder_prompt,
                            vae,
                            args,
                            accelerator,
                            weight_dtype,
                            global_step,
                        )
                
                progress_bar.set_postfix(**logs)
                
                if global_step >= args.max_train_steps:
                    break
            except Exception as e:
                logger.error(f"Exception in step {step} of epoch {epoch}: {e}")
                import traceback
                traceback.print_exc()
                continue
    
    if accelerator.is_main_process:
        logger.info(f"Training completed. Total skipped batches: {skipped_batches}")
        logger.info(f"Final global average loss: {loss_tracker.get_global_average():.6f}")
        logger.info(f"Final window average loss: {loss_tracker.get_window_average():.6f}")
    
    accelerator.wait_for_everyone()
    if accelerator.is_main_process:
        encoder_prompt = accelerator.unwrap_model(encoder_prompt)
        decoder_prompt = accelerator.unwrap_model(decoder_prompt)
        vae = accelerator.unwrap_model(vae)
        
        torch.save(encoder_prompt.state_dict(), os.path.join(args.output_dir, "encoder_prompt", "pytorch_model.bin"))
        torch.save(decoder_prompt.state_dict(), os.path.join(args.output_dir, "decoder_prompt", "pytorch_model.bin"))
        
        if args.finetune_vae_decoder:
            torch.save(vae.state_dict(), os.path.join(args.output_dir, "vae", "pytorch_model.bin"))
    
    accelerator.end_training()


if __name__ == "__main__":
    args = parse_args()
    main(args)