from PIL import Image
import os
import os.path
from typing import Any, Callable, List, Optional, Union, Tuple
from glob import glob

import torch
# from .vision import VisionDataset
# from .utils import download_and_extract_archive, verify_str_arg
from torch.utils.data.dataset import Dataset

"""
Copied from https://github.com/Skuldur/Oxford-IIIT-Pets-Pytorch/blob/master/Pytorch%20Image%20Classification.ipynb
"""


def load_image(filename) :
    img = Image.open(filename)
    img = img.convert('RGB')
    return img


def get_resisc45_data(root: str) -> List:
    filenames = glob(str(root) + '/*.jpg')
    print(f"root: {root}")
    classes = set()

    data = []
    labels = []

    # Load the images and get the classnames from the image path
    for image in filenames:
        class_name = image.rsplit("/", 1)[1].rsplit('_', 1)[0]
        # class_name = image.split("/")[1].split('.')[0]
        classes.add(class_name)
        img = load_image(image)

        data.append(img)
        labels.append(class_name)
    print(f"classes: {classes}")

    # convert classnames to indices
    class2idx = {cl: idx for idx, cl in enumerate(classes)}
    labels = torch.Tensor(list(map(lambda x: class2idx[x], labels))).long()

    data = list(zip(data, labels))
    return data


class Resisc45Dataset(Dataset):
    """
    RESISC45 dataset is a publicly available benchmark for Remote Sensing Image Scene Classification (RESISC),
    created by Northwestern Polytechnical University (NWPU).
    This dataset contains 31,500 images, covering 45 scene classes with 700 images in each class.
    Dataset to serve individual images to our model
    """
    def __init__(self, data, transform=None, target_transform=None):
        self.data = data
        self.len = len(self.data)
        print(f"RESISC45 Dataset with length: {self.len}")
        self.transform = transform
        self.target_transform = target_transform

    def __getitem__(self, index):
        img, label = self.data[index]

        if self.transform:
            img = self.transform(img)
        if self.target_transform:
            label = self.target_transform(label)
        return img, label

    def __len__(self):
        return self.len
