import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import sys
# sys.stderr = open(os.devnull, 'w')

import dataclasses
import functools
import numpy as np
import pandas as pd
from typing import List

from jax_privacy.experiments.image_data import base
from jax_privacy.experiments.image_data import loader
import tensorflow as tf
import tensorflow_datasets as tfds

from third_party.jax_privacy.jax_privacy.experiments.image_data.mnist_cifar_svhn \
    import _DatasetConfig

import pdb

METADATA_PATH = './data/cinic-10/metadata.csv'

def get_cinic(df_metadata):
    def load_image(
        path,
        label,
        id,
        imagenet,
        image_size=(32, 32),
    ):
        # load image
        image = tf.io.read_file(path)
        image = tf.image.decode_png(image, channels=3)  # Decode as RGB
        image = tf.image.resize(image, image_size)  # Resize to match CIFAR-10
        image = tf.image.convert_image_dtype(image, tf.float32)  # Normalize
        return {
            'image': image,
            'label': label,
            'id': id,
            'from_imagenet': imagenet,
        }

    dataset = tf.data.Dataset.from_tensor_slices((
        df_metadata['path'].values,
        df_metadata['label'].values,
        df_metadata['id'].values,
        df_metadata['imagenet'].values,
    ))
    dataset = dataset.map(load_image)
    return dataset

class _CinicDataLoader(loader.DataLoader):
    config: _DatasetConfig

    def load_raw_data(
        self,
        shuffle_files: bool=True,
    ) -> tf.data.Dataset:
        
        df_metadata = pd.read_csv(METADATA_PATH)
        df_metadata['id'] = df_metadata['id'].astype(np.int32)
        df_metadata['label'] = df_metadata['label'].astype(np.int32)

        # load data
        ds = get_cinic(df_metadata)
        if shuffle_files:
            ds = ds.shuffle(buffer_size=len(df_metadata))

        ds = ds.filter(
            lambda x: x['from_imagenet'] == True,
        )

        def clean_up(example):
            for col in [
                'from_imagenet',
            ]:
                del example[col]
            return example
        ds = ds.map(clean_up)

        return ds.map(base.DataInputs.from_dict)

Cinic10WithoutCifarLoader = _CinicDataLoader

Cinic10WithoutCifarTrainConfig = functools.partial(
    _DatasetConfig,
    name='cinic10',
    split_content='train+valid+test',
    image_size=(32, 32),
    num_classes=10,
    num_samples=210000,
)

if __name__=='__main__':                
    dataloader = Cinic10WithoutCifarLoader(
        config=Cinic10WithoutCifarTrainConfig(
            preprocess_name='standardise',
        )
    )

    ds =  dataloader.load_raw_data(
        shuffle_files=True,
    ).prefetch(tf.data.AUTOTUNE)

    ids = []
    for batch in ds.batch(1000):
        ids.extend(batch['metadata']['id'].numpy().tolist())

    assert len(ids) == 210000
