import os
import torch
from diffusers import AutoencoderKL, UNet2DConditionModel, SchedulerMixin
from diffusers.utils import pt_to_pil
from transformers import CLIPTextModel, CLIPTokenizer
from tqdm import tqdm
import shutil

from blip_vqa_eval import evaluate_direcotry_using_blip_vqa


def generate_samples_and_evaluate_blip_vqa(
    vae: AutoencoderKL,
    unet: UNet2DConditionModel,
    scheduler: SchedulerMixin,
    tokenizer: CLIPTokenizer,
    text_encoder: CLIPTextModel,
    prompt: str,
    fixed_text_embeddings: torch.Tensor,
    evaluation_path: str,
    batch_size: int = 10,
    num_evaluation_images: int = 30,
    guidance_scale: float = 7.5,
    num_inference_steps: int = 25,
    image_size: int = 512,
    clean_fixed_text_embeddings: torch.Tensor = None,
    early_guidance_timestep_threshold: int = -1,
    seed: int = None,
):
    assert num_evaluation_images % batch_size == 0, "just for now!!!"

    if os.path.exists(evaluation_path):
        print("Removing previous evaluation path ...")
        shutil.rmtree(evaluation_path)
    os.makedirs(evaluation_path)

    text_embeddings = fixed_text_embeddings.repeat(batch_size, 1, 1).clone()
    max_length = text_embeddings.shape[1]
    uncond_input = tokenizer([""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt")
    with torch.no_grad():
        uncond_embeddings = text_encoder(uncond_input.input_ids.to(unet.device))[0]
    text_embeddings = torch.cat([uncond_embeddings, text_embeddings])

    if early_guidance_timestep_threshold != -1:
        text_embeddings_clean = clean_fixed_text_embeddings.repeat(batch_size, 1, 1).clone()
        text_embeddings_clean = torch.cat([uncond_embeddings, text_embeddings_clean])

    torch.cuda.empty_cache(); # TODO: ?

    f = 2 ** (len(vae.config.block_out_channels) - 1)

    for b_idx in range(num_evaluation_images // batch_size):
        latents = torch.randn(
            (batch_size, unet.config.in_channels, image_size // f, image_size // f),
            device=unet.device,
            generator=None if seed is None else torch.Generator(device='cuda').manual_seed(seed*100 + b_idx),
        )
        latents = latents * scheduler.init_noise_sigma

        scheduler.set_timesteps(num_inference_steps)

        for t in tqdm(scheduler.timesteps):
            latent_model_input = torch.cat([latents] * 2)

            latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t)

            with torch.no_grad():
                noise_pred = unet(
                    latent_model_input, 
                    t, 
                    encoder_hidden_states=text_embeddings if t > early_guidance_timestep_threshold else text_embeddings_clean
                ).sample

            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

            latents = scheduler.step(noise_pred, t, latents).prev_sample

        latents = 1 / vae.scaling_factor * latents
        with torch.no_grad():
            images = vae.decode(latents).sample
        
        for idx, pil_img in enumerate(pt_to_pil(images)):
            pil_img.save(os.path.join(evaluation_path, f'{prompt[0]}_{(b_idx*batch_size + idx):06d}.png'))
        
    prev_device = vae.device
    vae.to('cpu')
    text_encoder.to('cpu')
    unet.to('cpu');
    
    torch.cuda.empty_cache(); # TODO: ?

    image_scores_dict = evaluate_direcotry_using_blip_vqa(image_folder_path=evaluation_path)

    torch.cuda.empty_cache(); # TODO: ?

    vae.to(prev_device)
    text_encoder.to(prev_device)
    unet.to(prev_device)

    average_score = sum(map(lambda x: float(x), image_scores_dict.values())) / len(image_scores_dict)

    return image_scores_dict, average_score
