# Copyright 2022-present, Lorenzo Bonicelli, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Simone Calderara.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import os
from typing import Optional, Tuple
import numpy as np
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import Dataset


def smart_joint(*paths):
    return os.path.join(*paths).replace("\\", "/")


class TinyImagenet(Dataset):
    """Defines the Tiny Imagenet dataset."""

    def __init__(
        self,
        root: str,
        train: bool = True,
        transform: Optional[nn.Module] = None,
        target_transform: Optional[nn.Module] = None,
        download: bool = False,
    ) -> None:
        self.not_aug_transform = transforms.Compose([transforms.ToTensor()])
        self.root = root
        self.train = train
        self.transform = transform
        self.target_transform = target_transform
        self.download = download

        if download:
            if os.path.isdir(root) and len(os.listdir(root)) > 0:
                print("Download not needed, files already on disk.")
            else:
                from onedrivedownloader import download

                print("Downloading dataset")
                ln = "https://unimore365-my.sharepoint.com/:u:/g/personal/263133_unimore_it/EVKugslStrtNpyLGbgrhjaABqRHcE3PB_r2OEaV7Jy94oQ?e=9K29aD"
                download(
                    ln,
                    filename=str(smart_joint(root, "tiny-imagenet-processed.zip")),
                    unzip=True,
                    unzip_path=root,
                    clean=True,
                )

        self.data = []
        for num in range(20):
            self.data.append(
                np.load(
                    smart_joint(
                        root,
                        "processed/x_%s_%02d.npy"
                        % ("train" if self.train else "val", num + 1),
                    )
                )
            )
        self.data = np.concatenate(np.array(self.data))

        self.targets = []
        for num in range(20):
            self.targets.append(
                np.load(
                    smart_joint(
                        root,
                        "processed/y_%s_%02d.npy"
                        % ("train" if self.train else "val", num + 1),
                    )
                )
            )
        self.targets = np.concatenate(np.array(self.targets))

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, index):
        img, target = self.data[index], self.targets[index]
        img = Image.fromarray(np.uint8(255 * img))

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target
