import argparse
import collections
import datetime
import logging
import inspect
import math
import os
import pickle
from typing import Dict, Optional, Tuple
from omegaconf import OmegaConf

import torch
import torch.nn.functional as F
import torch.utils.checkpoint
import glob
import diffusers
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from diffusers import AutoencoderKL, DDPMScheduler, DDIMScheduler
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version
from diffusers.utils.import_utils import is_xformers_available
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer

from tuneavideo.models.unet import UNet3DConditionModel
from tuneavideo.data.dataset import TuneAVideoDataset
from tuneavideo.pipelines.pipeline_tuneavideo import TuneAVideoPipeline
from tuneavideo.util import save_videos_grid, ddim_inversion
from einops import rearrange
from models.umt import UMT
from tasks.retrieval_utils import extract_text_feats

# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.10.0.dev0")

logger = get_logger(__name__, log_level="INFO")

from transformers import AutoProcessor, CLIPModel, AutoModel
from sklearn.metrics.pairwise import cosine_similarity
from tasks.retrieval import  eval_after_training,calculate_umtscore
import cv2
import numpy as np

from transformers import AutoProcessor, CLIPModel, AutoModel
from sklearn.metrics.pairwise import cosine_similarity

def gif2jpg(dir_path,gif_path):
    gif = Image.open(gif_path)
    frame_index = 0
    while True:
        try: gif.seek(frame_index)
        except EOFError: break
        image = gif.convert('RGB')
        image.save(dir_path+"{:05d}.jpg".format(frame_index))
        frame_index += 1

clip_candidates = {'viclip': None, 'clip': None}


def get_clip(name='viclip'):
    global clip_candidates
    m = clip_candidates[name]
    if m is None:
        if name == 'viclip':
            tokenizer = _Tokenizer()
            vclip = ViCLIP(tokenizer)
            # m = vclip
            m = (vclip, tokenizer)
        else:
            raise Exception('the target clip model is not found.')

    return m


def get_text_feat_dict(texts, clip, tokenizer, text_feat_d={}):
    for t in texts:
        feat = clip.get_text_features(t, tokenizer, text_feat_d)
        text_feat_d[t] = feat
    return text_feat_d


def get_vid_feat(frames, clip):
    return clip.get_vid_features(frames)
def _frame_from_video(video):
    while video.isOpened():
        success, frame = video.read()
        if success:
            yield frame
        else:
            break


v_mean = np.array([0.485, 0.456, 0.406]).reshape(1, 1, 3)
v_std = np.array([0.229, 0.224, 0.225]).reshape(1, 1, 3)
def normalize(data):
    return (data / 255.0 - v_mean) / v_std


def frames2tensor(vid_list, fnum=8, target_size=(224, 224), device=torch.device('cuda')):
    assert (len(vid_list) >= fnum)
    step = len(vid_list) // fnum
    vid_list = vid_list[::step][:fnum]
    vid_list = [cv2.resize(x[:, :, ::-1], target_size) for x in vid_list]
    vid_tube = [np.expand_dims(normalize(x), axis=(0, 1)) for x in vid_list]
    vid_tube = np.concatenate(vid_tube, axis=1)
    vid_tube = np.transpose(vid_tube, (0, 1, 4, 2, 3))
    vid_tube = torch.from_numpy(vid_tube).to(device, non_blocking=True).float()
    return vid_tube

logger = get_logger(__name__, log_level="INFO")
def _frame_from_video(video):
    while video.isOpened():
        success, frame = video.read()
        if success:
            yield frame
        else:
            break


def retrieve_text(frames, texts, name='viclip', topk=5, device=torch.device('cuda')):
    clip, tokenizer = get_clip(name)
    clip = clip.to(device)
    frames_tensor = frames2tensor(frames, device=device)
    vid_feat = get_vid_feat(frames_tensor, clip)
    return vid_feat
import copy
def main(
    pretrained_model_path: str,
    output_dir: str,
    train_data: Dict,
    validation_data: Dict,
    validation_steps: int = 100,
    trainable_modules: Tuple[str] = (
        "attn1.to_q",
        "attn2.to_q",
        "attn_temp",
    ),
    train_batch_size: int = 1,
    max_train_steps: int = 500,
    learning_rate: float = 3e-5,
    scale_lr: bool = False,
    lr_scheduler: str = "constant",
    lr_warmup_steps: int = 0,
    adam_beta1: float = 0.9,
    adam_beta2: float = 0.999,
    adam_weight_decay: float = 1e-2,
    adam_epsilon: float = 1e-08,
    max_grad_norm: float = 1.0,
    gradient_accumulation_steps: int = 1,
    gradient_checkpointing: bool = True,
    checkpointing_steps: int = 500,
    resume_from_checkpoint: Optional[str] = None,
    mixed_precision: Optional[str] = "fp16",
    use_8bit_adam: bool = False,
    enable_xformers_memory_efficient_attention: bool = True,
    seed: Optional[int] = None,
    is_PEFT=None,
):

    score_dict={}
    # train_data['num_frames']=32
    validation_steps=500
    checkpointing_steps=500
    trainable_modules.append('attn_temp_2')
    trainable_modules.append('_motion')
    # umt_score_dict=collections.defaultdict(list)
    umt_config=pickle.load(open('./mt_config.pkl','rb'))
    umt_config['test_file']['test'][0] = './msrvtt_ret_test1k.json'
    umt_config['train_file'][0] = './msrvtt_ret_train9k.json'
    umt_config['model']['text_encoder']['pretrained'] = './bert-large-uncased'
    umt_config['model']['vision_encoder']['pretrained'] = './l16_ptk710_f8_res224.pth'

    umt_config.pretrained_path='./l16_25m.pth'
    eval_config,test_name2loaders,model_without_ddp,tokenizer_umt,device_umt=eval_after_training(umt_config)


    *_, config = inspect.getargvalues(inspect.currentframe())


    accelerator = Accelerator(
        gradient_accumulation_steps=gradient_accumulation_steps,
        mixed_precision=mixed_precision,
    )

    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    logger.info(accelerator.state, main_process_only=False)
    if accelerator.is_local_main_process:
        transformers.utils.logging.set_verbosity_warning()
        diffusers.utils.logging.set_verbosity_info()
    else:
        transformers.utils.logging.set_verbosity_error()
        diffusers.utils.logging.set_verbosity_error()

    # If passed along, set the training seed now.
    if seed is not None:
        set_seed(seed)

    # Handle the output folder creation
    if accelerator.is_main_process:
        # now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
        # output_dir = os.path.join(output_dir, now)
        os.makedirs(output_dir, exist_ok=True)
        os.makedirs(f"{output_dir}/samples", exist_ok=True)
        os.makedirs(f"{output_dir}/inv_latents", exist_ok=True)
        os.makedirs(f"{output_dir}/checkpoints", exist_ok=True)
        # OmegaConf.save(config, os.path.join(output_dir, 'config.yaml'))

    # Load scheduler, tokenizer and models.
    noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler")
    tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
    text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder")
    vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae")
    unet = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet",is_PEFT=is_PEFT)

    # Freeze vae and text_encoder
    vae.requires_grad_(False)
    text_encoder.requires_grad_(False)

    unet.requires_grad_(False)
    for name, module in unet.named_modules():
        if name.endswith(tuple(trainable_modules)):
            for params in module.parameters():
                params.requires_grad = True

    if enable_xformers_memory_efficient_attention:
        if is_xformers_available():
            unet.enable_xformers_memory_efficient_attention()
        else:
            raise ValueError("xformers is not available. Make sure it is installed correctly")

    if False:
        unet.enable_gradient_checkpointing()

    if scale_lr:
        learning_rate = (
            learning_rate * gradient_accumulation_steps * train_batch_size * accelerator.num_processes
        )

    # Initialize the optimizer
    if use_8bit_adam:
        try:
            import bitsandbytes as bnb
        except ImportError:
            raise ImportError(
                "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
            )

        optimizer_cls = bnb.optim.AdamW8bit
    else:
        optimizer_cls = torch.optim.AdamW

    optimizer = optimizer_cls(
        unet.parameters(),
        lr=learning_rate,
        betas=(adam_beta1, adam_beta2),
        weight_decay=adam_weight_decay,
        eps=adam_epsilon,
    )

    # Get the training dataset
    train_dataset = TuneAVideoDataset(**train_data)

    # Preprocessing the dataset
    train_dataset.prompt_ids = tokenizer(
        train_dataset.prompt, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
    ).input_ids[0]

    # DataLoaders creation:
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset, batch_size=train_batch_size
    )

    # Get the validation pipeline
    validation_pipeline = TuneAVideoPipeline(
        vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet,
        scheduler=DDIMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler")
    )
    validation_pipeline.enable_vae_slicing()
    ddim_inv_scheduler = DDIMScheduler.from_pretrained(pretrained_model_path, subfolder='scheduler')
    ddim_inv_scheduler.set_timesteps(validation_data.num_inv_steps)

    # Scheduler
    lr_scheduler = get_scheduler(
        lr_scheduler,
        optimizer=optimizer,
        num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps,
        num_training_steps=max_train_steps * gradient_accumulation_steps,
    )
    # print(next(train_dataloader).shape)
    # Prepare everything with our `accelerator`.
    unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
        unet, optimizer, train_dataloader, lr_scheduler
    )

    # For mixed precision training we cast the text_encoder and vae weights to half-precision
    # as these models are only used for inference, keeping weights in full precision is not required.
    weight_dtype = torch.float32
    if accelerator.mixed_precision == "fp16":
        weight_dtype = torch.float16
    elif accelerator.mixed_precision == "bf16":
        weight_dtype = torch.bfloat16

    # Move text_encode and vae to gpu and cast to weight_dtype
    text_encoder.to(accelerator.device, dtype=weight_dtype)
    vae.to(accelerator.device, dtype=weight_dtype)

    # We need to recalculate our total training steps as the size of the training dataloader may have changed.
    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
    # Afterwards we recalculate our number of training epochs
    num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)

    # We need to initialize the trackers we use, and also store our configuration.
    # The trackers initializes automatically on the main process.
    if accelerator.is_main_process:
        accelerator.init_trackers("text2video-fine-tune")

    # Train!
    total_batch_size = train_batch_size * accelerator.num_processes * gradient_accumulation_steps

    logger.info("***** Running training *****")
    logger.info(f"  Num examples = {len(train_dataset)}")
    logger.info(f"  Num Epochs = {num_train_epochs}")
    logger.info(f"  Instantaneous batch size per device = {train_batch_size}")
    logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
    logger.info(f"  Gradient Accumulation steps = {gradient_accumulation_steps}")
    logger.info(f"  Total optimization steps = {max_train_steps}")
    global_step = 0
    first_epoch = 0
    is_default_training = True
    is_firststage_normtuning = True
    is_adaptertuning = True
    is_save_defalut_training = False
    is_save_firststage_normtuning = False
    is_save_adaptertuning = True
    norm_firststage_lr = 5.0e-05
    norm_firststage_steps = 400
    norm_validation_steps = norm_firststage_steps
    adapter_max_train_steps =  70
    adapter_validation_steps = adapter_max_train_steps

    # Potentially load in the weights and states from a previous save
    if resume_from_checkpoint:
        if resume_from_checkpoint != "latest":
            path = os.path.basename(resume_from_checkpoint)
        else:
            # Get the most recent checkpoint
            dirs = os.listdir(output_dir)
            dirs = [d for d in dirs if d.startswith("checkpoint")]
            dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
            path = dirs[-1]
        accelerator.print(f"Resuming from checkpoint {path}")
        accelerator.load_state(os.path.join(output_dir, path))
        global_step = int(path.split("-")[1])

        first_epoch = global_step // num_update_steps_per_epoch
        resume_step = global_step % num_update_steps_per_epoch

    if is_default_training:
        global_step = 0
        # Only show the progress bar once on each machine.
        progress_bar = tqdm(range(global_step, max_train_steps), disable=not accelerator.is_local_main_process)
        progress_bar.set_description("Steps")



        text_feats_uncondition=extract_text_feats([""],32,tokenizer_umt,model_without_ddp, device_umt)[0]
        text_feats_uncondition=text_feats_uncondition.clone().detach()
        for epoch in range(first_epoch, num_train_epochs):
            unet.train()
            train_loss = 0.0
            for step, batch in enumerate(train_dataloader):
                # Skip steps until we reach the resumed step
                if resume_from_checkpoint and epoch == first_epoch and step < resume_step:
                    if step % gradient_accumulation_steps == 0:
                        progress_bar.update(1)
                    continue
                # print(batch["pixel_values"].shape)
                with accelerator.accumulate(unet):
                    # Convert videos to latent space
                    pixel_values = batch["pixel_values"].to(weight_dtype)
                    video_length = pixel_values.shape[1]
                    pixel_values = rearrange(pixel_values, "b f c h w -> (b f) c h w")
                    latents = vae.encode(pixel_values).latent_dist.sample()
                    latents = rearrange(latents, "(b f) c h w -> b c f h w", f=video_length)
                    latents = latents * 0.18215

                    # Sample noise that we'll add to the latents
                    noise = torch.randn_like(latents)
                    bsz = latents.shape[0]
                    # Sample a random timestep for each video
                    timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
                    timesteps = timesteps.long()

                    # Add noise to the latents according to the noise magnitude at each timestep
                    # (this is the forward diffusion process)
                    noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

                    # Get the text embedding for conditioning
                    encoder_hidden_states = text_encoder(batch["prompt_ids"])[0]

                    # Get the target for loss depending on the prediction type
                    if noise_scheduler.prediction_type == "epsilon":
                        target = noise
                    elif noise_scheduler.prediction_type == "v_prediction":
                        target = noise_scheduler.get_velocity(latents, noise, timesteps)
                    else:
                        raise ValueError(f"Unknown prediction type {noise_scheduler.prediction_type}")

                    # Predict the noise residual and compute loss
                    model_pred = unet(noisy_latents, timesteps, [encoder_hidden_states,text_feats_uncondition]).sample
                    loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")

                    # Gather the losses across all processes for logging (if we use distributed training).
                    avg_loss = accelerator.gather(loss.repeat(train_batch_size)).mean()
                    train_loss += avg_loss.item() / gradient_accumulation_steps

                    # Backpropagate
                    accelerator.backward(loss)
                    if accelerator.sync_gradients:
                        accelerator.clip_grad_norm_(unet.parameters(), max_grad_norm)
                    optimizer.step()
                    lr_scheduler.step()
                    optimizer.zero_grad()

                # Checks if the accelerator has performed an optimization step behind the scenes
                if accelerator.sync_gradients:
                    progress_bar.update(1)
                    global_step += 1
                    accelerator.log({"train_loss": train_loss}, step=global_step)
                    train_loss = 0.0
                
                    if global_step % checkpointing_steps == 0:
                        if accelerator.is_main_process:
                            trainable_state_dict = {k: v for k, v in unet.state_dict().items() if
                                                    any([m in k for m in trainable_modules])}
                            save_path = os.path.join(output_dir, f"checkpoints/tuneavideo-{global_step}.pth")
                            # accelerator.save_state(save_path)
                            # torch.save(trainable_state_dict, save_path)
                            # logger.info(f"Saved state to {save_path}")
                    if is_save_defalut_training:
                        if global_step % validation_steps == 0:
                            if accelerator.is_main_process:
                                with torch.no_grad():
                                    ddim_inv_latent = None
                                    if validation_data.use_inv_latent:
                                        inv_latents_path = os.path.join(output_dir, f"inv_latents/ddim_latent-{global_step}.pt")
                                        ddim_inv_latent = ddim_inversion(
                                            validation_pipeline, ddim_inv_scheduler, video_latent=latents,
                                            num_inv_steps=validation_data.num_inv_steps, prompt="",motion_prompt=text_feats_uncondition)[-1].to(weight_dtype)
                                        # torch.save(ddim_inv_latent, inv_latents_path)

                                    for edited_type, edited_prompt in validation_data.prompts.items():
                                        sample = validation_pipeline(edited_prompt, motion_prompt=torch.cat([text_feats_uncondition,text_feats_uncondition]),latents=ddim_inv_latent,
                                                                    video_length=ddim_inv_latent.shape[2],
                                                                    generator=torch.manual_seed(seed),
                                                                    **validation_data).videos
                                        save_path = f"{output_dir}/samples/sample-{global_step}/{edited_type}/{edited_prompt}.gif"
                                        save_videos_grid(sample, save_path)


                logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
                progress_bar.set_postfix(**logs)

                if global_step >= max_train_steps:
                    break
    if is_firststage_normtuning:
        # normtuning
        global_step = 0
        unet.requires_grad_(False)
        norm_count = 0
        for name, param in unet.named_parameters():
            if 'norm' in name and "mona" not in name:
                norm_count += 1
                param.requires_grad = True    
        print('Norm para number: ',norm_count)
        norm_progress_bar = tqdm(range(0, norm_firststage_steps), disable=not accelerator.is_local_main_process)
        norm_progress_bar.set_description("Normtuning Steps")
        for param_group in optimizer.param_groups:
            param_group['lr'] = norm_firststage_lr

        text_feats_uncondition=extract_text_feats([""],32,tokenizer_umt,model_without_ddp, device_umt)[0]
        text_feats_uncondition=text_feats_uncondition.clone().detach()

        for epoch in range(0, norm_firststage_steps):
            unet.train()
            train_loss = 0.0
            for step, batch in enumerate(train_dataloader):
                # Skip steps until we reach the resumed step
                if resume_from_checkpoint and epoch == first_epoch and step < resume_step:
                    if step % gradient_accumulation_steps == 0:
                        norm_progress_bar.update(1)
                    continue

                with accelerator.accumulate(unet):
                    # Convert videos to latent space
                    pixel_values = batch["pixel_values"].to(weight_dtype)
                    video_length = pixel_values.shape[1]
                    # video_length = 32
                    pixel_values = rearrange(pixel_values, "b f c h w -> (b f) c h w")
                    latents = vae.encode(pixel_values).latent_dist.sample()
                    # print("video_length:", video_length)
                    latents = rearrange(latents, "(b f) c h w -> b c f h w", f=video_length)
                    latents = latents * 0.18215

                    # Sample noise that we'll add to the latents
                    noise = torch.randn_like(latents)
                    bsz = latents.shape[0]
                    # Sample a random timestep for each video
                    timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
                    timesteps = timesteps.long()

                    # Add noise to the latents according to the noise magnitude at each timestep
                    # (this is the forward diffusion process)
                    noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

                    # Get the text embedding for conditioning
                    encoder_hidden_states = text_encoder(batch["prompt_ids"])[0]

                    # Get the target for loss depending on the prediction type
                    if noise_scheduler.prediction_type == "epsilon":
                        target = noise
                    elif noise_scheduler.prediction_type == "v_prediction":
                        target = noise_scheduler.get_velocity(latents, noise, timesteps)
                    else:
                        raise ValueError(f"Unknown prediction type {noise_scheduler.prediction_type}")

                    # Predict the noise residual and compute loss
                    model_pred = unet(noisy_latents, timesteps, [encoder_hidden_states,text_feats_uncondition]).sample
                    loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")

                    # Gather the losses across all processes for logging (if we use distributed training).
                    avg_loss = accelerator.gather(loss.repeat(train_batch_size)).mean()
                    train_loss += avg_loss.item() / gradient_accumulation_steps

                    # Backpropagate
                    accelerator.backward(loss)
                    if accelerator.sync_gradients:
                        accelerator.clip_grad_norm_(unet.parameters(), max_grad_norm)
                    optimizer.step()
                    lr_scheduler.step()
                    optimizer.zero_grad()

                # Checks if the accelerator has performed an optimization step behind the scenes
                if accelerator.sync_gradients:
                    norm_progress_bar.update(1)
                    global_step += 1
                    accelerator.log({"train_loss": train_loss}, step=global_step)
                    train_loss = 0.0

            
                    if global_step % checkpointing_steps == 0:
                        if accelerator.is_main_process:
                            save_path = os.path.join(output_dir, f"checkpoint-{global_step}")
                            # accelerator.save_state(save_path)
                            # logger.info(f"Saved state to {save_path}")
                    if is_save_firststage_normtuning:
                        if global_step % norm_validation_steps == 0:
                            if accelerator.is_main_process:
                                samples = []
                                generator = torch.Generator(device=latents.device)
                                generator.manual_seed(seed)

                                ddim_inv_latent = None
                                if validation_data.use_inv_latent:
                                    inv_latents_path = os.path.join(output_dir, f"inv_latents/ddim_latent-{global_step}.pt")
                                    ddim_inv_latent = ddim_inversion(
                                        validation_pipeline, ddim_inv_scheduler, video_latent=latents,
                                        num_inv_steps=validation_data.num_inv_steps, prompt="",motion_prompt=text_feats_uncondition)[-1].to(weight_dtype)
                                    # torch.save(ddim_inv_latent, inv_latents_path)

                                for edited_type, edited_prompt in validation_data.prompts.items():
                                    sample = validation_pipeline(edited_prompt, motion_prompt=torch.cat([text_feats_uncondition,text_feats_uncondition]),latents=ddim_inv_latent,
                                                                    video_length=ddim_inv_latent.shape[2],
                                                                    generator=torch.manual_seed(seed),
                                                                    **validation_data).videos
                                    save_path = f"{output_dir}/samples/afternorm-sample-{global_step}/{edited_type}/{edited_prompt}.gif"
                                    save_videos_grid(sample, save_path)
                                    # samples.append(sample)
                                # samples = torch.concat(samples)
                                # save_path = f"{output_dir}/samples/sample-{global_step}.gif"
                                # save_videos_grid(samples, save_path)
                                # logger.info(f"Saved samples to {save_path}")

                logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
                norm_progress_bar.set_postfix(**logs)

                if global_step >= max_train_steps:
                    break
    if is_adaptertuning:
        # normtuning
        global_step = 0
        unet.requires_grad_(False)
        adapter_count =0
        for name, param in unet.named_parameters():
            if 'adapter_mona' in name:
                adapter_count += 1
                param.requires_grad = True    
        print('Adapter para number: ',adapter_count)
        
        # loss_function =  pure_huber
        # huber_c: 1
        # is_normtuning_first: true
        adapter_progress_bar = tqdm(range(0, adapter_max_train_steps), disable=not accelerator.is_local_main_process)
        adapter_progress_bar.set_description("Adaptertuning Steps")
        for param_group in optimizer.param_groups:
            param_group['lr'] = 1.0e-05
        text_feats_uncondition=extract_text_feats([""],32,tokenizer_umt,model_without_ddp, device_umt)[0]
        text_feats_uncondition=text_feats_uncondition.clone().detach()
        for epoch in range(0, adapter_max_train_steps):
            # unet.train()
            train_loss = 0.0
            for step, batch in enumerate(train_dataloader):
                # Skip steps until we reach the resumed step
                if resume_from_checkpoint and epoch == first_epoch and step < resume_step:
                    if step % gradient_accumulation_steps == 0:
                        adapter_progress_bar.update(1)
                    continue

                with accelerator.accumulate(unet):
                    # Convert videos to latent space
                    pixel_values = batch["pixel_values"].to(weight_dtype)
                    video_length = pixel_values.shape[1]
                    # video_length = 32
                    pixel_values = rearrange(pixel_values, "b f c h w -> (b f) c h w")
                    latents = vae.encode(pixel_values).latent_dist.sample()
                    # print("video_length:", video_length)
                    latents = rearrange(latents, "(b f) c h w -> b c f h w", f=video_length)
                    latents = latents * 0.18215

                    # Sample noise that we'll add to the latents
                    noise = torch.randn_like(latents)
                    bsz = latents.shape[0]
                    # Sample a random timestep for each video
                    timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
                    timesteps = timesteps.long()

                    # Add noise to the latents according to the noise magnitude at each timestep
                    # (this is the forward diffusion process)
                    noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

                    # Get the text embedding for conditioning
                    encoder_hidden_states = text_encoder(batch["prompt_ids"])[0]

                    # Get the target for loss depending on the prediction type
                    if noise_scheduler.prediction_type == "epsilon":
                        target = noise
                    elif noise_scheduler.prediction_type == "v_prediction":
                        target = noise_scheduler.get_velocity(latents, noise, timesteps)
                    else:
                        raise ValueError(f"Unknown prediction type {noise_scheduler.prediction_type}")

                    # Predict the noise residual and compute loss
                    model_pred = unet(noisy_latents, timesteps, [encoder_hidden_states,text_feats_uncondition],is_adapter_tuning=1).sample
                    loss = F.huber_loss(model_pred.float(), target.float(), reduction="mean", delta=1)
                    # print("loss grad_fn:", loss.grad_fn)
                    # Gather the losses across all processes for logging (if we use distributed training).
                    avg_loss = accelerator.gather(loss.repeat(train_batch_size)).mean()
                    train_loss += avg_loss.item() / gradient_accumulation_steps

                    # Backpropagate
                    accelerator.backward(loss)
                    if accelerator.sync_gradients:
                        accelerator.clip_grad_norm_(unet.parameters(), max_grad_norm)
                    optimizer.step()
                    lr_scheduler.step()
                    optimizer.zero_grad()

                # Checks if the accelerator has performed an optimization step behind the scenes
                if accelerator.sync_gradients:
                    adapter_progress_bar.update(1)
                    global_step += 1
                    accelerator.log({"train_loss": train_loss}, step=global_step)
                    train_loss = 0.0

                    if global_step % checkpointing_steps == 0:
                        if accelerator.is_main_process:
                            save_path = os.path.join(output_dir, f"checkpoint-{global_step}")
                            # accelerator.save_state(save_path)
                            # logger.info(f"Saved state to {save_path}")
                    if is_save_adaptertuning:
                        if global_step % adapter_validation_steps == 0:
                            if accelerator.is_main_process:
                                samples = []
                                generator = torch.Generator(device=latents.device)
                                generator.manual_seed(seed)

                                ddim_inv_latent = None
                                if validation_data.use_inv_latent:
                                    inv_latents_path = os.path.join(output_dir, f"inv_latents/ddim_latent-{global_step}.pt")
                                    ddim_inv_latent = ddim_inversion(
                                        validation_pipeline, ddim_inv_scheduler, video_latent=latents,
                                        num_inv_steps=validation_data.num_inv_steps, prompt="",motion_prompt=text_feats_uncondition)[-1].to(weight_dtype)
                                    # torch.save(ddim_inv_latent, inv_latents_path)

                                for edited_type, edited_prompt in validation_data.prompts.items():
                                    sample = validation_pipeline(edited_prompt, motion_prompt=torch.cat([text_feats_uncondition,text_feats_uncondition]),latents=ddim_inv_latent,
                                                                    video_length=ddim_inv_latent.shape[2],
                                                                    generator=torch.manual_seed(seed),
                                                                    **validation_data).videos
                                    save_path = f"{output_dir}/samples/sample-500/{edited_type}/{edited_prompt}.gif"
                                    save_videos_grid(sample, save_path)
                                    # samples.append(sample)
                                # samples = torch.concat(samples)
                                # save_path = f"{output_dir}/samples/sample-{global_step}.gif"
                                # save_videos_grid(samples, save_path)
                                # logger.info(f"Saved samples to {save_path}")

                    logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
                    adapter_progress_bar.set_postfix(**logs)

                    if global_step >= max_train_steps:
                        break
    accelerator.end_training()


if __name__ == "__main__":



    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, default=None)
    parser.add_argument("--gpu_id", type=str, default='0')

    args = parser.parse_args()


    main(**OmegaConf.load(args.config))
