import torch
from torch import nn
from torch.utils.data import DataLoader
import torchvision

import numpy as np

class Imagenet(torch.utils.data.Dataset):
    def __init__(self, path_to_preprocessed, train: bool = False, n_batches=110, n_samples_per_batch=128,
                 transform=None, target_transform=None):
        
        self.n_batches, self.n_samples_per_batch = n_batches, n_samples_per_batch

        self.path = path_to_preprocessed + ('/train/' if train else '/val/')

        self.transform = transform
        self.target_transform = target_transform

    def load(self, batch_num, image_num):
        one_hot_label = torch.tensor(
            np.load(self.path + f'{batch_num}_label.npy')[image_num]
        )
        image = torch.tensor(
            np.load(self.path + f'{batch_num}.npy')[image_num]
        ).to(torch.float32)
        
        return image, one_hot_label
    
    def __len__(self):
        return self.n_batches * self.n_samples_per_batch

    def __getitem__(self, idx):
        batch_idx = int(idx // self.n_samples_per_batch)
        image_idx = idx % self.n_samples_per_batch
        
        image, label = self.load(batch_idx, image_idx)

        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

    
def load_data(batch_size):
    mean = (0.485, 0.456, 0.406)
    std = (0.229, 0.224, 0.225)

    test_transform = torchvision.transforms.Compose(
        [
            torchvision.transforms.Resize(256),
            torchvision.transforms.CenterCrop(224),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(mean, std)
        ]
    )

    train_dataset = torchvision.datasets.ImageFolder(root="/PATH/TO/DATA/train", transform=test_transform)  # TODO
    test_dataset = torchvision.datasets.ImageFolder(root="/PATH/TO/DATA/val", transform=test_transform)

    train_dataloader = DataLoader(train_dataset, batch_size, shuffle=False, num_workers=18)
    test_dataloader = DataLoader(test_dataset, batch_size, shuffle=False, num_workers=18)

    return train_dataloader, test_dataloader