import os
import torch
import torchvision.transforms as transforms
import torchvision.transforms.functional as F
from torchvision.datasets import CIFAR10, STL10, ImageFolder
from torch.utils.data import DataLoader, ConcatDataset, sampler
from PIL import Image
import pathlib
IMAGE_EXTENSIONS = {'bmp', 'jpg', 'jpeg', 'pgm', 'png', 'ppm',
                    'tif', 'tiff', 'webp'}


class Crop(object):
    def __init__(self, x1, x2, y1, y2):
        self.x1 = x1
        self.x2 = x2
        self.y1 = y1
        self.y2 = y2

    def __call__(self, img):
        return F.crop(img, self.x1, self.y1, self.x2 - self.x1, self.y2 - self.y1)

    def __repr__(self):
        return self.__class__.__name__ + "(x1={}, x2={}, y1={}, y2={})".format(
            self.x1, self.x2, self.y1, self.y2
        )


def get_dataset(args, config):
    """
    Returns vanilla CIFAR10/STL10 dataset.
    """
    if config.data.random_flip is False:
        tran_transform = transforms.Compose(
            [transforms.Resize(config.data.image_size), transforms.ToTensor()]
        )
    else:
        tran_transform = transforms.Compose(
            [
                transforms.Resize(config.data.image_size),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.ToTensor(),
            ]
        )

    if config.data.dataset == "CIFAR10":
        dataset = CIFAR10(
            config.data.path,
            train=True,
            download=True,
            transform=tran_transform,
        )
        
    elif config.data.dataset == "STL10":
        
        # for STL10 use both train and test sets due to its small size
        train_dataset = STL10(
            config.data.path,
            split="train",
            download=True,
            transform=tran_transform,
        )
        test_dataset = STL10(
            config.data.path,
            split="test",
            download=True,
            transform=tran_transform,
        )
        dataset = ConcatDataset([train_dataset, test_dataset])
        
    train_loader = DataLoader(
        dataset,
        batch_size=config.training.batch_size,
        shuffle=True,
        num_workers=config.data.num_workers,
    )
    return train_loader


def all_but_one_class_path_dataset(config, data_path, label_to_drop):
    """
    Returns all classes but one from a folder with labels,
    e.g.,
    ./folder
        - /0
        - /1
        - /2 
        etc..
    """
    if config.data.random_flip is False:
        transform = transforms.Compose(
            [transforms.Resize(config.data.image_size), transforms.ToTensor()]
        )
    else:
        transform = transforms.Compose(
            [
                transforms.Resize(config.data.image_size),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.ToTensor(),
            ]
        )

    train_dataset = ImageFolder(
        data_path,
        transform=transform,
    )

    train_idx = find_indices(train_dataset.targets, label_to_drop)
    train_subset = torch.utils.data.Subset(train_dataset, train_idx)
    train_loader = torch.utils.data.DataLoader(train_subset, batch_size=config.training.batch_size, shuffle=True, drop_last=False)
    
    return train_loader


def all_but_one_class_dataset(config, label_to_drop):
    
    if config.data.random_flip is False:
        tran_transform = transforms.Compose(
            [transforms.Resize(config.data.image_size), transforms.ToTensor()]
        )
    else:
        tran_transform = transforms.Compose(
            [
                transforms.Resize(config.data.image_size),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.ToTensor(),
            ]
        )

    if config.data.dataset == "CIFAR10":
        train_dataset = CIFAR10(
            config.data.path,
            train=True,
            download=True,
            transform=tran_transform,
        )
        
        train_idx = find_indices(train_dataset.targets, label_to_drop)
        dataset = torch.utils.data.Subset(train_dataset, train_idx)
        
    elif config.data.dataset == "STL10":
        
        # for STL10 use both train and test sets due to its small size
        train_dataset = STL10(
            config.data.path,
            split="train",
            download=True,
            transform=tran_transform,
        )
        test_dataset = STL10(
            config.data.path,
            split="test",
            download=True,
            transform=tran_transform,
        )
        
        train_idx = find_group_indices(train_dataset.labels, label_to_drop)
        train_subset = torch.utils.data.Subset(train_dataset, train_idx)
        test_idx = find_indices(test_dataset.labels, label_to_drop)
        test_subset = torch.utils.data.Subset(test_dataset, test_idx)
        dataset = ConcatDataset([train_subset, test_subset])

    train_loader = DataLoader(
        dataset,
        batch_size=config.training.batch_size,
        shuffle=True,
        num_workers=config.data.num_workers,
    )
    
    return train_loader

class SubsetSampler(sampler.Sampler):
    r"""Samples elements randomly from a given list of indices, without replacement.
    Arguments:
        indices (sequence): a sequence of indices
    """

    def __init__(self, indices):
        self.indices = indices

    def __iter__(self):
        return (self.indices[i] for i in range(len(self.indices)))

    def __len__(self):
        return len(self.indices)

def groups_for_one_path_dataset(config, data_path, label_to_drop):
    if config.data.random_flip is False:
        transform = transforms.Compose(
            [transforms.Resize(config.data.image_size), transforms.ToTensor()]
        )
    else:
        transform = transforms.Compose(
            [
                transforms.Resize(config.data.image_size),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.ToTensor(),
            ]
        )

    train_dataset = ImageFolder(
        data_path,
        transform=transform,
    )
    train_idx = get_sequential_indices(train_dataset.targets, label_to_drop, config.grouping.label_list, config.grouping.num_samples * config.grouping.batch_num)
    #print(train_idx)
    #seq_sampler = SubsetSampler(train_idx)
    #batch_sampler = sampler.BatchSampler(seq_sampler, config.grouping.num_samples, drop_last=True)
    train_subset = torch.utils.data.Subset(train_dataset, train_idx)
    #test_idx = find_indices(test_dataset.labels, label_to_drop)
    #test_subset = torch.utils.data.Subset(test_dataset, test_idx)
    #dataset = ConcatDataset([train_subset, test_subset])

    train_loader = DataLoader(
        train_subset,
        #batch_sampler=batch_sampler,
        batch_size=config.grouping.batch_num,
        shuffle=False,
        num_workers=config.data.num_workers,
    )
    
    return train_loader

def groups_for_one_dataset(config, label_to_drop):
    if config.data.random_flip is False:
        tran_transform = transforms.Compose(
            [transforms.Resize(config.data.image_size), transforms.ToTensor()]
        )
    else:
        tran_transform = transforms.Compose(
            [
                transforms.Resize(config.data.image_size),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.ToTensor(),
            ]
        )

    if config.data.dataset == "CIFAR10":
        train_dataset = CIFAR10(
            config.data.path,
            train=True,
            download=True,
            transform=tran_transform,
        )
        
        train_idx = find_indices(train_dataset.targets, label_to_drop)
        dataset = torch.utils.data.Subset(train_dataset, train_idx)
        
    elif config.data.dataset == "STL10":
        
        # for STL10 use both train and test sets due to its small size
        train_dataset = STL10(
            config.data.path,
            split="train",
            download=True,
            transform=tran_transform,
        )
        test_dataset = STL10(
            config.data.path,
            split="test",
            download=True,
            transform=tran_transform,
        )
        
        train_idx = get_sequential_indices(train_dataset.labels, label_to_drop, config.grouping.label_list, config.grouping.num_samples * config.grouping.batch_num)
        
        seq_sampler = SubsetSampler(train_index)
        batch_sampler = sampler.BatchSampler(seq_sampler, config.grouping.num_samples, drop_last=True)
        train_subset = torch.utils.data.Subset(train_dataset, train_idx)
        #test_idx = find_indices(test_dataset.labels, label_to_drop)
        #test_subset = torch.utils.data.Subset(test_dataset, test_idx)
        #dataset = ConcatDataset([train_subset, test_subset])

    train_loader = DataLoader(
        dataset,
        #batch_sampler=batch_sampler,
        batch_size=config.grouping.batch_num,
        shuffle=False,
        num_workers=config.data.num_workers,
    )
    
    return train_loader


def logit_transform(image, lam=1e-6):
    image = lam + (1 - 2 * lam) * image
    return torch.log(image) - torch.log1p(-image)


def data_transform(config, X):
    if config.data.uniform_dequantization:
        X = X / 256.0 * 255.0 + torch.rand_like(X) / 256.0
    if config.data.gaussian_dequantization:
        X = X + torch.randn_like(X) * 0.01

    if config.data.rescaled:
        X = 2 * X - 1.0
    elif config.data.logit_transform:
        X = logit_transform(X)

    if hasattr(config, "image_mean"):
        return X - config.image_mean.to(X.device)[None, ...]

    return X


def inverse_data_transform(config, X):
    if hasattr(config, "image_mean"):
        X = X + config.image_mean.to(X.device)[None, ...]

    if config.data.logit_transform:
        X = torch.sigmoid(X)
    elif config.data.rescaled:
        X = (X + 1.0) / 2.0

    return torch.clamp(X, 0.0, 1.0)


class ImagePathDataset(torch.utils.data.Dataset):
    def __init__(self, img_folder, transforms=None, n=None):
        self.transforms = transforms
        
        path = pathlib.Path(img_folder)
        self.files = sorted([file for ext in IMAGE_EXTENSIONS
                       for file in path.glob('*.{}'.format(ext))])
        
        assert n is None or n <= len(self.files)
        self.n = len(self.files) if n is None else n
        
    def __len__(self):
        return self.n

    def __getitem__(self, i):
        path = self.files[i]
        img = Image.open(path).convert('RGB')
        if self.transforms is not None:
            img = self.transforms(img)
        return img
    
    
def find_indices(lst, condition):
    return [i for i, elem in enumerate(lst) if elem != condition]

def find_group_indices(lst, condition):
    return [i for i, elem in enumerate(lst) if elem in condition]

def get_sequential_indices(lst, condition, label_list, num_samples):
    indexs = []
    #print(label_list, lst, condition)
    for label in label_list:
        count = 0
        index = []
        for i, elem in enumerate(lst):
            #print(elem)
            if elem != condition and elem == label:
                index.append(i)
                count +=1
            if count >= num_samples:
                break
        indexs.extend(index)
    #print(len(indexs))
    return indexs
        
    #return [i for i, elem in enumerate(lst) if elem != condition]