from PIL import Image
import numpy as np

import torch
from torch.utils.data import Dataset
import torch.nn.functional as F
from torchvision.datasets import MNIST
import torchvision.transforms as transforms

from utils.logger import Logger


class DAMNISTClient(Dataset):
    def __init__(self, fl_dataset, client_id=None):

        self.fl_dataset = fl_dataset
        self.set_client(client_id)

    def set_client(self, index=None):
        fl = self.fl_dataset
        if index is None:
            self.client_id = None
            self.length = len(fl.data)
            self.data = fl.data
            self.targets = fl.targets
        else:
            if index < 0 or index >= fl.num_clients:
                raise ValueError('Number of clients is out of bounds.')
            self.client_id = index
            indices = fl.partition[self.client_id]
            self.data = fl.data[indices]
            self.length = len(self.data)
            self.targets = [fl.targets[i] for i in indices]
            self.labels = torch.Tensor(fl.client_labels[self.client_id]).int()

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        fl = self.fl_dataset
        img, target = self.data[index], self.targets[index]

        # doing this so that it is consistent with all other fl_datasets
        # to return a PIL Image
        # img = Image.fromarray(img)

        if self.client_id is None:
            img = fl.transforms[0](img)
        else:
            img = fl.transforms[self.client_id % len(fl.transforms)](img)
            if fl.target_transform is not None:
                target = fl.target_transform(target)

        return img, target

    def __len__(self):
        return self.length


class DAMNIST(MNIST):
    """
    Data Augmentation MNIST Dataset.
    Generate different domains per client based on different data augmentations.
    transform_indices: use a subset of the transforms by passing a slice object.
    """

    NUM_CLASSES = 10

    def __init__(self, root, train=True, transform=None, target_transform=None,
                 download=False, transform_indices=slice(None), classes_per_client=1.0):

        super().__init__(root, train=train, transform=transform,
                         target_transform=target_transform,
                         download=download)
        # image normalization
        mean, std = [0.5], [0.5]
        normalize = transforms.Normalize(mean=mean, std=std)

        def full_transform(client_transform=None):
            transform_list = [
                transforms.ToPILImage(),
                transforms.Grayscale(3),
                client_transform,
                transforms.Resize(32),
                transforms.ToTensor(),
                normalize,
            ]
            return transforms.Compose([t for t in transform_list if t is not None])

        self.client_transforms = [
            transforms.CenterCrop(22),
            transforms.Pad(14),
            transforms.RandomInvert(p=1.0),
            transforms.GaussianBlur(5, sigma=(0.1, 2.0)),
            #
            transforms.RandomHorizontalFlip(p=1.0),
            transforms.RandomVerticalFlip(1.0),
            transforms.RandomRotation(40),
            transforms.ColorJitter(brightness=0.8),
            #
            transforms.RandomAffine(degrees=(30, 70), translate=(0.1, 0.3), scale=(0.5, 0.75)),
            transforms.RandomSolarize(threshold=192.0),
        ]

        self.special_transforms = [
            transforms.Compose([
                transforms.ToPILImage(),
                transforms.Grayscale(3),
                transforms.Resize(32),
                transforms.ToTensor(),
                normalize,
                transforms.RandomErasing(p=1.0, scale=(0.02, 0.1)),
            ]),
            transforms.Compose([
                transforms.ToPILImage(),
                transforms.Grayscale(3),
                transforms.Resize(32),
                transforms.ToTensor(),
                normalize,
                transforms.Lambda(lambda x: (x + 0.5 * torch.rand(3,1,1)) * torch.rand(3,1,1).clip(-1,1)),
            ]),
            transforms.Compose([
                transforms.ToPILImage(),
                transforms.Grayscale(3),
                transforms.Resize(32),
                transforms.ToTensor(),
                normalize,
                transforms.Lambda(lambda x: (x + 0.5 * torch.rand(3,1,1)) * torch.rand(3,1,1)),
                normalize,
                transforms.Lambda(lambda x: x.clip(-1,1)),
            ]),
        ]

        if train:
            self.transforms = [full_transform(client_transform) for client_transform in self.client_transforms]
            self.transforms += self.special_transforms
            self.transforms = self.transforms[transform_indices]
        else:
            self.transforms = [full_transform(None)]

        # Client indices
        self.num_clients = len(self.transforms)
        self.client_labels = [None] * self.num_clients
        self.client_indices = [None] * self.num_clients
        classes_per_client = round(classes_per_client * self.NUM_CLASSES)
        # Get label indices
        # label_indices = [self.targets == label for label in range(self.NUM_CLASSES)]
        label_indices = torch.Tensor(self.targets).view(1,-1) == torch.arange(self.NUM_CLASSES).view(-1,1)
        for client_id in range(self.num_clients):
            # Shuffle labels and get the first `classes_per_client` classes
            shuffled_labels = torch.randperm(self.NUM_CLASSES)
            self.client_labels[client_id] = sorted(shuffled_labels[:classes_per_client])
            # Get union of label indices and store it for client
            sublabel_indices = torch.zeros(len(self.targets), dtype=torch.bool)
            for label in self.client_labels[client_id]:
                sublabel_indices |= label_indices[label]
            self.client_indices[client_id] = sublabel_indices

        self.partition = [None] * self.num_clients
        rand_indices = torch.randperm(len(self.data))
        samples_per_client = len(rand_indices) // self.num_clients
        start = 0
        for client_id in range(self.num_clients):
            indices = torch.zeros(len(self.targets), dtype=torch.bool)
            for i in rand_indices[start : start + samples_per_client]:
                indices |= F.one_hot(i, num_classes=len(self.data)).bool()
            self.partition[client_id] = torch.where(indices & self.client_indices[client_id])[0]
            Logger.get().debug(f"Partition[{client_id}]: Partition indices count = {indices.sum()}")
            Logger.get().debug(f"Partition[{client_id}]: Sub-label indices count = {self.client_indices[client_id].sum()}")
            Logger.get().debug(f"Partition[{client_id}]: Joint indices count = {len(self.partition[client_id])}")
            Logger.get().debug(f"Partition[{client_id}]: Labels = {[i.item() for i in self.client_labels[client_id]]}")
            start += samples_per_client

        # Uniform shuffle
        # self.num_clients = len(self.transforms)
        # shuffle = np.arange(len(self.data))
        # rng = np.random.default_rng(7049)
        # rng.shuffle(shuffle)
        # surplus = len(shuffle) % self.num_clients
        # if surplus > 0:
        #     shuffle = shuffle[:-surplus]
        # self.partition = shuffle.reshape([self.num_clients, -1])



