import torch
from torchvision import transforms, datasets
from torch.utils.data import SubsetRandomSampler, Dataset
import numpy as np
from data_creation import create_noisy_dataset, logit, get_lending_club, create_synthetic_dataset
from plotting_functions import make_tabular_dataframe, plot_tabular_df


class DummyDataset(Dataset):
    
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        data = self.data[idx]
        label = self.labels[idx]
        return [data, label]


def get_clean_loaders(opt, writer, stage='pretraining'):
    frac = opt.fraction
    batch_size = opt.batch_size
    noise_features_directly = opt.noise_features_directly if stage=='classification' else None
    novel_class=opt.novel_class

    if opt.task == 'MNIST':
        tabular = False
        n_continuous_features = None
        ncat_of_cat_features = None
        clean_join_data_dim = None
        image_dim = (1, 28, 28)
        n_categories = 10

        data_transforms = [transforms.ToTensor(), logit]

        train_dataset = datasets.MNIST('../_datasets/mnist', train=True, download=True,
                                       transform=transforms.Compose(data_transforms))
        test_dataset = datasets.MNIST('../_datasets/mnist', train=False, download=True,
                                      transform=transforms.Compose(data_transforms))

    elif opt.task == 'LendingClub':
        tabular = True
        n_categories = 2

        train_dataset, val_dataset, test_dataset, n_continuous_features, ncat_of_cat_features, clean_join_data_dim\
            = get_lending_club(opt.data_join_task)
        
        test_data_df = make_tabular_dataframe(test_dataset.data.to_numpy(), ncat_of_cat_features, n_continuous_features)
        image_dim = (n_continuous_features + sum(ncat_of_cat_features),)

        figure = plot_tabular_df(test_data_df, ncat_of_cat_features, n_continuous_features)
        writer.add_figure('plot_features/test_data', figure, global_step=1)

    else:
        raise NotImplementedError('Task name {} not recognised.'.format(opt.task))
        
    if novel_class:
        old_seed = torch.initial_seed()
        torch.manual_seed(42)
        print('Distributional shift experiment; splitting according to the given percentage division')
        # basically need to divide not in two, like the fraction flag, but in different fractions based on label
        splits = [8./9., 8./9., 8./9., 8./9., 8./9., 8./9., 8./9., 8./9., 8./9., 0.0]  # hard distributional shift
        all_pretrain_datasets, all_class_datasets = [], []
        all_pretrain_test_datasets, all_class_test_datasets = [], []
        for i, split in enumerate(splits):
            data_transforms = [transforms.ToTensor(), logit]
            train_dataset_repeat = datasets.MNIST('../_datasets/mnist', train=True, download=True,
                                       transform=transforms.Compose(data_transforms))
            idx = train_dataset_repeat.targets==i
            train_dataset_repeat.data = train_dataset_repeat.data[idx]
            if i == 9:
                train_dataset_repeat.targets = [1]*len(train_dataset_repeat.targets[idx])
            else:
                train_dataset_repeat.targets = [0]*len(train_dataset_repeat.targets[idx])
            splitA = int(split * train_dataset_repeat.data.shape[0])
            pretrain_train_dataset, class_dataset = torch.utils.data.random_split(train_dataset_repeat, [splitA, int(train_dataset_repeat.data.shape[0])-splitA])
            all_pretrain_datasets.append(pretrain_train_dataset)
            all_class_datasets.append(class_dataset)

            test_dataset_repeat = datasets.MNIST('../_datasets/mnist', train=False, download=True,
                                                 transform=transforms.Compose(data_transforms))
            idx = test_dataset_repeat.targets==i
            test_dataset_repeat.data = test_dataset_repeat.data[idx]
            if i == 9:
                test_dataset_repeat.targets = [1]*len(test_dataset_repeat.targets[idx])
            else:
                test_dataset_repeat.targets = [0]*len(test_dataset_repeat.targets[idx])
            # careful here: based on the application you may want to change the test set balance
            splitC = int(split * test_dataset_repeat.data.shape[0])
            pretrain_test_dataset, class_test_dataset = torch.utils.data.random_split(test_dataset_repeat, [splitC, int(test_dataset_repeat.data.shape[0])-splitC])
            all_pretrain_test_datasets.append(pretrain_test_dataset)
            all_class_test_datasets.append(class_test_dataset)
        # now need to concatenate the separate datasets
        pretrain_concat_dataset = torch.utils.data.ConcatDataset(all_pretrain_datasets)
        class_concat_dataset = torch.utils.data.ConcatDataset(all_class_datasets)
        pretrain_test_concat_dataset = torch.utils.data.ConcatDataset(all_pretrain_test_datasets)
        class_test_concat_dataset = torch.utils.data.ConcatDataset(all_class_test_datasets)
        # split validationfor the classification dataset
        splitB = int(0.9*len(class_concat_dataset))
        class_train_dataset, class_val_dataset = torch.utils.data.random_split(class_concat_dataset, [splitB, len(class_concat_dataset)-splitB])

        # now produce the data loaders
        pretrain_train_loader = torch.utils.data.DataLoader(pretrain_concat_dataset, batch_size=batch_size, shuffle=True,
                                                            drop_last=True)
        class_train_loader = torch.utils.data.DataLoader(class_train_dataset, batch_size=batch_size, shuffle=True,
                                                         drop_last=True)
        class_val_loader = torch.utils.data.DataLoader(class_val_dataset, batch_size=batch_size, shuffle=True,
                                                       drop_last=True)

        pretrain_test_loader = torch.utils.data.DataLoader(pretrain_test_concat_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
        class_test_loader = torch.utils.data.DataLoader(class_test_concat_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
        torch.manual_seed(old_seed)  # restore seed to initial value 
        # finally, setting the number of categories to 2
        n_categories = 2
    else:
        if frac != 1.:
            print('fractioning into pretrain and classification data')
            old_seed = torch.initial_seed()
            torch.manual_seed(42)
            splitA = int(frac * len(train_dataset))
            pretrain_train_dataset, class_dataset = torch.utils.data.random_split(train_dataset,
                                                                                  [splitA, len(train_dataset)-splitA])

            if opt.task == 'LendingClub':
                # in this instance use
                splitB = len(val_dataset)
                class_train_dataset = class_dataset
                class_val_dataset = val_dataset
                # no pretrain val dataset, use test set to check for overtraining instead
            else:
                splitB = int(0.9*len(class_dataset))
                class_train_dataset, class_val_dataset = torch.utils.data.random_split(class_dataset,
                                                                                   [splitB, len(class_dataset)-splitB])

            splitC = int(frac * len(test_dataset))
            pretrain_test_dataset, class_test_dataset = torch.utils.data.random_split(test_dataset,
                                                                                      [splitC, len(test_dataset)-splitC])

            pretrain_train_loader = torch.utils.data.DataLoader(pretrain_train_dataset, batch_size=batch_size, shuffle=True,
                                                                drop_last=True)
            class_train_loader = torch.utils.data.DataLoader(class_train_dataset, batch_size=batch_size, shuffle=True,
                                                             drop_last=True)
            class_val_loader = torch.utils.data.DataLoader(class_val_dataset, batch_size=batch_size, shuffle=True,
                                                           drop_last=True)

            pretrain_test_loader = torch.utils.data.DataLoader(pretrain_test_dataset, batch_size=batch_size, shuffle=True,
                                                               drop_last=True)
            class_test_loader = torch.utils.data.DataLoader(class_test_dataset, batch_size=batch_size, shuffle=True,
                                                            drop_last=True)
            torch.manual_seed(old_seed)  # restore seed to initial value 
        else:
            print("Warning: using same data for pretrain and classification")
            if opt.task == 'LendingClub':
                split_train_dataset = train_dataset
            else:
                train_len = int(len(train_dataset) * (5 / 6))
                val_len = len(train_dataset) - train_len

                split_train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_len, val_len])

            train_loader = torch.utils.data.DataLoader(split_train_dataset, batch_size=batch_size, shuffle=True,
                                                       drop_last=True)
            val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True,
                                                     drop_last=True)
            test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

            pretrain_train_loader = train_loader
            class_train_loader = train_loader
            pretrain_test_loader = test_loader
            class_test_loader = test_loader
            class_val_loader = val_loader

    if tabular:
        recon_dataset_images = torch.stack([torch.from_numpy(test_dataset[i][0]).float() for i in range(32)])
    else:
        recon_dataset_images = torch.stack([test_dataset[i][0] for i in range(32)])

    recon_dataset_labels = [test_dataset[i][-1] for i in range(32)]
    recon_dataset_labels = torch.FloatTensor(recon_dataset_labels).float().unsqueeze(1)

    if noise_features_directly:
        if tabular:
            if opt.task == 'LendingClub':
                if opt.data_join_task:
                    cts_feature_values = np.concatenate((train_dataset.data.iloc[:, :n_continuous_features].values,
                                                    val_dataset.data.iloc[:, :n_continuous_features].values,
                                                    test_dataset.data.iloc[:, :n_continuous_features].values
                                                    ),
                                                axis=0) 
                else:
                    # need to select only pretraining data and then clip everything based on that
                    # re-define loaders as you need to clip based on them
                    torch.manual_seed(42)
                    pretrain_train_loader_dummy = torch.utils.data.DataLoader(pretrain_train_dataset, batch_size=len(pretrain_train_dataset))
                    pretrain_train_loader_dummy_data = next(iter(pretrain_train_loader_dummy))[0]
                    pretrain_test_dataset_dummy = torch.utils.data.DataLoader(pretrain_test_dataset, batch_size=len(pretrain_test_dataset))
                    pretrain_test_loader_dummy_data = next(iter(pretrain_test_dataset_dummy))[0]

                    class_train_dataset_dummy = torch.utils.data.DataLoader(class_train_dataset, batch_size=len(class_train_dataset))
                    class_train_dataset_dummy_data = next(iter(class_train_dataset_dummy))[0]
                    class_train_dataset_dummy_label = next(iter(class_train_dataset_dummy))[1]
                    class_val_dataset_dummy = torch.utils.data.DataLoader(class_val_dataset, batch_size=len(class_val_dataset))
                    class_val_dataset_dummy_data = next(iter(class_val_dataset_dummy))[0]
                    class_val_dataset_dummy_label = next(iter(class_val_dataset_dummy))[1]
                    class_test_dataset_dummy = torch.utils.data.DataLoader(class_test_dataset, batch_size=len(class_test_dataset))
                    class_test_dataset_dummy_data = next(iter(class_test_dataset_dummy))[0]                    
                    class_test_dataset_dummy_label = next(iter(class_test_dataset_dummy))[1]                    

                    cts_feature_values = np.concatenate((pretrain_train_loader_dummy_data[:, :n_continuous_features],
                                                    pretrain_test_loader_dummy_data[:, :n_continuous_features]
                                                    ), axis=0) 
                    max_cts_feature = torch.from_numpy(np.max(cts_feature_values, axis=0)).unsqueeze(0)
                    min_cts_feature = torch.from_numpy(np.min(cts_feature_values, axis=0)).unsqueeze(0)

                    # clip the rest of the datasets
                    class_train_dataset_cts = torch.from_numpy(np.clip(class_train_dataset_dummy_data[:, :n_continuous_features].numpy(), min_cts_feature.numpy(), max_cts_feature.numpy()))
                    class_train_dataset = torch.cat([class_train_dataset_cts, class_train_dataset_dummy_data[:, n_continuous_features:]], dim=1)
                    class_train_dataset = DummyDataset(class_train_dataset, class_train_dataset_dummy_label)
                    class_train_loader = torch.utils.data.DataLoader(class_train_dataset, batch_size=batch_size, shuffle=True,
                                                                     drop_last=True)

                    class_val_dataset_cts = torch.from_numpy(np.clip(class_val_dataset_dummy_data[:, :n_continuous_features].numpy(), min_cts_feature.numpy(), max_cts_feature.numpy()))
                    class_val_dataset = torch.cat([class_val_dataset_cts, class_val_dataset_dummy_data[:, n_continuous_features:]], dim=1)
                    class_val_dataset = DummyDataset(class_val_dataset, class_val_dataset_dummy_label)
                    class_val_loader = torch.utils.data.DataLoader(class_val_dataset, batch_size=batch_size, shuffle=True,
                                                                     drop_last=True)

                    class_test_dataset_cts = torch.from_numpy(np.clip(class_test_dataset_dummy_data[:, :n_continuous_features].numpy(), min_cts_feature.numpy(), max_cts_feature.numpy()))
                    class_test_dataset = torch.cat([class_test_dataset_cts, class_test_dataset_dummy_data[:, n_continuous_features:]], dim=1)
                    class_test_dataset = DummyDataset(class_test_dataset, class_test_dataset_dummy_label)
                    class_test_loader = torch.utils.data.DataLoader(class_test_dataset, batch_size=batch_size, shuffle=True,
                                                                     drop_last=True)
                    torch.manual_seed(old_seed)  # restore seed to initial value 
            else:
                cts_feature_values = np.concatenate((train_dataset.data.iloc[:, :n_continuous_features].values,
                                                 test_dataset.data.iloc[:, :n_continuous_features].values),
                                                axis=0)
            max_diff_pix = torch.from_numpy(np.max(cts_feature_values, axis=0) - np.min(cts_feature_values, axis=0))
        else:
            # need to calculate the maximum variation for every pixel
            max_min_pair = logit(torch.Tensor([torch.max(train_dataset.data.clone().detach()).item(),
                                           torch.min(train_dataset.data.clone().detach()).item()]))
            max_diff_pix = max_min_pair[0] - max_min_pair[1]        

    else:
        max_diff_pix = None

    if stage == 'pretraining':
        return pretrain_train_loader, None, pretrain_test_loader, (recon_dataset_images, recon_dataset_labels), \
           image_dim, tabular, n_continuous_features, ncat_of_cat_features, n_categories, max_diff_pix
    elif stage == 'classification':
        return class_train_loader, class_val_loader, class_test_loader, (recon_dataset_images, recon_dataset_labels), \
           image_dim, clean_join_data_dim, tabular, n_continuous_features, ncat_of_cat_features, n_categories, max_diff_pix
    else:
        raise NotImplementedError("Stage is either classification or pretraining")

            
def get_noisy_loaders(opt):
    batch_size = opt.batch_size
    frac = opt.fraction
    novel_class=opt.novel_class

    if opt.synthetic_generation:
        train_dataset = create_synthetic_dataset(opt.decoder)
    else:
        if opt.noise_features_directly:
            transform = 'logit'
        elif opt.pixel_level:
            transform = 'reconstruction'
        else:
            transform = 'representation'
        
        if opt.task == 'LendingClub':
            train_dataset, val_dataset, norm_mean, norm_std = create_noisy_dataset(opt=opt, transform=transform)
        else:
            train_dataset = create_noisy_dataset(opt=opt, transform=transform)
            norm_mean, norm_std = 0, 1
    
    if novel_class:
        old_seed = torch.initial_seed()
        torch.manual_seed(42)
        print('Distributional shift experiment; splitting according to the given percentage division')
        # basically need to divide not in two, like the fraction flag, but in different fractions based on label
        splits = [8./9., 8./9., 8./9., 8./9., 8./9., 8./9., 8./9., 8./9., 8./9., 0.0]  # hard distributional shift
        all_pretrain_datasets, all_class_datasets = [], []
        #all_pretrain_test_datasets, all_class_test_datasets = [], []
        for i, split in enumerate(splits):
            data_transforms = [transforms.ToTensor(), logit]
            train_dataset_repeat = datasets.MNIST('../_datasets/mnist', train=True, download=True,
                                       transform=transforms.Compose(data_transforms))
            idx = train_dataset_repeat.targets==i
            # note we need to select the noisy dataset based on the clean dataset in the noisy case!
            if i == 9:
                train_dataset_dummy = DummyDataset(train_dataset.data[idx], [1]*len(train_dataset.labels[idx]))
            else:
                train_dataset_dummy = DummyDataset(train_dataset.data[idx], [0]*len(train_dataset.labels[idx]))
            splitA = int(split * train_dataset_dummy.data.shape[0])
            pretrain_train_dataset, class_dataset = torch.utils.data.random_split(train_dataset_dummy,
                                                                                  [splitA, int(train_dataset_dummy.data.shape[0])-splitA])
            all_pretrain_datasets.append(pretrain_train_dataset)
            all_class_datasets.append(class_dataset)

        # now need to concatenate the separate datasets
        pretrain_concat_dataset = torch.utils.data.ConcatDataset(all_pretrain_datasets)
        class_concat_dataset = torch.utils.data.ConcatDataset(all_class_datasets)

        # split validationfor the classification dataset
        splitB = int(0.9*len(class_concat_dataset))
        class_train_dataset, class_val_dataset = torch.utils.data.random_split(class_concat_dataset, [splitB, len(class_concat_dataset)-splitB])

        # now produce the data loaders
        class_train_loader = torch.utils.data.DataLoader(class_train_dataset, batch_size=batch_size, shuffle=True,
                                                         drop_last=True)
        class_val_loader = torch.utils.data.DataLoader(class_val_dataset, batch_size=batch_size, shuffle=True,
                                                       drop_last=True)


        torch.manual_seed(old_seed)  # restore seed to initial value 
    else:
        if frac != 1.:
            print('fractioning into pretrain and classification data')
            old_seed = torch.initial_seed()
            torch.manual_seed(42)
            splitA = int(frac * len(train_dataset))
            pretrain_train_dataset, class_dataset = torch.utils.data.random_split(train_dataset,
                                                                                  [splitA, len(train_dataset)-splitA])

            if opt.task == 'LendingClub':
                # in this instance use
                splitB = len(val_dataset)
                class_train_dataset = class_dataset
                class_val_dataset = val_dataset
                # no pretrain val dataset
            else:
                splitB = int(0.9*len(class_dataset))
                class_train_dataset, class_val_dataset = torch.utils.data.random_split(class_dataset,
                                                                                   [splitB, len(class_dataset)-splitB])

            class_train_loader = torch.utils.data.DataLoader(class_train_dataset, batch_size=batch_size, shuffle=True,
                                                             drop_last=True)
            class_val_loader = torch.utils.data.DataLoader(class_val_dataset, batch_size=batch_size, shuffle=True,
                                                           drop_last=True)
            torch.manual_seed(old_seed)

        else:
            print("Warning: using same data for pretrain and classification")
            if opt.task == 'LendingClub':
                split_train_dataset = train_dataset
            else:
                train_len = int(len(train_dataset) * (5 / 6))
                val_len = len(train_dataset) - train_len

                split_train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_len, val_len])

            class_train_loader = torch.utils.data.DataLoader(split_train_dataset, batch_size=batch_size, shuffle=True,
                                                             drop_last=True)
            class_val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True,
                                                           drop_last=True)

    return class_train_loader, class_val_loader, norm_mean, norm_std



