import torch
import os
from torch.utils.data import DataLoader
from torchvision.transforms import transforms
from imagenet21k_embeddings.folder import ImageFolder
import argparse

parser = argparse.ArgumentParser(description='Label ImageNet')
parser.add_argument('--data_dir',
                    type=str,
                    default='/datadrive1/adam/imagenet21k_resized/',
                    help='path to dataset used for training the model')
parser.add_argument('--embeddings_dir',
                    type=str,
                    default='/datadrive1/adam/ImageNet21k-embeddings/',
                    help='path to stored embeddings')
parser.add_argument('--embeddings_size',
                    type=int,
                    default=250000,
                    # default=10000,
                    help='size of the loaded embeddings (how many samples)')
parser.add_argument('--batch_size',
                    type=int,
                    default='1000',
                    help='size of the loaded embeddings (how many samples)')


def show(data_loader, batch_size, max_print=2):
    for batch_nr, (image, target, embeddings) in enumerate(data_loader):
        for idx, embedding in enumerate(embeddings):
            counter = batch_nr * batch_size + idx
            print(counter, ' len(embedding): ', len(embedding), flush=True)
            if len(embedding) == 0:
                print(f'Error bad_embedding for image nr: {counter}',
                      flush=True)
        # print(f"batch_nr: {batch_nr}, "
        #       f"image: {image}, "
        #       f"target: {target}, "
        #       f"embedding: {embedding}, embedding shape: {embedding.shape}")
        # if batch_nr >= max_print:
        #     break


def get_imagenet(
        data_dir, embeddings_dir, embeddings_size, batch_size, is_val=False):
    train_dir = os.path.join(data_dir, 'imagenet21k_train')
    val_dir = os.path.join(data_dir, 'imagenet21k_val')

    train_embeddings_dir = embeddings_dir + 'file_embeddings_train/'
    val_embeddings_dir = embeddings_dir + 'file_embeddings_val/'

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize,
    ])
    train_dataset = ImageFolder(root=train_dir,
                                embeddings_dir=train_embeddings_dir,
                                embeddings_size=embeddings_size,
                                transform=transform)

    subset_indices = [x for x in range(embeddings_size)]
    train_dataset = torch.utils.data.Subset(train_dataset,
                                            indices=subset_indices)

    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True)

    if is_val:
        val_dataset = ImageFolder(root=val_dir,
                                  embeddings_dir=val_embeddings_dir,
                                  embeddings_size=50000,
                                  transform=transform)
        val_loader = DataLoader(
            val_dataset, batch_size=batch_size, shuffle=False)
    else:
        val_loader = None
    return train_loader, val_loader


def main():
    args = parser.parse_args()
    train_loader, val_loader = get_imagenet(
        data_dir=args.data_dir, embeddings_dir=args.embeddings_dir,
        embeddings_size=args.embeddings_size, batch_size=args.batch_size)
    show(data_loader=train_loader, batch_size=args.batch_size)


if __name__ == "__main__":
    main()
