import argparse
import datetime
import inspect
import os
from omegaconf import OmegaConf

import torch
import torchvision.transforms as transforms
import torch.nn.functional as F

import diffusers
from diffusers import AutoencoderKL, DDIMScheduler

from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer

from animatediff.models.unet import UNet3DConditionModel
from animatediff.models.sparse_controlnet import SparseControlNetModel
from animatediff.pipelines.pipeline_animation import AnimationPipeline
from animatediff.utils.util import save_videos_grid
from animatediff.utils.util import load_weights
from diffusers.utils.import_utils import is_xformers_available
from animatediff.models.additional import LatentRectify
from animatediff.models.vae_decoder import DetailAutoencoderKL

from einops import rearrange, repeat

import csv, pdb, glob, math
from pathlib import Path
from PIL import Image
import numpy as np


@torch.no_grad()
def main(args):
    *_, func_args = inspect.getargvalues(inspect.currentframe())
    func_args = dict(func_args)
    
    time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
    savedir = f"samples/{args.config.split('/')[-3]}-{args.config.split('/')[-2]}-{Path(args.config).stem}-{time_str}"
    os.makedirs(savedir)

    config  = OmegaConf.load(args.config)
    samples = []

    with torch.no_grad():
        # create validation pipeline
        tokenizer    = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer")
        text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder").cuda()
        text_encoder.eval()

    sample_idx = 0
    for model_idx, model_config in enumerate(config):
        with torch.no_grad():
            model_config.W = model_config.get("W", args.W)
            model_config.H = model_config.get("H", args.H)
            model_config.L = model_config.get("L", args.L)

            inference_config = OmegaConf.load(model_config.get("inference_config", args.inference_config))

            if model_config.get("detail_vae_path", "") == "":
                vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae").cuda()
            else:
                vae = DetailAutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae", low_cpu_mem_usage=False) # low_cpu_mem_usage=False since it has new parameters
                print(f"load pretrained vae decoder from {model_config.detail_vae_path}")
                vae_checkpoint = torch.load(model_config.detail_vae_path, map_location="cpu")
                state_dict = vae_checkpoint["state_dict"] if "state_dict" in vae_checkpoint else vae_checkpoint
                extra_param_state_dict = {}
                for key in list(state_dict.keys()):
                    if 'ref_attn_blocks' in key or 'channel_matching_convs' in key:
                        extra_param_state_dict[key.replace('module.', '')] = state_dict.pop(key)
                m, u = vae.decoder.load_state_dict(extra_param_state_dict, strict=False)
                assert len(u) == 0
                vae.cuda()
            vae.eval()
            vae.enable_slicing()
            # vae.enable_tiling()

            unet = UNet3DConditionModel.from_pretrained_2d(args.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs)).cuda()
            unet.eval()

            if model_config.get("unet_path", "") != "":
                unet_state_dict = {}
                unet_path = model_config.get("unet_path", "")
                print(f"load pretrained unet from {unet_path}")
                _unet_state_dict = torch.load(unet_path, map_location="cpu")
                _unet_state_dict = _unet_state_dict["state_dict"] if "state_dict" in _unet_state_dict else _unet_state_dict
                for name, param in _unet_state_dict.items():
                    unet_state_dict[name.replace('module.','')] = _unet_state_dict[name]
                _, unexpected = unet.load_state_dict(unet_state_dict, strict=False)
                assert len(unexpected) == 0
                del unet_state_dict, _unet_state_dict


            # load controlnet model
            controlnet = controlnet_images = None
            if model_config.get("controlnet_path", "") != "":
                assert model_config.get("controlnet_images", "") != ""
                assert model_config.get("controlnet_config", "") != ""
                
                unet.config.num_attention_heads = 8
                unet.config.projection_class_embeddings_input_dim = None

                controlnet_config = OmegaConf.load(model_config.controlnet_config)
                controlnet = SparseControlNetModel.from_unet(unet, controlnet_additional_kwargs=controlnet_config.get("controlnet_additional_kwargs", {}))

                print(f"loading controlnet checkpoint from {model_config.controlnet_path} ...")
                controlnet_state_dict = torch.load(model_config.controlnet_path, map_location="cpu")
                controlnet_state_dict = controlnet_state_dict["controlnet"] if "controlnet" in controlnet_state_dict else controlnet_state_dict
                controlnet_state_dict = controlnet_state_dict["state_dict"] if "state_dict" in controlnet_state_dict else controlnet_state_dict
                controlnet_state_dict_update = {}
                controlnet_state_dict_update.update({name.replace('module.',''): param for name, param in controlnet_state_dict.items()})
                controlnet.load_state_dict(controlnet_state_dict_update)
                controlnet.cuda()
                controlnet.eval()

                image_paths = model_config.controlnet_images
                if isinstance(image_paths, str): image_paths = [image_paths]
                if "controlnet_masks" in model_config.keys():
                    mask_paths = model_config.controlnet_masks
                    if isinstance(mask_paths, str): mask_paths = [mask_paths]
                    assert len(image_paths) == len(mask_paths)

                print(f"controlnet image paths:")
                for path in image_paths: print(path)
                assert len(image_paths) <= model_config.L

                image_transforms = transforms.Compose([
                    transforms.Resize((model_config.H, model_config.W)),
                    transforms.ToTensor(),
                ])
                mask_transforms = transforms.Compose([
                    transforms.Resize((model_config.H, model_config.W)),
                    transforms.ToTensor(),
                ])

                if model_config.get("normalize_condition_images", False):
                    def image_norm(image):
                        image = image.mean(dim=0, keepdim=True).repeat(3,1,1)
                        image -= image.min()
                        image /= image.max()
                        return image
                else: image_norm = lambda x: x
                    
                controlnet_images = [image_norm(image_transforms(Image.open(path).convert("RGB"))) for path in image_paths]
                if "controlnet_masks" in model_config.keys():
                    controlnet_masks = [image_transforms(Image.open(path).convert("L")) for path in mask_paths]

                os.makedirs(os.path.join(savedir, "control_images"), exist_ok=True)
                for i, image in enumerate(controlnet_images):
                    Image.fromarray((255. * (image.numpy().transpose(1,2,0))).astype(np.uint8)).save(f"{savedir}/control_images/{i}.png")
                    if "controlnet_masks" in model_config.keys():
                        Image.fromarray((255. * (controlnet_masks[i].numpy().transpose(1,2,0).squeeze())).astype(np.uint8)).save(f"{savedir}/control_images/{i}_mask.png")

                controlnet_images = torch.stack(controlnet_images).unsqueeze(0).cuda()
                controlnet_images = rearrange(controlnet_images, "b f c h w -> b c f h w")

                if model_config.get("detail_vae_path", "") != "":
                    ref_pyramid_features = vae.encoder.pyramid_feature_forward(controlnet_images[:,:,0,:,:] * 2. - 1.)

                if controlnet.use_simplified_condition_embedding:
                    num_controlnet_images = controlnet_images.shape[2]
                    controlnet_images = rearrange(controlnet_images, "b c f h w -> (b f) c h w")
                    controlnet_images = vae.encode(controlnet_images * 2. - 1.).latent_dist.sample() * 0.18215
                    controlnet_images = rearrange(controlnet_images, "(b f) c h w -> b c f h w", f=num_controlnet_images)
                
                if "controlnet_masks" in model_config.keys():
                    controlnet_masks = torch.stack(controlnet_masks).cuda()
                    controlnet_masks = F.interpolate(controlnet_masks, (controlnet_images.shape[-2],controlnet_images.shape[-1])).unsqueeze(0)
                    if "mask_oracle" in model_config.keys() and model_config.mask_oracle:
                        controlnet_masks_root_path = os.path.dirname(mask_paths[0])
                        num_of_mask = len(os.listdir(controlnet_masks_root_path))
                        if num_of_mask > 1:
                            num_of_sample = model_config.L
                            sample_ind = [int(i/(num_of_sample-1)*(num_of_mask-1))+1 for i in range(num_of_sample)]
                            controlnet_masks = []
                            for ind in sample_ind:
                                tmp_mask = image_transforms(Image.open(os.path.join(controlnet_masks_root_path, "%05d.jpg"%max(1,min(ind,num_of_mask)))).convert("L")).cuda()
                                controlnet_masks.append(tmp_mask)
                            controlnet_masks = torch.stack(controlnet_masks)
                            controlnet_masks = F.interpolate(controlnet_masks,(controlnet_images.shape[-2],controlnet_images.shape[-1])).unsqueeze(0)
                    controlnet_masks = rearrange(controlnet_masks, "b f c h w -> b c f h w")
                if "controlnet_rel_video" in model_config.keys():
                    from decord import VideoReader
                    rel_video_reader = VideoReader(model_config.controlnet_rel_video[0])
                    rel_video_length = len(rel_video_reader)
                    rel_frame_index = np.linspace(0, rel_video_length -1, model_config.L, dtype=int)
                    rel_pixel_values = torch.from_numpy(np.dot(rel_video_reader.get_batch(rel_frame_index).asnumpy(), [0.299, 0.587, 0.114])).unsqueeze(-1).permute(0, 3, 1, 2).contiguous().to(controlnet_images.dtype)
                    rel_pixel_values = rel_pixel_values / 255.
                    controlnet_masks = F.interpolate(rel_pixel_values,(controlnet_images.shape[-2],controlnet_images.shape[-1])).unsqueeze(0)
                    controlnet_masks = rearrange(controlnet_masks, "b f c h w -> b c f h w")
            
            else:
                if model_config.get("cfg_noise_gt_first_frame", False) or model_config.get("cfg_unnoise_first_frame", False):
                    image_transforms = transforms.Compose([
                        transforms.Resize((model_config.H, model_config.W)),
                        transforms.ToTensor(),
                    ])
                    image_paths = model_config.controlnet_images
                    if isinstance(image_paths, str): image_paths = [image_paths]
                    controlnet_images = [image_transforms(Image.open(path).convert("RGB")) for path in image_paths]
                    os.makedirs(os.path.join(savedir, "control_images"), exist_ok=True)
                    for i, image in enumerate(controlnet_images):
                        Image.fromarray((255. * (image.numpy().transpose(1,2,0))).astype(np.uint8)).save(f"{savedir}/control_images/{i}.png")
                    controlnet_images = torch.stack(controlnet_images).unsqueeze(0).cuda()
                    controlnet_images = rearrange(controlnet_images, "b f c h w -> b c f h w")

                    if model_config.get("detail_vae_path", "") != "":
                        ref_pyramid_features = vae.encoder.pyramid_feature_forward(controlnet_images[:,:,0,:,:] * 2. - 1.)
                    
                    num_controlnet_images = controlnet_images.shape[2]
                    controlnet_images = rearrange(controlnet_images, "b c f h w -> (b f) c h w")
                    controlnet_images = vae.encode(controlnet_images * 2. - 1.).latent_dist.sample() * 0.18215
                    controlnet_images = rearrange(controlnet_images, "(b f) c h w -> b c f h w", f=num_controlnet_images)

            if model_config.get("latent_rectify_scale", 0.0) > 0.0:
                assert model_config.get("latent_rectify_module_file", "") != ""
                latent_rectify_module = LatentRectify(inner_dim=model_config.get("latent_rectify_dim", 256))
                latent_rectify_module_dict = torch.load(model_config.latent_rectify_module_file, map_location="cpu")
                latent_rectify_module_dict = latent_rectify_module_dict["LatentRectify"] if "LatentRectify" in latent_rectify_module_dict else latent_rectify_module_dict
                latent_rectify_module_dict = latent_rectify_module_dict["state_dict"] if "state_dict" in latent_rectify_module_dict else latent_rectify_module_dict
                latent_rectify_module_dict_update = {}
                latent_rectify_module_dict_update.update({name.replace('module.',''): param for name, param in latent_rectify_module_dict.items()})
                latent_rectify_module.load_state_dict(latent_rectify_module_dict_update)
                latent_rectify_module.cuda()
            else:
                latent_rectify_module = None

            ### delete due to different versions of diffusers
            # # set xformers
            # if is_xformers_available() and (not args.without_xformers):
            #     unet.enable_xformers_memory_efficient_attention()
            #     if controlnet is not None: controlnet.enable_xformers_memory_efficient_attention()

            pipeline = AnimationPipeline(
                vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet,
                scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)),
                controlnet=controlnet, latent_rectify_module=latent_rectify_module
            ).to("cuda")

            pipeline = load_weights(
                pipeline,
                # motion module
                motion_module_path         = model_config.get("motion_module", ""),
                motion_module_lora_configs = model_config.get("motion_module_lora_configs", []),
                # domain adapter
                adapter_lora_path          = model_config.get("adapter_lora_path", ""),
                adapter_lora_scale         = model_config.get("adapter_lora_scale", 1.0),
                # image layers
                dreambooth_model_path      = model_config.get("dreambooth_path", ""),
                lora_model_path            = model_config.get("lora_model_path", ""),
                lora_alpha                 = model_config.get("lora_alpha", 1.0),
            ).to("cuda")

            prompts      = model_config.prompt
            n_prompts    = list(model_config.n_prompt) * len(prompts) if len(model_config.n_prompt) == 1 else model_config.n_prompt
            
            random_seeds = model_config.get("seed", [-1])
            random_seeds = [random_seeds] if isinstance(random_seeds, int) else list(random_seeds)
            random_seeds = random_seeds * len(prompts) if len(random_seeds) == 1 else random_seeds

            ### load groundtruth video and add noise
            if model_config.get("groundtruth_video_path", None) is not None:
                from decord import VideoReader
                video_transforms = transforms.Compose([
                    transforms.Resize((model_config.H, model_config.W)),
                ])
                video_reader = VideoReader(model_config.get("groundtruth_video_path", None)[0])
                video_length = len(video_reader)
                batch_index = np.linspace(0, video_length - 1, model_config.L, dtype=int)
                gt_video = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous()
                gt_video = video_transforms(gt_video).cuda()
                gt_video_latent = vae.encode(gt_video / 255. * 2. - 1.).latent_dist.sample() * 0.18215
                gt_video_latent = rearrange(gt_video_latent.unsqueeze(0), "b f c h w -> b c f h w")
                noise = torch.randn_like(gt_video_latent)
                pipeline.scheduler.set_timesteps(model_config.steps)
                latents = pipeline.scheduler.add_noise(gt_video_latent, noise, pipeline.scheduler.timesteps[0])
            else:
                latents = None
        
        config[model_idx].random_seed = []
        for prompt_idx, (prompt, n_prompt, random_seed) in enumerate(zip(prompts, n_prompts, random_seeds)):
            
            # manually set random seed for reproduction
            if random_seed != -1: torch.manual_seed(random_seed)
            else: torch.seed()
            config[model_idx].random_seed.append(torch.initial_seed())
            
            print(f"current seed: {torch.initial_seed()}")
            print(f"sampling {prompt} ...")
            sample = pipeline(
                prompt,
                negative_prompt     = n_prompt,
                num_inference_steps = model_config.steps,
                guidance_scale      = model_config.guidance_scale,
                width               = model_config.W,
                height              = model_config.H,
                video_length        = model_config.L,
                latents=latents,

                controlnet_images = controlnet_images,
                controlnet_masks = controlnet_masks if "controlnet_masks" in model_config.keys() else None,
                controlnet_image_index = model_config.get("controlnet_image_indexs", [0]),
                latent_rectify_scale=model_config.get("latent_rectify_scale", 0.0),
                latent_rectify_clamp=model_config.get("latent_rectify_clamp", 1.0),

                cfg_noise_gt_first_frame=model_config.get("cfg_noise_gt_first_frame", False),
                cfg_unnoise_first_frame=model_config.get("cfg_unnoise_first_frame", False),
                ref_features = ref_pyramid_features if model_config.get("detail_vae_path", "") != "" else None,
            ).videos
            samples.append(sample)

            prompt = "-".join((prompt.replace("/", "").split(" ")[:10]))
            save_videos_grid(sample, f"{savedir}/sample/{sample_idx}-{prompt}.gif")
            print(f"save to {savedir}/sample/{prompt}.gif")
            
            sample_idx += 1

    samples = torch.concat(samples)
    save_videos_grid(samples, f"{savedir}/sample.gif", n_rows=4)

    OmegaConf.save(config, f"{savedir}/config.yaml")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--pretrained-model-path", type=str, default="models/StableDiffusion/stable-diffusion-v1-5",)
    parser.add_argument("--inference-config",      type=str, default="configs/inference/inference-v1.yaml")    
    parser.add_argument("--config",                type=str, required=True)
    
    parser.add_argument("--L", type=int, default=16)
    parser.add_argument("--W", type=int, default=512)
    parser.add_argument("--H", type=int, default=512)

    parser.add_argument("--without-xformers", action="store_true")

    args = parser.parse_args()
    main(args)
