import os
import torch
from torch.utils.data import DataLoader
from torchvision.transforms import transforms

from imagenet_embeddings.folder import ImageFolder
from imagenet_embeddings.indexes import idx
import numpy as np


def show(data_loader, max_print=2):
    for batch_nr, (image, target, embedding) in enumerate(data_loader):
        print('batch_nr: ', batch_nr, flush=True)
        # print(f"batch_nr: {batch_nr}, "
        #       f"image: {image}, "
        #       f"target: {target}, "
        #       f"embedding: {embedding}, embedding shape: {embedding.shape}")
        # if max_print > 0 and batch_nr >= max_print:
        #     break


def embedding_loader(path: str) -> np.array:
    return np.load(file=path + '.npy')


def get_imagenet(
        data_dir='',
        # data_dir='/'
        embeddings_dir='',
        batch_size=1000):
    train_dir = os.path.join(data_dir, 'train')
    val_dir = os.path.join(data_dir, 'val')

    train_embeddings_file = os.path.join(embeddings_dir,
                                         'file_embeddings_train')
    val_embeddings_file = os.path.join(embeddings_dir, 'file_embeddings_val')

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize,
    ])
    train_dataset = ImageFolder(root=train_dir,
                                embeddings_dir=train_embeddings_file,
                                embeddings_size=-1,
                                transform=transform)
    print('len(traindataset): ', len(train_dataset))

    for path, _ in train_dataset.samples:
        file_name = path.split('/')[-1]
        # embedding_path, embedding = self.embeddings[index]
        # print('index: ', index)
        to_file = train_embeddings_file + '/' + file_name
        embedding = []
        while len(embedding) == 0:
            try:
                embedding = embedding_loader(path=to_file)
            except IOError:
                label(to_file=to_file, file_name=path)

    # train_dataset = torch.utils.data.Subset(train_dataset, idx)
    val_dataset = ImageFolder(root=val_dir,
                              embeddings_size=-1,
                              embeddings_dir=val_embeddings_file,
                              transform=transform)
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=False)
    val_loader = DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False)
    return train_loader, val_loader


def main():
    train_loader, val_loader = get_imagenet()
    # show(data_loader=val_loader, max_print=-1)
    show(data_loader=train_loader, max_print=-1)


if __name__ == "__main__":
    main()
