import torch
import os
import glob
import argparse
import numpy as np
from PIL import Image
import torch.nn.functional as F
from tqdm import tqdm


from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from diffusers import AutoencoderKL, DDPMScheduler
from diffusers.utils import check_min_version
from diffusers.utils.import_utils import is_xformers_available
from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor, CLIPVisionModel

from pipelines.pipeline_spider import StableDiffusionControlNetPipeline
from utils.misc import load_dreambooth_lora
from utils.wavelet_color_fix import wavelet_color_fix, adain_color_fix

from models.mapper import Mapper, Remover
from models.controlnet_inj import ControlNetModel
from models.unet_2d_condition_inj import UNet2DConditionModel
from llava.llm_agent import LLavaAgent

from gfpgan import GFPGANer
import cv2
import torch
import numpy as np
from PIL import Image


def load_gfpgan(model_path=None, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=None):
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    gfpgan = GFPGANer(
        model_path=model_path,
        upscale=upscale,
        arch=arch,
        channel_multiplier=channel_multiplier,
        bg_upsampler=bg_upsampler,
        device=device
    )

    return gfpgan


def process_image_with_gfpgan(gfpgan, img, only_center_face=False, paste_back=True):
    if isinstance(img, str):
        img = cv2.imread(img)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    elif isinstance(img, Image.Image):
        img = np.array(img)

    _, _, restored_img = gfpgan.enhance(
        img,
        has_aligned=False,
        only_center_face=only_center_face,
        paste_back=paste_back
    )

    restored_img = Image.fromarray(restored_img)

    return restored_img


def get_inj_embedding(image, mapper, Remover, clip_image_processor, image_encoder_without_proj, reshape=False):
    if isinstance(image, torch.Tensor):
        batch_size = image.shape[0]
        processed_images = []

        for i in range(batch_size):
            img = image[i].detach().cpu()
            img = img.permute(1, 2, 0).numpy()
            img = (img * 255).astype(np.uint8)
            pil_img = Image.fromarray(img)
            processed_images.append(pil_img)

        gt_clip_image = clip_image_processor(
            images=processed_images, return_tensors='pt').pixel_values
    else:
        gt_clip_image = clip_image_processor(
            images=[image], return_tensors='pt').pixel_values

    gt_clip_image = gt_clip_image.to(device=image_encoder_without_proj.device,
                                     dtype=next(image_encoder_without_proj.parameters()).dtype)

    gt_clip_image = F.interpolate(gt_clip_image, (224, 224), mode='bilinear')
    image_features = image_encoder_without_proj(
        gt_clip_image, output_hidden_states=True)
    image_embeddings = [image_features[0]]
    image_embeddings = [emb.detach() for emb in image_embeddings]

    mapper = mapper.to(device=image_encoder_without_proj.device,
                       dtype=next(image_encoder_without_proj.parameters()).dtype)
    Remover = Remover.to(device=image_encoder_without_proj.device,
                                 dtype=next(image_encoder_without_proj.parameters()).dtype)

    inj_embedding = mapper(image_embeddings)
    inj_embedding = Remover(inj_embedding)

    if reshape:
        B, seq_len, _ = inj_embedding.shape
        inj_embedding = inj_embedding.reshape(B, -1, 512)
    return inj_embedding


def load_seesr_pipeline(args, accelerator, enable_xformers_memory_efficient_attention, vae, text_encoder, tokenizer, feature_extractor, unet, controlnet, mapper, Remover, scheduler):

    # Freeze vae and text_encoder
    vae.requires_grad_(False)
    text_encoder.requires_grad_(False)
    unet.requires_grad_(False)
    controlnet.requires_grad_(False)
    mapper.requires_grad_(False)
    Remover.requires_grad_(False)

    if enable_xformers_memory_efficient_attention:
        if is_xformers_available():
            unet.enable_xformers_memory_efficient_attention()
            controlnet.enable_xformers_memory_efficient_attention()
        else:
            raise ValueError(
                "xformers is not available. Make sure it is installed correctly")

    # Get the validation pipeline
    validation_pipeline = StableDiffusionControlNetPipeline(
        vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, feature_extractor=feature_extractor,
        unet=unet, controlnet=controlnet, scheduler=scheduler, safety_checker=None, requires_safety_checker=False,
    )

    validation_pipeline._init_tiled_vae(
        encoder_tile_size=args.vae_encoder_tiled_size, decoder_tile_size=args.vae_decoder_tiled_size)

    weight_dtype = torch.float32
    if accelerator.mixed_precision == "fp16":
        weight_dtype = torch.float16
    elif accelerator.mixed_precision == "bf16":
        weight_dtype = torch.bfloat16

    text_encoder.to(accelerator.device, dtype=weight_dtype)
    vae.to(accelerator.device, dtype=weight_dtype)
    unet.to(accelerator.device, dtype=weight_dtype)
    controlnet.to(accelerator.device, dtype=weight_dtype)
    mapper.to(accelerator.device, dtype=weight_dtype)
    Remover.to(accelerator.device, dtype=weight_dtype)

    return validation_pipeline


def load_llava_model(args, device='cuda'):
    llava_agent = LLavaAgent(
        args.llava_model_path, device=device, load_8bit=args.load_8bit, load_4bit=args.load_4bit)
    return llava_agent


def get_text_prompt(args, image, llava, device='cuda'):
    if isinstance(image, str):
        image = [Image.open(image).convert("RGB")]
    elif isinstance(image, list) and all(isinstance(img, str) for img in image):
        image = [Image.open(img).convert("RGB") for img in image]
    elif not isinstance(image, list):
        image = [image]
    else:
        raise ValueError(f"Invalid image type: {type(image)}")

    prompt = llava.gen_image_caption(image, qs=args.prompt)
    return prompt


def main(args,
         vae=None,
         text_encoder=None,
         tokenizer=None,
         unet=None,
         controlnet=None,
         mapper=None,
         Remover=None,
         image_encoder_without_proj=None,
         clip_image_processor=None,
         scheduler=None
         ):
    txt_path = os.path.join(args.output_dir, 'txt')
    os.makedirs(txt_path, exist_ok=True)

    accelerator = Accelerator(
        mixed_precision=args.mixed_precision,
    )
    image_encoder_without_proj = image_encoder_without_proj.to(
        accelerator.device)
    if args.seed is not None:
        set_seed(args.seed)

    if accelerator.is_main_process:
        os.makedirs(args.output_dir, exist_ok=True)

    if accelerator.is_main_process:
        accelerator.init_trackers("SeeSR")

    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    pipeline = load_seesr_pipeline(
        args, accelerator, False, vae, text_encoder, tokenizer, clip_image_processor, unet, controlnet, mapper, Remover, scheduler)
    llava = load_llava_model(args, accelerator.device)
    if args.use_gfpgan:
        gfpgan = load_gfpgan(args.gfpgan_model_path, device=accelerator.device)
    else:
        gfpgan = None

    if accelerator.is_main_process:
        generator = torch.Generator(device=accelerator.device)
        if args.seed is not None:
            generator.manual_seed(args.seed)

        if os.path.isdir(args.image_path):
            image_names = sorted(glob.glob(f'{args.image_path}/*.*'))
        else:
            image_names = [args.image_path]

        for image_idx, image_name in enumerate(tqdm(image_names[args.resume_image_num:], desc="Processing images", total=len(image_names[args.resume_image_num:]))):
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

            print(
                f'================== process {image_idx} imgs... ===================')
            print(image_name)
            validation_image = Image.open(image_name).convert("RGB")
            if args.use_gfpgan and gfpgan is not None:
                restored_image = process_image_with_gfpgan(
                    gfpgan, validation_image)
            else:
                restored_image = validation_image
            text_prompt = get_text_prompt(args, restored_image, llava)[0]
            negative_prompt = args.negative_prompt

            if args.save_prompts:
                txt_save_path = f"{txt_path}/{os.path.basename(image_name).split('.')[0]}.txt"
                file = open(txt_save_path, "w")
                file.write(text_prompt)
                file.close()
            print(f'{text_prompt}')

            ori_width, ori_height = validation_image.size
            resize_flag = False
            rscale = args.upscale

            if ori_width < args.process_size//rscale or ori_height < args.process_size//rscale:
                scale = (args.process_size//rscale)/min(ori_width, ori_height)
                tmp_image = validation_image.resize(
                    (int(scale*ori_width), int(scale*ori_height)))

                validation_image = tmp_image
                resize_flag = True

            validation_image = validation_image.resize(
                (validation_image.size[0]*rscale, validation_image.size[1]*rscale))
            validation_image = validation_image.resize(
                (validation_image.size[0]//8*8, validation_image.size[1]//8*8))
            width, height = validation_image.size
            resize_flag = True

            print(f'input size: {height}x{width}')

            inj_embedding = get_inj_embedding(
                validation_image, mapper, Remover, clip_image_processor, image_encoder_without_proj)

            for sample_idx in range(args.sample_times):
                print(
                    f'   Generating sample {sample_idx+1}/{args.sample_times}...')
                with torch.autocast("cuda"):
                    image = pipeline(
                        text_prompt, validation_image, num_inference_steps=args.num_inference_steps, generator=generator, height=height, width=width,
                        guidance_scale=args.guidance_scale, negative_prompt=negative_prompt, conditioning_scale=args.conditioning_scale,
                        start_point=args.start_point, ram_encoder_hidden_states=inj_embedding,
                        latent_tiled_size=args.latent_tiled_size, latent_tiled_overlap=args.latent_tiled_overlap,
                        args=args,
                    ).images[0]

                if args.align_method == 'nofix':
                    corrected_image = image
                elif args.align_method == 'wavelet':
                    corrected_image = wavelet_color_fix(
                        image, validation_image)
                elif args.align_method == 'adain':
                    corrected_image = adain_color_fix(image, validation_image)
                else:
                    corrected_image = image  # 默认不修复

                if resize_flag:
                    final_image = corrected_image.resize(
                        (ori_width*rscale, ori_height*rscale))
                else:
                    final_image = corrected_image

                name, ext = os.path.splitext(os.path.basename(image_name))
                save_path = f'{args.output_dir}/{name}.png'
                final_image.save(save_path)
                print(f'   Saved sample {sample_idx+1} to {save_path}')


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--spider_model_path", type=str,
                        required=True)
    parser.add_argument("--image_encoder_without_proj_path",
                        type=str, required=True)
    parser.add_argument("--pretrained_sd2_1_path", type=str,
                        required=True)
    parser.add_argument("--mapper_model_path", type=str,
                        required=True)
    parser.add_argument("--remover_model_path", type=str,
                        required=True)
    parser.add_argument("--llava_model_path", type=str, required=True)
    parser.add_argument("--negative_prompt", type=str,
                        default="dotted, noise, blur, lowres, smooth")
    parser.add_argument("--image_path", type=str,
                        required=True)
    parser.add_argument("--sample_times", type=int, default=1)
    parser.add_argument("--num_inference_steps", type=int, default=50)
    parser.add_argument("--guidance_scale", type=float, default=5.5)
    parser.add_argument("--conditioning_scale", type=float, default=1.0)
    parser.add_argument("--start_point", type=str,
                        choices=['lr', 'noise'], default='lr')
    parser.add_argument("--align_method", type=str,
                        choices=['wavelet', 'adain', 'nofix'], default='adain')
    parser.add_argument("--latent_tiled_size", type=int, default=96)
    parser.add_argument("--latent_tiled_overlap", type=int, default=32)
    parser.add_argument("--vae_encoder_tiled_size", type=int, default=1024)
    parser.add_argument("--vae_decoder_tiled_size", type=int, default=224)
    parser.add_argument("--process_size", type=int, default=512)
    parser.add_argument("--upscale", type=int, default=1)
    parser.add_argument("--output_dir", type=str,
                        required=True)
    parser.add_argument("--save_prompts", default=True, type=bool)
    parser.add_argument("--seed", type=int, default=None)
    parser.add_argument("--mixed_precision", type=str, default="fp16")
    parser.add_argument("--prompt", type=str,
                        default="Provide a detailed yet concise description of this person's face. Include their face shape, eyes, nose, mouth, eyebrows, skin `texture and tone, expression, and any notable features like moles, freckles, or wrinkles.")
    parser.add_argument("--load_8bit", action='store_true')
    parser.add_argument("--load_4bit", action='store_true')
    parser.add_argument("--inj_num_token", type=int, default=30)
    # scface 110
    parser.add_argument("--resume_image_num", type=int, default=0)
    parser.add_argument("--use_gfpgan", type=bool, default=True)
    parser.add_argument("--gfpgan_model_path", type=str, required=True)

    args = parser.parse_args()

    scheduler = DDPMScheduler.from_pretrained(
        args.pretrained_model_path, subfolder="scheduler")
    text_encoder = CLIPTextModel.from_pretrained(
        args.pretrained_model_path, subfolder="text_encoder")
    tokenizer = CLIPTokenizer.from_pretrained(
        args.pretrained_model_path, subfolder="tokenizer")
    vae = AutoencoderKL.from_pretrained(
        args.pretrained_model_path, subfolder="vae")
    unet = UNet2DConditionModel.from_pretrained(
        args.seesr_model_path, subfolder="unet")
    controlnet = ControlNetModel.from_pretrained(
        args.seesr_model_path, subfolder="controlnet")
    mapper = Mapper(input_dim=1024, output_dim=1024,
                    num_words=args.inj_num_token)
    mapper_ckpt = torch.load(
        args.mapper_model_path, map_location="cpu")
    mapper.load_state_dict(mapper_ckpt, strict=False)

    remover = Remover(
        input_dim=1024, output_dim=1024, num_words=args.inj_num_token)
    remover_ckpt = torch.load(
        args.remover_model_path, map_location="cpu")
    remover.load_state_dict(remover_ckpt, strict=False)
    image_encoder_without_proj = CLIPVisionModel.from_pretrained(
        args.image_encoder_without_proj_path)
    clip_image_processor = CLIPImageProcessor()

    main(args,
         vae=vae,
         text_encoder=text_encoder,
         tokenizer=tokenizer,
         unet=unet,
         controlnet=controlnet,
         mapper=mapper,
         Remover=remover,
         image_encoder_without_proj=image_encoder_without_proj,
         clip_image_processor=clip_image_processor,
         scheduler=scheduler)
