import os
import re
import json
import argparse

import torch
from PIL import Image
from diffusers import ControlNetModel, StableDiffusionControlNetImg2ImgPipeline, DDIMScheduler, AutoencoderKL
from guidance.sc_adapter.selective_adapter_norm_detailed_clip import SC_Adapter
from guidance.ip_adapter import IPAdapter
from guidance.sc_adapter.pipeline_adapter_plus import StableDiffusionPipeline
from diffusers.utils import load_image


def resize_for_condition_image(input_image: Image, resolution: int):
    input_image = input_image.convert("RGB")
    W, H = input_image.size
    k = float(resolution) / min(H, W)
    H *= k
    W *= k
    H = int(round(H / 64.0)) * 64
    W = int(round(W / 64.0)) * 64
    img = input_image.resize((W, H), resample=Image.LANCZOS)
    return img


if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument('--ip_adapter_ref_img', type=str, default='data/multi_view_test/doll', help="reference image path for ip adapter")
    opt = parser.parse_args()

    controlnet = ControlNetModel.from_pretrained('lllyasviel/control_v11f1e_sd15_tile', 
                                                torch_dtype=torch.float16)
    pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained("runwayml/stable-diffusion-v1-5",
                                            # custom_pipeline="stable_diffusion_controlnet_img2img",
                                            controlnet=controlnet,
                                            torch_dtype=torch.float16,
                                            safety_checker=None).to('cuda')
    pipe.enable_xformers_memory_efficient_attention()

    ## ip adapter
    base_model_path = "runwayml/stable-diffusion-v1-5"
    vae_model_path = "stabilityai/sd-vae-ft-mse"
    image_encoder_path = "IP-Adapter/models/image_encoder/"
    ip_ckpt = "IP-Adapter/models/ip-adapter_sd15.bin"
    noise_scheduler = DDIMScheduler(
    num_train_timesteps=1000,
        beta_start=0.00085,
        beta_end=0.012,
        beta_schedule="scaled_linear",
        clip_sample=False,
        set_alpha_to_one=False,
        steps_offset=1,
    )
    vae = AutoencoderKL.from_pretrained(vae_model_path).to(dtype=torch.float16)
    sd_pipe = StableDiffusionPipeline.from_pretrained(
        base_model_path,
        torch_dtype=torch.float16,
        scheduler=noise_scheduler,
        vae=vae,
        feature_extractor=None,
        safety_checker=None
    )
    ip_model = IPAdapter(sd_pipe, image_encoder_path, ip_ckpt, "cuda")

    if os.path.isdir(opt.ip_adapter_ref_img):
        with open(os.path.join(opt.ip_adapter_ref_img, 'sc_adapter.json'), 'r') as f:
            sc_view_dict = json.load(f)

    for key in sc_view_dict:
        source_path = os.path.join(opt.ip_adapter_ref_img, re.sub(r'(.*)(\.png)', r'\1_ori.png', key))
        source_image = load_image(source_path)
        if sc_view_dict[key]['SR'] == False:
            source_image.save(os.path.join(opt.ip_adapter_ref_img,key))
            continue
        if "img2img_prompt" in sc_view_dict[key]:
            source_image = ip_model.generate(
                pil_image=source_image,
                prompt=sc_view_dict[key]["img2img_prompt"],
                negative_prompt="blur, lowres, bad anatomy, bad hands, cropped, worst quality, " + sc_view_dict[key]['negative_prompt_SR'],
                num_samples=1,
                seed=None,
                scale=sc_view_dict[key]["ip_attn_scale"],
                guidance_scale=7.5,
                num_inference_steps=50
            )[0]
        condition_image = resize_for_condition_image(source_image, 1024)
        SR_text = sc_view_dict[key]["SR_prompt"] if "SR_prompt" in sc_view_dict[key] else sc_view_dict[key]['concept']
        print(SR_text)
        image = pipe(prompt="best quality, " + SR_text, 
                    negative_prompt="blur, lowres, bad anatomy, bad hands, cropped, worst quality, " + sc_view_dict[key]['negative_prompt_SR'], 
                    image=condition_image, 
                    # controlnet_conditioning_image=condition_image, 
                    control_image = condition_image, 
                    width=condition_image.size[0],
                    height=condition_image.size[1],
                    strength=sc_view_dict[key]['strength'] if 'strength' in sc_view_dict[key] else 0.8,
                    generator=torch.manual_seed(0),
                    num_inference_steps=32,
                    ).images[0]
        

        image.save(os.path.join(opt.ip_adapter_ref_img,key))
