import torch
import imageio
import os
from diffusers.models import UNet3DConditionModel
from transformers import CLIPTextModel, CLIPTokenizer
from types import SimpleNamespace
from diffusers import StableDiffusionPipeline, DDIMScheduler, DDIMInverseScheduler, TextToVideoSDPipeline

import sys
sys.path.insert(0, ".")
from pipelines.merge_best_pipeline import Merge_Best_Pipeline
from utils.lora import add_lora_weight

import torch
from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler, DPMSolverMultistepScheduler
from diffusers.utils import export_to_gif, export_to_video
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import numpy as np
from PIL import Image
device = "cuda"
dtype = torch.float16

def ensure_directory_exists(path):
    if not os.path.exists(path):
        os.makedirs(path)
        print(f"Directory created: {path}")
    else:
        print(f"Directory already exists: {path}")

def read_txt_file(file_path):
    try:
        with open(file_path, 'r', encoding='utf-8') as file:
            lines = file.readlines()
            lines = [line.strip() for line in lines]  
        return lines
    except FileNotFoundError:
        print("File Not Exist")
        return []

if __name__ == "__main__":

    pos_prompt = " ,best quality, extremely detailed, HD, ultra-realistic, 8K, HQ, masterpiece, trending on artstation, art, smooth"
    neg_prompt = "text, watermark, copyright, blurry, nsfw, noise, quick motion, bad quality, flicker, dirty, ugly, fast motion, quick cuts, fast editing, cuts"

    prompt_list = read_txt_file("all_category.txt")
    save_path = "our_modelscope/"
    num_frames = 16
    ensure_directory_exists(save_path)

    lora_path = "models/toonyou_beta6.safetensors"
    base = "frankjoshua/toonyou_beta6"

    for prompt in prompt_list:

        step = 8  # Options: [1,2,4,8]
        repo = "ByteDance/AnimateDiff-Lightning"
        ckpt = f"animatediff_lightning_{8}step_diffusers.safetensors"

        adapter = MotionAdapter().to(device, dtype)
        adapter.load_state_dict(load_file(hf_hub_download(repo ,ckpt), device=device))
        pipe = AnimateDiffPipeline.from_pretrained(base, motion_adapter=adapter, torch_dtype=dtype).to(device)
        pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
        pipe.enable_model_cpu_offload()
        pipe.enable_vae_slicing()

        output = pipe(prompt=prompt+pos_prompt,negative_prompt = neg_prompt, guidance_scale=1.0, num_inference_steps=step, num_frames=num_frames, height = 512, width = 512).frames[0]
        video = [frame for frame in output]


        stable_steps = [96, 93]

        
        pos_prompt = " ,best quality, extremely detailed, HD, ultra-realistic, 8K, HQ, masterpiece, trending on artstation, art, smooth"
        neg_prompt = "text, watermark, copyright, blurry, nsfw, noise, quick motion, bad quality, flicker, dirty, ugly, fast motion, quick cuts, fast editing, cuts"
        
        model_id = "runwayml/stable-diffusion-v1-5"
        video_model_id = "cerspense/zeroscope_v2_576w"
        
        zero_pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
        zero_pipe = add_lora_weight(zero_pipe, lora_path)
        
        scope_pipe = StableDiffusionPipeline.from_pretrained(video_model_id, torch_dtype=torch.float16)
        unet3d = scope_pipe.unet
        unet3d.enable_forward_chunking(chunk_size=8, dim=1)
        tokenizer3d = scope_pipe.tokenizer
        text_encoder_3d = scope_pipe.text_encoder
        
        pipe = Merge_Best_Pipeline(vae=zero_pipe.vae, text_encoder=zero_pipe.text_encoder, tokenizer=zero_pipe.tokenizer,\
            unet=zero_pipe.unet, unet3d=unet3d,tokenizer3d=tokenizer3d, scheduler=zero_pipe.scheduler, \
                text_encoder3d=text_encoder_3d, safety_checker=zero_pipe.safety_checker,\
                feature_extractor=zero_pipe.feature_extractor, requires_safety_checker=False)

        #### Load T2I model
        pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
        pipe.inverse_scheduler = DDIMInverseScheduler.from_config(pipe.scheduler.config)
        #### Load T2V model
        pipe.video_scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
        pipe.video_inverse_scheduler = DDIMInverseScheduler.from_config(pipe.video_scheduler.config)

        pipe.enable_vae_slicing()
        pipe.enable_model_cpu_offload()
        pipe.to("cuda")
        
            
        
        save_path = save_path+prompt+"-0.mp4"
        output = pipe(prompt=prompt+pos_prompt, stable_steps=stable_steps, stable_num=1, negative_prompt=neg_prompt, \
            num_inference_steps=100, alpha=1.5, video_length=num_frames, generator=torch.Generator(42), strength = 0.1, video = video)            
        result = [(r * 255).astype("uint8") for r in output.images]
        imageio.mimsave(save_path, result, fps=8)
