import os
from datasets import load_dataset
from diffusers import StableDiffusionPipeline
import torch
import argparse
import shutil
import numpy as np
from tqdm import tqdm
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler
from diffusers.utils import pt_to_pil

import sys
sys.path.append('../')
from train_mlp import CLIPTextEmbeddingLinearProjector, CLIPTextEmbeddingLinearSkipProjector, CLIPTextEmbeddingMLPProjector, WindowAwareLinearProjection


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--num_chunks",
        type=int,
        default=20,
    )
    parser.add_argument(
        "--chunk_idx",
        type=int,
        default=None,
        required=True,
    )
    parser.add_argument(
        "--path_to_indices",
        type=str,
        default="/cmlscratch/XXX/t5_analysis/clean_fid_coco/coco_random_indices.npy"
    )
    parser.add_argument(
        "--path_to_save_generated_images",
        type=str,
        default=None, # "/cmlscratch/XXX/t5_analysis/clean_fid_coco/sd/0"
        required=True,
    )
    parser.add_argument(
        "--early_guidance_timestep_threshold",
        type=int,
        default=-1,
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=None,
    )
    parser.add_argument(
        "--path_to_projector",
        type=str,
        default=None,
        required=True,
    )

    args = parser.parse_args()

    assert os.path.isfile(args.path_to_indices)

    return args


def get_list_chunk(arr: np.ndarray, num_chunks: int, chunk_idx: int) -> list:
    arr_len = arr.shape[0] 

    chunk_size = (arr_len + num_chunks - 1) // num_chunks

    start_index = chunk_size * chunk_idx
    end_index = min((chunk_idx + 1) * chunk_size, arr_len)

    print(f"Choosing chunk ({start_index}:{end_index})")
    print(f"First item of the chunk: {arr[start_index]}")
    print(f"Last item of the chunk: {arr[end_index-1]}", flush=True)

    return start_index, arr[start_index:end_index]


def get_text_embeddings(prompts, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel) -> torch.Tensor:
    text_input = tokenizer(
        prompts, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt"
    )
    with torch.no_grad():
        text_embeddings = text_encoder(text_input.input_ids.to('cuda'))[0]
    
    return text_embeddings


# model_id = "CompVis/stable-diffusion-v1-4"
model_id = "stabilityai/stable-diffusion-2-1"

def load_models():

    vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", use_safetensors=True)
    tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
    text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder", use_safetensors=True)
    unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet", use_safetensors=True)
    scheduler = PNDMScheduler.from_pretrained(model_id, subfolder="scheduler")

    vae.to('cuda')
    text_encoder.to('cuda')
    unet.to('cuda');

    num_inference_steps = 25
    scheduler.set_timesteps(num_inference_steps)

    return vae, tokenizer, text_encoder, unet, scheduler


def main():
    args = parse_args()

    if not os.path.exists(args.path_to_save_generated_images):
        print(f'Creating path \"{args.path_to_save_generated_images}\"')
        os.makedirs(args.path_to_save_generated_images)

    random_indices = np.load(args.path_to_indices)
    start_global_cnt, chunk_indices = get_list_chunk(random_indices, args.num_chunks, args.chunk_idx)
    dataset = load_dataset("HuggingFaceM4/COCO", split='train').select(chunk_indices)
    def transform(examples):
        return {'prompts': [x['raw'].strip('.').lower() for x in examples['sentences']]}
    dataset.set_transform(transform)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=5, shuffle=False) # TODO

    image_size = 512 if model_id == "CompVis/stable-diffusion-v1-4" else 768
    vae, tokenizer, text_encoder, unet, scheduler = load_models()
    text_encoder.requires_grad_(False)
    vae.requires_grad_(False)
    unet.requires_grad_(False);

    text_embedding_projector = torch.load(args.path_to_projector).to('cuda')

    global_cnt = start_global_cnt
    for batch in tqdm(dataloader):
        batch_size = len(batch['prompts'])

        fixed_text_embeddings = text_embedding_projector(get_text_embeddings(batch['prompts'], tokenizer, text_encoder)).detach()
        fixed_text_embeddings_clean = get_text_embeddings(batch['prompts'], tokenizer, text_encoder).detach()

        guidance_scale = 7.5
        guidance_timesteps_threshold = args.early_guidance_timestep_threshold

        text_embeddings = fixed_text_embeddings.clone()
        text_embeddings_clean = fixed_text_embeddings_clean.clone()

        max_length = text_embeddings.shape[1]
        uncond_input = tokenizer([""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt")
        with torch.no_grad():
            uncond_embeddings = text_encoder(uncond_input.input_ids.to(unet.device))[0]
        text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
        text_embeddings_clean = torch.cat([uncond_embeddings, text_embeddings_clean])

        f = 2 ** (len(vae.config.block_out_channels) - 1)

        latents = torch.randn(
            (batch_size, unet.config.in_channels, image_size // f, image_size // f),
            device=unet.device,
            generator=None if args.seed is None else torch.Generator(device='cuda').manual_seed(global_cnt + args.seed),
        )
        latents = latents * scheduler.init_noise_sigma

        scheduler.set_timesteps(50)

        for t in tqdm(scheduler.timesteps):
            latent_model_input = torch.cat([latents] * 2)

            latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t)

            with torch.no_grad():
                if t > guidance_timesteps_threshold:
                    noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
                else:
                    noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings_clean).sample

            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

            latents = scheduler.step(noise_pred, t, latents).prev_sample

        latents = 1 / vae.scaling_factor * latents
        with torch.no_grad():
            images = vae.decode(latents).sample

        images = pt_to_pil(images)

        for img in images:
            img.save(os.path.join(args.path_to_save_generated_images, f"{global_cnt:06d}.png"))
            global_cnt += 1


if __name__ == "__main__":
    main()
