"""
Colored MNIST Dataset
"""


import copy
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from matplotlib.colors import LinearSegmentedColormap, to_rgb
from tqdm import tqdm

import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import make_grid
from os.path import join

def train_val_split(dataset, val_split, seed):
    """
    Compute indices for train and val splits
    
    Args:
    - dataset (torch.utils.data.Dataset): Pytorch dataset
    - val_split (float): Fraction of dataset allocated to validation split
    - seed (int): Reproducibility seed
    Returns:
    - train_indices, val_indices (np.array, np.array): Dataset indices
    """
    train_ix = int(np.round(val_split * len(dataset)))
    all_indices = np.arange(len(dataset))
    np.random.seed(seed)
    np.random.shuffle(all_indices)
    train_indices = all_indices[train_ix:]
    val_indices = all_indices[:train_ix]
    return train_indices, val_indices

class ColoredMNIST(Dataset):
    """
    Colored MNIST dataset - labels spuriously correlated with color
    - We store the label, the spurious attribute, and subclass labels if applicable
    Args:
    - data (torch.Tensor): MNIST images
    - targets (torch.Tensor): MNIST original labels
    - train_classes (list[]): List of lists describing how to organize labels
                                - Each inner list denotes a group, i.e. 
                                they all have the same classification label
                                - Any labels left out are excluded from training set
    - train (bool): Training or test dataset
    - p_correlation (float): Strength of spurious correlation, in [0, 1]
    - test_shift (str): How to organize test set, from 'random', 'same', 'new'
    - cmap (str): Colormap for coloring MNIST digits
    - flipped (bool): If true, color background and keep digit black
    - transform (torchvision.transforms): Image transformations
    Returns:
    - __getitem__() returns tuple of image, label, and the index, which can be used for
                    looking up additional info (e.g. subclass label, spurious attribute)
    """

    def __init__(self, data, targets, train_classes=[[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]],
                 train=True, p_correlation=0.995, test_shift='random', cmap='hsv',
                 flipped=False, transform=None):    
        # Initialize classes
        self.class_map = self._init_class_map(train_classes)
        self.classes = list(self.class_map.keys())
        self.new_classes = np.unique(list(self.class_map.values()))

        self.test_classes = [x for x in np.unique(
            targets) if x not in self.classes]
        # Setup spurious correlation ratios per class
        self.p_correlation = [p_correlation, 0.95, 0.9, 0.8, 0.6]
        self.train = train
        self.test_shift = test_shift
        self.transform = transform

        # Filter for train_classes
        class_filter = torch.stack([(targets == i)
                                    for i in self.classes]).sum(dim=0)
        self.targets = targets[class_filter > 0]
        data = data[class_filter > 0]

        self.targets_all = {'spurious': np.zeros(len(self.targets), dtype=int),
                            'sub_target': copy.deepcopy(self.targets)}
        # Update targets
        self.targets = torch.tensor([self.class_map[t.item()] for t in self.targets],
                                    dtype=self.targets.dtype)
        self.targets_all['target'] = self.targets.numpy()
        
        # Colors + Data
        self.colors = self._init_colors(cmap)
        if flipped:
            data = 255 - data
        if data.shape[1] != 3:   # Add RGB channels
            data = data.unsqueeze(1).repeat(1, 3, 1, 1)
        self.data = self._init_data(data)
        self.spurious_group_names = self.colors
                
        self.n_classes = len(train_classes)
        self.n_groups = pow(self.n_classes, 2)
        target_spurious_to_group_ix = np.arange(self.n_groups).reshape((self.n_classes, self.n_classes)).astype('int')
        
        # Access datapoint's subgroup idx, i.e. 1 of 25 diff values if we have 5 classes, 5 colors
        group_array = []
        for ix in range(len(self.targets_all['target'])):
            y = self.targets_all['target'][ix]
            a = self.targets_all['spurious'][ix]
            group_array.append(target_spurious_to_group_ix[y][a])
        group_array = np.array(group_array)
        self.group_array = torch.LongTensor(group_array)
        
        # Index for (y, a) group
        all_group_labels = []
        for n in range(self.n_classes):
            for m in range(self.n_classes):
                all_group_labels.append(str((n, m)))
        self.targets_all['group_idx'] = self.group_array.numpy()
        self.group_labels = all_group_labels

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, idx):
        sample = self.data[idx]
        if self.transform:
            sample = self.transform(sample)
        return (sample, self.targets[idx])

    def _init_class_map(self, classes):
        class_map = {}
        for c_ix, targets in enumerate(classes):
            for t in targets:
                class_map[t] = c_ix
        return class_map

    def _init_colors(self, cmap):
        # Initialize list of RGB color values
        try:
            cmap = cm.get_cmap(cmap)
        except ValueError:  # single color
            cmap = self._get_single_color_cmap(cmap)
        cmap_vals = np.arange(0, 1, step=1 / len(self.new_classes))
        colors = []
        for ix, c in enumerate(self.new_classes):
            rgb = cmap(cmap_vals[ix])[:3]
            rgb = [int(np.float(x)) for x in np.array(rgb) * 255]
            colors.append(rgb)
        return colors

    def _get_single_color_cmap(self, c):
        rgb = to_rgb(c)
        r1, g1, b1 = rgb
        cdict = {'red':   ((0, r1, r1),
                           (1, r1, r1)),
                 'green': ((0, g1, g1),
                           (1, g1, g1)),
                 'blue':  ((0, b1, b1),
                           (1, b1, b1))}
        cmap = LinearSegmentedColormap('custom_cmap', cdict)
        return cmap

    def _init_data(self, data):
        self.selected_indices = []
        pbar = tqdm(total=len(self.targets), desc='Initializing data')
        for ix, c in enumerate(self.new_classes):
            class_ix = np.where(self.targets == c)[0]
            is_spurious = np.random.binomial(1, self.p_correlation[ix],
                                             size=len(class_ix))
            for cix_, cix in enumerate(class_ix):
                # Replace pixels
                pixels_r = np.where(
                    np.logical_and(data[cix, 0, :, :] >= 120,
                                   data[cix, 0, :, :] <= 255))
                # May refactor this out as a separate function later
                if self.train or self.test_shift == 'iid':
                    color_ix = (ix if is_spurious[cix_] else
                                np.random.choice([
                                    x for x in np.arange(len(self.colors)) if x != ix]))
                elif 'shift' in self.test_shift:
                    n = int(self.test_shift.split('_')[-1])
                    color_ix = (ix + n) % len(self.new_classes)
                else:
                    color_ix = np.random.randint(len(self.colors))
                color = self.colors[color_ix]
                data[cix, :, pixels_r[0], pixels_r[1]] = (
                    torch.tensor(color, dtype=torch.uint8).unsqueeze(1).repeat(1, len(pixels_r[0])))
                self.targets_all['spurious'][cix] = int(color_ix)
                pbar.update(1)
        return data.float() / 255  # For normalization

    def get_dataloader(self, batch_size, shuffle, num_workers):
        return DataLoader(self, batch_size=batch_size, shuffle=shuffle,
                          num_workers=num_workers)


def load_colored_mnist(args, train=True, transform=None):
    """
    Default dataloader setup for Colored MNIST
    Args:
    - transform (torchvision.transforms): Image transformations
    Returns:
    - (train_loader, test_loader): Tuple of dataloaders for train and test
    """
    
    transform = (transforms.Compose([transforms.Resize(40),
                                     transforms.RandomCrop(32, padding=0),
                                     transforms.Normalize((0.5, 0.5, 0.5),
                                                          (0.5, 0.5, 0.5))])
                 if transform is None else transform)
    
    mnist = torchvision.datasets.MNIST(root='./data', train=train, download=True)

    colored_mnist = ColoredMNIST(data=mnist.data,
                                    targets=mnist.targets,
                                    train=train,
                                    transform=transform,
                                    p_correlation=args.p_correlation)

    return colored_mnist
    
    
# Refactor for modularity
def load_dataloaders(args, train=True, transform=None):
    return load_colored_mnist(args, train, transform)

if __name__ == '__main__':
    transform = (transforms.Compose([transforms.Resize(40),
                                        transforms.RandomCrop(32, padding=0),
                                        transforms.Normalize((0.5, 0.5, 0.5),
                                                            (0.5, 0.5, 0.5))]))
    mnist = torchvision.datasets.MNIST(root='./data', train=True, download=True)
    ColoredMNIST(data=mnist.data, targets=mnist.targets)