import torch
import random
import numpy as np
from torch.utils.data import Dataset
from torchvision import transforms, datasets
from PIL import ImageFilter, Image


class TwoCropsTransform:
    """Take two random crops of one image as the query and key."""
    def __init__(self, base_transform):
        self.base_transform = base_transform

    def __call__(self, x):
        q = self.base_transform(x)
        k = self.base_transform(x)
        return [q, k]


class GaussianBlur(object):
    """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709"""
    def __init__(self, sigma=None):
        if sigma is None:
            sigma = [.1, 2.]
        self.sigma = sigma

    def __call__(self, x):
        sigma = random.uniform(self.sigma[0], self.sigma[1])
        x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
        return x


class CIFAR10Dataset(Dataset):
    def __init__(self, path, train=True, augmented=True):
        self.train = train
        self.augmented = augmented
        self.data = datasets.CIFAR10(root=path, train=train, download=True, transform=None)

        if self.augmented:
            self.transform = TwoCropsTransform(transforms.Compose([
                transforms.RandomResizedCrop(size=32, scale=(0.2, 1.0)),
                transforms.RandomHorizontalFlip(),
                transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
                GaussianBlur(sigma=[0.1, 2.0]),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])
            ]))
        else:
            self.transform = transforms.Compose([
                transforms.Resize(32),
                transforms.CenterCrop(32),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])
            ])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img, target = self.data[idx]

        if self.train and self.augmented:
            x1, x2 = self.transform(img)
            return x1, x2
        else:
            img = self.transform(img)
            return img, target


class FashionMNISTDataset(Dataset):
    def __init__(self, path, train=True, augmented=True):
        self.train = train
        self.augmented = augmented
        self.data = datasets.FashionMNIST(root=path, train=train, download=True, transform=None)

        if self.augmented:
            self.transform = TwoCropsTransform(transforms.Compose([
                transforms.RandomResizedCrop(size=28, scale=(0.8, 1.0)),
                transforms.RandomHorizontalFlip(),
                transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5], std=[0.5])  # grayscale
            ]))
        else:
            self.transform = transforms.Compose([
                transforms.Resize(28),
                transforms.CenterCrop(28),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5], std=[0.5])
            ])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img, target = self.data[idx]

        if self.train and self.augmented:
            x1, x2 = self.transform(img)
            return x1, x2
        else:
            img = self.transform(img)
            return img, target


class DsrpitesDataset(Dataset):
    def __init__(self, npy_files):
        self.data = np.concatenate([np.load(file) for file in npy_files], axis=0)
        self.transform = transforms.Compose([transforms.ToTensor()])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        sample = self.data[idx]
        if self.transform:
            sample = self.transform(sample)
        return sample


dsprites_latent_vl = {'color': np.array([1.]),
                      'shape': np.array([1., 2., 3.]),  # square, ellipse, heart
                      'scale': np.array([0.5, 0.6, 0.7, 0.8, 0.9, 1.]),
                      'orientation': np.array([0., 0.16110732, 0.32221463, 0.48332195,
                                               0.64442926, 0.80553658, 0.96664389, 1.12775121,
                                               1.28885852, 1.44996584, 1.61107316, 1.77218047,
                                               1.93328779, 2.0943951, 2.25550242, 2.41660973,
                                               2.57771705, 2.73882436, 2.89993168, 3.061039,
                                               3.22214631, 3.38325363, 3.54436094, 3.70546826,
                                               3.86657557, 4.02768289, 4.1887902, 4.34989752,
                                               4.51100484, 4.67211215, 4.83321947, 4.99432678,
                                               5.1554341, 5.31654141, 5.47764873, 5.63875604,
                                               5.79986336, 5.96097068, 6.12207799, 6.28318531]),  # [0, 2 pi]
                      'posX': np.array([0., 0.03225806, 0.06451613, 0.09677419, 0.12903226,
                                        0.16129032, 0.19354839, 0.22580645, 0.25806452,
                                        0.29032258, 0.32258065, 0.35483871, 0.38709677,
                                        0.41935484, 0.4516129, 0.48387097, 0.51612903,
                                        0.5483871, 0.58064516, 0.61290323, 0.64516129,
                                        0.67741935, 0.70967742, 0.74193548, 0.77419355,
                                        0.80645161, 0.83870968, 0.87096774, 0.90322581,
                                        0.93548387, 0.96774194, 1.]),
                      'posY': np.array([0., 0.03225806, 0.06451613, 0.09677419, 0.12903226,
                                        0.16129032, 0.19354839, 0.22580645, 0.25806452,
                                        0.29032258, 0.32258065, 0.35483871, 0.38709677,
                                        0.41935484, 0.4516129, 0.48387097, 0.51612903,
                                        0.5483871, 0.58064516, 0.61290323, 0.64516129,
                                        0.67741935, 0.70967742, 0.74193548, 0.77419355,
                                        0.80645161, 0.83870968, 0.87096774, 0.90322581,
                                        0.93548387, 0.96774194, 1.]),
                      }


class DSpritesDataset(Dataset):
    def __init__(self, transform=None):
        file_path = '../data/dsprites/dsprites_dataset.npz'
        data = np.load(file_path)
        self.imgs = data['imgs'].astype(np.float32)  # Convert images to float32
        self.latents_values = data['latents_values']
        self.latents_classes = data['latents_classes']  # Assuming this is the correct key
        self.transform = transform
        self.set_indices = self.find_angle_sets()
        self.size = 737280

    def __len__(self):
        return self.imgs.shape[0]

    def __getitem__(self, idx):
        image = self.imgs[idx][None, :, :]
        image = Image.fromarray((image.squeeze() * 255).astype(np.uint8), mode='L')
        if self.transform:
            image = self.transform(image)

        latent_values = self.latents_values[idx]
        latent_classes = self.latents_classes[idx]
        return image, latent_values, latent_classes

    def find_angle_sets(self, angle1=0, angle2=10, angle3=20):
        sets_indices = []
        angle_indices = {angle1: {}, angle2: {}, angle3: {}}
        for i, label in enumerate(self.latents_classes):
            angle = label[3]
            if angle in [angle1, angle2, angle3]:
                key = tuple(np.concatenate((label[:3], label[4:])))
                angle_indices[angle][key] = i
        for key in angle_indices[angle1]:
            if key in angle_indices[angle2] and key in angle_indices[angle3]:
                set_idx = (angle_indices[angle1][key], angle_indices[angle2][key], angle_indices[angle3][key])
                sets_indices.append(set_idx)
        return sets_indices

    def save_sets(self):
        indices = []
        imgs_1 = []
        imgs_2 = []
        imgs_3 = []
        labels = []

        for idx1, idx2, idx3 in self.set_indices:
            indices.append((idx1, idx2, idx3))
            img1, img2, img3 = self.imgs[idx1], self.imgs[idx2], self.imgs[idx3]
            imgs_1.append(img1)
            imgs_2.append(img2)
            imgs_3.append(img3)
            label1 = self.latents_classes[idx1]
            labels.append(label1)

        np.save('pairs_indices.npy', np.array(indices))
        np.save('imgs_a.npy', np.array(imgs_1))
        np.save('imgs_test.npy', np.array(imgs_2))
        np.save('imgs_b.npy', np.array(imgs_3))
        np.save('labels.npy', np.array(labels))
        print(len(self.set_indices))


if __name__ == "__main__":
    ds = DSpritesDataset(transform=None)
    ds.save_sets()
