import os
import dataclasses
import functools
import numpy as np
from typing import List

from jax_privacy.experiments.image_data import base
from jax_privacy.experiments.image_data import loader
import tensorflow as tf
import tensorflow_datasets as tfds

from third_party.jax_privacy.jax_privacy.experiments.image_data.mnist_cifar_svhn \
    import _DatasetConfig

import pdb

def get_ids(ds):
    all_ids = []
    for batch in ds.batch(5000): 
        all_ids.extend([id for id in batch['id'].numpy()])
    return np.sort(all_ids)

def load_crafted(crafted_dir = 'data/crafted/sgd'):
    image_path = os.path.join(crafted_dir, 'images.npy')
    labels_path = os.path.join(crafted_dir, 'labels.npy')

    cifar10_mean = [0.4914, 0.4822, 0.4465]
    cifar10_std = [0.2023, 0.1994, 0.2010]
    images = np.load(image_path)
    images = images * cifar10_std + cifar10_mean
    images = np.clip(images, 0, 1) * 255
    images = np.round(images).astype(np.uint8)
    
    labels = np.load(labels_path)
    labels = np.argmax(labels, axis=1)
    
    ids = [f'crafted_{i}' for i in np.arange(len(images))]

    dataset = tf.data.Dataset.from_tensor_slices(
        {'id': ids, 'image': images, 'label': labels}
    )
    return dataset

@dataclasses.dataclass(kw_only=True, slots=True)
class _CanaryDatasetConfig(_DatasetConfig):
    seed: int=0
    num_samples: int=2500 # expected number samples
    num_canaries: int=5000 # m
    num_noncanaries: int=0 # r 
    p_canaries_include: float=0.5
    mislabel_canaries: bool=False

    filter_include: bool=True # get training examples (includes both canaries and non-canaries)
    filter_canary: bool=False # get only the canaries
    filter_holdout: bool=False # get only the canaries

    craft_type: str='none'


class _CanaryDataLoader(loader.DataLoader):
    config: _CanaryDatasetConfig

    def load_raw_data(
        self,
        shuffle_files: bool=True,
    ) -> tf.data.Dataset:
        n = self.config.num_samples
        m = self.config.num_canaries
        r = self.config.num_noncanaries
        m_train = int(m * self.config.p_canaries_include)

        assert n == m_train + r
        
        # load all CIFAR-10 splits
        ds_all = {
            split: tfds.load(
                name=self.config.name,
                split=split,
                shuffle_files=shuffle_files,
            )
            for split in ['train', 'test']
        }

        # load crafted canaries
        ds_all['crafted'] = None
        if self.config.craft_type in ['sgd']:
            crafted_dir = f'data/crafted/{self.config.craft_type}' 
            ds_all['crafted'] = load_crafted(crafted_dir=crafted_dir)
        else:
            assert self.config.craft_type == 'none'

        # get ids
        # cifar10
        all_train_ids = get_ids(ds_all['train'])
        all_train_ids = [x.decode() for x in all_train_ids]

        rng = np.random.RandomState(self.config.seed)

        # randomly sample example into sets (via sorting + indexing)
        rng.shuffle(all_train_ids)
        train_ids = all_train_ids[:r]
        holdout_ids = all_train_ids[r:] # use remaining as holdout ids

        # sample canaries to include in training
        if ds_all['crafted'] is not None:
            all_crafted_ids = get_ids(ds_all['crafted'])
            rng.shuffle(all_crafted_ids)
            canary_ids = all_crafted_ids[:m]
        else:
            canary_ids = all_train_ids[r: r + m]
            holdout_ids = all_train_ids[r + m:]
        canary_ids_include = canary_ids[:m_train]

        # combine into training set
        all_include_ids = np.concatenate((train_ids, canary_ids_include))

        for split, ds in ds_all.items():
            if ds is None:
                continue

            def mark_include_in_training(example):
                example['include'] = tf.reduce_any(tf.equal(example['id'], all_include_ids))
                return example
            ds = ds.map(mark_include_in_training)

            def mark_holdout(example):
                example['holdout'] = tf.reduce_any(tf.equal(example['id'], holdout_ids))
                return example
            ds = ds.map(mark_holdout)

            def mark_canary(example):
                example['canary'] = tf.reduce_any(tf.equal(example['id'], canary_ids))
                return example
            ds = ds.map(mark_canary)

            ds_all[split] = ds
        
        # combine
        if self.config.split_content == 'test':
            ds = ds_all['test']
        else:
            ds = ds_all['train']
            if ds_all['crafted'] is not None:
                # remove crafted images that aren't used as canaries
                ds_all['crafted'] = ds_all['crafted'].filter(
                    lambda x: x['canary'] == True,
                )
                ds = ds.concatenate(ds_all['crafted'])
        
        # for white box attack: we use the example id as the diract grad idx
        # train: 0-49999, test: 50000-59999
        def get_dirac_grad_idx(example):
            split_id = example['id']

            id_string = tf.strings.regex_replace(
                split_id, pattern=r'(train_|test_|crafted_)', rewrite=''
            )
            id_int = tf.strings.to_number(
                id_string, out_type=tf.int32
            )

            is_test = tf.strings.regex_full_match(split_id, ".*test_.*")
            if is_test:
                id_int += 50000
            
            is_crafted = tf.strings.regex_full_match(split_id, ".*crafted_.*")
            if is_crafted:
                id_int += 60000

            example['id'] = id_int 
            return example
        ds = ds.map(get_dirac_grad_idx)

        # change the labels of canaries
        if self.config.mislabel_canaries:
            tf.random.set_seed(self.config.seed)

            def change_label(example):
                old_label = example['label']
                example['old_label'] = old_label
                
                valid_labels = tf.range(0, self.config.num_classes, dtype=tf.int64)
                valid_labels = tf.gather(valid_labels, tf.where(valid_labels != old_label)[:, 0])

                def get_new_label():
                    new_label = tf.random.shuffle(valid_labels, seed=self.config.seed)[0]
                    return new_label
                
                example['label'] = tf.cond(
                    example['canary'],
                    lambda: get_new_label(),
                    lambda: example['label'],
                )
                return example
            
            ds = ds.map(change_label)

        # leave out canaries where include == false (i.e., just the examples to train on)
        if self.config.filter_include:
            ds = ds.filter(
                lambda x: x['include'] == True,
            )
        # return dataset of only canaries
        if self.config.filter_canary:
            ds = ds.filter(
                lambda x: x['canary'] == True,
            )
        # return dataset of only holdout set
        if self.config.filter_holdout:
            ds = ds.filter(
                lambda x: x['holdout'] == True,
            )
        else:
            ds = ds.filter(
                lambda x: x['holdout'] == False,
            )

        return ds.map(base.DataInputs.from_dict)

Cifar10CraftedLoader = _CanaryDataLoader

Cifar10CraftedTrainValidConfig = functools.partial(
    _CanaryDatasetConfig,
    name='cifar10',
    image_size=(32, 32),
    num_classes=10,
    split_content='train+test',
    filter_holdout=False,
)
Cifar10CraftedTestConfig = functools.partial(
    _CanaryDatasetConfig,
    name='cifar10',
    image_size=(32, 32),
    num_classes=10,
    split_content='test',
    filter_holdout=False,
    filter_include=False,
    filter_canary=False,
)

# testing 
if __name__=='__main__':
    import itertools
    import pandas as pd

    # (1) test different sets don't interesect and have consistent labels
    def test1():
        for m, r in [
            (1000, 49000)
            (1000, 0),
            (500, 0),
            (100, 0),
            (1000, 1000),
            (500, 1000),
            (100, 1000),
        ]:
            n = int(m * 0.5) + r
            print(f'n:{n}\tm:{m}\tr:{r}')

            for seed in [
                0,
                123,
            ]:
                print(f'seed:{seed}')
                # train set (both canaries and non-canaries)
                dataloader_train = Cifar10CraftedLoader(
                    config=Cifar10CraftedTrainValidConfig(
                        preprocess_name='standardise',
                        num_samples=n,
                        num_canaries=m,
                        num_noncanaries=r,
                        seed=seed,
                        mislabel_canaries=True,
                        filter_canary=False,
                        filter_include=True,
                    )
                )
                # holdout set
                dataloader_holdout=Cifar10CraftedLoader(
                    config=Cifar10CraftedTrainValidConfig(
                        preprocess_name='standardise',
                        num_samples=n,
                        num_canaries=m,
                        num_noncanaries=r,
                        seed=seed,
                        mislabel_canaries=True,
                        filter_canary=False,
                        filter_include=False,
                        filter_holdout=True,
                    )
                )
                # canaries (all)
                dataloader_canaries=Cifar10CraftedLoader(
                    config=Cifar10CraftedTrainValidConfig(
                        preprocess_name='standardise',
                        num_samples=n,
                        num_canaries=m,
                        num_noncanaries=r,
                        seed=seed,
                        mislabel_canaries=True,
                        filter_canary=True,
                        filter_include=False,
                    )
                )

                dict_ds = {
                    'train': dataloader_train.load_raw_data(
                        shuffle_files=True,
                    ),
                    'holdout': dataloader_holdout.load_raw_data(
                        shuffle_files=True,
                    ),
                    'canary': dataloader_canaries.load_raw_data(
                        shuffle_files=True,
                    ),
                }

                dict_df = {}
                for split, ds in dict_ds.items():
                    x = {'id': [], 'label': [], 'old_label': []}
                    for batch in ds.batch(4096):
                        x['id'].extend(batch['metadata']['id'].numpy().tolist())
                        x['label'].extend(batch['label'].numpy().tolist())
                        if 'old_label' in batch['metadata'].keys():
                            x['old_label'].extend(batch['metadata']['old_label'].numpy().tolist())

                    if len(x['old_label']) == 0:
                        del x['old_label']

                    dict_df[split] = pd.DataFrame(x)

                # holdout set shouldn't interesect with other sets
                assert pd.merge(dict_df['train'], dict_df['holdout'], on='id').shape[0] == 0
                assert pd.merge(dict_df['canary'], dict_df['holdout'], on='id').shape[0] == 0
                # check labels are the same
                df_compare = pd.merge(dict_df['train'], dict_df['canary'], on='id', suffixes=['_train', '_canary'])
                assert (df_compare['label_train'] == df_compare['label_canary']).all()
                assert (df_compare['old_label_train'] == df_compare['old_label_canary']).all()
                assert (df_compare['label_train'] != df_compare['old_label_train']).all()

    # (2) test number of example are correct for different settings
    def test2():
        for m, r in [
            (1000, 49000)
            (1000, 0),
            (500, 0),
            (100, 0),
            (1000, 1000),
            (500, 1000),
            (100, 1000),
        ]:
            n = int(m * 0.5) + r
            print(f'n:{n}\tm:{m}\tr:{r}')

            for seed in [
                0,
                123,
            ]:
                print(f'seed:{seed}')

                # check test set
                dataloader_test = Cifar10CraftedLoader(
                    config=Cifar10CraftedTestConfig(
                        preprocess_name='standardise',
                        num_samples=n,
                        num_canaries=m,
                        num_noncanaries=r,
                        seed=seed,
                    )
                )
                ds = dataloader_test.load_raw_data(
                    shuffle_files=True,
                )
                num_examples = 0
                for x in ds.batch(4096):
                    num_examples += len(x['image'])
                assert num_examples == 10000
                print(f'\nTest set\nPass\n')

                # check train set (with canaries)
                for mislabel_canaries in [True, False]:

                    combinations = list(itertools.product([True, False], repeat=2))
                    for filter_include, filter_canary in combinations:

                        dataloader = Cifar10CraftedLoader(
                            config=Cifar10CraftedTrainValidConfig(
                                preprocess_name='standardise',
                                num_samples=n,
                                num_canaries=m,
                                num_noncanaries=r,
                                mislabel_canaries=mislabel_canaries,
                                seed=seed,
                                filter_include=filter_include,
                                filter_canary=filter_canary,
                            )
                        )
                        ds = dataloader.load_raw_data(
                            shuffle_files=True,
                        )

                        num_examples = 0 
                        for x in ds.batch(4096):
                            if filter_include:
                                assert tf.reduce_all(x['metadata']['include'])
                            if filter_canary:
                                assert tf.reduce_all(x['metadata']['canary'])

                            num_examples += len(x['image'])
                            
                            # make sure labels were changed iff its a canary
                            mask_canary = x['metadata']['canary']
                            if mislabel_canaries:
                                assert tf.reduce_all(x['label'][mask_canary] != x['metadata']['old_label'][mask_canary])
                                assert tf.reduce_all(x['label'][mask_canary] != x['metadata']['old_label'][mask_canary])
                                assert tf.reduce_all(x['label'][~mask_canary] == x['metadata']['old_label'][~mask_canary])

                        # filter for all canaries
                        if filter_include:
                            # technically should be around the median (with some sampling variance)
                            if filter_canary:
                                num_examples == int(m * 0.5)
                            else:
                                num_examples = int(m * 0.5) + r
                        else:
                            if filter_canary:
                                assert num_examples == m
                            else:
                                if num_examples != m + r:
                                    pdb.set_trace()
                                assert num_examples == m + r
                
                        print(f'mislabel={mislabel_canaries}\ninclude={filter_include}\ncanary={filter_canary}\nPass!\n')

    test1()
    test2()