import logging
import os
import random
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from ruamel.yaml import YAML

import numpy as np
import pandas as pd
import torch
import torch.utils.checkpoint

from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel, DDIMScheduler, PNDMScheduler
from transformers import CLIPTextModel, CLIPTokenizer
from transformers import CLIPProcessor, CLIPModel

from model import model_types
from config import parse_args
from utils_model import save_model, load_model

from PIL import Image
import clip

args = parse_args()    

def unfreeze_layers_unet(unet, condition):
    print("Num trainable params unet: ", sum(p.numel() for p in unet.parameters() if p.requires_grad))
    return unet

def cvtImg(img):
    img = img.permute([0, 2, 3, 1])
    img = img - img.min()
    img = (img / img.max())
    return img.numpy().astype(np.float32)

def show_examples(x):
    plt.figure(figsize=(10, 10))
    imgs = cvtImg(x)
    for i in range(25):
        plt.subplot(5, 5, i+1)
        plt.imshow(imgs[i])
        plt.axis('off')

def show_examples(x):
    plt.figure(figsize=(10, 5),dpi=200)
    imgs = cvtImg(x)
    for i in range(8):
        plt.subplot(1, 8, i+1)
        plt.imshow(imgs[i])
        plt.axis('off')

def show_images(images):
    images = [np.array(image) for image in images]
    images = np.concatenate(images, axis=1)
    return Image.fromarray(images)

def show_image(image):
    return Image.fromarray(image)

def prompt_with_template(profession, template):
    profession = profession.lower()
    custom_prompt = template.replace("{{placeholder}}", profession)
    return custom_prompt

def get_prompt_embeddings(args, prompt_domain, prompt_class, labels, tokenizer, text_encoder, padding_type="do_not_pad"):
    prompt_init = []
    for cid in labels:        
        prompt_init.append(f'a {args.domain} style of a {args.categories[cid]}')            
        padding=True
        max_length=None

    inputs = tokenizer(prompt_init, 
        # max_length=tokenizer.model_max_length, 
        padding=padding,
        max_length=max_length, 
        truncation=True,
        return_tensors="pt"
    )
    input_ids = torch.LongTensor(inputs.input_ids)
    text_f = text_encoder(input_ids.to('cuda'))[0]
    if prompt_domain is not None:
        num_prompt_domain = 1
        text_f[:, 2:2+num_prompt_domain] = prompt_domain.unsqueeze(0).repeat(labels.shape[0], 1, 1)
    if prompt_class is not None:
        num_prompt_class = 1
        text_f[:, -1-num_prompt_class:-1] = prompt_class[labels]

    return text_f

def main():
    args = parse_args()    

    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
    yaml = YAML()
    yaml.dump(vars(args), open(os.path.join(args.output_dir, 'test_config.yaml'), 'w'))

    # Load models and create wrapper for stable diffusion
    tokenizer = CLIPTokenizer.from_pretrained(
        args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
    )
    text_encoder = CLIPTextModel.from_pretrained(
        args.pretrained_model_name_or_path,
        subfolder="text_encoder",
        revision=args.revision,
    )
    vae = AutoencoderKL.from_pretrained(
        args.pretrained_model_name_or_path,
        subfolder="vae",
        revision=args.revision,
    )
    unet = UNet2DConditionModel.from_pretrained(
        args.pretrained_model_name_or_path,
        subfolder="unet",
        revision=args.revision,
    )
    if args.scheduler == 'ddim':
        scheduler = DDIMScheduler(
            beta_start=0.00085, beta_end=0.012, 
            beta_schedule="scaled_linear", 
            clip_sample=False, 
            set_alpha_to_one=False,
            num_train_timesteps=1000,
            steps_offset=1,
        )
    elif args.scheduler == 'pndm':
        scheduler = PNDMScheduler.from_pretrained(
            args.pretrained_model_name_or_path, 
            subfolder="scheduler"
        )
    elif args.scheduler == 'ddpm':
        scheduler = DDPMScheduler.from_pretrained(
        args.pretrained_model_name_or_path, 
        subfolder="scheduler"
        )
    else:
        raise NotImplementedError(args.scheduler)

    # Freeze vae and text_encoder
    vae.requires_grad_(False)
    text_encoder.requires_grad_(False)
    unet.requires_grad_(False)
    num_concepts=7

    device=torch.device('cuda')

    model=StableDiffusionPipeline(
        vae=vae,
        text_encoder=text_encoder,
        tokenizer=tokenizer,
        unet=unet,
        scheduler=scheduler,
        safety_checker=None,
        feature_extractor=None,
        requires_safety_checker=False,
    )
    model=model.to(device)
    if args.fp16:
        print('Using fp16')
        model.unet=model.unet.half()
        model.vae=model.vae.half()
        model.text_encoder=model.text_encoder.half()

    dataloader = None
    # following https://arxiv.org/pdf/2306.16064
    categories = args.categories

    def generate_data_per_domain_prompt(model, categories, unet, device, args, idx=None):
        domain = args.domain
        domains = args.domains
        if args.dataset=='officehome':
            sample_per_class_per_domain = 80
        else:
            sample_per_class_per_domain = 160
        
        did = domains.index(domain)

        from collections import defaultdict
        # load latents
        latents = defaultdict(list)
        latents_root = f'/root/InterpretDiffusion/exps_{args.dataset}/fedlip_prompt_d_{domain}'
        latents_mean = torch.load(f"{latents_root}/mean.pt")
        latents_std = torch.load(f"{latents_root}/std.pt")
        concept = None

        # load soft prompts
        prompt_class, prompt_domain = None, None
        if 'w_cprompt' in args.test_type:
            prompt_class = torch.load(f"{latents_root}/prompt_class.pth")
        if 'w_dprompt' in args.test_type:
            prompt_domain = torch.load(f'{latents_root}/prompt_domain.pth')
            
            if "wnoise" in args.test_type and prompt_domain is not None:
                # add random noise to spec_concept     
                intensity = float(args.test_type.split("wnoise_")[1][:3])
                prompt_domain = prompt_domain + torch.randn_like(prompt_domain) * intensity

        for cid, c in enumerate(categories):
            save_image_dir=os.path.join(args.output_dir, c, f"{domain}_{args.test_type}")
            os.makedirs(save_image_dir, exist_ok=True)   
            
            for i in range(sample_per_class_per_domain):
                if not (i>=args.start_idx and i<=args.end_idx): continue
                if os.path.exists(f"{save_image_dir}/{i}.jpg"): continue
                seed = did * 1000000 + cid * 1000 + i
                labels = torch.tensor([cid] * 1).to(device)
                prompt_embeds = get_prompt_embeddings(
                    args, prompt_domain, prompt_class, labels, 
                    tokenizer, text_encoder, 
                    padding_type="max_length")

                if 'w_latent' in args.test_type:
                    mean, std = latents_mean[c][i % len(latents_mean[c])], latents_std[c][i % len(latents_mean[c])]
                    sample = torch.randn(
                        mean.shape,
                        device=device,
                        dtype=mean.dtype,
                    )
                    latent = mean + std * sample
                else:
                    latent = None

                image = predict_cond(
                    model=model, 
                    prompt=None, prompt_embeds=prompt_embeds,
                    seed=seed, condition=concept,
                    img_size=args.resolution, num_inference_steps=args.num_inference_steps,
                    negative_prompt=args.negative_prompt, latent=latent,
                )
                image.save(f"{save_image_dir}/{i}.jpg")
    
    generate_data_per_domain_prompt(model=model, categories=categories, unet=unet, device=device, args=args)

def predict_cond(model, 
                prompt, 
                seed, 
                condition, 
                img_size,
                num_inference_steps=50,
                interpolator=None, 
                negative_prompt=None,
                latent=None,
                prompt_embeds=None,
                ):
    
    generator = torch.Generator("cuda").manual_seed(seed) if seed is not None else None
    output = model(prompt=prompt, prompt_embeds=prompt_embeds,
                height=img_size, width=img_size, 
                num_inference_steps=num_inference_steps, 
                generator=generator, 
                controlnet_cond=condition,
                controlnet_interpolator=interpolator,
                negative_prompt=negative_prompt,
                latents=latent.unsqueeze(0) if latent is not None else None,
                )
    image = output[0][0]
    return image

if __name__ == "__main__":
    main()
