
import os
import torchvision
from torch.utils.data import TensorDataset
import torchvision.transforms.functional as F

THIS_PATH = os.path.dirname(__file__)
ROOT_PATH = os.path.abspath(os.path.join(THIS_PATH, '..', '..', '..', '..'))
IMAGE_PATH = os.path.join(ROOT_PATH, 'datasets/disentanglement/mnist')


import numpy as np
from torchvision.transforms import InterpolationMode
import matplotlib.pyplot as plt


def _discrete_exponential_probs(n: int, mu: int, sigma: float) -> np.ndarray:
    sigma = max(float(sigma), 1e-9)
    lmbda = 1.0 / sigma
    idx = np.arange(n, dtype=np.float64)
    logits = -lmbda * idx
    logits -= logits.max()
    probs = np.exp(logits)
    base_probs = probs / probs.sum()
    shifted_probs = np.roll(base_probs, shift=int(mu))

    return shifted_probs


import torch
import numpy as np
import random
from typing import Optional, Sequence
from torch.utils.data import TensorDataset
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from torchvision.datasets import MNIST



def _discrete_exponential_probs(n: int, mu: int, sigma: float) -> np.ndarray:
    sigma = max(float(sigma), 1e-9)
    lmbda = 1.0 / sigma
    idx = np.arange(n, dtype=np.float64)
    logits = -lmbda * idx
    logits -= logits.max()
    probs = np.exp(logits)
    base_probs = probs / probs.sum()
    return np.roll(base_probs, shift=int(mu))



def set_color_rot_mnist_biased(
        dataset,
        samples_per_class: int,
        rotations: int,
        train: bool = True,
        color_std: Optional[float] = None,
        rot_std: Optional[float] = None,
        image_size: int = 64,
        seed: Optional[int] = None,
        replace: bool = True
) -> TensorDataset:
    if seed is not None:
        np.random.seed(seed)
        random.seed(seed)
        torch.manual_seed(seed)

    imgs_gray = dataset.data.numpy()
    labels = dataset.targets.numpy()
    base_angle = 360.0 / float(rotations)
    out_imgs, out_labels = [], []

    if train:
        if color_std is None or rot_std is None:
            raise ValueError("Need color_std and rot_std.")

        p_global_color = _discrete_exponential_probs(n=3, mu=0, sigma=color_std)
        p_global_rot = _discrete_exponential_probs(n=rotations, mu=0, sigma=rot_std)
        color_means = np.random.choice(3, size=10, p=p_global_color)
        rot_means = np.random.choice(rotations, size=10, p=p_global_rot)

    for digit in range(10):
        idxs = np.where(labels == digit)[0]

        if not train:
            num_combinations = 3 * rotations
            if samples_per_class % num_combinations != 0:
                raise ValueError(
                    f"samples_per_class ({samples_per_class})must be  "
                    f"3 * rotations ({num_combinations}) devide."
                )
            samples_per_combo = samples_per_class // num_combinations

            base_choice = np.random.choice(idxs, size=samples_per_class, replace=replace)

            ptr = 0
            for c in range(3):
                for r in range(rotations):
                    chosen_idx = base_choice[ptr: ptr + samples_per_combo]
                    ptr += samples_per_combo

                    g = torch.from_numpy(imgs_gray[chosen_idx]).float().unsqueeze(1)
                    rgb = torch.zeros((samples_per_combo, 3, g.shape[-2], g.shape[-1]), dtype=torch.float32)
                    rgb[:, c:c + 1, :, :] = g

                    rgb = F.interpolate(rgb, size=(image_size, image_size), mode='bilinear', align_corners=False)

                    angle = base_angle * float(r)
                    rotated_batch = TF.rotate(rgb, angle) if angle != 0 else rgb

                    out_imgs.append(rotated_batch)
                    out_labels.append(torch.full((samples_per_combo,), digit, dtype=torch.long))

            continue
        p_color = _discrete_exponential_probs(n=3, mu=color_means[digit], sigma=0.01)
        p_rot = _discrete_exponential_probs(n=rotations, mu=rot_means[digit], sigma=0.01)

        joint = np.outer(p_color, p_rot)
        counts = np.random.multinomial(samples_per_class, joint.flatten())
        base_choice = np.random.choice(idxs, size=samples_per_class, replace=replace)
        ptr = 0

        for i, n_take in enumerate(counts):
            if n_take == 0: continue
            c, r = i // rotations, i % rotations
            chosen_idx = base_choice[ptr:ptr + n_take];
            ptr += n_take
            g = torch.from_numpy(imgs_gray[chosen_idx]).float().unsqueeze(1)
            rgb = torch.zeros((n_take, 3, g.shape[-2], g.shape[-1]), dtype=torch.float32)
            rgb[:, c:c + 1, :, :] = g
            rgb = F.interpolate(rgb, size=[image_size, image_size], mode='bilinear', antialias=True)
            angle = base_angle * float(r)
            rotated_batch = TF.rotate(rgb, angle) if angle != 0 else rgb
            out_imgs.append(rotated_batch);
            out_labels.append(torch.full((n_take,), digit, dtype=torch.long))

    images = torch.cat(out_imgs, dim=0) / 255.0
    targets = torch.cat(out_labels, dim=0)
    return torch.utils.data.TensorDataset(images, targets)




def download_mnist_biased(samples_per_class: int = 270, option: str = 'both', color_std: float=0.5, rot_std: float =0.5):
    # Create data directory
    if option == 'both':
        filename = '/color_rot_mnist_biased_{}_{}'.format(color_std, rot_std)
    elif option == 'color':
        filename = '/color_mnist_biased'
    else:
        raise ValueError("Option must be 'color', 'rot' or 'both'")

    os.makedirs(IMAGE_PATH + filename, exist_ok=True)
    # Download MNIST dataset
    trainset = MNIST(
        root=IMAGE_PATH + "/MNIST",
        train=True,
        download=True,
        transform=torchvision.transforms.Compose([torchvision.transforms.Resize((64, 64), interpolation=InterpolationMode.BICUBIC, antialias=True),
                                                  torchvision.transforms.ToTensor(), ]),
    )

    testset = MNIST(
        root=IMAGE_PATH + "/MNIST",
        train=False,
        download=True,
        transform=torchvision.transforms.Compose([torchvision.transforms.Resize((64, 64), interpolation=InterpolationMode.BICUBIC, antialias=True),
                                                  torchvision.transforms.ToTensor(), ]),
    )
    if option == 'both':
        trainset = set_color_rot_mnist_biased(trainset, samples_per_class, rotations=3, train=True, color_std=color_std, rot_std=rot_std, replace=False)
        testset = set_color_rot_mnist_biased(testset, samples_per_class*5, rotations=3, train=False, color_std=color_std, rot_std=rot_std)
    else:
        raise ValueError("Option must be 'color', 'rot' or 'both'")

    # Save the imagegrid
    grid_img = torchvision.utils.make_grid(trainset.tensors[0], nrow=32)
    plt.rcParams.update({"font.size": 7})
    plt.figure(figsize=(8, 16))
    plt.imshow(grid_img.permute(1, 2, 0).numpy())
    _ = plt.xticks([]), plt.yticks([])

    plt.savefig(
        IMAGE_PATH + filename + "/biased_train_imagegrid.png",
        bbox_inches="tight",
    )
    plt.clf()

    # Save the imagegrid
    grid_img = torchvision.utils.make_grid(testset.tensors[0], nrow=32)
    plt.rcParams.update({"font.size": 7})
    plt.figure(figsize=(8, 16))
    plt.imshow(grid_img.permute(1, 2, 0).numpy())
    _ = plt.xticks([]), plt.yticks([])

    plt.savefig(
        IMAGE_PATH + filename + "/biased_test_imagegrid.png",
        bbox_inches="tight",
    )
    plt.clf()

    # Save the datasets
    torch.save(
        trainset,
        IMAGE_PATH + filename + "/train.pt",
    )
    torch.save(
        testset,
        IMAGE_PATH + filename + "/test.pt",
    )
    print(f'Generated MNIST-biased dataset at {IMAGE_PATH + filename}')

    # plot the dataset
    samples_per_class = torch.unique(trainset.tensors[1], return_counts=True)
    sort_idx = torch.argsort(samples_per_class[1], descending=True)
    samples_per_class = (samples_per_class[0][sort_idx], samples_per_class[1][sort_idx])

    if option == 'both':
        labels = [str(360/3*k) + j + str(i) for i in range(10) for j in ['R', 'G', 'B'] for k in range(3)]
    elif option == 'color':
        labels = [j + str(i) for i in range(10) for j in ['R', 'G', 'B']]
    elif option == 'rot':
        labels = [str(i) for i in range(10)]
    else:
        raise ValueError("Option must be 'color', 'rot' or 'both'")
    labels = [labels[i] for i in sort_idx.numpy()]

    plt.figure(figsize=(6, 3))
    plt.bar(labels, samples_per_class[1].numpy())
    plt.savefig(
        IMAGE_PATH + filename + "/biased_histogram.png",
        bbox_inches="tight",
    )

    return


class MNISTBias(TensorDataset):
    def __init__(self, dataset, *args, **kwargs):
        super(MNISTBias, self).__init__(*dataset.tensors)
        self.data = dataset.tensors

    def __getitem__(self, idx):
        x, y = tuple(tensor[idx] for tensor in self.data)
        return x, y

    def __len__(self):
        return len(self.data[0])
