import torch
import os
import sys
import diffusers
import time
import shutil
import argparse
import logging

from diffusers import CogVideoXPipeline
from diffusers.utils import export_to_video
from qdiff.utils import apply_func_to_submodules, seed_everything, setup_logging

def main(args):
    seed_everything(args.seed)
    torch.set_grad_enabled(False)
    device="cuda" if torch.cuda.is_available() else "cpu"

    if args.log is not None:
        if not os.path.exists(args.log):
            os.makedirs(args.log)
    log_file = os.path.join(args.log, 'run.log')
    setup_logging(log_file)
    logger = logging.getLogger(__name__)

    if args.sage_attn:
        import torch.nn.functional as F
        from sageattention import sageattn
        F.scaled_dot_product_attention = sageattn
        logger.info('using sage_attn INT8')


    ckpt_path = args.ckpt if args.ckpt is not None else "./models--THUDM--CogVideoX-5b/snapshots/8d6ea3f817438460b25595a120f109b88d5fdfad"
    pipe = CogVideoXPipeline.from_pretrained(
        ckpt_path,
        torch_dtype=torch.bfloat16
    ).to(device)

    # INFO: if memory intense
    # pipe.enable_model_cpu_offload()
    # pipe.vae.enable_tiling()

    # read the promts
    prompt_path = args.prompt if args.prompt is not None else "./prompts.txt"
    prompts = []
    with open(prompt_path, 'r') as f:
        lines = f.readlines()
        for line in lines:
            prompts.append(line.strip())

    for i, prompt in enumerate(prompts):
        video = pipe(
            prompt=prompt,
            num_videos_per_prompt=1,
            num_inference_steps=args.num_sampling_steps, # 50
            num_frames=49,
            guidance_scale=args.cfg_scale,
            generator=torch.Generator(device="cuda").manual_seed(args.seed),
        ).frames[0]
        print(f"Export video to output_{i}.mp4")
        save_path = os.path.join(args.log, "generated_videos")
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        export_to_video(video, os.path.join(save_path, f"output_{i}.mp4"), fps=8)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--log", type=str)
    parser.add_argument("--cfg-scale", type=float, default=4.0)
    parser.add_argument("--num-sampling-steps", type=int, default=50)
    parser.add_argument("--prompt", type=str, default=None)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--ckpt", type=str, default=None)
    parser.add_argument("--sage_attn", action="store_true")
    args = parser.parse_args()
    main(args)
