from diffusers import AutoencoderKL, DDIMScheduler, UNet2DConditionModel
from tqdm import tqdm
from transformers import CLIPTextModel, CLIPTokenizer
import torch
from PIL import Image
from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler

from utils.activation_detection import prepare_diffusion_inputs


def load_sd_components(model_path):
    if model_path == 'v1-4':
        model_path = 'CompVis/stable-diffusion-v1-4'
    elif model_path == 'v1-5':
        model_path = 'runwayml/stable-diffusion-v1-5'
    vae = AutoencoderKL.from_pretrained(model_path,
                                        subfolder="vae")
    unet = UNet2DConditionModel.from_pretrained(
        model_path,
        subfolder="unet")
    
    scheduler = DDIMScheduler.from_pretrained(model_path, subfolder='scheduler')
    return vae, unet, scheduler


def load_text_components(model_path):
    if model_path == 'v1-4' or model_path =='v1-5':
        model_path = 'openai/clip-vit-large-patch14'
        tokenizer = CLIPTokenizer.from_pretrained(model_path)
        text_encoder = CLIPTextModel.from_pretrained(model_path)
    else:
        scheduler = EulerDiscreteScheduler.from_pretrained(model_path, subfolder="scheduler")
        pipe = StableDiffusionPipeline.from_pretrained(model_path, scheduler=scheduler, torch_dtype=torch.float32)
        tokenizer = pipe.tokenizer
        text_encoder = pipe.text_encoder
    return tokenizer, text_encoder


@torch.no_grad()
def generate_images(
    prompts, 
    tokenizer, 
    text_encoder, 
    vae, 
    unet, 
    scheduler, 
    text_embeddings=None, 
    num_inference_steps=50, 
    seed=1, 
    guidance_scale=7, 
    samples_per_prompt=1,
    verbose=False
):
    if text_embeddings is not None:
        print("Using provided text embedding")
        generator = torch.manual_seed(seed)

        if text_embeddings.shape[0] == 1:
            text_embeddings = torch.repeat_interleave(text_embeddings, dim=0, repeats=samples_per_prompt)

        latents = torch.randn(
            (text_embeddings.shape[0], unet.config.in_channels, 512 // 8, 512 // 8),
            generator=generator,
        )

        if guidance_scale != 0:
            max_length = text_embeddings.shape[-2]
            uncond_input = tokenizer([""] * len(text_embeddings),
                                        padding="max_length",
                                        max_length=max_length,
                                        return_tensors="pt")
            uncond_embeddings = text_encoder(
                uncond_input.input_ids.to(text_encoder.device))[0]
            text_embeddings = torch.cat([uncond_embeddings, text_embeddings])


        latents = latents.to(text_embeddings.device)
    else:
        if type(prompts) is str:
            prompts = [prompts]
        latents, text_embeddings = prepare_diffusion_inputs(prompts, tokenizer, text_encoder, unet, guidance_scale=guidance_scale, samples_per_prompt=samples_per_prompt, seed=seed)
    scheduler.set_timesteps(num_inference_steps)

    with torch.autocast(device_type="cuda", dtype=torch.float16):
        for i, t in tqdm(enumerate(scheduler.timesteps), total=len(scheduler.timesteps), disable=not verbose):            
            if guidance_scale == 0:
                # if we are not using classifier free guidance, just use the latents directly without duplicating
                latent_model_input = latents
            else:
                # duplicate latents for classifier free guidance
                latent_model_input = torch.cat([latents] * 2)

            latent_model_input = scheduler.scale_model_input(latent_model_input, t)

            with torch.no_grad():
                noise_pred = unet(
                    latent_model_input.cuda(),
                    t,
                    encoder_hidden_states=text_embeddings, 
                    return_dict=False
                )[0]
            
            if guidance_scale != 0:
                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                noise_pred = noise_pred_uncond + guidance_scale * (
                    noise_pred_text - noise_pred_uncond)
                    
            latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0]

            torch.cuda.empty_cache() 

        latents = 1 / vae.config.scaling_factor * latents
        image = vae.decode(latents).sample
        image = (image / 2 + 0.5).clamp(0, 1)
        image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
        images = (image * 255).round().astype("uint8")
        pil_images = [Image.fromarray(image) for image in images]
        
        return pil_images

def compute_text_embedding(prompts, tokenizer, text_encoder):
    if type(prompts) is str:
        prompts = [prompts]
    text_input = tokenizer(prompts,
                            padding="max_length",
                            max_length=tokenizer.model_max_length,
                            truncation=True,
                            return_tensors="pt")
    text_embeddings = text_encoder(
        text_input.input_ids.to(text_encoder.device))[0]

    return text_embeddings
