import os
import torch
import imageio
import json
from transformers import CLIPTextModel, CLIPTokenizer
from types import SimpleNamespace
from diffusers import StableDiffusionPipeline, DDIMScheduler, DDIMInverseScheduler

import sys
sys.path.insert(0, ".")
from pipelines.merge_best_pipeline import Merge_Best_Pipeline
from pipelines.lavie_models import UNet3DConditionModel
from utils.lora import add_lora_weight

import torch
from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler, DPMSolverMultistepScheduler
from diffusers.utils import export_to_gif, export_to_video
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import numpy as np
from PIL import Image
device = "cuda"
dtype = torch.float16

def ensure_directory_exists(path):
    if not os.path.exists(path):
        os.makedirs(path)
        print(f"Directory created: {path}")
    else:
        print(f"Directory already exists: {path}")

def read_txt_file(file_path):
    try:
        with open(file_path, 'r', encoding='utf-8') as file:
            lines = file.readlines()
            lines = [line.strip() for line in lines] 
        return lines
    except FileNotFoundError:
        print("File Not Exist")
        return []

def find_model(model_name):
    """
    Finds a pre-trained model, downloading it if necessary. Alternatively, loads a model from a local path.
    """
    checkpoint = torch.load(model_name, map_location=lambda storage, loc: storage)
    if "ema" in checkpoint:  # supports checkpoints from train.py
        print('Ema existing!')
        checkpoint = checkpoint["ema"]
    return checkpoint

def from_pretrained_2d(cls, pretrained_model_path, subfolder=None):
    if subfolder is not None:
        pretrained_model_path = os.path.join(pretrained_model_path, subfolder)

    config_file = os.path.join(pretrained_model_path, 'config.json')
    if not os.path.isfile(config_file):
        raise RuntimeError(f"{config_file} does not exist")
    with open(config_file, "r") as f:
        config = json.load(f)
    config["_class_name"] = cls.__name__
    config["down_block_types"] = [
        "CrossAttnDownBlock3D",
        "CrossAttnDownBlock3D",
        "CrossAttnDownBlock3D",
        "DownBlock3D"
    ]
    config["up_block_types"] = [
        "UpBlock3D",
        "CrossAttnUpBlock3D",
        "CrossAttnUpBlock3D",
        "CrossAttnUpBlock3D"
    ]

    config["use_first_frame"] = False

    from diffusers.utils import WEIGHTS_NAME # diffusion_pytorch_model.bin
    

    model = cls.from_config(config)
    model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
    if not os.path.isfile(model_file):
        raise RuntimeError(f"{model_file} does not exist")
    state_dict = torch.load(model_file, map_location="cpu")
    for k, v in model.state_dict().items():
        # print(k)
        if '_temp' in k:
            state_dict.update({k: v})
        if 'attn_fcross' in k: # conpy parms of attn1 to attn_fcross
            k = k.replace('attn_fcross', 'attn1')
            state_dict.update({k: state_dict[k]})
        if 'norm_fcross' in k:
            k = k.replace('norm_fcross', 'norm1')
            state_dict.update({k: state_dict[k]})

    model.load_state_dict(state_dict)

    return model

if __name__ == "__main__":

    pos_prompt = " ,best quality, extremely detailed, HD, ultra-realistic, 8K, HQ, masterpiece, trending on artstation, art, smooth"
    neg_prompt = "text, watermark, copyright, blurry, nsfw, noise, quick motion, bad quality, flicker, dirty, ugly, fast motion, quick cuts, fast editing, cuts"


    prompt_list = read_txt_file("all_category.txt")
    save_path = "our_lavie/"
    num_frames = 16
    ensure_directory_exists(save_path)
    lora_path = "models/toonyou_beta6.safetensors"
    base = "frankjoshua/toonyou_beta6"
    
    for prompt in prompt_list:

        step = 8  # Options: [1,2,4,8]
        repo = "ByteDance/AnimateDiff-Lightning"
        ckpt = f"animatediff_lightning_{8}step_diffusers.safetensors"

        adapter = MotionAdapter().to(device, dtype)
        adapter.load_state_dict(load_file(hf_hub_download(repo ,ckpt), device=device))
        pipe = AnimateDiffPipeline.from_pretrained(base, motion_adapter=adapter, torch_dtype=dtype).to(device)
        pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
        pipe.enable_model_cpu_offload()
        pipe.enable_vae_slicing()

        output = pipe(prompt=prompt+pos_prompt,negative_prompt = neg_prompt, guidance_scale=1.0, num_inference_steps=step, num_frames=num_frames, height = 512, width = 512).frames[0]
        video = [frame for frame in output]

        stable_steps = [96, 93]

        pos_prompt = " ,best quality, extremely detailed, HD, ultra-realistic, 8K, HQ, masterpiece, trending on artstation, art, smooth"
        neg_prompt = "text, watermark, copyright, blurry, nsfw, noise, quick motion, bad quality, flicker, dirty, ugly, fast motion, quick cuts, fast editing, cuts"

        model_id = "runwayml/stable-diffusion-v1-5"
        unet3d_id = "models/stable-diffusion-v1-4/unet"
        text3d_id = "models/stable-diffusion-v1-4/text_encoder"
        tokenizer3d_id = "models/stable-diffusion-v1-4/tokenizer"
        video_unet_id = "models/lavie_base.pt"

        #### Load T2I model
        zero_pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
        zero_pipe = add_lora_weight(zero_pipe, lora_path)
        

        #### Load T2V model
        unet3d = UNet3DConditionModel.from_pretrained_2d(unet3d_id).to(dtype=torch.float16)
        state_dict = find_model(video_unet_id)
        unet3d.load_state_dict(state_dict)
        tokenizer3d = CLIPTokenizer.from_pretrained(tokenizer3d_id)
        text_encoder_3d = CLIPTextModel.from_pretrained(text3d_id, torch_dtype=torch.float16)
        
        pipe = Merge_Best_Pipeline(vae=zero_pipe.vae, text_encoder=zero_pipe.text_encoder, tokenizer=zero_pipe.tokenizer,\
            unet=zero_pipe.unet, unet3d=unet3d,tokenizer3d=tokenizer3d, scheduler=zero_pipe.scheduler, \
                text_encoder3d=text_encoder_3d, safety_checker=zero_pipe.safety_checker,\
                feature_extractor=zero_pipe.feature_extractor, requires_safety_checker=False)

        #### T2I scheduler
        pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
        pipe.inverse_scheduler = DDIMInverseScheduler.from_config(pipe.scheduler.config)
        
        #### T2V scheduler
        pipe.video_scheduler = DDIMScheduler.from_pretrained("models/stable-diffusion-v1-4/", 
                                        subfolder="scheduler",
                                        beta_start=0.0001, 
                                        beta_end=0.02, 
                                        beta_schedule="linear"
                                        )
        pipe.video_inverse_scheduler = DDIMInverseScheduler.from_config(pipe.video_scheduler.config)
        pipe.enable_vae_slicing()
        pipe.enable_model_cpu_offload()
        pipe.enable_xformers_memory_efficient_attention()
        pipe.to("cuda")

            
        save_path = save_path+prompt+"-0.mp4"
        output = pipe(prompt=prompt+pos_prompt, stable_steps=stable_steps, stable_num=1, negative_prompt=neg_prompt, \
            num_inference_steps=100, alpha=1.5, video_length=num_frames, generator=torch.Generator(42), strength = 0.1, video = video) 

        result = [(r * 255).astype("uint8") for r in output.images]
        imageio.mimsave(save_path, result, fps=8)
