import argparse
import json
import torch
import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "2"
from omegaconf import OmegaConf
from tqdm import tqdm
from torchvision import transforms
from torchvision.io import write_video
from einops import rearrange
import torch.distributed as dist
from torch.utils.data import DataLoader, SequentialSampler
from torch.utils.data.distributed import DistributedSampler

from pipeline import CausalDiffusionInferencePipeline,CausalInferencePipeline
from utils.dataset import TextDataset, TextImagePairDataset
from utils.misc import set_seed

from demo_utils.memory import gpu, get_cuda_free_memory_gb, DynamicSwapInstaller

parser = argparse.ArgumentParser()
parser.add_argument("--config_path", type=str, help="Path to the config file")
parser.add_argument("--checkpoint_path", type=str, help="Path to the checkpoint folder")
parser.add_argument("--data_path", type=str, help="Path to the dataset")
parser.add_argument("--extended_prompt_path", type=str, help="Path to the extended prompt")
parser.add_argument("--output_folder", type=str, help="Output folder")
parser.add_argument("--num_output_frames", type=int, default=21,
                    help="Number of overlap frames between sliding windows")
parser.add_argument("--i2v", action="store_true", help="Whether to perform I2V (or T2V by default)")
parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA parameters")
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument("--first_block_steps", type=int, default=4, help="first_block_steps")
parser.add_argument("--denoising_steps", type=int, default=4, help="denoising_steps")
parser.add_argument("--epoch", type=str, default="000000", help="the epoch to use")
parser.add_argument("--num_samples", type=int, default=1, help="Number of samples to generate per prompt")
parser.add_argument("--save_with_index", action="store_true",
                    help="Whether to save the video using the index or prompt as the filename")
args = parser.parse_args()

# Initialize distributed inference
if "LOCAL_RANK" in os.environ:
    dist.init_process_group(backend='nccl')
    local_rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(local_rank)
    device = torch.device(f"cuda:{local_rank}")
    world_size = dist.get_world_size()
    set_seed(args.seed + local_rank)
else:
    device = torch.device("cuda")
    local_rank = 0
    world_size = 1
    set_seed(args.seed)

print(f'Free VRAM {get_cuda_free_memory_gb(gpu)} GB')
low_memory = get_cuda_free_memory_gb(gpu) < 40

torch.set_grad_enabled(False)

config = OmegaConf.load(args.config_path)
default_config = OmegaConf.load("configs/default_config.yaml")
config = OmegaConf.merge(default_config, config)

# Initialize pipeline
if hasattr(config, 'denoising_step_list'):
    # Few-step inference
    pipeline = CausalInferencePipeline(config, device=device)
else:
    # Multi-step diffusion inference
    pipeline = CausalDiffusionInferencePipeline(config, device=device)

# if args.checkpoint_path:
#     state_dict = torch.load(args.checkpoint_path, map_location="cpu")
#     pipeline.generator.load_state_dict(state_dict['generator' if not args.use_ema else 'generator_ema'])
# add the classifier
# pipeline.generator.adding_cls_branch(
#             atten_dim=1536, num_class=1, time_embed_dim=0)

pipeline = pipeline.to(dtype=torch.bfloat16)
if low_memory:
    DynamicSwapInstaller.install_model(pipeline.text_encoder, device=gpu)
pipeline.generator.to(device=gpu)
pipeline.vae.to(device=gpu)
pipeline.text_encoder.to(device=gpu)

# Create dataset
if args.i2v:
    assert not dist.is_initialized(), "I2V does not support distributed inference yet"
    transform = transforms.Compose([
        transforms.Resize((480, 832)),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ])
    dataset = TextImagePairDataset(args.data_path, transform=transform)
else:
    dataset = TextDataset(prompt_path=args.data_path, extended_prompt_path=args.extended_prompt_path)
num_prompts = len(dataset)
print(f"Number of prompts: {num_prompts}")

if dist.is_initialized():
    sampler = DistributedSampler(dataset, shuffle=False, drop_last=True)
else:
    sampler = SequentialSampler(dataset)
dataloader = DataLoader(dataset, batch_size=1, sampler=sampler, num_workers=0, drop_last=False)

# Create output directory (only on main process to avoid race conditions)
if local_rank == 0:
    os.makedirs(args.output_folder, exist_ok=True)

if dist.is_initialized():
    dist.barrier()


def encode(self, videos: torch.Tensor) -> torch.Tensor:
    device, dtype = videos[0].device, videos[0].dtype
    scale = [self.mean.to(device=device, dtype=dtype),
             1.0 / self.std.to(device=device, dtype=dtype)]
    output = [
        self.model.encode(u.unsqueeze(0), scale).float().squeeze(0)
        for u in videos
    ]

    output = torch.stack(output, dim=0)
    return output

def remove_fsdp_prefix(state_dict):
    new_state_dict = {}
    for k, v in state_dict.items():
        # 去掉 '_fsdp_wrapped_module.' 这个子串
        new_key = k.replace('_fsdp_wrapped_module.', '')
        new_state_dict[new_key] = v
    return new_state_dict


# epochs = [ "001000", "002000", "003000", "004000", "005000", "006000", "007000", "008000", "009000", "010000", "011000"]
epochs = [args.epoch]
# epochs = ["001500", "002000", "002500", "003000", "003500", "004000"]
first_block_steps=args.first_block_steps
denoising_steps=args.denoising_steps

for k in epochs:
    set_seed(args.seed)
    state_dict = torch.load(os.path.join(args.checkpoint_path, f"checkpoint_model_{k}/model.pt"), map_location="cpu")
    pipeline.generator.load_state_dict(state_dict['generator'] if (not args.use_ema) or ("generator_ema" not in state_dict) else remove_fsdp_prefix(state_dict['generator_ema']))
    
    for i, batch_data in tqdm(enumerate(dataloader), disable=(local_rank != 0)):
        idx = batch_data['idx'].item()

        # For DataLoader batch_size=1, the batch_data is already a single item, but in a batch container
        # Unpack the batch data for convenience
        if isinstance(batch_data, dict):
            batch = batch_data
        elif isinstance(batch_data, list):
            batch = batch_data[0]  # First (and only) item in the batch

        all_video = []
        num_generated_frames = 0  # Number of generated (latent) frames

        if args.i2v:
            # For image-to-video, batch contains image and caption
            prompt = batch['prompts'][0]  # Get caption from batch
            prompts = [prompt] * args.num_samples

            # Process the image
            image = batch['image'].squeeze(0).unsqueeze(0).unsqueeze(2).to(device=device, dtype=torch.bfloat16)

            # Encode the input image as the first latent
            initial_latent = pipeline.vae.encode_to_latent(image).to(device=device, dtype=torch.bfloat16)
            initial_latent = initial_latent.repeat(args.num_samples, 1, 1, 1, 1)

            sampled_noise = torch.randn(
                [args.num_samples, args.num_output_frames - 1, 16, 60, 104], device=device, dtype=torch.bfloat16
            )
        else:
            # For text-to-video, batch is just the text prompt
            prompt = batch['prompts'][0]
            extended_prompt = batch['extended_prompts'][0] if 'extended_prompts' in batch else None
            if extended_prompt is not None:
                prompts = [extended_prompt] * args.num_samples
            else:
                prompts = [prompt] * args.num_samples
            initial_latent = None

            sampled_noise = torch.randn(
                [args.num_samples, args.num_output_frames, 16, 60, 104], device=device, dtype=torch.bfloat16
            )
            
            
            # use lambda to control the degree of change of noise
            # lam = 0.7
            # for j in range(1, sampled_noise[0].shape[0]):
            #     sampled_noise[0][j] = sampled_noise[0][j-1]*(1-lam) + lam * sampled_noise[0][j]

        # Generate 81 frames
        # denoising_step_list = pipeline.denoising_step_list
        # for j in range(0, 3):
            # all_video = []
            # pipeline.denoising_step_list = torch.cat([denoising_step_list[0:j], denoising_step_list[j+1:j+2], denoising_step_list[j+3:]])
            # pipeline.denoising_step_list = denoising_step_list
            
        # one step generation
        # pipeline.denoising_step_list = pipeline.denoising_step_list[:1]
        
        # video, latents, frame_logits = pipeline.inference_with_classify_first_block_different_steps(
        #     noise=sampled_noise,
        #     text_prompts=prompts,
        #     return_latents=True,
        #     initial_latent=initial_latent,
        #     low_memory=low_memory
        # )
           
        video, latents = pipeline.inference_with_first_block_different_steps(
            noise=sampled_noise,
            text_prompts=prompts,
            return_latents=True,
            initial_latent=initial_latent,
            low_memory=low_memory,
            first_block_steps=first_block_steps,
            denoising_steps=denoising_steps,
        )
        current_video = rearrange(video, 'b t c h w -> b t h w c').cpu()
        all_video.append(current_video)
        num_generated_frames += latents.shape[1]

        # Final output video
        video = 255.0 * torch.cat(all_video, dim=1)

        # Clear VAE cache
        pipeline.vae.model.clear_cache()

        # Save the video if the current prompt is not a dummy prompt
        if idx < num_prompts:
            model = "regular" if not args.use_ema else "ema"
            for seed_idx in range(args.num_samples):
                # All processes save their videos
                if args.save_with_index:
                    output_path = os.path.join(args.output_folder, f"epoch_{k}", f'{idx}-{seed_idx}_{model}.mp4')
                    os.makedirs(os.path.join(args.output_folder, f"epoch_{k}"), exist_ok=True)
                else:
                    output_path = os.path.join(args.output_folder, f"epoch_{k}_{model}", f'{prompt[:100]}-{seed_idx}.mp4')
                    os.makedirs(os.path.join(args.output_folder, f"epoch_{k}_{model}"), exist_ok=True)
                    # output_path = os.path.join(args.output_folder, f'reduce_{j}_{j+1}_stepss.mp4')
                write_video(output_path, video[seed_idx], fps=16)
                # write the logits into a txt file with the same name as the video
                # the logits is a list of list, each list is the logits of each frame, output as json file
                # with open(output_path.replace('.mp4', '.json'), 'w') as f:
                #     json.dump(frame_logits, f)

