"""ReBias
Copyright (c) 2020-present NAVER Corp.
MIT license

Python implementation of Biased-MNIST.
"""
import logging
import os
import pickle
from pathlib import Path

import numpy as np
import torch
from PIL import Image
from debias.datasets.utils import TwoCropTransform, get_confusion_matrix, get_unsup_confusion_matrix
from torch.utils import data
from torchvision import transforms
from torchvision.datasets import MNIST
import random 

class BiasedMNIST(MNIST):
    """A base class for Biased-MNIST.
    We manually select ten colours to synthetic colour bias. (See `COLOUR_MAP` for the colour configuration)
    Usage is exactly same as torchvision MNIST dataset class.

    You have two paramters to control the level of bias.

    Parameters
    ----------
    root : str
        path to MNIST dataset.
    data_label_correlation : float, default=1.0
        Here, each class has the pre-defined colour (bias).
        data_label_correlation, or `rho` controls the level of the dataset bias.

        A sample is coloured with
            - the pre-defined colour with probability `rho`,
            - coloured with one of the other colours with probability `1 - rho`.
              The number of ``other colours'' is controlled by `n_confusing_labels` (default: 9).
        Note that the colour is injected into the background of the image (see `_binary_to_colour`).

        Hence, we have
            - Perfectly biased dataset with rho=1.0
            - Perfectly unbiased with rho=0.1 (1/10) ==> our ``unbiased'' setting in the test time.
        In the paper, we explore the high correlations but with small hints, e.g., rho=0.999.

    n_confusing_labels : int, default=9
        In the real-world cases, biases are not equally distributed, but highly unbalanced.
        We mimic the unbalanced biases by changing the number of confusing colours for each class.
        In the paper, we use n_confusing_labels=9, i.e., during training, the model can observe
        all colours for each class. However, you can make the problem harder by setting smaller n_confusing_labels, e.g., 2.
        We suggest to researchers considering this benchmark for future researches.
    """

    COLOUR_MAP = [[255, 0, 0], [0, 255, 0], [0, 0, 255], [225, 225, 0], [225, 0, 225],
                  [0, 255, 255], [255, 128, 0], [255, 0, 128], [128, 0, 255], [128, 128, 128]]

    def __init__(self, root, bias_feature_root='./biased_feats', split='train', transform=None, target_transform=None,
                 download=False, data_label_correlation=1.0, n_confusing_labels=9, seed=1, load_bias_feature=False, train_corr=None, under_sample=None):
        assert split in ['train', 'valid']
        train = split in ['train']
        super().__init__(root, train=train, transform=transform,
                         target_transform=target_transform,
                         download=download)

        self.data_label_correlation = data_label_correlation
        self.load_bias_feature = load_bias_feature
        if self.load_bias_feature:
            if train_corr:
                bias_feature_dir = f'{bias_feature_root}/train{train_corr}-corr{data_label_correlation}-seed{seed}'
                logging.info(f'load bias feature: {bias_feature_dir}')
                self.bias_features = torch.load(f'{bias_feature_dir}/bias_feats.pt')
                self.marginal = torch.load(f'{bias_feature_dir}/marginal.pt')
            else:
                bias_feature_dir = f'{bias_feature_root}/color_mnist-corr{data_label_correlation}-seed{seed}'
                logging.info(f'load bias feature: {bias_feature_dir}')
                self.bias_features = torch.load(f'{bias_feature_dir}/bias_feats.pt')
                self.marginal = torch.load(f'{bias_feature_dir}/marginal.pt')

        save_path = Path(root) / 'pickles' / f'color_mnist-corr{data_label_correlation}-seed{seed}' / split
        if save_path.is_dir():
            logging.info(f'use existing color_mnist from {save_path}')
            self.data = pickle.load(open(save_path / 'data.pkl', 'rb'))
            self.targets = pickle.load(open(save_path / 'targets.pkl', 'rb'))
            self.biased_targets = pickle.load(open(save_path / 'biased_targets.pkl', 'rb'))
        else:
            self.random = True

            self.n_confusing_labels = n_confusing_labels
            self.data, self.targets, self.biased_targets = self.build_biased_mnist()

            indices = np.arange(len(self.data))
            self._shuffle(indices)

            self.data = self.data[indices].numpy()
            self.targets = self.targets[indices]
            self.biased_targets = self.biased_targets[indices]

            logging.info(f'save color_mnist to {save_path}')
            save_path.mkdir(parents=True, exist_ok=True)
            pickle.dump(self.data, open(save_path / 'data.pkl', 'wb'))
            pickle.dump(self.targets, open(save_path / 'targets.pkl', 'wb'))
            pickle.dump(self.biased_targets, open(save_path / 'biased_targets.pkl', 'wb'))

    

        targets_ = torch.zeros((len(self.targets), 10)) 
        targets_[torch.arange((len(self.targets, ))), self.targets] = 1 
        self.targets_bin = targets_ 

        self.calculate_bias_weights()

        if under_sample: 

            self.targets_ = torch.clone(self.targets)
            self.biased_targets_ = torch.clone(self.biased_targets) 
            self.targets_bin_ = torch.clone(self.targets_bin)
            self.groups_counts_ = torch.clone(self.groups_counts)
            self.data_ = np.copy(self.data)
        
        if under_sample == 'bin' and split == 'train': 
            self.under_sample_bin(verbose=True)
            self.remove_unused_labels()

        if under_sample == 'ce' and split == 'train':
            self.under_sample_ce(verbose=True)
        
        if under_sample == 'os' and split == 'train':
            self.over_sample_ce(verbose=True)

        if load_bias_feature:
            self.confusion_matrix_org, self.confusion_matrix = get_unsup_confusion_matrix(num_classes=10,
                                                                                          targets=self.targets,                                                                 biases=self.biased_targets,
                                                                                          marginals=self.marginal)
        else:
            self.confusion_matrix_org, self.confusion_matrix, self.confusion_matrix_by = get_confusion_matrix(
                num_classes=10,
                targets=self.targets,
                biases=self.biased_targets)

    @property
    def raw_folder(self):
        return os.path.join(self.root, 'raw')

    @property
    def processed_folder(self):
        return os.path.join(self.root, 'processed')

    def _shuffle(self, iteratable):
        if self.random:
            np.random.shuffle(iteratable)

    def _make_biased_mnist(self, indices, label):
        raise NotImplementedError

    def _update_bias_indices(self, bias_indices, label):
        if self.n_confusing_labels > 9 or self.n_confusing_labels < 1:
            raise ValueError(self.n_confusing_labels)

        indices = np.where((self.targets == label).numpy())[0]
        self._shuffle(indices)
        indices = torch.LongTensor(indices)

        n_samples = len(indices)
        n_correlated_samples = int(n_samples * self.data_label_correlation)
        n_decorrelated_per_class = int(np.ceil((n_samples - n_correlated_samples) / (self.n_confusing_labels)))

        correlated_indices = indices[:n_correlated_samples]
        bias_indices[label] = torch.cat([bias_indices[label], correlated_indices])

        decorrelated_indices = torch.split(indices[n_correlated_samples:], n_decorrelated_per_class)

        other_labels = [_label % 10 for _label in range(label + 1, label + 1 + self.n_confusing_labels)]
        self._shuffle(other_labels)

        for idx, _indices in enumerate(decorrelated_indices):
            _label = other_labels[idx]
            bias_indices[_label] = torch.cat([bias_indices[_label], _indices])


    def under_sample_bin(self, verbose=False): 
        self.targets_bin  = torch.clone(self.targets_bin_) 
        for idx in range(10):
            pos_samples, pos_counts = self.get_samples_count(1, self.targets_bin[:, idx]) 
            maj_idx = np.argmax(pos_counts) 
            
            #print(pos_counts) 
            #print('----') 

            for i in range(10): 
                pos_samples_, pos_counts_ = self.get_samples_count(1, self.targets_bin[:, i]) 
                if np.argmax(pos_counts_) == maj_idx: 
                    continue 
                
                min_count = pos_counts_[maj_idx]*(1-self.data_label_correlation)/self.data_label_correlation
                # min_count = pos_counts_[maj_idx]*(1-self.data_label_correlation)/(self.data_label_correlation*9)
                for j in range(10): 
                    if j == maj_idx: 
                        continue
                    self.drop_samples_bin(min_count, idx, j, pos_samples_)
           
            # pos_samples, pos_counts = self.get_samples_count(1, self.targets_bin[:, idx]) 
            # neg_samples, neg_counts = self.get_samples_count(0, self.targets_bin[:, idx]) 
            # if verbose: 
            #     print('Resulting Counts: ') 
            #     print(pos_counts)
            #     print(neg_counts) 
            _, pos_counts = self.get_samples_count(1, self.targets_bin[:, idx]) 
            print(pos_counts)
            for j in range(10): 
                if j == idx: 
                    continue
                _, neg_counts = self.get_samples_count_conditioned(0, self.targets_bin[:, idx], self.targets, j) 
                print(neg_counts, j) 
            print('----------')
            # quit()


    # def under_sample_bin(self, verbose=False): 
    #     self.targets_bin  = torch.clone(self.targets_bin_) 
    #     for class_idx in range(10):
    #         pos_samples, pos_counts = self.get_samples_count(1, self.targets_bin[:, class_idx]) 
    #         maj_idx = np.argmax(pos_counts) 
            
    #         neg_class_idxs = [idx for idx in range(10) if idx != class_idx]
    #         neg_bias_idxs = [idx for idx in range(10) if idx != maj_idx]

    #         worst_idx = np.argmin(pos_counts)

    #         for neg_class_idx in neg_class_idxs: 

    #             pos_samples_, pos_counts_ = self.get_samples_count(neg_class_idx, self.targets) 
    #             min_count = (pos_counts_[maj_idx]*pos_counts[worst_idx])/(pos_counts[maj_idx])
                
                # if min_count < 1: 
                #     new_maj = (pos_counts[worst_idx] * pos_counts_[maj_idx])/1
                #     self.drop_samples_bin(new_maj, class_idx, maj_idx, pos_samples)
                #     pos_samples, pos_counts = self.get_samples_count(1, self.targets_bin[:, class_idx]) 

    #         for neg_class_idx in neg_class_idxs: 

    #             pos_samples_, pos_counts_ = self.get_samples_count_conditioned(0, self.targets_bin[:, class_idx], self.targets, neg_class_idx) 
    #             for neg_bias_idx in neg_bias_idxs: 
    #                 min_count = (pos_counts_[maj_idx]*pos_counts[neg_bias_idx])/(pos_counts[maj_idx])
    #                 if min_count < 1: 

    #                     new_maj = (pos_counts[maj_idx]*3)/pos_counts[neg_bias_idx] - pos_counts_[maj_idx] 
    #                     new_maj = round(new_maj)
    #                     if new_maj < 0: 
    #                         continue
    #                     new_maj_inds = np.random.choice(pos_samples_[maj_idx], new_maj)
                        
    #                     self.add_indices(new_maj_inds, class_idx)
    #                     pos_samples_, pos_counts_ = self.get_samples_count_conditioned(0, self.targets_bin[:, class_idx], self.targets, neg_class_idx) 


    #         for neg_class_idx in neg_class_idxs: 
                
    #             pos_samples_, pos_counts_ = self.get_samples_count_conditioned(0, self.targets_bin[:, class_idx], self.targets, neg_class_idx) 
    #             for neg_bias_idx in neg_bias_idxs: 
    #                 min_count = (pos_counts_[maj_idx]*pos_counts[neg_bias_idx])/(pos_counts[maj_idx])
    #                 if min_count < 1: 
    #                     min_count = 1                        
    #                 self.drop_samples_bin(min_count, class_idx, neg_bias_idx, pos_samples_)

    #             # pos_samples_, pos_counts_ = self.get_samples_count_conditioned(0, self.targets_bin[:, class_idx], self.targets, neg_class_idx) 
    #             # for neg_bias_idx in neg_bias_idxs: 
    #             #     ratio_neg = pos_counts_[neg_bias_idx]/pos_counts_[maj_idx]
    #             #     ratio_pos = pos_counts[neg_bias_idx]/pos_counts[maj_idx]

    #             #     if np.abs(ratio_pos - ratio_neg) > 0.001: 
    #             #         new_maj = pos_counts[maj_idx]/pos_counts[neg_bias_idx] - pos_counts_[maj_idx] 
    #             #         new_maj = round(new_maj)
    #             #         if new_maj < 0: 
    #             #             continue
    #             #         new_maj_inds = np.random.choice(pos_samples_[maj_idx], new_maj)
                        
    #             #         self.add_indices(new_maj_inds, class_idx)
    #             #         pos_samples_, pos_counts_ = self.get_samples_count_conditioned(0, self.targets_bin[:, class_idx], self.targets, neg_class_idx) 
                
    #         _, pos_counts = self.get_samples_count(1, self.targets_bin[:, class_idx]) 
    #         print(pos_counts)
    #         for j in range(10): 
    #             if j == class_idx: 
    #                 continue
    #             _, neg_counts = self.get_samples_count_conditioned(0, self.targets_bin[:, class_idx], self.targets, j) 
    #             print(neg_counts, j) 
    #         print('----------')


    def add_indices(self, indices, class_idx): 
        self.targets = torch.cat((self.targets, self.targets[indices]), dim=0)
        self.biased_targets = torch.cat((self.biased_targets, self.biased_targets[indices]), dim=0)
        self.groups_counts = torch.cat((self.groups_counts, self.groups_counts[indices]), dim=0)
        self.data = np.concatenate((self.data, self.data[indices]), axis=0)
        
        new_targets_bin = torch.full((len(indices), 10), -1)
        new_targets_bin[:, class_idx] = 0
        self.targets_bin = torch.cat((self.targets_bin, new_targets_bin), dim=0)

    def drop_samples_bin(self, to_keep, idx, bias, samples): 

        idx_pool = samples[bias]
        to_keep_samples = [] 
        if len(idx_pool) == 0: 
            return
        to_keep_samples.extend(random.sample(idx_pool, round(to_keep)))
        to_drop_ar = list(set(samples[bias]) - set(to_keep_samples))
        self.targets_bin[to_drop_ar, idx] = -1
    
    def get_samples_count_conditioned(self, pos_neg, targets, targets2, idx2): 
        samples = [] 
        counts = [] 

        all_idx = np.arange(len(targets)) 
        for i in range(10):
            idx = torch.logical_and(targets==pos_neg, self.biased_targets==i)
            idx = torch.logical_and(idx, targets2 == idx2)
            idxs_ = all_idx[idx] 
            samples.append(list(idxs_)) 
            counts.append(len(idxs_)) 

        return samples, counts


    def get_samples_count(self, pos_neg, targets): 
        samples = [] 
        counts = [] 

        all_idx = np.arange(len(targets)) 
        for i in range(10):
            idxs_ = all_idx[torch.logical_and(targets==pos_neg, self.biased_targets==i)] 
            samples.append(list(idxs_)) 
            counts.append(len(idxs_)) 

        return samples, counts


    def remove_unused_labels(self): 

        
        to_keep_samples = torch.sum(self.targets_bin, dim=-1)
        to_keep_samples = torch.arange(len(self.targets))[to_keep_samples != -10]

        self.targets = self.targets[to_keep_samples]
        self.biased_targets = self.biased_targets[to_keep_samples]
        self.groups_counts = self.groups_counts[to_keep_samples]
        self.targets_bin = self.targets_bin[to_keep_samples]
        self.data = self.data[to_keep_samples]


    def under_sample_ce(self, verbose=False):
        self.targets = torch.clone(self.targets_)
        self.biased_targets = torch.clone(self.biased_targets_) 
        self.targets_bin = torch.clone(self.targets_bin_)
        self.groups_counts = torch.clone(self.groups_counts_)
        self.data = np.copy(self.data_)

        total_counts = [] 
        for class_idx in range(10):

            _, pos_counts = self.get_samples_count(
                                            class_idx,
                                            self.targets) 
            total_counts.extend(pos_counts)
        
        min_count = np.min(total_counts)
        to_keep_samples = [] 

        for class_idx in range(10): 
            pos_samples, _ = self.get_samples_count(
                                            class_idx,
                                            self.targets)
            for bias_pos_samples in pos_samples: 
                to_keep_samples.extend(random.sample(bias_pos_samples, min_count))
        
        self.targets = self.targets[to_keep_samples]
        self.biased_targets = self.biased_targets[to_keep_samples]
        self.groups_counts = self.groups_counts[to_keep_samples]
        self.targets_bin = self.targets_bin[to_keep_samples]
        self.data = self.data[to_keep_samples]

        if verbose: 
            for class_idx in range(10):
                _, pos_counts = self.get_samples_count(class_idx,
                                                       self.targets) 
                print(pos_counts)

    def over_sample_ce(self, verbose=False):
        self.targets = torch.clone(self.targets_)
        self.biased_targets = torch.clone(self.biased_targets_) 
        self.targets_bin = torch.clone(self.targets_bin_)
        self.groups_counts = torch.clone(self.groups_counts_)
        self.data = np.copy(self.data_)

        total_counts = [] 
        for class_idx in range(10):

            _, pos_counts = self.get_samples_count(
                                            class_idx,
                                            self.targets) 
            total_counts.extend(pos_counts)
        
        max_count = np.max(total_counts)
        to_keep_samples = [] 

        for class_idx in range(10): 
            pos_samples, _ = self.get_samples_count(
                                            class_idx,
                                            self.targets)
            for bias_pos_samples in pos_samples: 
                to_keep_samples.extend(np.random.choice(bias_pos_samples, max_count, replace=True))
        
        self.targets = self.targets[to_keep_samples]
        self.biased_targets = self.biased_targets[to_keep_samples]
        self.groups_counts = self.groups_counts[to_keep_samples]
        self.targets_bin = self.targets_bin[to_keep_samples]
        self.data = self.data[to_keep_samples]

        if verbose: 
            for class_idx in range(10):
                _, pos_counts = self.get_samples_count(class_idx,
                                                       self.targets) 
                print(pos_counts)


    def calculate_bias_weights(self): 
        
        groups_counts = [0]*100
        for t, b in zip(self.targets, self.biased_targets): 
            idx = t + (b*10) 
            idx = int(idx)
            groups_counts[idx] += 1 

        groups_counts = np.array(groups_counts) 
        groups_counts = groups_counts/np.sum(groups_counts)       
        groups_counts = 1/(groups_counts) 
        groups_weights = np.zeros((len(self.targets)))
        
        for idx_, t, b in zip(range(len(self.targets)), self.targets, self.biased_targets): 
            idx = t + (b*10)
            idx = int(idx)
            groups_weights[idx_] = groups_counts[idx] 
       
        self.groups_counts = torch.from_numpy(groups_weights)

    def build_biased_mnist(self):
        """Build biased MNIST.
        """
        n_labels = self.targets.max().item() + 1

        bias_indices = {label: torch.LongTensor() for label in range(n_labels)}
        for label in range(n_labels):
            self._update_bias_indices(bias_indices, label)

        data = torch.ByteTensor()
        targets = torch.LongTensor()
        biased_targets = []

        for bias_label, indices in bias_indices.items():
            _data, _targets = self._make_biased_mnist(indices, bias_label)
            data = torch.cat([data, _data])
            targets = torch.cat([targets, _targets])
            biased_targets.extend([bias_label] * len(indices))

        biased_targets = torch.LongTensor(biased_targets)
        return data, targets, biased_targets


    def __getitem__(self, index):
        img, target, bias, target_bin = self.data[index], int(self.targets[index]), int(self.biased_targets[index]), self.targets_bin[index]
        img = Image.fromarray(img.astype(np.uint8), mode='RGB')
        gc = self.groups_counts[index]

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        if self.load_bias_feature:
            bias_feat = self.bias_features[index]
            return img, target, bias, index, bias_feat, target_bin
        else:
            return img, target_bin, bias, index, target, gc


class ColourBiasedMNIST(BiasedMNIST):
    def __init__(self, root, bias_feature_root='./biased_feats', split='train', transform=None, target_transform=None,
                 download=False, data_label_correlation=1.0, n_confusing_labels=9, seed=1, load_bias_feature=False, train_corr=None, under_sample=None):
        super(ColourBiasedMNIST, self).__init__(root, bias_feature_root=bias_feature_root, split=split, transform=transform,
                                                target_transform=target_transform,
                                                download=download,
                                                data_label_correlation=data_label_correlation,
                                                n_confusing_labels=n_confusing_labels, seed=seed,
                                                load_bias_feature=load_bias_feature, train_corr=train_corr, under_sample=under_sample)

    def _binary_to_colour(self, data, colour):
        fg_data = torch.zeros_like(data)
        fg_data[data != 0] = 255
        fg_data[data == 0] = 0
        fg_data = torch.stack([fg_data, fg_data, fg_data], dim=1)

        bg_data = torch.zeros_like(data)
        bg_data[data == 0] = 1
        bg_data[data != 0] = 0
        bg_data = torch.stack([bg_data, bg_data, bg_data], dim=3)
        bg_data = bg_data * torch.ByteTensor(colour)
        bg_data = bg_data.permute(0, 3, 1, 2)

        data = fg_data + bg_data
        return data.permute(0, 2, 3, 1)

    def _make_biased_mnist(self, indices, label):
        return self._binary_to_colour(self.data[indices], self.COLOUR_MAP[label]), self.targets[indices]


def get_color_mnist(root, batch_size, data_label_correlation,
                    n_confusing_labels=9, split='train', num_workers=2, seed=1, aug=True,
                    two_crop=False, ratio=0, bias_feature_root='./biased_feats', load_bias_feature=False, given_y=True, train_corr=None, under_sample=None):
    logging.info(f'get_color_mnist - split: {split}, aug: {aug}, given_y: {given_y}, ratio: {ratio}')
    normalize = transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    if aug:
        train_transform = transforms.Compose([
            transforms.RandomRotation(20),
            transforms.RandomApply([
                transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
            ], p=0.8),
            transforms.ToTensor(),
            normalize,
        ])
    else:
        train_transform = transforms.Compose([
            transforms.ToTensor(),
            normalize])
    if two_crop:
        train_transform = TwoCropTransform(train_transform)

    if split == 'train_val':
        dataset = ColourBiasedMNIST(
            root, split='train', transform=train_transform,
            download=True, data_label_correlation=data_label_correlation,
            n_confusing_labels=n_confusing_labels, seed=seed,
            load_bias_feature=load_bias_feature,
            train_corr=train_corr, under_sample=under_sample
        )

        indices = list(range(len(dataset)))
        split = int(np.floor(0.1 * len(dataset)))
        np.random.shuffle(indices)
        valid_idx = indices[:split]
        valid_sampler = data.sampler.SubsetRandomSampler(valid_idx)

        dataloader = data.DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            shuffle=False,
            sampler=valid_sampler,
            num_workers=num_workers,
            pin_memory=True,
            drop_last=False)

        return dataloader

    else:
        dataset = ColourBiasedMNIST(
            root, bias_feature_root=bias_feature_root, split=split, transform=train_transform,
            download=True, data_label_correlation=data_label_correlation,
            n_confusing_labels=n_confusing_labels, seed=seed,
            load_bias_feature=load_bias_feature,
            train_corr=train_corr, under_sample=under_sample
        )

        def clip_max_ratio(score):
            upper_bd = score.min() * ratio
            return np.clip(score, None, upper_bd)

        if ratio != 0:
            if load_bias_feature:
                weights = dataset.marginal
            else:
                if given_y:
                    weights = [1 / dataset.confusion_matrix_by[c, b] for c, b in zip(dataset.targets, dataset.biased_targets)]
                else:
                    weights = [1 / dataset.confusion_matrix[b, c] for c, b in zip(dataset.targets, dataset.biased_targets)]

            if ratio > 0:
                weights = clip_max_ratio(np.array(weights))
            sampler = data.WeightedRandomSampler(weights, len(weights), replacement=True)
        else:
            sampler = None

        dataloader = data.DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            shuffle=True if sampler is None and split == 'train' else False,
            sampler=sampler,
            num_workers=num_workers,
            pin_memory=True, 
            drop_last=split == 'train')

        return dataloader
