# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse
import logging
import os
import sys
import warnings
from datetime import datetime
warnings.filterwarnings('ignore')
import random
import torch
import torch.distributed as dist
from wan.timeutils import ClockContext
from PIL import Image
from io import BytesIO
from base64 import b64encode, b64decode
import torchvision
import sys
import imageio.v3 as iio
from wan.bench_speed.textimage2video_causal_server_prefill_all import WanTI2VCausalServer
from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS
from wan.distributed.util import init_distributed_group
from wan.utils.utils import save_video, str2bool
from generate_causal import _validate_args, _parse_args, _init_logging


def extract_first_k_frames(video_path, k):
    # Read the first k frames from the video
    frames = []
    reader = iio.imiter(video_path, plugin="FFMPEG")
    for idx, frame in enumerate(reader):
        if idx >= k:
            break
        frames.append(torch.tensor(frame, dtype=torch.uint8))

    jpeg_message_list = []
    for frame in frames:
        jpeg_tensor = torchvision.io.encode_jpeg(frame.permute(2, 0, 1))
        jpeg_message_list.append(b64encode(jpeg_tensor.numpy().tobytes()).decode("utf-8"))

    return jpeg_message_list

def test_encode_error(wan_ti2v, imgs):
    frame_data = []
    for img in imgs:
        img = torch.frombuffer(b64decode(img), dtype=torch.uint8)
        frame_data.append(torchvision.io.decode_jpeg(img, mode=torchvision.io.ImageReadMode.RGB))
    frame_data = torch.stack(frame_data, dim=1) # C, T, H, W
    frame_data = frame_data.float().div_(255.0).sub_(0.5).div_(0.5).to(wan_ti2v.device)
    
    latent_all = wan_ti2v.vae.encode([frame_data])[0].unsqueeze(0)
    start_frame = 65
    range_frame = 16
    range_latent_frame = range_frame // 4
    start_latent_frame = (start_frame - 1) // 4 + 1
    for pre_cat in range(1, 10, 4):
        print(f"pre_cat: {pre_cat}")
        z_gt = latent_all[:,:,start_latent_frame:start_latent_frame+range_latent_frame,:]
        frame_16 = frame_data[:, start_frame:start_frame+range_frame,]
        frame_17 = frame_data[:, start_frame-pre_cat:start_frame+range_frame,]
        z16 = wan_ti2v.vae.encode([frame_16])[0].unsqueeze(0)
        z17 = wan_ti2v.vae.encode([frame_17])[0].unsqueeze(0)[:, :, (pre_cat - 1)//4 + 1:]
        print(f"{(z_gt - z16).norm()=}")
        print(f"{(z_gt - z17).norm()=}")
            
            
def generate(args, new_args):
    assert not args.use_prompt_extend, "Prompt extension is not supported in WanTI2VCausal generation."
    rank = 0
    world_size = 1
    local_rank = 0
    device = local_rank
    _init_logging(rank)

    if args.offload_model is None:
        args.offload_model = False if world_size > 1 else True
        logging.info(
            f"offload_model is not specified, set to {args.offload_model}.")
    if world_size > 1:
        torch.cuda.set_device(local_rank)
        dist.init_process_group(
            backend="nccl",
            init_method="env://",
            rank=rank,
            world_size=world_size)
    else:
        assert not (
            args.t5_fsdp or args.dit_fsdp
        ), f"t5_fsdp and dit_fsdp are not supported in non-distributed environments."
        assert not (
            args.ulysses_size > 1
        ), f"sequence parallel are not supported in non-distributed environments."

    if args.ulysses_size > 1:
        assert args.ulysses_size == world_size, f"The number of ulysses_size should be equal to the world size."
        init_distributed_group()

    cfg = WAN_CONFIGS[args.task]
    if args.ulysses_size > 1:
        assert cfg.num_heads % args.ulysses_size == 0, f"`{cfg.num_heads=}` cannot be divided evenly by `{args.ulysses_size=}`."

    logging.info(f"Generation job args: {args}")
    logging.info(f"Generation model config: {cfg}")

    if dist.is_initialized():
        base_seed = [args.base_seed] if rank == 0 else [None]
        dist.broadcast_object_list(base_seed, src=0)
        args.base_seed = base_seed[0]

    logging.info(f"Input prompt: {args.prompt}")


    logging.info("Creating WanTI2VCausal pipeline.")
    wan_ti2v = WanTI2VCausalServer(
        config=cfg,
        checkpoint_dir=args.ckpt_dir,
        pt_dir=args.pt_dir,
        device_id=device,
        rank=rank,
        t5_fsdp=args.t5_fsdp,
        dit_fsdp=args.dit_fsdp,
        use_sp=(args.ulysses_size > 1),
        t5_cpu=args.t5_cpu,
        convert_model_dtype=args.convert_model_dtype,
    )

    logging.info(f"Generating video ...")
    txt_files = os.listdir(os.path.join(new_args.dataset_path, "metas"))
    random.seed(42)
    random.shuffle(txt_files)
    logging.info(f"frame num: {args.frame_num}")
    for txt_file in txt_files[:10]:
        output_prefix = txt_file.replace(".txt", "")
        mp4_path = os.path.join(new_args.dataset_path, 'videos', txt_file.replace(".txt", ".mp4"))
        with open(os.path.join(new_args.dataset_path, "metas", txt_file), "r") as f:
            prompt = f.read().strip()
            print(prompt)
        n_new_frame = 8
        imgs = extract_first_k_frames(mp4_path, args.frame_num)
        # end to end performance list: 16 new frames per step, total 113 frames
        # stateless native: 32.88s
        # incremental decode: 24.30s
        # prefill only last 4 latent frame: 22.s
        # incremental encode: 17.5s
        # test_encode_error(wan_ti2v, imgs)
        with ClockContext(f"{'end to end gen':-^30}"):
            for num_conditional_frames in range(1, args.frame_num, n_new_frame):
                print(f"Generating {num_conditional_frames} to {num_conditional_frames + n_new_frame}")
                with ClockContext(f"{'generate chunk':-^30}"):
                    video = wan_ti2v.generate(
                        prompt,
                        img=imgs[max(0, num_conditional_frames - n_new_frame): num_conditional_frames],
                        size=SIZE_CONFIGS[args.size],
                        max_area=MAX_AREA_CONFIGS[args.size],
                        frame_num=num_conditional_frames + n_new_frame,
                        num_conditional_frames=num_conditional_frames,
                        shift=args.sample_shift,
                        sample_solver=args.sample_solver,
                        sampling_steps=args.sample_steps,
                        guide_scale=args.sample_guide_scale,
                        seed=args.base_seed,
                        offload_model=args.offload_model
                    )
        wan_ti2v.clean_all_state()

        if rank == 0:
            formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S")
            formatted_prompt = output_prefix.replace(" ", "_").replace("/","_")[:70]
            suffix = '.mp4'
            save_file_path = f"{args.task}_{args.size.replace('*','x') if sys.platform=='win32' else args.size}_{args.ulysses_size}_{formatted_prompt}_{formatted_time}" + suffix
            save_file_path = os.path.join(new_args.output_dir, save_file_path)

            logging.info(f"Saving generated video to {save_file_path}")
            save_video(
                tensor=video[None],
                save_file=save_file_path,
                fps=cfg.sample_fps,
                nrow=1,
                normalize=True,
                value_range=(-1, 1))
    del video

    torch.cuda.synchronize()
    if dist.is_initialized():
        dist.barrier()
        dist.destroy_process_group()

    logging.info("Finished.")


if __name__ == "__main__":
    args, left_args = _parse_args()
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset_path",type=str, default=None)
    parser.add_argument("--output_dir",type=str, required=True)
    new_args = parser.parse_args(left_args)
    os.makedirs(new_args.output_dir, exist_ok=True)
    generate(args, new_args)
