from pathlib import Path

import torch
from torch.utils.data import Dataset
from torchvision import datasets as torch_dataset, transforms
from torch.distributions import Normal, Categorical, MultivariateNormal
from PIL import Image
from sklearn.decomposition import PCA


class MNIST(torch_dataset.MNIST):
    """Wrapper for the MNIST dataset. Simply adds the 'is_image' attribute"""

    _annealing_iteration = 0

    def __init__(self, *args, kde_width=None, PCA=False, annealing=False,
                 annealing_iterations=None, **kwargs):
        super(MNIST, self).__init__(*args, **kwargs)
        self.is_image = True
        self._kde_width = kde_width
        self._PCA = PCA
        self.known_marginal = False
        self._anneal_Normal = Normal(0, 1)
        self.annealing = annealing
        self._max_annealing_iterations = annealing_iterations
        self._beta = 0

        self.shape = super(MNIST, self).__getitem__(0)[0].shape # ignore the label

        if annealing:
            if annealing_iterations is None:
                raise ValueError("When annealing always specify number of annealing iterations")

    def __getitem__(self, index):
        if not self.annealing:
            data, target = super(MNIST, self).__getitem__(index)
        else:
            anneal_categorical = self._update_annealing_state()
            cluster = anneal_categorical.sample().item()
            if cluster == 1:
                data = self._anneal_Normal.sample()
                target = torch.tensor(0)
                if self.target_transform is not None:
                    target = self.target_transform(target)
            else:
                data, target = super(MNIST, self).__getitem__(index)
        return data, target

    def _update_annealing_state(self):
        annealing_iteration = self._get_annealing_iteration()
        if annealing_iteration < self._max_annealing_iterations:
            self._beta = (1/self._max_annealing_iterations)*annealing_iteration
            probs = torch.Tensor([self._beta, 1-self._beta])
            anneal_categorical = Categorical(probs=probs)
            return anneal_categorical
        else:
            # turn off annealing once we have finished annealing
            self.annealing = False
            self._beta = 1
            probs = torch.Tensor([self._beta, 1-self._beta])
            anneal_categorical = Categorical(probs=probs)
        return anneal_categorical

    @classmethod
    def update_annealing(cls):
        cls._annealing_iteration += 1

    @classmethod
    def _get_annealing_iteration(cls):
        return cls._annealing_iteration

    def initialize_log_prob(self, X):
        X = X[:int(X.size(0)/5)]
        if self._PCA:
            X = X.flatten(start_dim=1).numpy()
            pca = PCA(n_components=0.28)
            self._transformed_X = torch.from_numpy(pca.fit_transform(X))
            self._data_mean = torch.from_numpy(pca.mean_)
            self._components = torch.from_numpy(pca.components_)
        else:
            self._transformed_X = X.flatten(start_dim=1)

        # SCOTTS FACTOR - see https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.gaussian_kde.html
        self._M, D  = self._transformed_X.shape
        self.known_marginal = True

    def log_prob(self, X):
        if self.annealing:
            raise AttributeError("Annealing is used - should never call log_prob!!")
        X = X.flatten(start_dim=1)
        batch_size = X.size(0)
        if self._PCA:
            X = X - self._data_mean
            X = torch.matmul(X, self._components.T)
        X = X.repeat_interleave(self._M, dim=0)
        transformed_X_repeated = self._transformed_X.repeat(batch_size, 1)
        distances = torch.pairwise_distance(X, transformed_X_repeated)
        pair_distances = distances.view(batch_size, -1)
        unnorm_logexp = -0.5*pair_distances/(self._kde_width**2)
        unnorm_logprob = torch.logsumexp(unnorm_logexp, dim=1)
        return unnorm_logprob

    def to(self, device, *args, **kwargs):
        if self.annealing:
            raise AttributeError("Annealing is used - should never need to place anything on a device!!")
        if self._PCA:
            self._data_mean = self._data_mean.to(device=device)
            self._components = self._components.to(device=device)
            self._transformed_X = self._transformed_X.to(device=device)


class LatentData(Dataset):
    """Generate random data where each element is sampled from a Normal(0,1)
    Args:
    """
    _annealing_iteration = 0

    def __init__(self, size=100, target_transform=None, known_marginal=True,
                 annealing=False, annealing_iterations=None):
        super(LatentData, self).__init__()
        self.size = size
        self.target_transform = target_transform
        self.known_marginal = known_marginal
        self._anneal_Normal = Normal(0, 1)
        self.annealing = annealing
        self._max_annealing_iterations = annealing_iterations
        self._beta = 0
        if annealing:
            if annealing_iterations is None:
                msg = "When annealing always"
                msg += "specify number of annealing iterations"
                raise ValueError(msg)

    def __getitem__(self, index):
        target = torch.randint(0, 1, size=(1,), dtype=torch.long)[0]
        data = self._sample()
        if self.target_transform is not None:
            target = self.target_transform(target)

        return data, target

    def __len__(self):
        return self.size

    def _sample(self):
        raise NotImplementedError()

    def log_prob(self, X):
        raise NotImplementedError()

    def _update_annealing_state(self):
        annealing_iteration = self._get_annealing_iteration()
        if annealing_iteration < self._max_annealing_iterations:
            self._beta = (1/self._max_annealing_iterations)*annealing_iteration
            probs = torch.Tensor([self._beta, 1-self._beta])
            anneal_categorical = Categorical(probs=probs)
            return anneal_categorical
        else:
            # turn off annealing once we have finished annealing
            self.annealing = False
            self._beta = 1
            probs = torch.Tensor([self._beta, 1-self._beta])
            anneal_categorical = Categorical(probs=probs)
        return anneal_categorical

    @classmethod
    def update_annealing(cls):
        cls._annealing_iteration += 1

    @classmethod
    def _get_annealing_iteration(cls):
        return cls._annealing_iteration

    def to(self, device, *args, **kwargs):
        return self._parameters_to_device(device, *args, **kwargs)

    def _parameters_to_device(self, device, *args, **kwargs):
        raise NotImplementedError()


class LatentSimpleData(LatentData):
    """Generate random data where each element is sampled from a Normal(0,1)
    Args:
    """

    def __init__(self, loc=0, std=1, shape=(3, 128, 128), *args, **kwargs):
        super(LatentSimpleData, self).__init__(*args, **kwargs)
        self.shape = torch.Size(shape)
        _loc = torch.zeros(self.shape)+loc
        _std = _loc.new_ones(self.shape)*std
        self._Normal = Normal(_loc, _std)

    def _sample(self):
        return self._Normal.sample()

    def __getitem__(self, index):
        if not self.annealing:
            data, target = super(LatentSimpleData, self).__getitem__(index)
        else:
            anneal_categorical = self._update_annealing_state()
            cluster = anneal_categorical.sample().item()
            if cluster == 1:
                data = self._anneal_Normal.sample()
                target = torch.tensor(0)
                if self.target_transform is not None:
                    target = self.target_transform(target)
            else:
                data, target = super(LatentSimpleData, self).__getitem__(index)
        return data, target

    def log_prob(self, X):
        batch_size = X.size(0)
        return self._Normal_device.log_prob(X.view(batch_size, -1)).sum(dim=1)

    def _parameters_to_device(self, device, *args, **kwargs):
        loc = self._Normal.loc.to(device=device, *args, **kwargs)
        scale = self._Normal.scale.to(device=device, *args, **kwargs)
        self._Normal_device = Normal(loc, scale)
        return self


class LatentMixtureData(LatentData):
    """Generate random data where each element is sampled from a Normal(0,1)
    Args:
    """

    def __init__(self, locs=(-1, 1), stds=(1, 1), probs=(0.5, 0.5),
                 shape=(3, 128, 128), *args, **kwargs):
        super(LatentMixtureData, self).__init__(*args, **kwargs)
        if len(locs) != len(probs):
            msg = "Number of means must equal number of mixture probabilities"
            raise ValueError(msg)
        self.shape = torch.Size(shape)
        self._normals = [Normal(torch.zeros(self.shape)+loc,
                                torch.zeros(self.shape)*std)
                         for loc, std in zip(locs, stds)]
        self._categorical = Categorical(probs=torch.Tensor(probs))

    def __getitem__(self, index):
        if not self.annealing:
            data, target = super(LatentMixtureData, self).__getitem__(index)
        else:
            anneal_categorical = self._update_annealing_state()
            cluster = anneal_categorical.sample().item()
            if cluster == 1:
                data = self._anneal_Normal.sample()
                target = torch.tensor(0)
                if self.target_transform is not None:
                    target = self.target_transform(target)
            else:
                data, target = super(LatentMixtureData, self).__getitem__(index)
        return data, target

    def _sample(self):
        d = self._normals[self._categorical.sample().item()]
        return d.sample()

    def log_prob(self, X):
        batch_size = X.size(0)

        X = X.view(batch_size, -1)
        lognormals= torch.stack([N.log_prob(X).sum(-1)
                                 for N in self._normals_device], dim=1)
        return torch.logsumexp(self._categorical_device.logits + lognormals,
                               dim=1)

    def _parameters_to_device(self, device, *args, **kwargs):
        parameters = [(N.loc.to(device=device, *args, **kwargs),
                       N.scale.to(device=device, *args, **kwargs))
                      for N in self._normals]
        self._normals_device = [Normal(loc, scale)
                                for loc, scale in parameters]
        logits = self._categorical.logits.to(device=device, *args, **kwargs)
        self._categorical_device = Categorical(logits=logits)
        return self


# horse/zebra dataset
class HorseZebra(Dataset):
    """Dataset used in cycleGAN
    For more see https://taesung.me/
    And for more data: https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/
    """

    def __init__(self, data_dir, transform=None):
        """ Horse to Zebra datasets, this dataset (including training and validation data)
        is small enough that we can load all of it into memory (~120 M)
        Args:
        data_dir (string): Path to location of data
        transform (callable, optional): Optional transform applied to each sample
        """
        super(HorseZebra, self).__init__()

        self.data_dir = Path(data_dir)
        self.transform = transform
        self.file_names = self.data_dir.glob('*.jpg')
        self.img2tensor = transforms.ToTensor()
        self.images = []
        self.n_cleanups = 0
        # a little bit of data cleanup
        for f in self.file_names:
            img = self.img2tensor(Image.open(f))
            if img.size(0) == 3:
                self.images.append(img)
            else:
                # image does not have three channels
                self.n_cleanups += 1

        self.label = str(data_dir)[-1] # either A (horse) or B (zebra)
        self._datashape = self.images[0].shape

        str_to_label = {'A': 0, 'B': 1}
        self._str2lab_fn = lambda x: torch.Tensor([str_to_label[x]]).long()

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = self.images[idx]

        if self.transform:
            sample = self.transform(img)
        else:
            sample = img
        return sample, self._str2lab_fn(self.label)

    @property
    def data_shape(self):
        return self._datashape


class LSUN(torch_dataset.LSUN):

    def __init__(self, *args, **kwargs):
        super(LSUN, self).__init__(*args, **kwargs)
        self.is_image = True

        self.shape = super(LSUN, self).__getitem__(0)[0].shape # ignore the label
