import logging

import numpy as np
import torch
from torch import nn
import torch.utils.data as data
from PIL import Image
from torchvision.datasets import CIFAR10


logging.basicConfig()
logger = logging.getLogger()
logger.setLevel(logging.INFO)


class TruncatedDataset(data.Dataset):
    def __init__(self, dataset, dataidxs=None, transform=None, target_transform=None, directional_label=False):

        self.dataset = dataset
        self.dataidxs = dataidxs
        self.transform = transform
        self.target_transform = target_transform
        self.directional_label = directional_label

        self.data, self.targets = self.__build_truncated_dataset__(dataset)

    def __build_truncated_dataset__(self, dataset):
        data = dataset.data
        target = np.array(dataset.targets)

        # random assign directional label
        if self.directional_label:
            target = np.random.randint(0, 64, len(dataset.targets))

        if self.dataidxs is not None:
            data = data[self.dataidxs]
            target = target[self.dataidxs]
        else:
            self.dataidxs = range(len(data))
        target = target.tolist()
        return data, target

    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target = self.data[index], self.targets[index]
        idx = self.dataidxs[index]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        # if isinstance(img, torch.Tensor):
        #     # FMNIST hard coded
        #     img = Image.fromarray(img.numpy())
        #     # img = img.convert("RGB")
        # else:
        #     # SVHN hardcoded
        #     # img = Image.fromarray(np.transpose(img, (1, 2, 0)))
        #     img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)
        return img, target, idx

    def __len__(self):
        return len(self.data)


