import torch, os
from tqdm import tqdm
from diffusers.models import MotionAdapter
from diffusers import AnimateDiffSDXLPipeline, DDIMScheduler
from utils.pipeline_animatediff import AnimateDiffPipeline_GN
from utils.pipeline_animatediff_sdxl import AnimateDiffSDXLPipeline_GN
from utils.pipeline_modelscope_t2v import ModelScopeT2V_GN
from utils.pipeline_latte import LattePipeline_GN
from diffusers.utils import export_to_gif, export_to_video


def parse_args():
    parser = argparse.ArgumentParser(description="Smoothed Path Optimization")
    # ----------Model Checkpoint Loading Arguments----------
    parser.add_argument(
        "--method",
        type=str,
        default="animatediffv3",
        choices=["animatediffv3", "animatediffv3_sdxl", "modelscopet2v", "latte"],
        help="The name of the bese model to use.",
    )
    parser.add_argument(
        "--recall_timesteps",
        type=int,
        default=1,
    )
    parser.add_argument(
        "--noise_type",
        type=str,
        default="gaussian",
    )
    parser.add_argument(
        "--ensemble",
        type=int,
        default=20,
    )
    parser.add_argument(
        "--momentum",
        type=float,
        default=0.15,
    )
    parser.add_argument(
        "--traj_momentum",
        type=float,
        default=0.05,
    )
    parser.add_argument(
        "--ensemble_rate",
        type=float,
        default=0.05,
    )
    parser.add_argument(
        "--fast_ensemble",
        action='store_true',
        default=False,
    )
    parser.add_argument(
        "--tag",
        type=str,
        default="",
    )
    args = parser.parse_args()

    return args

if __name__ == "__main__":
    args = parse_args()
    
    if args.method == "animatediffv3":
        adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-3").to(dtype=torch.float16,device=torch.device("cuda"))
        model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
        pipe = AnimateDiffPipeline_GN.from_pretrained(model_id, motion_adapter=adapter).to(dtype=torch.float16,device=torch.device("cuda"))
        scheduler = DDIMScheduler.from_pretrained(
            model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", steps_offset=1
        )
        pipe.scheduler = scheduler
    elif args.method == "animatediffv3_sdxl":
        adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-sdxl-beta", torch_dtype=torch.float16)
        model_id = "stabilityai/stable-diffusion-xl-base-1.0"
        scheduler = DDIMScheduler.from_pretrained(
            model_id,
            subfolder="scheduler",
            clip_sample=False,
            timestep_spacing="linspace",
            beta_schedule="linear",
            steps_offset=1,
        )
        pipe = AnimateDiffSDXLPipeline_GN.from_pretrained(
            model_id,
            motion_adapter=adapter,
            scheduler=scheduler,
            torch_dtype=torch.float16,
            variant="fp16",
        ).to("cuda")
    elif args.method == "modelscopet2v":
            pipe = ModelScopeT2V_GN.from_pretrained("ali-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16").to(dtype=torch.float16,device=torch.device("cuda"))
            scheduler = DDIMScheduler.from_pretrained("ali-vilab/text-to-video-ms-1.7b", subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", steps_offset=1)
            pipe.scheduler = scheduler
    elif args.method == "latte":
            pipe = LattePipeline_GN.from_pretrained("maxin-cn/Latte-1", torch_dtype=torch.float16).to(device=torch.device("cuda"))
            vae = AutoencoderKLTemporalDecoder.from_pretrained("maxin-cn/Latte-1", subfolder="vae_temporal_decoder", torch_dtype=torch.float16).to(device=torch.device("cuda"))
            pipe.vae = vae
    else:
        raise ValueError(f"Method {args.method} not supported")
    
    pipe.enable_vae_slicing()
    pipe.enable_model_cpu_offload()


    pipe.scheduler = scheduler
    pipe.recall_timesteps = args.recall_timesteps
    pipe.ensemble = args.ensemble
    pipe.momentum = args.momentum
    if args.method == "latte":
        pipe.traj_momentum = 0.95
    else:
        args.traj_momentum
    pipe.ensemble_rate = args.ensemble_rate
    pipe.fast_ensemble = args.fast_ensemble
    pipe.noise_type = args.noise_type

    prompts = [
        "Spiderman is surfing",
        "Yellow and black tropical fish dart through the sea",
        "An epic tornado attacking above aglowing city at night",
        "Slow pan upward of blazing oak fire in an indoor fireplace",
        "a cat wearing sunglasses and working as a lifeguard at pool",
        "A dog in astronaut suit and sunglasses floating in space"
    ]

    for i, prompt in enumerate(prompts):
        local_path = f"./result_rgs_{i}.gif"
        if args.method == "latte":
            output = pipe(
                prompt,
                video_length=video_length,
                output_type="pil",
            )
        elif args.method == "animatediff_sdxl":
            output =  pipe(
                    prompt=prompt,
                    negative_prompt="",
                    num_inference_steps=20,
                    guidance_scale=8,
                    width=1024,
                    height=1024,
                    num_frames=16,
                )
        else:
            output = pipe(
                prompt=prompt,
                num_inference_steps=50,
                guidance_scale=7,
                num_frames=16,
                )
        frames = output.frames[0]
        export_to_gif(frames, local_path)

