import torch
from diffusers.models import UNet2DConditionModel
from models.pipeline_animatediff_video2video import AnimateDiffVideoToVideoPipeline, DDIMScheduler, MotionAdapter
from models.unet_motion_model import UNetMotionModel
from diffusers.utils import export_to_gif

from diffusers.utils.export_utils import export_to_video
from io import BytesIO
from PIL import Image
from utils.video_utils import read_from_video
from utils.image_utils import load_size
from config_video import RunConfig, Range
from pathlib import Path
import glob
import os
adapter = MotionAdapter.from_pretrained(
     "guoyww/animatediff-motion-adapter-v1-5", torch_dtype=torch.float16
    ).to("cuda")
pipe = AnimateDiffVideoToVideoPipeline.from_pretrained(
     "SG161222/Realistic_Vision_V5.1_noVAE", motion_adapter=adapter, torch_dtype=torch.float16
    ).to("cuda")

unet = UNet2DConditionModel.from_pretrained("SG161222/Realistic_Vision_V5.1_noVAE", subfolder="unet", torch_dtype=torch.float16).to("cuda")
unet = UNetMotionModel.from_unet2d(unet, adapter).to("cuda")
pipe.estimator = unet

pipe.scheduler = DDIMScheduler.from_pretrained("SG161222/Realistic_Vision_V5.1_noVAE", subfolder="scheduler")


video_files = ["./dataset/videos/a_woman_is_walking.mp4"]
image_files = ["./dataset/images/tree.png"]
for video_path in video_files:
    for image_path in image_files:
        video_content = read_from_video(
            video_path, as_pil=True
        )
        style_image = Image.fromarray(load_size(image_path))
        video_style = [style_image] * 100
        content_prompt = video_path.split('/')[-1].split('.')[0].replace("_"," ").replace("-"," ")
        style_prompt = image_path.split('/')[-1].split('.')[0].replace("_"," ").replace("-"," ")
        print(f"{content_prompt=}")
        print(f"{style_prompt=}")
        save_dir = f"./video_results/{content_prompt.replace(' ','_')}_{style_prompt.replace(' ','_')}"
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        save_file = f"{save_dir}/test_0.mp4"
        if os.path.exists(save_file):
            print(f"{save_file} alread exist!!!!")
            continue
        content_zts, content_xts = pipe.invert(video=video_content[0:6], prompt=content_prompt, strength=0.68, num_inference_steps=100)
        style_zts, style_xts = pipe.invert(video=video_style[0:6], prompt=style_prompt, strength=0.68, num_inference_steps=100)

        init_latents = torch.stack([content_xts[0], style_xts[0], content_xts[0]])
        init_zs = torch.stack([content_zts, style_zts, content_zts])
        prompts = [content_prompt, style_prompt, content_prompt]
        output = pipe(
            latents=init_latents, prompt=prompts, strength=0.68,zs=init_zs, config= config,num_inference_steps=100
            )

        for i in range(3):
            export_to_video(output[0][i], f"{save_dir}/test_{i}.mp4")
            print(f"video saved to {save_dir}/test_{i}.mp4")
