import os
import argparse
import numpy as np
import torch
from tqdm import tqdm
from PIL import Image
from .libs.autoencoder import get_model
from .libs.clip import FrozenCLIPEmbedder
from .datasets import MSCOCODatabase, CC3MDataset
from torchvision.utils import save_image

DATASET = {
    'mscoco': {
        'dataset_class': MSCOCODatabase,
        'root_dir': '/gpfs/projects/bsc70/hpai/storage/data/datasets/raw/MSCOCO',
        'annFile': '/gpfs/projects/bsc70/hpai/storage/data/datasets/raw/MSCOCO/annotations'
    },
    'cc3m': {
        'dataset_class': CC3MDataset,
        'root_dir': '/gpfs/projects/bsc70/bsc131047/data/cc3m'
    }
}

def main(resolution=256):
    parser = argparse.ArgumentParser()
    parser.add_argument('--split', default='train')
    parser.add_argument('--dataset', choices=['mscoco', 'cc3m'], default='mscoco')
    args = parser.parse_args()
    print(args)

    if args.split == "train":
        if args.dataset == 'cc3m':
            datas = DATASET['cc3m']['dataset_class'](
                root_dir=f"{DATASET['cc3m']['root_dir']}/train",
                size=resolution,
                augmentation=False
            )
            save_dir = f'/gpfs/projects/bsc70/bsc193242/Data/cc3m{resolution}_features/train'
        elif args.dataset == 'mscoco':
            datas = DATASET['mscoco']['dataset_class'](
                root=f"{DATASET['mscoco']['root_dir']}/train2017",
                annFile=f"{DATASET['mscoco']['annFile']}/captions_train2017.json",
                size=resolution
            )
            save_dir = f'/gpfs/projects/bsc70/bsc193242/Data/coco{resolution}_features/train'
        else:
            raise NotImplementedError("ERROR!")
    elif args.split == "val":
        datas = MSCOCODatabase(root='/gpfs/projects/bsc70/hpai/storage/data/datasets/raw/MSCOCO/val2014',
                             annFile='/gpfs/projects/bsc70/hpai/storage/data/datasets/raw/MSCOCO/annotations/captions_val2014.json',
                             size=resolution)
        save_dir = f'/gpfs/projects/bsc70/bsc193242/Data/coco{resolution}_features/val'
    else:
        raise NotImplementedError("ERROR!")

    device = "cuda"
    os.makedirs(save_dir, exist_ok=True) 

    autoencoder = get_model('/gpfs/projects/bsc70/bsc193242/Models/stable-diffusion/autoencoder_kl.pth')
    autoencoder.to(device)
    clip = FrozenCLIPEmbedder()
    clip.eval()
    clip.to(device)

    # Here code to define the point where we want to resume preprocessing
    existing = [int(f.split(".")[0]) for f in os.listdir(save_dir) if f.endswith(".npy") and "_" not in f]
    start_idx = max(existing) - 1 if existing else 0
    print(f"Resuming preprocessing from index {start_idx}...")

    with torch.no_grad():
        for idx, data in tqdm(enumerate(datas)):
            if idx < start_idx:
                continue

            x, captions = data

            if len(x.shape) == 3:
                x = x[None, ...]
            
            # Save image
            img_tensor = torch.tensor(x, device='cpu') 
            save_image(img_tensor, os.path.join(save_dir, f'{idx}.png'))

            # Save image latent
            x = torch.tensor(x, device=device)
            moments = autoencoder(x, fn='encode_moments').squeeze(0)
            #moments = moments.detach().cpu().numpy()
            moments = moments.detach().to("cpu", non_blocking=True).contiguous().numpy()
            np.save(os.path.join(save_dir, f'{idx}.npy'), moments)

            # Save original captions
            np.save(os.path.join(save_dir, f"{idx}_captions.npy"), np.array(captions, dtype=object))

            # Save CLIP features
            latent = clip.encode(captions)
            for i in range(len(latent)):
                c = latent[i].detach().cpu().numpy()
                np.save(os.path.join(save_dir, f'{idx}_{i}.npy'), c)


if __name__ == '__main__':
    main()