import dataclasses
import functools
import numpy as np
from typing import List

from jax_privacy.experiments.image_data import base
from jax_privacy.experiments.image_data import loader
import tensorflow as tf
import tensorflow_datasets as tfds

from third_party.jax_privacy.jax_privacy.experiments.image_data.mnist_cifar_svhn \
    import _DatasetConfig

import pdb

@dataclasses.dataclass(kw_only=True, slots=True)
class _CanaryDatasetConfig(_DatasetConfig):
    seed: int=0
    num_samples: int=2500 # expected number samples
    num_canaries: int=5000 # m
    num_noncanaries: int=0 # r 
    p_canaries_include: float=0.5
    mislabel_canaries: bool=False

    filter_include: bool=True # get training examples (includes both canaries and non-canaries)
    filter_canary: bool=False # get only the canaries
    filter_holdout: bool=False # get only the canaries


class _CanaryDataLoader(loader.DataLoader):
    config: _CanaryDatasetConfig

    def load_raw_data(
        self,
        shuffle_files: bool=True,
    ) -> tf.data.Dataset:
        assert self.config.split_content in ['train+test']

        n = self.config.num_samples
        m = self.config.num_canaries
        r = self.config.num_noncanaries
        m_train = int(m * self.config.p_canaries_include)

        assert n == m_train + r

        def get_ids(ds):
            all_ids = []
            for batch in ds.batch(5000): 
                all_ids.extend([id.decode() for id in batch['id'].numpy()])
            return np.sort(all_ids)
        
        # load all splits
        ds_all = {
            split: tfds.load(
                name=self.config.name,
                split=split,
                shuffle_files=shuffle_files,
            )
            for split in ['train', 'test']
        }

        # get ids of all training and test set examples
        all_train_ids = get_ids(ds_all['train'])
        all_test_ids = get_ids(ds_all['test'])
        all_ids = np.concatenate((all_train_ids, all_test_ids))

        rng = np.random.RandomState(self.config.seed)

        # randomly sample example into sets (via sorting + indexing)
        rng.shuffle(all_ids)
        train_ids = all_ids[0: r]
        canary_ids = all_ids[r: r + m]
        holdout_ids = all_ids[r + m:] # use remaining has holdout ids

        # sample canaries to include in training
        mask = np.zeros(m, dtype=bool)
        mask[:m_train] = True
        rng.shuffle(mask)
        canary_ids_include = canary_ids[mask]

        # combine into training set
        all_include_ids = np.concatenate((train_ids, canary_ids_include))

        for split, ds in ds_all.items():
            def mark_include_in_training(example):
                example['include'] = tf.reduce_any(tf.equal(example['id'], all_include_ids))
                return example
            ds = ds.map(mark_include_in_training)

            def mark_holdout(example):
                example['holdout'] = tf.reduce_any(tf.equal(example['id'], holdout_ids))
                return example
            ds = ds.map(mark_holdout)

            def mark_canary(example):
                example['canary'] = tf.reduce_any(tf.equal(example['id'], canary_ids))
                return example
            ds = ds.map(mark_canary)

            ds_all[split] = ds
        
        # for white box attack: we use the example id as the diract grad idx
        # train: 0-49999, test: 50000-59999
        def get_dirac_grad_idx(example):
            split_id = example['id']

            id_string = tf.strings.regex_replace(
                split_id, pattern=r'(train_|test_)', rewrite=''
            )
            id_int = tf.strings.to_number(
                id_string, out_type=tf.int32
            )

            in_test = tf.strings.regex_full_match(split_id, ".*test_.*")
            if in_test:
                id_int += 50000
            example['id'] = id_int 
            return example
        
        # combine
        ds = ds_all['train'].concatenate(ds_all['test'])
        ds = ds.map(get_dirac_grad_idx)

        # change the labels of canaries
        if self.config.mislabel_canaries:
            tf.random.set_seed(self.config.seed)

            def change_label(example):
                old_label = example['label']
                example['old_label'] = old_label
                
                valid_labels = tf.range(0, self.config.num_classes, dtype=tf.int64)
                valid_labels = tf.gather(valid_labels, tf.where(valid_labels != old_label)[:, 0])

                def get_new_label():
                    new_label = tf.random.shuffle(valid_labels, seed=self.config.seed)[0]
                    return new_label
                
                example['label'] = tf.cond(
                    example['canary'],
                    lambda: get_new_label(),
                    lambda: example['label'],
                )
                return example
            
            ds = ds.map(change_label)

        # leave out canaries where include == false (i.e., just the examples to train on)
        if self.config.filter_include:
            ds = ds.filter(
                lambda x: x['include'] == True,
            )
        # return dataset of only canaries
        if self.config.filter_canary:
            ds = ds.filter(
                lambda x: x['canary'] == True,
            )
        # return dataset of only holdout set
        if self.config.filter_holdout:
            ds = ds.filter(
                lambda x: x['holdout'] == True,
            )
        else:
            ds = ds.filter(
                lambda x: x['holdout'] == False,
            )

        return ds.map(base.DataInputs.from_dict)

Cifar10Loader = _CanaryDataLoader

Cifar10TrainValidConfig = functools.partial(
    _CanaryDatasetConfig,
    name='cifar10',
    image_size=(32, 32),
    num_classes=10,
    split_content='train+test',
    filter_holdout=False,
)
Cifar10TestConfig = functools.partial(
    _CanaryDatasetConfig,
    name='cifar10',
    image_size=(32, 32),
    num_classes=10,
    split_content='train+test',
    filter_include=False,
    filter_canary=False,
    filter_holdout=True,
)

# testing 
if __name__=='__main__':
    import itertools
    import pandas as pd

    # (1) test different sets don't interesect and have consistent labels
    def test1():
        for m, r in [
            (10000, 0),
            (5000, 0),
            (1000, 0),
            (10000, 1000),
            (5000, 1000),
            (1000, 1000),
        ]:
            n = int(m * 0.5) + r
            print(f'n:{n}\tm:{m}\tr:{r}')

            for seed in [
                0,
                123,
            ]:
                print(f'seed:{seed}')
                # train set (both canaries and non-canaries)
                dataloader_train = Cifar10Loader(
                    config=Cifar10TrainValidConfig(
                        preprocess_name='standardise',
                        num_samples=n,
                        num_canaries=m,
                        num_noncanaries=r,
                        seed=seed,
                        mislabel_canaries=True,
                        filter_canary=False,
                        filter_include=True,
                    )
                )
                # holdout set
                dataloader_holdout=Cifar10Loader(
                    config=Cifar10TestConfig(
                        preprocess_name='standardise',
                        num_samples=n,
                        num_canaries=m,
                        num_noncanaries=r,
                        seed=seed,
                    )
                )
                # canaries (all)
                dataloader_canaries=Cifar10Loader(
                    config=Cifar10TrainValidConfig(
                        preprocess_name='standardise',
                        num_samples=n,
                        num_canaries=m,
                        num_noncanaries=r,
                        seed=seed,
                        mislabel_canaries=True,
                        filter_canary=True,
                        filter_include=False,
                    )
                )

                dict_ds = {
                    'train': dataloader_train.load_raw_data(
                        shuffle_files=True,
                    ),
                    'holdout': dataloader_holdout.load_raw_data(
                        shuffle_files=True,
                    ),
                    'canary': dataloader_canaries.load_raw_data(
                        shuffle_files=True,
                    ),
                }

                dict_df = {}
                for split, ds in dict_ds.items():
                    x = {'id': [], 'label': [], 'old_label': []}
                    for batch in ds.batch(4096):
                        x['id'].extend(batch['metadata']['id'].numpy().tolist())
                        x['label'].extend(batch['label'].numpy().tolist())
                        if 'old_label' in batch['metadata'].keys():
                            x['old_label'].extend(batch['metadata']['old_label'].numpy().tolist())

                    if len(x['old_label']) == 0:
                        del x['old_label']

                    dict_df[split] = pd.DataFrame(x)

                # holdout set shouldn't interesect with other sets
                assert pd.merge(dict_df['train'], dict_df['holdout'], on='id').shape[0] == 0
                assert pd.merge(dict_df['canary'], dict_df['holdout'], on='id').shape[0] == 0
                # check labels are the same
                df_compare = pd.merge(dict_df['train'], dict_df['canary'], on='id', suffixes=['_train', '_canary'])
                assert (df_compare['label_train'] == df_compare['label_canary']).all()
                assert (df_compare['old_label_train'] == df_compare['old_label_canary']).all()
                assert (df_compare['label_train'] != df_compare['old_label_train']).all()

    # (2) test number of example are correct for different settings
    def test2():
        for m, r in [
            (10000, 0),
            (5000, 0),
            (1000, 0),
            (10000, 1000),
            (5000, 1000),
            (1000, 1000),
        ]:
            n = int(m * 0.5) + r
            print(f'n:{n}\tm:{m}\tr:{r}')

            for seed in [
                0,
                123,
            ]:
                print(f'seed:{seed}')

                # check test set
                dataloader_test = Cifar10Loader(
                    config=Cifar10TestConfig(
                        preprocess_name='standardise',
                        num_samples=n,
                        num_canaries=m,
                        num_noncanaries=r,
                        seed=seed,
                    )
                )
                ds = dataloader_test.load_raw_data(
                    shuffle_files=True,
                )
                num_examples = 0
                for x in ds.batch(4096):
                    num_examples += len(x['image'])
                assert num_examples == 60000 - (m + r)
                print(f'\nTest set\nPass\n')

                # check train set (with canaries)
                for mislabel_canaries in [True, False]:

                    combinations = list(itertools.product([True, False], repeat=2))
                    for filter_include, filter_canary in combinations:

                        dataloader = Cifar10Loader(
                            config=Cifar10TrainValidConfig(
                                preprocess_name='standardise',
                                num_samples=n,
                                num_canaries=m,
                                num_noncanaries=r,
                                mislabel_canaries=mislabel_canaries,
                                seed=seed,
                                filter_include=filter_include,
                                filter_canary=filter_canary,
                            )
                        )
                        ds = dataloader.load_raw_data(
                            shuffle_files=True,
                        )

                        num_examples = 0 
                        for x in ds.batch(4096):
                            if filter_include:
                                assert tf.reduce_all(x['metadata']['include'])
                            if filter_canary:
                                assert tf.reduce_all(x['metadata']['canary'])

                            num_examples += len(x['image'])
                            
                            # make sure labels were changed iff its a canary
                            mask_canary = x['metadata']['canary']
                            if mislabel_canaries:
                                assert tf.reduce_all(x['label'][mask_canary] != x['metadata']['old_label'][mask_canary])
                                assert tf.reduce_all(x['label'][mask_canary] != x['metadata']['old_label'][mask_canary])
                                assert tf.reduce_all(x['label'][~mask_canary] == x['metadata']['old_label'][~mask_canary])

                        # filter for all canaries
                        if filter_include:
                            # technically should be around the median (with some sampling variance)
                            if filter_canary:
                                num_examples == int(m * 0.5)
                            else:
                                num_examples = int(m * 0.5) + r
                        else:
                            if filter_canary:
                                assert num_examples == m
                            else:
                                assert num_examples == m + r
                
                        print(f'mislabel={mislabel_canaries}\ninclude={filter_include}\ncanary={filter_canary}\nPass!\n')

    test1()
    test2()