from diffusers import DiffusionPipeline, DDIMScheduler
import torch
#from diffusers.loaders import LoraLoaderMixin
from PIL import Image
from typing import Optional, Union
import math
import argparse
from tqdm import tqdm
import random

import sys 
sys.path.append("..") 
from lora_files_patchall.loaders.lora import LoraLoaderMixin
from lora_files_patchall.text_files.models.clip.modeling_clip import CLIPTextModel
from lora_files_patchall.models.unet_2d_condition import UNet2DConditionModel


class Merger(torch.nn.Module):
    def __init__(
        self,
        interval_num: int,
        init_merger_img: Optional[float] = 1.0,
        init_merger_ref: Optional[float] = 1.0,
    ):
        super().__init__()

        self.merger_img = torch.nn.Parameter(torch.ones((interval_num,), dtype=torch.float32) * init_merger_img)
        self.merger_ref = torch.nn.Parameter(torch.ones((interval_num,), dtype=torch.float32) * init_merger_ref)

def parse_args(input_args=None):
    parser = argparse.ArgumentParser(description="Inference.")
    parser.add_argument("--label_path", type=str, default=None)
    parser.add_argument("--img_path", type=str, default=None)
    parser.add_argument("--ref_path", type=str, default=None)
    parser.add_argument("--merge_path", type=str, default=None)
    parser.add_argument("--save_path", type=str, default=None)
    parser.add_argument("--sample_num", type=int, default=1)
    parser.add_argument("--sample_index", type=int, default=0)
    parser.add_argument("--interval_num", type=int, default=20)
    
    if input_args is not None:
        args = parser.parse_args(input_args)
    else:
        args = parser.parse_args()

    return args

if __name__ == "__main__":
    args = parse_args()

    num_inference_steps = 50
    guidance_scale = 7.5
    interval_num = args.interval_num

    lines = open(args.label_path, 'r').readlines()
    line = lines[args.sample_index]
    splits = line.split('\t')

    img_path = args.img_path + splits[0]
    ref_path = args.ref_path + splits[1]
    merge_path = args.merge_path + splits[0] + '-' + splits[1]
    pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base",torch_dtype=torch.float32)
    pipe.scheduler = DDIMScheduler(
        beta_start=0.00085,
        beta_end=0.012,
        beta_schedule="scaled_linear",
        clip_sample=False,
        set_alpha_to_one=False,
    )
    pipe.unet = UNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="unet", revision=None, variant=None)
    pipe.text_encoder = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="text_encoder", revision=None)

    lora_state_dict_img, network_alphas_img = LoraLoaderMixin.lora_state_dict(img_path)
    LoraLoaderMixin.load_lora_into_unet(lora_state_dict_img, network_alphas=network_alphas_img, unet=pipe.unet, adapter_name="lora_img")
    LoraLoaderMixin.load_lora_into_text_encoder(lora_state_dict_img, network_alphas=network_alphas_img, text_encoder=pipe.text_encoder, adapter_name="lora_img")
    lora_state_dict_ref, network_alphas_ref = LoraLoaderMixin.lora_state_dict(ref_path)
    LoraLoaderMixin.load_lora_into_unet(lora_state_dict_ref, network_alphas=network_alphas_ref, unet=pipe.unet, adapter_name="lora_ref")
    LoraLoaderMixin.load_lora_into_text_encoder(lora_state_dict_ref, network_alphas=network_alphas_ref, text_encoder=pipe.text_encoder, adapter_name="lora_ref")

    pipe.tokenizer.add_tokens(['<asset0>'])
    pipe.text_encoder.resize_token_embeddings(len(pipe.tokenizer))
    pipe.text_encoder.get_input_embeddings().load_state_dict(torch.load(ref_path+'/embeddings.pt'))

    pipe.to("cuda")

    merger = Merger(interval_num)
    merger.load_state_dict(torch.load(merge_path+'/merger.pt'))
    merger.to("cuda")

    # add
    pipe.unet.load_state_dict(torch.load(merge_path+'/weight_unet.pt'), strict=False)
    pipe.text_encoder.load_state_dict(torch.load(merge_path+'/weight_text.pt'), strict=False)
    # end

    # Encode input prompt
    prompt = splits[2][:-1]
    text_inputs = pipe.tokenizer(prompt, padding="max_length", max_length=pipe.tokenizer.model_max_length, truncation=True, return_tensors="pt")
    text_input_ids = text_inputs.input_ids
    uncond_tokens = ['']
    uncond_input = pipe.tokenizer(uncond_tokens, padding="max_length", max_length=pipe.tokenizer.model_max_length, truncation=True, return_tensors="pt")

    for ii in tqdm(range(args.sample_num)):
        generator_seed = random.randint(0, 1000000)
        generator = torch.Generator("cuda").manual_seed(generator_seed)

        # Prepare timesteps
        pipe.scheduler.set_timesteps(num_inference_steps, device=pipe.device)
        timesteps = pipe.scheduler.timesteps

        # Prepare latent variables
        latents = torch.randn([1,4,64,64], generator=generator, device=pipe.device).to(pipe.device)

        # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
        extra_step_kwargs = {}

        # Add image embeds for IP-Adapter
        added_cond_kwargs = None

        # Optionally get Guidance Scale Embedding
        timestep_cond = None

        # Denoising loop
        with torch.no_grad():
            for i, t in enumerate(timesteps):
                # expand the latents if we are doing classifier free guidance
                latent_model_input = torch.cat([latents] * 2)
                latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t)

                merger_num = int(t / math.ceil(1.0 * pipe.scheduler.config.num_train_timesteps / interval_num))
                pipe.unet.set_adapters(["lora_img", "lora_ref"], [merger.merger_img[merger_num], merger.merger_ref[merger_num]])
                LoraLoaderMixin.set_adapters_for_text_encoder(LoraLoaderMixin, adapter_names=["lora_img", "lora_ref"], text_encoder=pipe.text_encoder, text_encoder_weights=[merger.merger_img[merger_num], merger.merger_ref[merger_num]])

                prompt_embeds = pipe.text_encoder(text_input_ids.to(pipe.device), attention_mask=None)[0]
                negative_prompt_embeds = pipe.text_encoder(uncond_input.input_ids.to(pipe.device), attention_mask=None)[0]
                prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])

                # predict the noise residual
                noise_pred = pipe.unet(
                    latent_model_input,
                    t,
                    encoder_hidden_states=prompt_embeds,
                    timestep_cond=timestep_cond,
                    cross_attention_kwargs=None,
                    added_cond_kwargs=added_cond_kwargs,
                    return_dict=False,
                )[0]

                # perform guidance
                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

                # compute the previous noisy sample x_t -> x_t-1
                latents = pipe.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]

            img = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False, generator=generator)[0]

        img = (img / 2 + 0.5).clamp(0, 1)
        img = img.detach().cpu().permute(0, 2, 3, 1).numpy()
        img = (img * 255).round().squeeze(0)
        img = Image.fromarray(img.astype('uint8')).convert('RGB')

        img.save(args.save_path+'{}-{}-{}.png'.format(splits[0], splits[1], str(generator_seed)))