import numpy as np
import torch
import os
from torchvision import transforms
import torch.utils.data
import PIL
import torchvision.transforms.functional as FT
from PIL import Image
import bisect 

import torchvision
import torch.distributed

import pdb

   
"""
Transformation functions
"""
def pad(img, size, mode):
    if isinstance(img, PIL.Image.Image):
        img = np.array(img)
    return np.pad(img, [(size, size), (size, size), (0, 0)], mode)

class UniformNoising(object):
    """
    Add uniform noise to input images
    """
    def __init__(self, min, max):
        self.min = min
        self.max = max

    def __call__(self, tensor):
        """
        Args:
            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
        Returns:
            Tensor: Normalized Tensor image.
        """
        noise = torch.FloatTensor(tensor.shape).uniform_(self.min, self.max)
        tensor = tensor + noise
        return tensor

    def __repr__(self):
        return self.__class__.__name__
    
class GaussianBlur(object):
    """
        PyTorch version of
        https://github.com/google-research/simclr/blob/244e7128004c5fd3c7805cf3135c79baa6c3bb96/data_util.py#L311
    """
    def gaussian_blur(self, image, sigma):
        image = image.reshape(1, 3, image.shape[1], image.shape[2])
        radius = np.int(self.kernel_size/2)
        kernel_size = radius * 2 + 1
        x = np.arange(-radius, radius + 1)

        blur_filter = np.exp(
              -np.power(x, 2.0) / (2.0 * np.power(np.float(sigma), 2.0)))
        blur_filter /= np.sum(blur_filter)

        conv1 = torch.nn.Conv2d(3, 3, kernel_size=(kernel_size, 1), groups=3, padding=[kernel_size//2, 0], bias=False)
        conv1.weight = torch.nn.Parameter(
            torch.Tensor(np.tile(blur_filter.reshape(kernel_size, 1, 1, 1), 3).transpose([3, 2, 0, 1])))

        conv2 = torch.nn.Conv2d(3, 3, kernel_size=(1, kernel_size), groups=3, padding=[0, kernel_size//2], bias=False)
        conv2.weight = torch.nn.Parameter(
            torch.Tensor(np.tile(blur_filter.reshape(kernel_size, 1, 1, 1), 3).transpose([3, 2, 1, 0])))

        res = conv2(conv1(image))
        assert res.shape == image.shape
        return res[0]

    def __init__(self, kernel_size, p=0.5):
        self.kernel_size = kernel_size
        self.p = p

    def __call__(self, img):
        with torch.no_grad():
            assert isinstance(img, torch.Tensor)
            if np.random.uniform() < self.p:
                return self.gaussian_blur(img, sigma=np.random.uniform(0.2, 2))
            return img

    def __repr__(self):
        return self.__class__.__name__ + '(kernel_size={0}, p={1})'.format(self.kernel_size, self.p)

class CenterCropAndResize(object):
    """Crops the given PIL Image at the center.
    Args:
        size (sequence or int): Desired output size of the crop. If size is an
            int instead of sequence like (h, w), a square crop (size, size) is
            made.
    """

    def __init__(self, proportion, size):
        self.proportion = proportion
        self.size = size

    def __call__(self, img):
        """
        Args:
            img (PIL Image): Image to be cropped.
        Returns:
            PIL Image: Cropped and image.
        """
        w, h = (np.array(img.size) * self.proportion).astype(int)
        img = FT.resize(
            FT.center_crop(img, (h, w)),
            (self.size, self.size),
            interpolation=PIL.Image.BICUBIC
        )
        return img

    def __repr__(self):
        return self.__class__.__name__ + '(proportion={0}, size={1})'.format(self.proportion, self.size)


class Clip(object):
    def __call__(self, x):
        return torch.clamp(x, 0, 1)

def get_color_distortion(s=1.0):
    # s is the strength of color distortion.
    # given from https://arxiv.org/pdf/2002.05709.pdf
    color_jitter = transforms.ColorJitter(0.8*s, 0.8*s, 0.8*s, 0.2*s)
    rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8)
    rnd_gray = transforms.RandomGrayscale(p=0.2)
    color_distort = transforms.Compose([
        rnd_color_jitter,
        rnd_gray])
    return color_distort


"""
Dataset info functions
"""    
    
def get_classes(dataset):
    n_classes = None
    if dataset == 'cifar10':
        n_classes = 10
    elif dataset == 'cifar100':
        n_classes = 100
    elif dataset == 'stl10':
        n_classes = 10
    elif dataset == 'dtd':
        n_classes = 47
    elif dataset == 'mnist':
        n_classes = 10
    elif dataset == 'svhn':
        n_classes = 10
    elif dataset == 'tiny_imagenet':
        n_classes = 200
    elif dataset == 'imagenet':
        n_classes = 1000
    elif dataset == 'imagenet100':
        n_classes = 100
    return n_classes

"""
Dataloader preparation functions
"""  
class CustomConcatDataset(torch.utils.data.Dataset):
    r"""Dataset as a concatenation of multiple datasets.

    This class is useful to assemble different existing datasets.

    Args:
        datasets (sequence): List of datasets to be concatenated
    """
    @staticmethod
    def cumsum(sequence):
        r, s = [], 0
        for e in sequence:
            l = len(e)
            r.append(l + s)
            s += l
        return r

    def __init__(self, datasets):
        super(CustomConcatDataset, self).__init__()
        # Cannot verify that datasets is Sized
        assert len(datasets) > 0, 'datasets should not be an empty iterable'  # type: ignore[arg-type]
        self.datasets = list(datasets)
        for d in self.datasets:
            assert not isinstance(d, torch.utils.data.IterableDataset), "ConcatDataset does not support IterableDataset"
        self.cumulative_sizes = self.cumsum(self.datasets)

    def __len__(self):
        return self.cumulative_sizes[-1]

    def __getitem__(self, idx):
        if idx < 0:
            if -idx > len(self):
                raise ValueError("absolute value of index should not exceed dataset length")
            idx = len(self) + idx
        dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
        if dataset_idx == 0:
            sample_idx = idx
        else:
            sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
            
        sample, label = self.datasets[dataset_idx][sample_idx]
        return sample, idx, label, dataset_idx

    @property
    def cummulative_sizes(self):
        warnings.warn("cummulative_sizes attribute is renamed to "
                      "cumulative_sizes", DeprecationWarning, stacklevel=2)
        return self.cumulative_sizes
       
class ContrastiveLearningViewGenerator(object):
    """Take two random crops of one image as the query and key."""

    def __init__(self, base_transform, multiplier, clean_transform = None, return_clean_image = False):
        self.base_transform = base_transform
        self.multiplier = multiplier
        self.clean_transform = clean_transform
        self.return_clean_image = return_clean_image

    def __call__(self, x):
        out = None
        if self.multiplier > 1:
            out = [self.base_transform(x) for m in range(self.multiplier)]
        else:
            out = self.base_transform(x)
        
        if self.return_clean_image:
            out = [self.clean_transform(x), out]
        
        return out

def get_simclr_transform(dataset_name, image_size, color_dist_s, scale_lower, use_color_dist = False, use_rotation = False):
    if dataset_name == 'imagenet' or dataset_name == 'imagenet100':
        transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.CenterCrop(image_size),
            transforms.RandomResizedCrop(
                image_size,
                scale=(scale_lower, 1.0),
                interpolation=PIL.Image.BICUBIC,
            ),
            transforms.RandomHorizontalFlip(0.5),
            get_color_distortion(s = color_dist_s),
            transforms.ToTensor(),
            GaussianBlur(image_size // 10, 0.5),
            Clip(),
        ])
        
    else:
        transform = [transforms.RandomResizedCrop
                        (image_size,
                        scale=(scale_lower, 1.0),
                        interpolation=PIL.Image.BICUBIC,
                        ),
                    transforms.RandomHorizontalFlip()
                    ]
        if use_color_dist:
            transform.append(get_color_distortion(s = color_dist_s))
            
        if use_rotation:
            transform.append(transforms.RandomRotation(degrees=(-180, 180)))
            
            
        transform += [transforms.ToTensor(),
#                             GaussianBlur(image_size // 10, 0.5),
                            Clip()
                        ]
        transform = transforms.Compose(transform)
        
    return transform

def get_clean_transform(dataset_name, image_size):
    transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.CenterCrop(image_size),
            transforms.ToTensor()
        ])

    return transform

def get_linear_model_transforms(dataset_name, image_size, color_dist_s, scale_lower):  
    if dataset_name == 'imagenet' or dataset_name == 'imagenet100':
        train_transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.CenterCrop(image_size),
            transforms.RandomResizedCrop(
                image_size,
                scale=(scale_lower, 1.0),
                interpolation=PIL.Image.BICUBIC,
            ),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            Clip(),
        ])
        test_transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
            Clip(),
        ])
        
    else:
        train_transform = transforms.Compose([
            transforms.RandomCrop(image_size, padding=4, padding_mode='reflect'),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            Clip(),
        ])
        test_transform = transforms.Compose([
            transforms.ToTensor(),
        ])
        
    return train_transform, test_transform
    
def get_dataset(dataset_name, data_root, train_transform, test_transform, k_shot = None):
    data_root = os.path.join(data_root, dataset_name)
    
    if dataset_name == 'cifar10':
        trainset = torchvision.datasets.CIFAR10(root=data_root, train=True, download=True, transform=train_transform)
        testset = torchvision.datasets.CIFAR10(root=data_root, train=False, download=True, transform=test_transform)
     
    elif dataset_name == 'svhn':
        trainset = torchvision.datasets.SVHN(root=data_root, split='train', download=True, transform=train_transform)
        testset = torchvision.datasets.SVHN(root=data_root, split='test', download=True, transform=test_transform)
     
    elif dataset_name == 'cifar100':
        trainset = torchvision.datasets.CIFAR100(root=data_root, train=True, download=True, transform=train_transform)
        testset = torchvision.datasets.CIFAR100(root=data_root, train=False, download=True, transform=test_transform)
    
    elif dataset_name == 'stl10':
        trainset = torchvision.datasets.STL10(root=data_root, split='train', download=True, transform=train_transform)
        testset = torchvision.datasets.STL10(root=data_root, split='test', download=True, transform=test_transform)
        
    elif dataset_name == 'dtd':
        traindir = os.path.join(data_root, 'images')
        valdir = os.path.join(data_root, 'images')
        trainset = torchvision.datasets.ImageFolder(root=traindir, transform=train_transform)
        testset = torchvision.datasets.ImageFolder(root=valdir, transform=test_transform)
        
    elif dataset_name == 'tiny_imagenet':
        traindir = os.path.join(data_root, 'train')
        valdir = os.path.join(data_root, 'val')
        trainset = torchvision.datasets.ImageFolder(root=traindir, transform=train_transform)
        testset = torchvision.datasets.ImageFolder(root=valdir, transform=test_transform)
    
    elif dataset_name == 'imagenet' or dataset_name == 'imagenet100':
        traindir = os.path.join(data_root, 'train')
        valdir = os.path.join(data_root, 'val')
        trainset = torchvision.datasets.ImageFolder(traindir, transform=train_transform)
        testset = torchvision.datasets.ImageFolder(valdir, transform=test_transform)
        
    else:
        raise NotImplementedError
        
    if k_shot != None:
        indices = torch.LongTensor([])
        for i in range(get_classes(dataset_name)):
            if type(trainset.targets) == list:
                trainset.targets = torch.Tensor(trainset.targets)
            ids = torch.where(trainset.targets == i)[0]
            k_samples = ids[torch.randperm(len(ids))[:k_shot]]
            indices = torch.cat([indices, k_samples])

        if dataset_name == 'tiny_imagenet':
            trainset.imgs = [trainset.imgs[ind.item()] for ind in indices]
            trainset.samples = [trainset.samples[ind.item()] for ind in indices]
        else:
            trainset.data = trainset.data[indices]
        trainset.targets = trainset.targets[indices]
        
    return trainset, testset

def get_sampler(dataset, dist = 'dp'):
    sampler = None
    if dist == 'ddp':
        sampler = torch.utils.data.distributed.DistributedSampler(dataset)
        print(f'Process {torch.distributed.get_rank()}: {len(sampler)} training samples per epoch')
    return sampler

def get_dataloader(dataset, sampler, batch_size, workers): 
    loader = torch.utils.data.DataLoader(
        dataset,
        batch_size = batch_size, 
        shuffle = (sampler is None),
        num_workers = workers, 
        pin_memory = False, 
        sampler = sampler, 
        drop_last = True)
        
    return loader