import pdb
import os
import random
import numpy as np
import torch
import torchvision
from torchvision.datasets import MNIST
import torchvision.transforms.functional as F
import torch.nn.functional as Func
from torch.utils.data import TensorDataset
import matplotlib.pyplot as plt
from torchvision.transforms import InterpolationMode

THIS_PATH = os.path.dirname(__file__)
ROOT_PATH = os.path.abspath(os.path.join(THIS_PATH, '..', '..', '..', '..', '..'))
IMAGE_PATH = os.path.join(ROOT_PATH, 'datasets/disentanglement/mnist')


def set_color_mnist_longtailed(dataset: MNIST, samples_per_class: int, train: bool) -> TensorDataset:
    """
    Generate a long-tailed dataset from the MNIST dataset.
    Args:
        dataset (MNIST): The MNIST dataset.
        samples_per_class (int): The number of samples per class.
        train (bool): Whether to generate the training or test set.
    Returns:
        TensorDataset: The long-tailed dataset.
    """
    img, labels = dataset.data.numpy(), dataset.targets.numpy()

    if train:
        # Create the power law distribution for 30 classes
        samples = np.random.power(0.3, size=30) * samples_per_class
        samples = np.ceil(samples).astype(int)

    else:
        # Create uniform distribution for 30 classes of 250 samples each.
        samples_per_class = 250
        samples = (np.ones(30) * samples_per_class).astype(int)

    imgs_rgb = []
    labels_rgb = []
    for i in range(10):
        samples_added = 0
        for j in range(3):
            # class_idx = i * 3 + j
            class_idx = (j * 10 + i) % 10
            # Get data.
            data_tmp = img[labels == i][
                       samples_added: samples_added + samples[class_idx]
                       ]
            # data_tmp = F.pad(torch.from_numpy(data_tmp), (2, 2, 2, 2), fill=0)
            data_tmp = torch.from_numpy(data_tmp)
            data = torch.zeros(data_tmp.shape + (3,))
            data[:, :, :, j] = data_tmp
            data = data.permute(0, 3, 1, 2)
            data = Func.interpolate(data, 64, mode='bilinear')
            tmp = data.numpy().transpose(0, 2, 3, 1)

            # Add data to list
            imgs_rgb.append(tmp)
            labels_rgb.extend(list(np.ones(samples[class_idx]) * class_idx))
            samples_added += samples[i]
    # Concatenate all data
    imgs_rgb = np.concatenate(imgs_rgb) / 255
    labels_rgb = np.asarray(labels_rgb)

    # Convert to tensor
    imgs_rgb = torch.from_numpy(imgs_rgb).permute(0, 3, 1, 2).float()
    targets = torch.from_numpy(labels_rgb).long()

    return TensorDataset(imgs_rgb, targets)


def set_color_rot_mnist_longtailed(dataset: MNIST,
                                   samples_per_class: int,
                                   rotations: int,
                                   train: bool) -> TensorDataset:
    """
    Generate a long-tailed dataset from the MNIST dataset.
    Args:
        dataset (MNIST): The MNIST dataset.
        samples_per_class (int): The number of samples per class.
        train (bool): Whether to generate the training or test set.
    Returns:
        TensorDataset: The long-tailed dataset.
    """
    img, labels = dataset.data.numpy(), dataset.targets.numpy()
    # pdb.set_trace()
    if train:
        # Create the power law distribution for 30 * |R| classes
        samples = np.random.power(0.3, size=30 * rotations) * samples_per_class
        samples = np.ceil(samples).astype(int)

    else:
        # Create uniform distribution for 30 classes of 250 samples each.
        samples_per_class = 250 // rotations
        samples = (np.ones(30 * rotations) * samples_per_class).astype(int)

    imgs_rgb = []
    labels_rgb = []
    for i in range(10):
        samples_added = 0
        for j in range(3):
            class_idx = (j * 10 + i) % 10 # * rotations + j * rotations + k
            # pdb.set_trace()

            # Get data.
            data_tmp = img[labels == i][
                       samples_added: samples_added + samples[class_idx]
                       ]
            # data_tmp = F.pad(torch.from_numpy(data_tmp), (2, 2, 2, 2), fill=0)
            data_tmp = torch.from_numpy(data_tmp)
            data = torch.zeros(data_tmp.shape + (3,))

            data[:, :, :, j] = data_tmp
            data = data.permute(0, 3, 1, 2)
            data = Func.interpolate(data, 64, mode='bilinear')
            data = F.rotate(data, angle=12*random.randint(0, 30))
            tmp = data.numpy().transpose(0, 2, 3, 1)

            # Add data to list
            imgs_rgb.append(tmp)
            labels_rgb.extend(list(np.ones(samples[class_idx]) * class_idx))
            samples_added += samples[class_idx]
    # Concatenate all data
    imgs_rgb = np.concatenate(imgs_rgb) / 255
    labels_rgb = np.asarray(labels_rgb)

    # Convert to tensor
    imgs_rgb = torch.from_numpy(imgs_rgb).permute(0, 3, 1, 2).float()
    targets = torch.from_numpy(labels_rgb).long()

    return TensorDataset(imgs_rgb, targets)



def download_mnist_longtailed(samples_per_class: int = 500, option: str = 'both'):
    # Create data directory
    if option == 'both':
        filename = '/color_rot_mnist_longtailed'
    elif option == 'color':
        filename = '/color_mnist_longtailed'
    else:
        raise ValueError("Option must be 'color', 'rot' or 'both'")

    os.makedirs(IMAGE_PATH + filename, exist_ok=True)
    # Download MNIST dataset
    trainset = MNIST(
        root=IMAGE_PATH + "/MNIST",
        train=True,
        download=True,
        transform=torchvision.transforms.Compose([torchvision.transforms.Resize((64, 64), interpolation=InterpolationMode.BICUBIC, antialias=True),
                                                  torchvision.transforms.ToTensor(), ]),
    )

    testset = MNIST(
        root=IMAGE_PATH + "/MNIST",
        train=False,
        download=True,
        transform=torchvision.transforms.Compose([torchvision.transforms.Resize((64, 64), interpolation=InterpolationMode.BICUBIC, antialias=True),
                                                  torchvision.transforms.ToTensor(), ]),
    )
    if option == 'both':
        trainset = set_color_rot_mnist_longtailed(trainset, samples_per_class, rotations=3, train=True)
        testset = set_color_rot_mnist_longtailed(testset, samples_per_class, rotations=3, train=False)
    elif option == 'color':
        trainset = set_color_mnist_longtailed(trainset, samples_per_class, True)
        testset = set_color_mnist_longtailed(testset, samples_per_class, False)
    else:
        raise ValueError("Option must be 'color', 'rot' or 'both'")

    # Save the imagegrid
    grid_img = torchvision.utils.make_grid(trainset.tensors[0], nrow=32)
    plt.rcParams.update({"font.size": 7})
    plt.figure(figsize=(8, 16))
    plt.imshow(grid_img.permute(1, 2, 0).numpy())
    _ = plt.xticks([]), plt.yticks([])

    plt.savefig(
        IMAGE_PATH + filename + "/longtailed_imagegrid.png",
        bbox_inches="tight",
    )
    plt.clf()

    # Save the datasets
    torch.save(
        trainset,
        IMAGE_PATH + filename + "/train.pt",
    )
    torch.save(
        testset,
        IMAGE_PATH + filename + "/test.pt",
    )
    print(f'Generated MNIST-longtailed dataset at {IMAGE_PATH + filename}')

    # plot the dataset
    samples_per_class = torch.unique(trainset.tensors[1], return_counts=True)
    sort_idx = torch.argsort(samples_per_class[1], descending=True)
    samples_per_class = (samples_per_class[0][sort_idx], samples_per_class[1][sort_idx])

    if option == 'both':
        labels = [str(360/3*k) + j + str(i) for i in range(10) for j in ['R', 'G', 'B'] for k in range(3)]
    elif option == 'color':
        labels = [j + str(i) for i in range(10) for j in ['R', 'G', 'B']]
    elif option == 'rot':
        labels = [str(i) for i in range(10)]
    else:
        raise ValueError("Option must be 'color', 'rot' or 'both'")
    labels = [labels[i] for i in sort_idx.numpy()]

    plt.figure(figsize=(6, 3))
    plt.bar(labels, samples_per_class[1].numpy())
    plt.savefig(
        IMAGE_PATH + filename + "/longtailed_histogram.png",
        bbox_inches="tight",
    )


class MNISTLongTailDataset(TensorDataset):
    def __init__(self, dataset, *args, **kwargs):
        super(MNISTLongTailDataset, self).__init__(*dataset.tensors)
        self.data = dataset.tensors

    def __getitem__(self, idx):
        x, y = tuple(tensor[idx] for tensor in self.data)
        return x, y

    def __len__(self):
        return len(self.data[0])