import argparse
import math
import os
from pathlib import Path
from fastvideo.utils.parallel_states import (
    initialize_sequence_parallel_state,
    destroy_sequence_parallel_group,
    get_sequence_parallel_state,
    nccl_info,
)
from torchvision.io import write_video
from einops import rearrange
import numpy as np
from fastvideo.utils.communications import sp_parallel_dataloader_wrapper, broadcast
import time
from torch.utils.data import DataLoader
import torch
from omegaconf import OmegaConf
from torch.distributed.fsdp import (
    StateDictType,
    FullStateDictConfig,
)
from torch.utils.data.distributed import DistributedSampler
from fastvideo.utils.dataset_utils import LengthGroupedSampler
import wandb
from accelerate.utils import set_seed
from tqdm.auto import tqdm
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version
from fastvideo.dataset.mc_rl_datasets import MCDatasetValStd
import torch.distributed as dist
from fastvideo.utils.checkpoint import save_checkpoint
from fastvideo.utils.logging_ import main_print
import cv2
from diffusers.video_processor import VideoProcessor
from collections import deque
from torch.nn import functional as F
import importlib
import pickle
from fastvideo.models.WM.reward_model.IDM.inverse_dynamics_model import IDMAgent
from typing import List

# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.31.0")

def compare_actions(json_actions, predicted_actions):
    """Compare predicted actions with ground truth actions."""
    comparison_results = {}
    VALID_ACTIONS = {
        'back', 'drop', 'forward', 'jump', 'left', 'right', 
        'sneak', 'sprint', 'attack', 'use',
    }

    for action_name in VALID_ACTIONS:
        # Extract ground truth action sequence from json_actions
        true_actions = []
        for frame_action in json_actions:
            true_actions.append(frame_action.get(action_name, 0))
        
        # Get predicted action sequence
        if action_name in predicted_actions:
            pred_actions = predicted_actions[action_name][0]  # Extract 16 values from (1, 16)
            
            # Compare each frame
            frame_comparison = []
            for i in range(8):
                true_val = true_actions[i]
                pred_val = pred_actions[i]
                
                # Compare predicted and true values
                is_correct = 1 if true_val == pred_val else 0
                frame_comparison.append(is_correct)
            
            comparison_results[action_name] = np.array([frame_comparison])
        else:
            # If action not in predictions, mark all as incorrect
            comparison_results[action_name] = np.array([[0] * 8])
    
    return comparison_results

def calculate_frame_accuracy(comparison_results):
    """Calculate frame accuracy from comparison results."""
    batch_actions = []
    for action_result in comparison_results.values():
        batch_actions.append(action_result[0])  # Take first row from (1, 16)
    
    # Calculate average accuracy for all actions in current batch
    batch_actions = np.vstack(batch_actions)  # shape: (num_actions, 16)
    batch_accuracy = np.mean(batch_actions, axis=0)  # shape: (16,)
    batch_accuracy[batch_accuracy < 1] = 0
    reward = np.mean(batch_accuracy)
    return reward


def get_obj_from_str(string, reload=False, invalidate_cache=True):
    """Get object from string representation."""
    module, cls = string.rsplit(".", 1)
    if invalidate_cache:
        importlib.invalidate_caches()
    if reload:
        module_imp = importlib.import_module(module)
        importlib.reload(module_imp)
    return getattr(importlib.import_module(module, package=None), cls)


def instantiate_from_config(config):
    """Instantiate object from configuration."""
    if not "target" in config:
        if config == "__is_first_stage__":
            return None
        elif config == "__is_unconditional__":
            return None
        raise KeyError("[sgm.util][instantiate_from_config] Expected key `target` to instantiate.")
    return get_obj_from_str(config["target"])(**config.get("params", dict()))

def load_model_from_config(config, ckpt):
    """Load model from configuration and checkpoint."""
    config = OmegaConf.load(config)
    config.model.params.ckpt_path = ckpt
    config.model.strict_loading=False
    model = instantiate_from_config(config.model)
    return model

def sd3_time_shift(shift, t):
    """Apply SD3 time shift transformation."""
    return (shift * t) / (1 + (shift - 1) * t)

def flux_step(
    model_output: torch.Tensor,
    latents: torch.Tensor,
    eta: float,
    sigmas: torch.Tensor,
    index: int,
    prev_sample: torch.Tensor,
    grpo: bool,
    sde_solver: bool,
):
    """Perform one step of the flux sampling process."""
    sigma = sigmas[index]
    dsigma = sigmas[index + 1] - sigma
    # Calculate basic update
    prev_sample_mean = latents + dsigma * model_output
    pred_original_sample = latents - sigma * model_output
    
    prev_sample_mean[:, 0] = latents[:, 0]
    pred_original_sample[:, 0] = latents[:, 0]

    delta_t = sigma - sigmas[index + 1]
    std_dev_t = eta * math.sqrt(delta_t)

    if sde_solver:
        score_estimate = -(latents-pred_original_sample*(1 - sigma))/sigma**2
        log_term = -0.5 * eta**2 * score_estimate
        prev_sample_mean = prev_sample_mean + log_term * dsigma
        # Ensure first frame (GT) remains unchanged after SDE correction
        prev_sample_mean[:, 0] = latents[:, 0]

    if grpo and prev_sample is None:
        prev_sample = prev_sample_mean + torch.randn_like(prev_sample_mean) * std_dev_t
        # Ensure first frame (GT) is not affected by random noise
        prev_sample[:, 0] = latents[:, 0]

    if grpo:
        # log prob of prev_sample given prev_sample_mean and std_dev_t
        log_prob = (
            -((prev_sample.detach().to(torch.float32) - prev_sample_mean.to(torch.float32)) ** 2) / (2 * (std_dev_t**2))
        ) - math.log(std_dev_t)- torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))

        # mean along all but batch dimension
        log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
        return prev_sample, pred_original_sample, log_prob
    else:
        return prev_sample_mean, pred_original_sample




def assert_eq(x, y, msg=None):
    """Assert that two values are equal."""
    assert x == y, f"{msg or 'Assertion failed'}: {x} != {y}"


def run_sample_step(
        args,
        z,
        action,
        progress_bar,
        sigma_schedule,
        transformer,
        grpo_sample,
    ):
    """Run one sampling step with GRPO."""
    noise_ctx = 0
    if grpo_sample:
        all_latents = [z]
        all_log_probs = []
        for i in progress_bar:
            B = z.shape[0]
            sigma = sigma_schedule[i]
            timestep_value = int(sigma * 1000)
            timesteps = torch.full([z.shape[1]], timestep_value, device=z.device, dtype=torch.long)
            timesteps[0] = noise_ctx
            timesteps = timesteps.unsqueeze(0)
            
            transformer.eval()
            with torch.no_grad():
                with torch.autocast("cuda", torch.bfloat16):
                    model_pred = transformer(z, timesteps, action)
                    
            z, pred_original, log_prob = flux_step(model_pred, z.to(torch.bfloat16), args.eta, sigmas=sigma_schedule, index=i, prev_sample=None, grpo=True, sde_solver=True)
            z.to(torch.bfloat16)
            all_latents.append(z)
            all_log_probs.append(log_prob)
        latents = pred_original.to(torch.float32)
        all_latents = torch.stack(all_latents, dim=1)  # (batch_size, num_steps + 1, 4, 64, 64)
        all_log_probs = torch.stack(all_log_probs, dim=1)  # (batch_size, num_steps, 1)
        return z, latents, all_latents, all_log_probs

        
def grpo_one_step(
            args,
            latents,
            pre_latents,
            video,
            action,
            transformer,
            timesteps,
            i,
            sigma_schedule,
):
    """Perform one GRPO training step."""
    B = video.shape[0]
    noise_ctx = 0
    timesteps = timesteps.repeat(latents.shape[1])
    timesteps[0] = noise_ctx
    timesteps = timesteps.unsqueeze(0)
    with torch.autocast("cuda", torch.bfloat16):
        transformer.train()
        model_pred = transformer(latents, timesteps, action)
    z, pred_original, log_prob = flux_step(model_pred, latents.to(torch.float32), args.eta, sigma_schedule, i, prev_sample=pre_latents.to(torch.float32), grpo=True, sde_solver=True)
    return log_prob

def sample_reference_model(
    args,
    step,
    device, 
    transformer,
    video, 
    action,
    inferencer,
    action_json,
    gt_video,
):
    """Sample from the reference model for GRPO training."""
    sample_steps = args.sampling_steps
    sigma_schedule = torch.linspace(1, 0.001, sample_steps + 1)
    sigma_schedule = 1.0 - sigma_schedule
    sigma_schedule = (3.0 * sigma_schedule / (1 + (3.0 -1) * sigma_schedule)).flip(dims=[0])
    
    assert_eq(
        len(sigma_schedule),
        sample_steps + 1,
        "sigma_schedule must have length sample_steps + 1",
    )

    B = video.shape[0]
    noise_abs_max = 20
    batch_size = 1
    batch_indices = torch.chunk(torch.arange(B), B // batch_size)
    
    all_latents = []
    all_log_probs = []
    all_rewards = []
    
    if args.use_same_noise:
        chunk = torch.randn((batch_size, 7, *video.shape[-3:]), dtype=torch.bfloat16, device=video.device)
        chunk = torch.clamp(chunk, -noise_abs_max, +noise_abs_max)

    for index, batch_idx in enumerate(batch_indices):
        batch_video = video[batch_idx]
        batch_action = action[batch_idx]
        
        if not args.use_same_noise:
            generator = torch.Generator(device=device)
            generator.manual_seed(42)
            chunk = torch.randn((batch_size, 7, *video.shape[-3:]), dtype=torch.bfloat16, device=video.device)
            chunk = torch.clamp(chunk, -noise_abs_max, +noise_abs_max)
        
        batch_video = torch.cat([batch_video, chunk], dim=1)
        grpo_sample = True
        progress_bar = tqdm(range(0, sample_steps), desc="Sampling Progress")
        with torch.no_grad():
            z, latents, batch_latents, batch_log_probs = run_sample_step(
                args,
                batch_video,
                batch_action,
                progress_bar,
                sigma_schedule,
                transformer.model,
                grpo_sample,
            )
        # Accumulate latents and log_probs from all batches
        all_latents.append(batch_latents)
        all_log_probs.append(batch_log_probs)
        
        with torch.inference_mode():
            video_output = transformer.tokenizer_decode(latents)
            video_output = rearrange(video_output, "b t c h w -> (b t) c h w")
        video_output = (video_output.clamp(-1, 1) + 1) / 2
        samples_to_save = rearrange(video_output, "t c h w -> t h w c")
        
        idm_in = (samples_to_save.cpu().numpy() * 255).astype(np.uint8)
               
        rank = int(os.environ["RANK"])
        # Use the last part of output_dir as video directory name
        project_name = os.path.basename(args.output_dir) if args.output_dir else "grpo_model"
        video_dir = f"{args.video_output_dir}/{project_name}"
        os.makedirs(video_dir, exist_ok=True)
        write_video(f"{video_dir}/sample_{rank}_{index}.mp4", idm_in, fps=args.fps)
        
        if args.use_idm:
            try:
                with torch.no_grad():
                    predicted_actions = inferencer.predict_actions(idm_in)
                    
                    reward = compare_actions(action_json, predicted_actions)
                    reward = calculate_frame_accuracy(reward)
                    reward = torch.tensor(reward).to(video.device)
                    all_rewards.append(reward.unsqueeze(0))
            except Exception as e:
                print("Error in IDM prediction:")
                print(e)
                reward = torch.tensor(0.0).to(video.device)
                all_rewards.append(reward.unsqueeze(0))

    all_latents = torch.cat(all_latents, dim=0)
    all_log_probs = torch.cat(all_log_probs, dim=0)
    all_rewards = torch.cat(all_rewards, dim=0)
    
    return all_rewards, all_latents, all_log_probs, sigma_schedule

def gather_tensor(tensor):
    """Gather tensors from all processes."""
    if not dist.is_initialized():
        return tensor
    world_size = dist.get_world_size()
    gathered_tensors = [torch.zeros_like(tensor) for _ in range(world_size)]
    dist.all_gather(gathered_tensors, tensor)
    return torch.cat(gathered_tensors, dim=0)

def train_one_step(
    args,
    device,
    transformer,
    inferencer,
    optimizer,
    lr_scheduler,
    loader,
    max_grad_norm,
    step,
):
    """Perform one training step with GRPO."""
    total_loss = 0.0
    grad_norm = torch.tensor(0.0)  # Initialize grad_norm
    optimizer.zero_grad()
    
    video, action, action_json, video_name = next(loader)
    gt_video = video
    video = video[:, 0:1, :, :, :]
    print(video_name)
    
    if args.use_group:
        def repeat_tensor(tensor):
            if tensor is None:
                return None
            return torch.repeat_interleave(tensor, args.num_generations, dim=0)
        with torch.no_grad():
            video = transformer.tokenizer_encode(video)

        video = repeat_tensor(video)
        action = repeat_tensor(action)
        
    all_rewards, all_latents, all_log_probs, sigma_schedule = sample_reference_model(
            args,
            step,
            device, 
            transformer,
            video, 
            action, 
            inferencer,
            action_json,
            gt_video,
        )
    batch_size = all_latents.shape[0]
    timestep_value = [int(sigma * 1000) for sigma in sigma_schedule][:args.sampling_steps]
    timestep_values = [timestep_value[:] for _ in range(batch_size)]
    device = all_latents.device
    timesteps =  torch.tensor(timestep_values, device=all_latents.device, dtype=torch.long)
    samples = {
        "timesteps": timesteps.detach().clone()[:, :-1],
        "latents": all_latents[
            :, :-1
        ][:, :-1],  # each entry is the latent before timestep t
        "next_latents": all_latents[
            :, 1:
        ][:, :-1],  # each entry is the latent after timestep t
        "log_probs": all_log_probs[:, :-1],
        "rewards": all_rewards.to(torch.float32),
        "video": video,
        "action": action,
    }
    gathered_reward = gather_tensor(samples["rewards"])
    if dist.get_rank() == 0:
        print("gathered_reward", gathered_reward)
        with open(f'{args.log_dir}/reward.txt', 'a') as f:  
            f.write(f"{gathered_reward.mean().item()}\n")

        

    n = len(samples["rewards"]) // (args.num_generations)
    advantages = torch.zeros_like(samples["rewards"])
    for i in range(n):
        start_idx = i * args.num_generations
        end_idx = (i + 1) * args.num_generations
        group_rewards = samples["rewards"][start_idx:end_idx]
        group_mean = group_rewards.mean()
        group_std = group_rewards.std() + 1e-8
        advantages[start_idx:end_idx] = (group_rewards - group_mean) / group_std
    
    samples["advantages"] = advantages    

    # best-of-n selection
    total_scores = samples["advantages"] 

    sorted_indices = torch.argsort(total_scores)

    top_indices = sorted_indices[-args.bestofn//2:]     
    bottom_indices = sorted_indices[:args.bestofn//2]     
    selected_indices = torch.cat([top_indices, bottom_indices])
    shuffled_order = torch.randperm(len(selected_indices), device=selected_indices.device)
    selected_indices = selected_indices[shuffled_order]     

    if args.num_generations != args.bestofn:
        for key in samples:
            samples[key] = samples[key][selected_indices]
        batch_size = len(selected_indices)
    
    perms = torch.stack(
        [
            torch.randperm(len(samples["timesteps"][0]))
            for _ in range(batch_size)
        ]
    ).to(device) 
    for key in ["timesteps", "latents", "next_latents", "log_probs"]:
        samples[key] = samples[key][
            torch.arange(batch_size).to(device) [:, None],
            perms,
        ]
    samples_batched = {
        k: v.unsqueeze(1)
        for k, v in samples.items()
    }
    # dict of lists -> list of dicts for easier iteration
    samples_batched_list = [
        dict(zip(samples_batched, x)) for x in zip(*samples_batched.values())
    ]

    train_timesteps = int(len(samples["timesteps"][0])*args.timestep_fraction)

    # Add global gradient accumulation counter
    gradient_accumulation_counter = 0

    for i,sample in list(enumerate(samples_batched_list)):
        # Initialize variables to avoid UnboundLocalError
        loss = torch.tensor(0.0, device=device)
        ratio = torch.tensor(1.0, device=device)
        advantages = torch.tensor(0.0, device=device)
        final_loss = torch.tensor(0.0, device=device)
        
        for _ in range(train_timesteps):
            clip_range = 1e-4
            adv_clip_max = 0.2
            
            new_log_probs = grpo_one_step(
                args,
                sample["latents"][:,_],
                sample["next_latents"][:,_],
                sample["video"],
                sample["action"],
                transformer.model,
                sample["timesteps"][:,_],
                perms[i][_],
                sigma_schedule,
            )
            
            ratio = torch.exp(new_log_probs - sample["log_probs"][:,_])

            # print("sample[advantages]", sample["advantages"])
            advantages = torch.clamp(
                sample["advantages"],
                -adv_clip_max,
                adv_clip_max,
            )
            unclipped_loss = -advantages * ratio
            clipped_loss = -advantages * torch.clamp(
                ratio,
                1.0 - clip_range,
                1.0 + clip_range,
            )
            loss = torch.mean(torch.maximum(unclipped_loss, clipped_loss)) / (args.gradient_accumulation_steps * train_timesteps)

            final_loss = loss
            
            final_loss.backward()
            
            avg_loss = final_loss.detach().clone()
            dist.all_reduce(avg_loss, op=dist.ReduceOp.AVG)
            total_loss += avg_loss.item()
            
            # Increment gradient accumulation counter
            gradient_accumulation_counter += 1
            
            # Check if gradient accumulation steps reached
            if gradient_accumulation_counter % args.gradient_accumulation_steps == 0:
                grad_norm = torch.nn.utils.clip_grad_norm_(transformer.parameters(), max_grad_norm)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()
        
        if dist.get_rank()%8==0:
            print("loss", loss.item())
            print("reward", sample["rewards"].item())
            print(" ratio", ratio)
            print("advantage", advantages.item())
            print("final loss", final_loss.item())
        dist.barrier()
    return total_loss, grad_norm.item(), sample["rewards"]

def main(args):
    torch.backends.cuda.matmul.allow_tf32 = True

    local_rank = int(os.environ["LOCAL_RANK"])
    rank = int(os.environ["RANK"])
    world_size = int(os.environ["WORLD_SIZE"])
    dist.init_process_group("nccl")
    torch.cuda.set_device(local_rank)
    device = torch.cuda.current_device()
    initialize_sequence_parallel_state(args.sp_size)

    # If passed along, set the training seed now. On GPU...
    if args.seed is not None:
        # TODO: t within the same seq parallel group should be the same. Noise should be different.
        set_seed(args.seed)
    # We use different seeds for the noise generation in each process to ensure that the noise is different in a batch.

    # Handle the repository creation
    if rank <= 0 and args.output_dir is not None:
        os.makedirs(args.output_dir, exist_ok=True)
    
    # Create log directory
    if rank <= 0:
        os.makedirs(args.log_dir, exist_ok=True)

    # For mixed precision training we cast all non-trainable weigths to half-precision
    # as these weights are only used for inference, keeping weights in full precision is not required.
    inferencer = None

    # Load IDM as Reward Model
    if args.use_idm:
        agent_parameters = pickle.load(open(args.idm_model_path, "rb"))
        net_kwargs = agent_parameters["model"]["args"]["net"]["args"]
        pi_head_kwargs = agent_parameters["model"]["args"]["pi_head_opts"]
        pi_head_kwargs["temperature"] = float(pi_head_kwargs["temperature"])
        inferencer = IDMAgent(idm_net_kwargs=net_kwargs, pi_head_kwargs=pi_head_kwargs)
        inferencer.load_weights(args.idm_weights_path)
        inferencer.to("cuda")
        for param in inferencer.parameters():
            param.requires_grad = False
        inferencer.eval()

    # Keep the master weight to float32
    
    main_print(f"--> loading model from {args.model_type}")
    
    # Load NFD as transformer
    transformer = load_model_from_config(args.model_config_path, args.model_checkpoint_path)

    main_print(
        f"  Total training parameters = {sum(p.numel() for p in transformer.parameters() if p.requires_grad) / 1e6} M"
    )
    main_print(
        f"--> Initializing FSDP with sharding strategy: {args.fsdp_sharding_startegy}"
    )

    # Model is small, no need for FSDP and gradient checkpointing
    main_print(f"--> model loaded")
    
    # Set model as trainable.
    transformer.to("cuda")
    transformer.train()

    params_to_optimize = transformer.parameters()
    params_to_optimize = list(filter(lambda p: p.requires_grad, params_to_optimize))

    optimizer = torch.optim.AdamW(
        params_to_optimize,
        lr=args.learning_rate,
        betas=(0.9, 0.999),
        weight_decay=args.weight_decay,
        eps=1e-8,
    )

    init_steps = 0
    main_print(f"optimizer: {optimizer}")

    lr_scheduler = get_scheduler(
        args.lr_scheduler,
        optimizer=optimizer,
        num_warmup_steps=args.lr_warmup_steps,
        num_training_steps=1000,
        num_cycles=args.lr_num_cycles,
        power=args.lr_power,
        last_epoch=init_steps - 1,
    )


    # Load dataset with action sequences
    train_dataset = MCDatasetValStd(
        data_dir=args.data_dir, 
        filelist_path=args.filelist_path,
        is_latent_input=False, 
        is_action_index=False, 
        max_frame_per_video=8, 
        video_key="mp4"
    )

    sampler = DistributedSampler(
            train_dataset, rank=rank, num_replicas=world_size, shuffle=True, seed=args.sampler_seed
        )
    
    train_dataloader = DataLoader(
        train_dataset,
        sampler=sampler,
        pin_memory=False,
        batch_size=args.train_batch_size,
        num_workers=args.dataloader_num_workers,
        drop_last=True,
    )

    num_update_steps_per_epoch = math.ceil(
        len(train_dataloader)
        / args.gradient_accumulation_steps
        * args.sp_size
        / args.train_sp_batch_size
    )
    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)

    if rank <= 0:
        project = args.tracker_project_name or "fastvideo"
        wandb.init(project=args.wandb_project, entity=args.wandb_entity)

    # Train!
    total_batch_size = (
        args.train_batch_size
        * world_size
        * args.gradient_accumulation_steps
        / args.sp_size
        * args.train_sp_batch_size
    )
    main_print("***** Running training *****")
    main_print(f"  Num examples = {len(train_dataset)}")
    main_print(f"  Dataloader size = {len(train_dataloader)}")
    main_print(f"  Resume training from step {init_steps}")
    main_print(f"  Instantaneous batch size per device = {args.train_batch_size}")
    main_print(
        f"  Total train batch size (w. data & sequence parallel, accumulation) = {total_batch_size}"
    )
    main_print(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
    main_print(f"  Total optimization steps per epoch = {args.max_train_steps}")
    main_print(
        f"  Total training parameters per FSDP shard = {sum(p.numel() for p in transformer.parameters() if p.requires_grad) / 1e9} B"
    )
    # print dtype
    main_print(f"  Master weight dtype: {transformer.parameters().__next__().dtype}")

    progress_bar = tqdm(
        range(0, args.max_train_steps),
        initial=init_steps,
        desc="Steps",
        # Only show the progress bar once on each machine.
        disable=local_rank > 0,
    )

    loader = sp_parallel_dataloader_wrapper(
        train_dataloader,
        device,
        args.train_batch_size,
        args.sp_size,
        args.train_sp_batch_size,
    )

    step_times = deque(maxlen=100)

    # Record initial weights for monitoring changes
    initial_weights = {}
    if rank == 0:
        for name, param in transformer.named_parameters():
            if param.requires_grad:
                initial_weights[name] = param.data.clone()
                break  # Only record first parameter for monitoring

    for epoch in range(1):
        if isinstance(sampler, DistributedSampler):
            sampler.set_epoch(epoch) # Crucial for distributed shuffling per epoch

        for step in range(init_steps+1, args.max_train_steps+1):
            start_time = time.time()
            if step % args.checkpointing_steps == 0:
                save_checkpoint(transformer, rank, args.output_dir, step, epoch)
                dist.barrier()
            loss, grad_norm, reward = train_one_step(
                args,
                device, 
                transformer,
                inferencer,
                optimizer,
                lr_scheduler,
                loader,
                args.max_grad_norm,
                step,
            )
    
            step_time = time.time() - start_time
            step_times.append(step_time)
            avg_step_time = sum(step_times) / len(step_times)
    
            progress_bar.set_postfix(
                {
                    "loss": f"{loss:.4f}",
                    "step_time": f"{step_time:.2f}s",
                    "grad_norm": grad_norm,
                }
            )
            progress_bar.update(1)
            if rank <= 0:
                wandb.log(
                    {
                        "train_loss": loss,
                        "learning_rate": lr_scheduler.get_last_lr()[0],
                        "step_time": step_time,
                        "avg_step_time": avg_step_time,
                        "grad_norm": grad_norm,
                        "reward": reward,
                    },
                    step=step,
                )



    if get_sequence_parallel_state():
        destroy_sequence_parallel_group()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_type", type=str, default="nfd_hf", help="The type of model to train."
    )
    # dataset & dataloader
    parser.add_argument("--num_frames", type=int, default=163)
    parser.add_argument(
        "--dataloader_num_workers",
        type=int,
        default=10,
        help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
    )
    parser.add_argument(
        "--train_batch_size",
        type=int,
        default=16,
        help="Batch size (per device) for the training dataloader.",
    )
    parser.add_argument("--group_frame", action="store_true")  # TODO
    parser.add_argument("--group_resolution", action="store_true")  # TODO

    # text encoder & vae & diffusion model
    parser.add_argument("--dit_model_name_or_path", type=str, default=None)

    # diffusion setting
    parser.add_argument("--ema_decay", type=float, default=0.999)
    parser.add_argument("--ema_start_step", type=int, default=0)
    parser.add_argument("--cfg", type=float, default=0.1)
    parser.add_argument(
        "--precondition_outputs",
        action="store_true",
        help="Whether to precondition the outputs of the model.",
    )

    # validation & logs
    parser.add_argument("--uncond_prompt_dir", type=str)
    parser.add_argument(
        "--validation_sampling_steps",
        type=str,
        default="64",
        help="use ',' to split multi sampling steps",
    )
    parser.add_argument(
        "--validation_guidance_scale",
        type=str,
        default="4.5",
        help="use ',' to split multi scale",
    )
    parser.add_argument("--validation_steps", type=int, default=50)
    # parser.add_argument("--log_validation", action="store_true")
    parser.add_argument("--tracker_project_name", type=str, default=None)
    parser.add_argument(
        "--seed", type=int, default=None, help="A seed for reproducible training."
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default=None,
        help="The output directory where the model predictions and checkpoints will be written.",
    )
    parser.add_argument(
        "--checkpoints_total_limit",
        type=int,
        default=None,
        help=("Max number of checkpoints to store."),
    )
    parser.add_argument(
        "--checkpointing_steps",
        type=int,
        default=500,
        help=(
            "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
            " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
            " training using `--resume_from_checkpoint`."
        ),
    )
    parser.add_argument(
        "--resume_from_checkpoint",
        type=str,
        default=None,
        help=(
            "Whether training should be resumed from a previous checkpoint. Use a path saved by"
            ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
        ),
    )
    parser.add_argument(
        "--resume_from_lora_checkpoint",
        type=str,
        default=None,
        help=(
            "Whether training should be resumed from a previous lora checkpoint. Use a path saved by"
            ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
        ),
    )
    parser.add_argument(
        "--logging_dir",
        type=str,
        default="logs",
        help=(
            "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
            " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
        ),
    )

    # optimizer & scheduler & Training
    parser.add_argument("--num_train_epochs", type=int, default=100)
    parser.add_argument(
        "--max_train_steps",
        type=int,
        default=None,
        help="Total number of training steps to perform.  If provided, overrides num_train_epochs.",
    )
    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=1,
        help="Number of updates steps to accumulate before performing a backward/update pass.",
    )
    parser.add_argument(
        "--learning_rate",
        type=float,
        default=1e-4,
        help="Initial learning rate (after the potential warmup period) to use.",
    )
    parser.add_argument(
        "--scale_lr",
        action="store_true",
        default=False,
        help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
    )
    parser.add_argument(
        "--lr_warmup_steps",
        type=int,
        default=10,
        help="Number of steps for the warmup in the lr scheduler.",
    )
    parser.add_argument(
        "--max_grad_norm", default=2.0, type=float, help="Max gradient norm."
    )
    parser.add_argument("--selective_checkpointing", type=float, default=1.0)
    parser.add_argument(
        "--allow_tf32",
        action="store_true",
        help=(
            "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
            " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
        ),
    )
    parser.add_argument(
        "--mixed_precision",
        type=str,
        default=None,
        choices=["no", "fp16", "bf16"],
        help=(
            "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
            " 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the"
            " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
        ),
    )
    parser.add_argument(
        "--use_cpu_offload",
        action="store_true",
        help="Whether to use CPU offload for param & gradient & optimizer states.",
    )

    parser.add_argument("--sp_size", type=int, default=1, help="For sequence parallel")
    parser.add_argument(
        "--train_sp_batch_size",
        type=int,
        default=1,
        help="Batch size for sequence parallel training",
    )
    parser.add_argument("--fsdp_sharding_startegy", default="full")
    # lr_scheduler
    parser.add_argument(
        "--lr_scheduler",
        type=str,
        default="constant_with_warmup",
        help=(
            'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
            ' "constant", "constant_with_warmup"]'
        ),
    )
    parser.add_argument(
        "--lr_num_cycles",
        type=int,
        default=1,
        help="Number of cycles in the learning rate scheduler.",
    )
    parser.add_argument(
        "--lr_power",
        type=float,
        default=1.0,
        help="Power factor of the polynomial scheduler.",
    )
    parser.add_argument(
        "--weight_decay", type=float, default=0.01, help="Weight decay to apply."
    )
    parser.add_argument(
        "--master_weight_type",
        type=str,
        default="fp32",
        help="Weight type to use - fp32 or bf16.",
    )
    parser.add_argument(
        "--h",
        type=int,
        default=None,   
        help="video height",
    )
    parser.add_argument(
        "--w",
        type=int,
        default=None,   
        help="video width",
    )
    parser.add_argument(
        "--t",
        type=int,
        default=None,   
        help="video length",
    )
    parser.add_argument(
        "--sampling_steps",
        type=int,
        default=None,   
        help="sampling steps",
    )
    parser.add_argument(
        "--eta",
        type=float,
        default=None,   
        help="noise eta",
    )
    parser.add_argument(
        "--fps",
        type=int,
        default=None,   
        help="fps of stored video",
    )
    parser.add_argument(
        "--sampler_seed",
        type=int,
        default=None,   
        help="seed of sampler",
    )
    parser.add_argument(
        "--use_group",
        action="store_true",
        default=False,
        help="whether to use group",
    )
    parser.add_argument(
        "--num_generations",
        type=int,
        default=16,   
        help="num_generations per prompt",
    )
    parser.add_argument(
        "--use_same_noise",
        action="store_true",
        default=False,
        help="whether to use same noise",
    )
    parser.add_argument(
        "--use_idm",
        action="store_true",
        default=False,
        help="whether to videoalign reward model",
    )
    parser.add_argument(
        "--timestep_fraction",
        type = float,
        default=1.0,
        help="timestep_fraction",
    )
    parser.add_argument(
        "--shift",
        type = float,
        default=1.0,
        help="shift value",
    )
    parser.add_argument(
        "--bestofn",
        type = int,
        default=8,
        help="the chosen samples in best-of-n",
    )
    parser.add_argument(
        "--coef",
        type=float,
        default=1.0,   
        help="coef",
    )
    
    # Model and data paths
    parser.add_argument(
        "--model_config_path",
        type=str,
        required=True,
        help="Path to model configuration file",
    )
    parser.add_argument(
        "--model_checkpoint_path",
        type=str,
        required=True,
        help="Path to model checkpoint file",
    )
    parser.add_argument(
        "--data_dir",
        type=str,
        required=True,
        help="Path to data directory",
    )
    parser.add_argument(
        "--filelist_path",
        type=str,
        required=True,
        help="Path to file list",
    )
    parser.add_argument(
        "--idm_model_path",
        type=str,
        default=None,
        help="Path to IDM model file",
    )
    parser.add_argument(
        "--idm_weights_path",
        type=str,
        default=None,
        help="Path to IDM weights file",
    )
    parser.add_argument(
        "--video_output_dir",
        type=str,
        default="./videos",
        help="Directory to save output videos",
    )
    parser.add_argument(
        "--log_dir",
        type=str,
        default="./logs",
        help="Directory to save log files",
    )
    parser.add_argument(
        "--wandb_project",
        type=str,
        default="nfd_finetune",
        help="Wandb project name",
    )
    parser.add_argument(
        "--wandb_entity",
        type=str,
        default=None,
        help="Wandb entity name",
    )

    args = parser.parse_args()
    main(args)