

import os
import sys
os.environ['TOKENIZERS_PARALLELISM'] = 'true'

import torch.distributed as dist
from omegaconf import OmegaConf
import argparse
import torch
import time
import gc
import json
from einops import rearrange
from torch.nn import functional as F
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import torchvision.transforms.functional as TF
from PIL import Image
from accelerate import Accelerator
from accelerate.utils import ProjectConfiguration, set_seed, DistributedDataParallelKwargs
from tqdm import tqdm
import logging
import datasets
import transformers
import shutil
import diffusers
from diffusers.utils import export_to_video
from diffusers.training_utils import EMAModel
from copy import deepcopy

current_file_path = os.path.abspath(__file__)
project_roots = [os.path.dirname(current_file_path), os.path.dirname(os.path.dirname(current_file_path)), os.path.dirname(os.path.dirname(os.path.dirname(current_file_path)))]
for project_root in project_roots:
    sys.path.insert(0, project_root) if project_root not in sys.path else None

from videox_fun.models.wan_dmd import DMD
from videox_fun.models.wan_RL import wan_RL
from videox_fun.data.dataset import ImageVideoDataset, BucketDistributedSampler, cycle, DistributedKRepeatSampler
from videox_fun.utils.utils import get_video_to_video_latent, resize_mask
from videox_fun.pipeline import WanFunInpaintPipeline, WanFunCausalInpaintPipeline
from videox_fun.pipeline import WanFunControlPipeline, WanFunCausalControlPipeline

from collections import defaultdict
import flow_grpo.rewards
from concurrent import futures

class Trainer:
    def __init__(self, config):
        self.config = config

        logging_dir = os.path.join(config.output_dir, config.logging_dir)
        accelerator_project_config = ProjectConfiguration(project_dir=config.output_dir, logging_dir=logging_dir)
        ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)

        self.accelerator = accelerator = Accelerator(
            gradient_accumulation_steps=config.gradient_accumulation_steps,
            mixed_precision=config.mixed_precision,
            log_with=config.report_to,
            project_config=accelerator_project_config,
            # kwargs_handlers=[ddp_kwargs],
        )
        deepspeed_plugin = accelerator.state.deepspeed_plugin
        if deepspeed_plugin is not None:
            zero_stage = int(deepspeed_plugin.zero_stage)
            print(f"Using DeepSpeed Zero stage: {zero_stage}")
        else:
            zero_stage = 0
            print("DeepSpeed is not enabled.")
        if accelerator.is_main_process:
            self.writer = SummaryWriter(log_dir=logging_dir)

        # Make one log on every process with the configuration for debugging.
        logging.basicConfig(
            format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
            datefmt="%m/%d/%Y %H:%M:%S",
            level=logging.INFO,
        )
        logging.info(accelerator.state)
        if accelerator.is_local_main_process:
            datasets.utils.logging.set_verbosity_warning()
            transformers.utils.logging.set_verbosity_warning()
            diffusers.utils.logging.set_verbosity_info()
        else:
            datasets.utils.logging.set_verbosity_error()
            transformers.utils.logging.set_verbosity_error()
            diffusers.utils.logging.set_verbosity_error()

        # If passed along, set the training seed now.
        # if config.seed is not None:
        #     set_seed(config.seed + accelerator.process_index)

        # Handle the repository creation
        if accelerator.is_main_process:
            if config.output_dir is not None:
                os.makedirs(config.output_dir, exist_ok=True)
        
        if accelerator.is_main_process:
            log_config = {}
            for k, v in vars(config).items():
                if isinstance(v, (int, float, str, bool, torch.Tensor)):
                    log_config[k] = v
                elif isinstance(v, dict):
                    for kk, vv in v.items():
                        log_config[f"{k}.{kk}"] = str(vv)
                else:
                    log_config[k] = str(v)
            accelerator.init_trackers(config.tracker_project_name, config=log_config)

        self.dtype = torch.bfloat16 if config.mixed_precision == 'bf16' else torch.float32
        self.device = torch.cuda.current_device()

        # Step 2: Initialize the model and optimizer
        self.distillation_model = wan_RL(config, device=self.device, dtype=self.dtype)

        if config.use_ema:
            ema_generator = deepcopy(self.distillation_model.generator)
            self.ema_generator = EMAModel(ema_generator.parameters(), model_cls=ema_generator.__class__, model_config=ema_generator.config)
            self.ema_generator.to(self.device)

        self.generator_optimizer = torch.optim.AdamW(
            [param for param in self.distillation_model.generator.parameters()
             if param.requires_grad],
            lr=config.lr,
            betas=(config.beta1, config.beta2)
        )


        self.distillation_model.generator = accelerator.prepare(self.distillation_model.generator)
        # self.distillation_model.fake_score = accelerator.prepare(self.distillation_model.fake_score)
        # if self.distillation_model.discriminator is not None:
        #     self.distillation_model.discriminator = accelerator.prepare(self.distillation_model.discriminator)
        self.generator_optimizer = accelerator.prepare(self.generator_optimizer)
        # self.critic_optimizer = accelerator.prepare(self.critic_optimizer)

        # Step 3: Initialize the dataloader

        self.backward_simulation = getattr(config, "backward_simulation", True)

        self.reward_fn = getattr(flow_grpo.rewards, 'video_reward')(accelerator.device, config.reward_fn)
        self.executor = futures.ThreadPoolExecutor(max_workers=8)


        dataset = ImageVideoDataset(
            data_root=config.train_data_dir,
            dataset_file=config.train_data_meta,
            caption_column=config.caption_column,
            video_column=config.video_column,
            resolution_buckets=config.video_resolution_buckets,
            frame_interval=config.video_sample_stride,
        )
        # self.sampler = BucketDistributedSampler(
        #     dataset, 
        #     config.batch_size, 
        #     accelerator.num_processes, 
        #     accelerator.process_index
        # )

        self.sampler = DistributedKRepeatSampler(
            dataset=dataset,
            batch_size=config.batch_size,
            k=config.num_video_per_prompt,
            num_replicas=accelerator.num_processes,
            rank=accelerator.process_index,
            seed=42
        )
        dataloader = torch.utils.data.DataLoader(
            dataset, 
            batch_sampler=self.sampler, 
            num_workers=config.dataloader_num_workers,
            pin_memory=True,
        )
        # self.dataloader = cycle(dataloader, sampler)
        self.dataloader = iter(dataloader)

        self.step = 0
        self.global_step = 0
        self.avg_loss_dict = {}
        self.max_grad_norm = 1.0
        self.previous_time = None

        if config.generator_name == "wan":
            pipeline_cls = WanFunControlPipeline
        elif config.generator_name == "causal_wan":
            pipeline_cls = WanFunCausalControlPipeline
        self.pipeline = pipeline_cls(
            vae=self.distillation_model.vae, 
            text_encoder=self.distillation_model.text_encoder,
            tokenizer=self.distillation_model.tokenizer,
            transformer=accelerator.unwrap_model(self.distillation_model.generator),
            scheduler=self.distillation_model.scheduler,
            clip_image_encoder=self.distillation_model.clip_image_encoder,
        )
        self.pipeline.to(device=self.device, dtype=self.dtype)


    def train_one_step(self):

        ################################Sampling##########################################
        self.distillation_model.eval()
        samples = []
        prompts = []
        cache_samples = []

        for i in tqdm(
            range(self.config.num_batches_per_epoch),
            desc=f"Epoch {self.global_step}: sampling",
            disable=not self.accelerator.is_local_main_process,
            position=0,
        ):
            # Step 1: Get the next batch of text prompts
            
            self.sampler.set_epoch(self.global_step * self.config.num_batches_per_epoch + i)
            batch = next(self.dataloader)
            text_prompts = batch["text"]
            pixel_values = batch["pixel_values"].to(device=self.device, dtype=self.dtype)
            clip_pixel_values = batch["clip_pixel_values"].to(device=self.device, dtype=self.dtype)
            control_pixel_values = batch["control_pixel_values"].to(device=self.device, dtype=self.dtype)
            ref_pixel_values = batch["ref_pixel_values"].to(device=self.device, dtype=self.dtype)
            
            with torch.no_grad():
                # This way is quicker when batch grows up
                def _batch_encode_vae(pixel_values):
                    pixel_values = rearrange(pixel_values, "b f c h w -> b c f h w")
                    bs = 1
                    new_pixel_values = []
                    for j in range(0, pixel_values.shape[0], bs):
                        pixel_values_bs = pixel_values[j : j + bs]
                        pixel_values_bs = self.distillation_model.vae.encode(pixel_values_bs)[0]
                        pixel_values_bs = pixel_values_bs.sample()
                        new_pixel_values.append(pixel_values_bs)
                    output = torch.cat(new_pixel_values, dim = 0)
                    return output
                latents = _batch_encode_vae(pixel_values)
                clean_latent = latents

                # Encode inpaint latents.
                control_latents = _batch_encode_vae(control_pixel_values)
                ref_latents = _batch_encode_vae(ref_pixel_values)

                if self.config.train_mode == "control_ref_repeat":
                    ref_latents_conv_in = ref_latents.repeat(1, 1, latents.size()[2], 1, 1)
                else:
                    ref_latents_conv_in = torch.zeros_like(latents).to(ref_latents.device, ref_latents.dtype)
                    ref_latents_conv_in[:, :, :1] = ref_latents
                
                # Make first frame control to zero
                control_latents[:, :, :1] = 0

                control_latents = torch.cat([control_latents, ref_latents_conv_in], dim = 1)

                clip_context = []
                for clip_pixel_value in clip_pixel_values:
                    clip_image = Image.fromarray(np.uint8(clip_pixel_value.float().cpu().numpy()))
                    clip_image = TF.to_tensor(clip_image).sub_(0.5).div_(0.5).to(self.device, self.dtype)
                    _clip_context = self.distillation_model.clip_image_encoder([clip_image[:, None, :, :]])

                    zero_init_clip_in = np.random.choice([True, False], p=[0.1, 0.9])
                    clip_context.append(_clip_context if not zero_init_clip_in else torch.zeros_like(_clip_context))
                    
                clip_context = torch.cat(clip_context)


            batch_size = len(text_prompts)


            clean_latent = rearrange(clean_latent, "b c f h w -> b f c h w")
            control_latents = rearrange(control_latents, "b c f h w -> b f c h w")
            image_or_video_shape = clean_latent.shape

            # Step 2: Extract the conditional infos
            with torch.no_grad():
                conditional_dict = self.distillation_model.encode_text(text_prompts)

                if not getattr(self, "unconditional_dict", None):
                    unconditional_dict = self.distillation_model.encode_text([self.config.negative_prompt] * batch_size)
                    unconditional_dict = {k: v.detach()
                                        for k, v in unconditional_dict.items()}
                    self.unconditional_dict = unconditional_dict  # cache the unconditional_dict
                else:
                    unconditional_dict = self.unconditional_dict

            # Step 3: Sampling
            with self.accelerator.autocast():
                with torch.no_grad():
                    images, sample, cache_sample = self.distillation_model.sampling(
                        image_or_video_shape=image_or_video_shape,
                        conditional_dict=conditional_dict,
                        unconditional_dict=unconditional_dict,
                        clean_latent=clean_latent,
                        inpaint_latents=control_latents,
                        clip_context=clip_context,
                    )
            
            
            # compute rewards asynchronously
            if self.global_step % 50 == 0:
                for k in range(images.shape[0]):
                    video = images[k].permute(0, 2, 3, 1).float().cpu().numpy()
                    export_to_video(video, save_path, fps=8)
                    # import shutil
                    # shutil.copy(save_path, final_save_path)
                
            rewards = self.executor.submit(self.reward_fn, images, prompts, traj=control_pixel_values.repeat(self.config.mini_num_image_per_prompt, 1, 1, 1, 1), only_strict=True,random_block=cache_sample["random_block"])
            
            time.sleep(0)
            sample.update({"rewards": rewards})
            samples.append(sample)
            cache_samples.append(cache_sample)

        # wait for all rewards to be computed
        for sample in tqdm(
            samples,
            desc="Waiting for rewards",
            disable=not self.accelerator.is_local_main_process,
            position=0,
        ):
            rewards, reward_metadata = sample["rewards"].result()
            # accelerator.print(reward_metadata)
            sample["rewards"] = {
                key: torch.as_tensor(value, device=self.accelerator.device).float()
                for key, value in rewards.items()
            }

        # collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...)
        samples = {
            k: torch.cat([s[k] for s in samples], dim=0)
            if not isinstance(samples[0][k], dict)
            else {
                sub_key: torch.cat([s[k][sub_key] for s in samples], dim=0)
                for sub_key in samples[0][k]
            }
            for k in samples[0].keys()
        }


        samples["rewards"]["ori_avg"] = samples["rewards"]["avg"]
        # The purpose of repeating `adv` along the timestep dimension here is to make it easier to introduce timestep-dependent advantages later, such as adding a KL reward.
        samples["rewards"]["avg"] = samples["rewards"]["avg"].unsqueeze(1).repeat(1, self.config.train_num_steps)
        # gather rewards across processes
        gathered_rewards = {key: self.accelerator.gather(value) for key, value in samples["rewards"].items()}
        gathered_rewards = {key: value.cpu().numpy() for key, value in gathered_rewards.items()}
        # log rewards and images
        if self.accelerator.is_main_process:
            self.accelerator.log(
                {
                    "epoch": self.global_step,
                    **{f"reward_{key}": value.mean() for key, value in gathered_rewards.items() if '_strict_accuracy' not in key and '_accuracy' not in key},
                },
                step=self.global_step,
            )

        # per-prompt mean/std tracking
        print("###########gathered_rewards",gathered_rewards['avg'])
        # if self.accelerator.is_main_process:
        #     self.writer.add_scalar("reward_avg", gathered_rewards['avg'].mean(), self.global_step)
        advantages = (gathered_rewards['avg'] - gathered_rewards['avg'].mean()) / (gathered_rewards['avg'].std() + 1e-4)

        # ungather advantages; we only need to keep the entries corresponding to the samples on this process
        advantages = torch.as_tensor(advantages)
        samples["advantages"] = (
            advantages.reshape(self.accelerator.num_processes, -1, advantages.shape[-1])[self.accelerator.process_index]
            .to(self.accelerator.device)
        )
        if self.accelerator.is_local_main_process:
            print("advantages: ", samples["advantages"].abs().mean())

        del samples["rewards"]
        # del samples["prompt_ids"]

        total_batch_size, num_timesteps = samples["timesteps"].shape
        if self.global_step < 2:
            return
        ###########################################Training##########################################
        for inner_epoch in range(self.config.num_inner_epochs):
            # rebatch for training
            samples_batched = {
                k: v.reshape(-1, total_batch_size//self.config.num_batches_per_epoch, *v.shape[1:])
                for k, v in samples.items()
            }

            # dict of lists -> list of dicts for easier iteration
            samples_batched = [
                dict(zip(samples_batched, x)) for x in zip(*samples_batched.values())
            ]
            
            # train
            self.distillation_model.generator.train()
            info = defaultdict(list)
            for i, sample in tqdm(
                list(enumerate(samples_batched)),
                desc=f"Epoch {self.global_step}.{inner_epoch}: training",
                position=0,
                disable=not self.accelerator.is_local_main_process,
            ):
                for j in tqdm(
                    range(self.config.train_num_steps),
                    desc="Timestep",
                    position=1,
                    leave=False,
                    disable=not self.accelerator.is_local_main_process,
                ):
                    if hasattr(self.distillation_model.generator.module, "disable_adapter"):
                        context = self.distillation_model.generator.module.disable_adapter()
                    else:
                        # 不存在 adapter，使用空上下文管理器
                        from contextlib import nullcontext
                        context = nullcontext()
                    with self.accelerator.accumulate(self.distillation_model.generator):
                        with self.accelerator.autocast():
                            prev_sample, log_prob, prev_sample_mean, std_dev_t, dt = self.distillation_model.inference_pipeline.compute_log_prob(sample, i, j, cache_samples[i])
                            if self.config.beta > 0:
                                with torch.no_grad():
                                    # with context:
                                    prev_sample_ref, log_prob_ref, prev_sample_mean_ref, std_dev_t_ref, dt_ref = self.distillation_model.inference_pipeline.compute_log_prob(sample, i, j, cache_samples[i], is_ref=True)

                        # grpo logic
                        advantages = torch.clamp(
                            sample["advantages"][:, j],
                            -self.config.adv_clip_max,
                            self.config.adv_clip_max,
                        )
                        ratio = torch.exp(log_prob - sample["log_probs"][:, j])
                        unclipped_loss = -advantages * ratio
                        clipped_loss = -advantages * torch.clamp(
                            ratio,
                            1.0 - self.config.clip_range,
                            1.0 + self.config.clip_range,
                        )
                        policy_loss = torch.mean(torch.maximum(unclipped_loss, clipped_loss))
                        print("############policy_loss",policy_loss)
                        if self.config.beta > 0:
                            kl_loss = ((prev_sample_mean - prev_sample_mean_ref) ** 2).mean(dim=(1,2,3), keepdim=True) / (2 * std_dev_t * dt_ref ** 2)
                            kl_loss = torch.mean(kl_loss)
                            loss = policy_loss + self.config.beta * kl_loss
                            # loss = self.config.beta * kl_loss
                        else:
                            loss = policy_loss
                        print("############kl_loss",kl_loss)
                        info["approx_kl"].append(
                            0.5
                            * torch.mean((log_prob - sample["log_probs"][:, j]) ** 2)
                        )
                        info["clipfrac"].append(
                            torch.mean(
                                (
                                    torch.abs(ratio - 1.0) > self.config.clip_range
                                ).float()
                            )
                        )
                        info["clipfrac_gt_one"].append(
                            torch.mean(
                                (
                                    ratio - 1.0 > self.config.clip_range
                                ).float()
                            )
                        )
                        info["clipfrac_lt_one"].append(
                            torch.mean(
                                (
                                    1.0 - ratio > self.config.clip_range
                                ).float()
                            )
                        )
                        info["policy_loss"].append(policy_loss)
                        if self.config.beta > 0:
                            info["kl_loss"].append(kl_loss)

                        info["loss"].append(loss)


                        # backward pass
                        self.accelerator.backward(loss)
                        if self.accelerator.sync_gradients:
                            self.accelerator.clip_grad_norm_(
                                self.distillation_model.generator.parameters(), self.config.max_grad_norm
                            )
                        self.generator_optimizer.step()
                        self.generator_optimizer.zero_grad()
                    # Checks if the accelerator has performed an optimization step behind the scenes
                    if self.accelerator.sync_gradients:
                        # assert (j == train_timesteps[-1]) and (
                        #     i + 1
                        # ) % config.train.gradient_accumulation_steps == 0
                        # log training-related stuff
                        info = {k: torch.mean(torch.stack(v)) for k, v in info.items()}
                        info = self.accelerator.reduce(info, reduction="mean")
                        info.update({"epoch": self.global_step, "inner_epoch": inner_epoch})
                        # if self.accelerator.is_main_process:
                        #     wandb.log(info, step=global_step)
                        self.accelerator.log(info, step=self.global_step)
                        info = defaultdict(list)
                        if self.accelerator.is_main_process:
                            self.writer.add_scalars("loss", info, self.global_step)
                if self.config.use_ema:
                    self.ema_generator.step(self.distillation_model.generator.parameters())

        

    def add_visualization(self, generator_log_dict, critic_log_dict, input_dict):
        visual_dict = {}
        visual_dict.update({
            "original_pixel_values": input_dict["original_pixel_values"],
            "input_clean_latent": self.pipeline.decode_latents(input_dict["clean_latent"].permute(0,2,1,3,4))[0].permute(0,2,1,3,4),
            "input_control_latents": self.pipeline.decode_latents(input_dict["control_latents"].permute(0,2,1,3,4))[0].permute(0,2,1,3,4),
        })

        if critic_log_dict:
            critictrain_latent, critictrain_noisy_latent, critictrain_pred_image = map(
                lambda x: self.pipeline.decode_latents(x.permute(0,2,1,3,4))[0].permute(0,2,1,3,4),
                [critic_log_dict['critictrain_latent'], critic_log_dict['critictrain_noisy_latent'],
                    critic_log_dict['critictrain_pred_image']]
            )
            visual_dict.update({
                "critictrain_latent": critictrain_latent,
                "critictrain_noisy_latent": critictrain_noisy_latent,
                "critictrain_pred_image": critictrain_pred_image
            })

        if "dmdtrain_clean_latent" in generator_log_dict:
            (dmdtrain_generator_noisy_input, dmdtrain_clean_latent, dmdtrain_noisy_latent, dmdtrain_pred_real_image, dmdtrain_pred_fake_image) = map(
                lambda x: self.pipeline.decode_latents(x.permute(0,2,1,3,4))[0].permute(0,2,1,3,4) if x is not None else None,
                [generator_log_dict.get('generator_noisy_input', None), generator_log_dict['dmdtrain_clean_latent'], generator_log_dict['dmdtrain_noisy_latent'],
                    generator_log_dict['dmdtrain_pred_real_image'], generator_log_dict['dmdtrain_pred_fake_image']]
            )

            visual_dict.update(
                {
                    "dmdtrain_generator_noisy_input": dmdtrain_generator_noisy_input,
                    "dmdtrain_clean_latent": dmdtrain_clean_latent,
                    "dmdtrain_noisy_latent": dmdtrain_noisy_latent,
                    "dmdtrain_pred_real_image": dmdtrain_pred_real_image,
                    "dmdtrain_pred_fake_image": dmdtrain_pred_fake_image
                }
            )
        # wandb_loss_dict.update(visual_dict)
        for k, v in visual_dict.items():
            if v is None:
                continue
            for i in range(v.shape[0]):
                video = v[i].permute(0, 2, 3, 1).float().cpu().numpy()
                os.makedirs(os.path.dirname(save_path), exist_ok=True)
                os.makedirs(os.path.dirname(final_save_path), exist_ok=True)
                export_to_video(video, save_path, fps=8)
                # import shutil
                # shutil.copy(save_path, final_save_path)

    def train(self):
        self.validate()
        # self.save()

        progress_bar = tqdm(
            range(0, self.config.max_train_steps),
            initial=0,
            desc="Steps",
            # Only show the progress bar once on each machine.
            disable=not self.accelerator.is_local_main_process,
        )
        while True:
            self.train_one_step()
            if (not self.config.no_save) and (self.global_step + 1) % self.config.checkpoint_iters == 0:
                self.save()
                torch.cuda.empty_cache()
            
            if (self.global_step + 1) % self.config.validation_iters == 0:
                self.validate()

            self.step += 1
            if self.step % self.config.gradient_accumulation_steps == 0:
                self.global_step += 1
                progress_bar.update(1)
            # self.accelerator.log(loss_logs, step=self.step)

            if self.global_step >= self.config.max_train_steps:
                break
        if self.global_step % self.config.checkpoint_iters != 0:
            self.save()
            self.validate()
    
    def save(self):
        if self.accelerator.is_main_process:
            if self.config.checkpoints_total_limit is not None:
                checkpoints = os.listdir(self.config.output_dir)
                checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
                checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))

                # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
                if len(checkpoints) >= self.config.checkpoints_total_limit:
                    num_to_remove = len(checkpoints) - self.config.checkpoints_total_limit + 1
                    removing_checkpoints = checkpoints[0:num_to_remove]

                    logging.info(
                        f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
                    )
                    logging.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")

                    for removing_checkpoint in removing_checkpoints:
                        removing_checkpoint = os.path.join(self.config.output_dir, removing_checkpoint)
                        shutil.rmtree(removing_checkpoint, ignore_errors=True)

            output_dir = os.path.join(self.config.output_dir, f"checkpoint-{self.global_step+1}")
            os.makedirs(output_dir, exist_ok=True)
            generator = self.accelerator.unwrap_model(self.distillation_model.generator)
            generator.save_pretrained(os.path.join(output_dir, "generator"))
            print(f"Save generator model to {os.path.join(output_dir, 'generator')}")
            if self.config.use_ema:
                self.ema_generator.save_pretrained(os.path.join(output_dir, "ema_generator"))
                print(f"Save ema generator model to {os.path.join(output_dir, 'ema_generator')}")
        self.accelerator.wait_for_everyone()

    @torch.no_grad()
    def validate(self):
        if self.accelerator.is_main_process:
            if self.config.use_ema:
                self.ema_generator.store(self.distillation_model.generator.parameters())
                self.ema_generator.copy_to(self.distillation_model.generator.parameters())

            print(f"Start validation at step {self.global_step+1}...")
            validation_samples = json.load(open(self.config.validation_data_meta))

            if self.config.seed is None:
                generator = None
            else:
                generator = torch.Generator(device=self.device).manual_seed(self.config.seed)

            if isinstance(self.config.video_resolution_buckets[0], str):
                self.config.video_resolution_buckets = [eval(bucket) for bucket in self.config.video_resolution_buckets]

            for index, sample in enumerate(validation_samples):
                prompt = validation_samples[index]["prompt"]
                control_video = validation_samples[index]["control_video"]
                video_length = validation_samples[index]["length"]
                fps = validation_samples[index].get("fps", 8)
                ref_image = validation_samples[index]["image_start"]
                width, height = Image.open(ref_image).size
                sorted_buckets = sorted(self.config.video_resolution_buckets, key=lambda x: (abs(x[1]/x[2]-height/width), abs(x[1]-height)+abs(x[2]-width)))
                sample_size = sorted_buckets[0][1:3]

                video_length = int(video_length // self.distillation_model.vae.config.temporal_compression_ratio * self.distillation_model.vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1
                input_video, input_video_mask, ref_image, clip_image = get_video_to_video_latent(control_video, video_length=video_length, sample_size=sample_size, fps=fps, ref_image=ref_image)
                video = self.pipeline(
                    prompt, 
                    num_frames = video_length,
                    negative_prompt = "色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走",
                    height      = sample_size[0],
                    width       = sample_size[1],
                    guidance_scale = self.config.guidance_scale,
                    generator   = generator,

                    control_video = input_video,
                    ref_image = ref_image,
                    clip_image = clip_image,
                    mode = self.config.train_mode,
                    timesteps = self.config.denoising_step_list if self.config.generator_name == "wan" else None,
                    context_perturbation = self.config.context_perturbation,
                ).videos[0]
                rewards = self.executor.submit(self.reward_fn, video.permute(1, 0, 2, 3).unsqueeze(0), prompt, traj=input_video.permute(0, 2, 1, 3, 4) *2-1, only_strict=True, random_block=-1)
            
                reward_avg = torch.tensor([score for score in rewards.result()[0]['avg']], device=self.device).float().mean().item()
                self.writer.add_scalar(f"validation/reward_avg_{index:03d}", reward_avg, self.global_step+1)
                video = video.permute(1, 2, 3, 0).cpu().numpy()  # (c f h w) -> (f h w c)
                export_to_video(video, save_path, fps=8)
                # import shutil
                # shutil.copy(save_path, final_save_path)
            
            if self.config.use_ema:
                self.ema_generator.restore(self.distillation_model.generator.parameters())
            
            gc.collect()
            torch.cuda.empty_cache()
            torch.cuda.ipc_collect()
            
        self.accelerator.wait_for_everyone()


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--config_path", type=str, required=True)
    parser.add_argument("--no_save", action="store_true")
    parser.add_argument("--no_visualize", action="store_true")

    args = parser.parse_args()

    config = OmegaConf.load(args.config_path)
    config.no_save = args.no_save
    config.no_visualize = args.no_visualize

    trainer = Trainer(config)
    trainer.train()

if __name__ == "__main__":
    main()
