import json
import os
import sys
import argparse
from tqdm import tqdm

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from PIL import Image
import requests
from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPVisionModel
from transformers import AutoProcessor, CLIPVisionModelWithProjection
from diffusers import StableUnCLIPImg2ImgPipeline, DPMSolverMultistepScheduler

import torch
from torchvision.transforms import autoaugment, transforms, InterpolationMode

from llava.datasets.fmri_vit3d_datasets import fMRIViT3dDataset

parser = argparse.ArgumentParser()

parser.add_argument(
    "--gpu",
    type=int,
    default=0,
    help="gpu"
)

parser.add_argument(
    "--seed",
    type=int,
    default=42,
    help="the seed (for reproducible sampling)",
)

parser.add_argument(
    "--dataset",
    type=str,
    default="nsd",
    help="output directory",
)

parser.add_argument(
    "--slice",
    type=int,
    default=-1,
)

parser.add_argument(
    '--augment-num',
    type=int,
    default=0
)

parser.add_argument(
    '--augment-begin',
    type=int,
    default=0,
)

args = parser.parse_args()
device = "cuda:{}".format(args.gpu) if torch.cuda.is_available() else "cpu"


class ImageDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        root_dir,
        to_clip,
        to_vae,
        transform=None,
    ):
        self.root_dir = root_dir
        self.transform = transform
        self.to_clip = to_clip
        self.to_vae = to_vae

    def __len__(self):
        return len(os.listdir(f'{root_dir}/images'))

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = f'{root_dir}/images/{args.dataset}_image_{idx:06}.png'
        image = Image.open(img_name).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return np.array(image)


pipe = StableUnCLIPImg2ImgPipeline.from_pretrained(
    "stabilityai/stable-diffusion-2-1-unclip", torch_dtype=torch.float16
)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to(device)


if __name__ == '__main__':
    root_dir = f'/mnt/NSD_dataset/datasets/{args.dataset}'

    if args.augment_num > 0:
        vision_embeds_dir = f'{root_dir}/vision_embeds_aug'
        os.makedirs(vision_embeds_dir, exist_ok=True)

        transform = transforms.Compose([
            autoaugment.RandAugment(interpolation=InterpolationMode.BILINEAR),
        ])

    else:
        vision_embeds_dir = f'{root_dir}/vision_embeds'
        os.makedirs(vision_embeds_dir, exist_ok=True)
        vae_embeds_dir = f'{root_dir}/vae_embeds'
        os.makedirs(vae_embeds_dir, exist_ok=True)
        transform = None


    total = len(os.listdir(f'{root_dir}/images'))
    batch_size = 10
    print(f'[Total]: {total}')

    sample_per_slice = total // 8 + 1
    if args.slice >= 0:
        start = args.slice * sample_per_slice
        end = min((args.slice + 1) * sample_per_slice, total)
    else:
        start = 0
        end = total

    dataset = ImageDataset(
        root_dir,
        transform=transform,
        to_clip=pipe.feature_extractor,
        to_vae=pipe.image_processor.preprocess,
    )
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=8
    )

    for epoch in range(args.augment_num + 1):
        for i, image in enumerate(tqdm(dataloader, desc="Generating Embeds")):
            idx = i * batch_size
            with torch.no_grad():
                images_clip = pipe.feature_extractor(images=image, return_tensors="pt").pixel_values

                images_clip = images_clip.to(device=device)
                image_embeds = pipe.image_encoder(images_clip).image_embeds
                image_embeds = image_embeds.cpu().squeeze(0).numpy()

                if args.augment_num == 0:
                    images_vae = pipe.image_processor.preprocess(image.resize((768, 768))).to(device).half()
                    vae_embeds = pipe.vae.encode(images_vae)
                    vae_embeds = vae_embeds['latent_dist'].mean.cpu().squeeze(0)

                    for j in range(len(image_embeds)):
                        np.save(f'{vision_embeds_dir}/vision_{idx + j:06}.npy', image_embeds[j])
                        torch.save(vae_embeds[j], f'{vae_embeds_dir}/vae_{idx + j:06}.pt')
                else:
                    for j in range(len(image_embeds)):
                        np.save(f'{vision_embeds_dir}/vision_{idx + j:06}_{args.augment_begin + epoch:03}.npy', image_embeds[j])
