import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets
import numpy as np
import pandas as pd


class NoisyDataset(Dataset):
    
    def __init__(self, data, labels, clean_join_data=None):
        self.data = data
        self.clean_join_data = clean_join_data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        data = self.data[idx]
        label = self.labels[idx]
        if self.clean_join_data is not None:
            clean_join_data = self.clean_join_data[idx]
            return [data, clean_join_data, label]
        else:
            return [data, label]
    

class LendingClub(Dataset):
    """Lending Club dataset."""

    def __init__(self, data, labels, clean_join_data=None, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with training data.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.data = data
        self.clean_join_data = clean_join_data
        self.labels = labels
        self.transform = transform
        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if isinstance(idx, torch.Tensor):
            idx = idx.tolist()
        
        data = self.data.iloc[idx].values
        label = self.labels[idx]
        
        if self.transform:
            data = torch.tensor(data)
            data = self.transform(data)

        if self.clean_join_data is not None:
            clean_join_data = torch.tensor(self.clean_join_data.iloc[idx].values)
            return [data, clean_join_data, label]
        else:
            return [data, label]
        

class GetDataRepresentation(object):
    '''
    Data transform: Takes image or tabular data and convertes to latent representation

    Args: 
        model: The representation model, a pre-trained VAE to encode data into
        a lower dimensional state
        clip: If None, no clipping will occur. If a value is specified, all data
        with L1 norm > clip will be reduced back into the taxicab sphere
    '''
    def __init__(self, model, for_reconstruction=False, clip=None):
        self.model = model
        self.clip = clip
        self.for_reconstruction = for_reconstruction
    def __call__(self, x):
        rep = self.model.get_data_representatation(x, clip=self.clip)
        if self.for_reconstruction:
            return rep
        else:
            return rep[0]


class GetDataReconstruction(object):
    '''
    Data transform: Takes MNIST image and convertes to latent representation, adds noise, then back to pixel space

    Args:
        model: The representation model, a pre-trained VAE to encode data into
        a lower dimensional state
        clip: If None, no clipping will occur. If a value is specified, all data
        with L1 norm > clip will be reduced back into the taxicab sphere
    '''

    def __init__(self, model, privacy_inducing_std, clip=None):
        self.model = model
        self.clip = clip
        self.privacy_inducing_std = privacy_inducing_std

    def __call__(self, x):
        return self.model.get_data_reconstruction(x, self.privacy_inducing_std, clip=self.clip)[0]


def logit(x):
    eps = 1e-2
    x = torch.clamp(x, eps, 1 - eps)
    return torch.log(x) - torch.log(torch.ones_like(x) - x)


def add_data_noise(image, noise, noise_type):
    if noise > 0:
        
        if noise_type == 'Binary':
            flip = (torch.rand_like(image) < noise)  # boolean tensor, true if bit needs flipping
            return torch.where(flip, image, torch.ones_like(image) - image)
            
        elif noise_type == 'Laplace':
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

            eps = torch.distributions.Laplace(0, 1).rsample(image.size()).to(device)
            scale = noise/np.sqrt(2)
            return eps.mul(scale).add_(image.to(device)).to(device)

        elif noise_type == 'Gaussian':
            eps = torch.randn_like(image)
            return eps.mul(noise).add_(image)
            
        else:
            raise NotImplementedError("No noise type selected")
    else:
        return image


def add_categorical_noise(labels, n_categories, flip_prob):
    """
    if we flip, we must choose one of the other (n_categories - 1) categories, so sample a
    category, c, uniformly from 1 to (n_categories - 1), and if c > original_cat then
    new_cat = c + 1 (avoids sampling the original_cat).
    """
    if flip_prob > 0:
        labels = labels.float()
        flip = (torch.rand_like(labels) < flip_prob)  # boolean tensor, true if bit needs flipping
        flipped_labels = torch.floor((n_categories - 1) * torch.rand_like(labels))
        flipped_labels = torch.where(flipped_labels >= labels, flipped_labels + 1, flipped_labels)
        return torch.where(flip, flipped_labels, labels)
    else:
        return labels


def add_data_noise_tabular(data, opt):
    x_noise = opt.x_noise
    noise_type = opt.noise_type
    # first the continuous features
    noisy_data = add_data_noise(data[:, 0], x_noise[0], noise_type).view(1, -1)
    for i in range(opt.n_continuous_features-1):
        noisy_data = torch.cat((noisy_data, add_data_noise(data[:, i+1], x_noise[i+1], noise_type).view(1, -1)))
    noisy_data = torch.transpose(noisy_data, 0, 1).to(opt.device)
    # then the categorical features
    counter = opt.n_continuous_features
    for i, n_cat in enumerate(opt.ncat_of_cat_features):
        current_categorical = data[:, counter:counter+n_cat]
        inverse_one_hot = torch.max(current_categorical, axis=1)[1]
        noisy_inverse_one_hot = add_categorical_noise(inverse_one_hot, n_cat, x_noise[i+opt.n_continuous_features])
        noisy_one_hot = torch.nn.functional.one_hot(noisy_inverse_one_hot.to(torch.int64), n_cat).float().to(opt.device)
        counter = counter + n_cat
        noisy_data = torch.cat((noisy_data, noisy_one_hot), axis=1)
    return noisy_data
            
################################################################################

def create_synthetic_dataset(decoder, dataset_size=60000):
    with torch.no_grad():
        gen_batch_size = 5000  # Needs too be small enough to fit into memory
        n_iters = int(dataset_size / gen_batch_size) + 1
        for i in range(n_iters):
            data, labels = decoder.prior_sample(gen_batch_size)
            if i == 0:
                noisy_data, noisy_labels = data, labels
            else:
                noisy_data = torch.cat([noisy_data, data], dim=0)
                noisy_labels = torch.cat([noisy_labels, labels], dim=0)

        noisy_data = noisy_data[:dataset_size]
        noisy_labels = noisy_labels[:dataset_size].float()

        return NoisyDataset(noisy_data, noisy_labels)

      
def create_noisy_dataset(opt, transform):
    '''
    Function creates a noisy MNIST dataset but only from the train set!
    '''
    task = opt.task
    noise_type = opt.noise_type
    x_noise = opt.x_noise
    y_noise = opt.y_noise
    mahalonobis_d = opt.md
    data_join_task=opt.data_join_task
    encoder = opt.encoder
    decoder = opt.decoder
    with torch.no_grad():
        if task == 'MNIST':
            n_categories = 10
            
            if transform == 'logit':
                data_transforms = [transforms.ToTensor(), logit]
            elif transform == 'representation':
                data_transforms = [transforms.ToTensor(), logit, GetDataRepresentation(encoder, clip=mahalonobis_d)]
            elif transform == 'reconstruction':
                data_transforms = [transforms.ToTensor(), logit,
                                   GetDataRepresentation(encoder, for_reconstruction=True, clip=mahalonobis_d),
                                   GetDataReconstruction(decoder, x_noise, clip=mahalonobis_d)]
            else:
                raise NotImplementedError()
                
            train_dataset = datasets.MNIST('../_datasets/mnist', train=True, download=True,
                                           transform=transforms.Compose(data_transforms))
                
            train_loader = DataLoader(train_dataset, batch_size=len(train_dataset))
            train_dataset_data = next(iter(train_loader))[0]
            train_dataset_label = next(iter(train_loader))[1]
    
            # Add noise: # In future these functions should really live here not in data_loaders
            noisy_labels = add_categorical_noise(train_dataset_label, n_categories, y_noise)
            if transform == 'representation' or transform == 'logit':
                noisy_data = add_data_noise(train_dataset_data, x_noise, noise_type)
            else:
                noisy_data = train_dataset_data
            
            return NoisyDataset(noisy_data, noisy_labels)

        elif task == 'LendingClub':
            n_categories = 2
            if transform == 'logit':
                data_transforms = []
            elif transform == 'representation':
                data_transforms = [GetDataRepresentation(encoder, clip=mahalonobis_d)]
            elif transform == 'reconstruction': 
                data_transforms = [GetDataRepresentation(encoder, for_reconstruction=True, clip=mahalonobis_d),
                                   GetDataReconstruction(decoder, x_noise, clip=mahalonobis_d)]
            else:
                raise NotImplementedError('Task name {} not recognised.'.format(task))

            train_dataset, val_dataset, test_dataset, _, _, _ = get_lending_club(data_join_task,
                                                                     transform=transforms.Compose(data_transforms))
            
            # this next step uses the loader to apply the transform
            train_loader = DataLoader(train_dataset, batch_size=len(train_dataset))

            train_dataset_data_label = next(iter(train_loader))
            if data_join_task:
                train_dataset_data = train_dataset_data_label[0]
                train_dataset_data_clean_join = train_dataset_data_label[1]
                train_dataset_label = train_dataset_data_label[2]
            else:
                train_dataset_data = train_dataset_data_label[0]
                train_dataset_label = train_dataset_data_label[1]
            
            # Add noise: # In future these functions should really live here not in data_loaders
            noisy_train_labels = add_categorical_noise(train_dataset_label, n_categories, y_noise)
            
            # the first if statement is in the case of noising features directly on a tabular dataset
            if opt.noise_features_directly:
                print('Noising features directly')
                noisy_train_data = add_data_noise_tabular(train_dataset_data, opt)
            elif transform == 'representation' or transform == 'logit':
                noisy_train_data = add_data_noise(train_dataset_data, x_noise, noise_type)
            else:
                noisy_train_data = train_dataset_data
            
            val_loader = DataLoader(val_dataset, batch_size=len(val_dataset))

            val_dataset_data_label = next(iter(val_loader))
            if data_join_task:
                val_dataset_data = val_dataset_data_label[0]
                val_dataset_data_clean_join = val_dataset_data_label[1]
                val_dataset_label = val_dataset_data_label[2]
            else:
                val_dataset_data = val_dataset_data_label[0]
                val_dataset_label = val_dataset_data_label[1]

            # Add noise: # In future these functions should really live here not in data_loaders
            noisy_val_labels = add_categorical_noise(val_dataset_label, n_categories, y_noise)

            # the first if statement is in the case of noising features directly on a tabular dataset
            if opt.noise_features_directly:
                print('Noising features directly')
                noisy_val_data = add_data_noise_tabular(val_dataset_data, opt)
            elif transform == 'representation' or transform == 'logit':
                noisy_val_data = add_data_noise(val_dataset_data, x_noise, noise_type)
            else:
                noisy_val_data = val_dataset_data

            if data_join_task:
                # it seems to make very little difference between train and train+val, so we just take train
                norm_mean, norm_std = torch.mean(noisy_train_data, dim=0), torch.std(noisy_train_data, dim=0)
                noisy_train_set = NoisyDataset(noisy_train_data, noisy_train_labels,
                                               clean_join_data=train_dataset_data_clean_join)
                noisy_val_set = NoisyDataset(noisy_val_data, noisy_val_labels,
                                             clean_join_data=val_dataset_data_clean_join)
            else:
                norm_mean, norm_std = 0, 1
                noisy_train_set = NoisyDataset(noisy_train_data, noisy_train_labels)
                noisy_val_set = NoisyDataset(noisy_val_data, noisy_val_labels)

            return noisy_train_set, noisy_val_set, norm_mean, norm_std


def percentile_from_md(latent_distn, latent_dim, md, n_prior_samples=10000000):
    if latent_distn == 'Laplace':
        r_values = np.random.laplace(0, 1 / np.sqrt(2), (latent_dim, n_prior_samples))
    elif latent_distn == 'Gaussian':
        r_values = np.random.normal(0, 1, (latent_dim, n_prior_samples))
    else:
        raise NotImplementedError('latent_distn must be Laplace or Gaussian')
    l1_norms = np.linalg.norm(r_values, axis=0, ord=1)
    percentile = np.mean(l1_norms < md)

    return percentile

  
def split_dataset_features(data, cts_variables_noisy, one_hot_variables_noisy, cts_variables_clean,
                           one_hot_variables_clean):
        data_noisy = pd.concat(
            [data.filter(regex="^{}__|^{}$".format(var, var)) for var in cts_variables_noisy + one_hot_variables_noisy],
            axis=1).astype(np.float32)
        data_clean = pd.concat(
            [data.filter(regex="^{}__|^{}$".format(var, var)) for var in cts_variables_clean + one_hot_variables_clean],
            axis=1).astype(np.float32)
        return data_noisy, data_clean


def get_lending_club(data_join_task, transform=None):

    train_loans = pd.read_csv('../_datasets/LendingClub/accepted_2007_to_2018Q4_preprocessed_train.csv.gz',
                              compression='gzip', low_memory=False)
    train_loans = train_loans.drop('Unnamed: 0', axis=1)

    val_loans=pd.read_csv('../_datasets/LendingClub/accepted_2007_to_2018Q4_preprocessed_val.csv.gz', 
                          compression='gzip', low_memory=False)
    val_loans = val_loans.drop('Unnamed: 0', axis=1)

    test_loans=pd.read_csv('../_datasets/LendingClub/accepted_2007_to_2018Q4_preprocessed_test.csv.gz', 
                           compression='gzip', low_memory=False)
    test_loans = test_loans.drop('Unnamed: 0', axis=1)

    train_labels = train_loans['charged_off'].astype(np.long)
    train_data = train_loans.drop('charged_off', axis=1).astype(np.float32)
    
    val_labels = val_loans['charged_off'].astype(np.long)
    val_data = val_loans.drop('charged_off', axis=1).astype(np.float32)
    
    test_labels = test_loans['charged_off'].astype(np.long)
    test_data = test_loans.drop('charged_off', axis=1).astype(np.float32)

    f = open('../_datasets/LendingClub/variable_type_lengths.txt', 'r')
    variables_info_str = f.read()
    f.close()
    variables_info = eval(variables_info_str)

    continuous_variables = variables_info['continuous_variables']
    categorical_variables = variables_info['categorical_variables']

    if data_join_task:

        continuous_variables_clean = ['earliest_cr_line', 'open_acc']
        categorical_variables_clean = ['initial_list_status', 'application_type', 'addr_state', 'home_ownership',
                                       'emp_length', 'pub_rec_bankruptcies']
        
        continuous_variables_noisy = [var for var in continuous_variables if var not in continuous_variables_clean]
        categorical_variables_noisy = [var for var in categorical_variables if var not in categorical_variables_clean]

        n_continuous_features = len(continuous_variables_noisy)
        ncat_of_cat_features = [train_data.filter(regex="^{}__".format(var)).shape[1] for var in
                                categorical_variables_noisy]
        clean_join_data_dim = sum([train_data.filter(regex="^{}__".format(var)).shape[1]
                                   for var in categorical_variables_clean]) + len(continuous_variables_clean)

        train_data_noisy, train_data_clean = split_dataset_features(train_data,
                                                                    continuous_variables_noisy,
                                                                    categorical_variables_noisy,
                                                                    continuous_variables_clean,
                                                                    categorical_variables_clean)

        test_data_noisy, test_data_clean = split_dataset_features(test_data,
                                                                  continuous_variables_noisy,
                                                                  categorical_variables_noisy,
                                                                  continuous_variables_clean,
                                                                  categorical_variables_clean)

        val_data_noisy, val_data_clean = split_dataset_features(val_data,
                                                                continuous_variables_noisy,
                                                                categorical_variables_noisy,
                                                                continuous_variables_clean,
                                                                categorical_variables_clean)

        train_dataset = LendingClub(train_data_noisy, train_labels, clean_join_data=train_data_clean,
                                    transform=transform)
        test_dataset = LendingClub(test_data_noisy, test_labels, clean_join_data=test_data_clean, transform=transform)
        val_dataset = LendingClub(val_data_noisy, val_labels, clean_join_data=val_data_clean, transform=transform)

    else:
        train_dataset = LendingClub(train_data, train_labels, transform=transform)
        val_dataset = LendingClub(val_data, val_labels, transform=transform)
        test_dataset = LendingClub(test_data, test_labels, transform=transform)

        n_continuous_features = len(continuous_variables)
        ncat_of_cat_features = [train_data.filter(regex="^{}__".format(var)).shape[1] for var in
                                categorical_variables]
        clean_join_data_dim = None

    return train_dataset, val_dataset, test_dataset, n_continuous_features, ncat_of_cat_features, clean_join_data_dim
