import os.path as osp
import glob
import io
from functools import partial
import numpy as np
from flax import jax_utils
import jax
import json
import tensorflow_io as tfio
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.python.lib.io import file_io
import skvideo.io

from . import encoded_h5py_dataset
from . import encoded_ae_h5py_dataset


GCS_PATH = 'TODO'


def load_kinetics(config, train, repeat=True, xmap=True, prefetch=False):
    if xmap:
        num_data = jax.device_count() // config.num_shards
        num_data_local = max(1, jax.local_device_count() // config.num_shards)
        if num_data >= jax.process_count():
            num_ds_shards = jax.process_count()
            ds_shard_id = jax.process_index()
        else:
            num_ds_shards = num_data
            n_proc_per_shard = jax.process_count() // num_data
            ds_shard_id = jax.process_index() // n_proc_per_shard
    else:
        num_data_local = jax.local_device_count()
        num_ds_shards = jax.process_count()
        ds_shard_id = jax.process_index()

    batch_size = config.batch_size // num_ds_shards
    split = 'train' if train else 'test'
    folder = osp.join(config.data_path, split, '*', '*.mp4')
    if folder.startswith('gs://'):
        fns = tf.io.gfile.glob(folder)
    else:
        fns = list(glob.glob(folder))

    def get_file_end(name):
        return osp.join(osp.basename(osp.dirname(name)), osp.basename(name))

    ignore_files = json.load(file_io.FileIO('gs://TODO_smae/datasets/kinetics600/bad_paths.json', 'r'))
    ignore_files = set(ignore_files)
    original_len = len(fns)
    fns = list(filter(lambda x: get_file_end(x) not in ignore_files, fns))
    print(f'Filtered ignored files {len(fns)} / {original_len}')

    dataset_size = len(fns)

    def process(path):
        video = tfio.experimental.ffmpeg.decode_video(tf.io.read_file(path))

        T = tf.shape(video)[0]
        start_idx = tf.random.uniform((), 0, T - config.seq_len + 1, dtype=tf.int32)

        video = tf.identity(video[start_idx:start_idx + config.seq_len])
        video = tf.identity(video[:config.seq_len])
        video = tf.cast(video, tf.float32)
        video = 2 * (video / 255.) - 1

        H, W = tf.shape(video)[1], tf.shape(video)[2]

        # scale
        scale = tf.cast(config.image_size / tf.minimum(H, W), tf.float32)
        target_size = tf.cond(H < W,
                              lambda: (config.image_size, tf.cast(tf.math.ceil(tf.cast(W, tf.float32) * scale), tf.int32)),
                              lambda: (tf.cast(tf.math.ceil(tf.cast(H, tf.float32) * scale), tf.int32), config.image_size)
        )
        video = tf.image.resize(video, size=target_size)

        # center crop
        h_start = (target_size[0] - config.image_size) // 2
        w_start = (target_size[1] - config.image_size) // 2
        video = video[:, h_start:h_start + config.image_size, w_start:w_start + config.image_size]

        return dict(video=video)

    dataset = tf.data.Dataset.from_tensor_slices(fns)
    dataset = dataset.map(
        process,
        num_parallel_calls=tf.data.experimental.AUTOTUNE
    )

    if config.cache:
        dataset = dataset.cache()

    if repeat:
        dataset = dataset.repeat()

    if train:
        dataset = dataset.shuffle(batch_size * 32, seed=config.seed)

    dataset = dataset.batch(batch_size, drop_remainder=True)
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)

    def prepare_tf_data(xs):
        def _prepare(x):
            x = x._numpy()
            return x.reshape((num_data_local, -1) + x.shape[1:])
        xs = jax.tree_map(_prepare, xs)
        if 'actions' not in xs:
            xs['actions'] = None
        return xs

    iterator = map(prepare_tf_data, dataset)

    if prefetch:
        iterator = jax_utils.prefetch_to_device(iterator, 2)

    return iterator, dataset_size


def get_load_fn(data_path):
    if 'kinetics' in data_path:
        return load_kinetics
    else:
        raise Exception()


class Data:
    def __init__(self, config, xmap=False):
        self.config = config
        self.xmap = xmap

        if config.data_path.startswith('gs://') or osp.exists(config.data_path):
            load_fn = get_load_fn(config.data_path)
            _, train_size = load_fn(config, train=True)
            _, test_size = load_fn(config, train=False)
            self.train_size = train_size
            self.test_size = test_size
        else:
            dataset_builder = tfds.builder(self.config.data_path,
                    data_dir=GCS_PATH if config.download else None)
            dataset_builder.download_and_prepare()

            self.train_size = dataset_builder.info.splits['train'].num_examples
            self.test_size = dataset_builder.info.splits['test'].num_examples

        print(f'Dataset {config.data_path} of size {self.train_size} / {self.test_size}')

    @property
    def train_itr_per_epoch(self):
        return self.train_size // self.config.batch_size

    @property
    def test_itr_per_epoch(self):
        return self.test_size // self.config.batch_size

    def create_iterator(self, train, repeat=True, prefetch=True):
        if osp.exists(self.config.data_path) or self.config.data_path.startswith('gs://'):
            load_fn = get_load_fn(self.config.data_path)
            return load_fn(self.config, train=train, repeat=repeat, xmap=self.xmap, prefetch=prefetch)[0]

        if self.xmap:
            num_data = jax.device_count() // self.config.num_shards
            num_data_local = max(1, jax.local_device_count() // self.config.num_shards)
            if num_data >= jax.process_count():
                num_ds_shards = jax.process_count()
                ds_shard_id = jax.process_index()
            else:
                num_ds_shards = num_data
                n_proc_per_shard = jax.process_count() // num_data
                ds_shard_id = jax.process_index() // n_proc_per_shard
        else:
            num_data_local = jax.local_device_count()
            num_ds_shards = jax.process_count()
            ds_shard_id = jax.process_index()

        batch_size = self.config.batch_size // num_ds_shards
        split_name = 'train' if train else 'test'

        if 'encoded_ae_h5py_dataset' in self.config.data_path:
            seq_len = self.config.seq_len
            def process(features):
                video = tf.cast(features['video'], tf.float32)
                T = tf.shape(video)[0]
                start_idx = tf.random.uniform((), 0, T - seq_len + 1, dtype=tf.int32)
                video = tf.identity(video[start_idx:start_idx + seq_len])
                actions = tf.cast(features['actions'], tf.int32)
                actions = tf.identity(actions[start_idx:start_idx + seq_len])
                return dict(video=video, actions=actions)
        elif 'encoded_h5py_dataset' in self.config.data_path:
            seq_len = self.config.seq_len
            def process(features):
                video = tf.cast(features['video'], tf.int32)
                T = tf.shape(video)[0]
                start_idx = tf.random.uniform((), 0, T - seq_len + 1, dtype=tf.int32)
                video = tf.identity(video[start_idx:start_idx + seq_len])
                actions = tf.cast(features['actions'], tf.int32)
                actions = tf.identity(actions[start_idx:start_idx + seq_len])
                return dict(video=video, actions=actions)

        dataset_builder = tfds.builder(self.config.data_path,
                data_dir=GCS_PATH if self.config.download else None)
        dataset_builder.download_and_prepare()
        num_examples = dataset_builder.info.splits[split_name].num_examples
        split_size = num_examples // num_ds_shards
        itrs_per_epoch = split_size // batch_size
        start = ds_shard_id * split_size
        split = '{}[{}:{}]'.format(split_name, start, start + split_size)
        dataset = dataset_builder.as_dataset(split=split)

        if self.config.cache:
            dataset = dataset.cache()

        options = tf.data.Options()
        options.threading.private_threadpool_size = 48
        options.threading.max_intra_op_parallelism = 1
        dataset = dataset.with_options(options)
        dataset = dataset.map(process)

        if repeat:
            dataset = dataset.repeat()
        if train:
            dataset = dataset.shuffle(batch_size * 32, seed=self.config.seed)

        dataset = dataset.batch(batch_size, drop_remainder=True)
        dataset = dataset.prefetch(batch_size)

        def prepare_tf_data(xs):
            def _prepare(x):
                x = x._numpy()
                x = x.reshape((num_data_local, -1) + x.shape[1:])
                return x
            xs = jax.tree_map(_prepare, xs)
            if 'actions' not in xs:
                xs['actions'] = None
            return xs

        iterator = map(prepare_tf_data, dataset)

        if prefetch:
            iterator = jax_utils.prefetch_to_device(iterator, 2)

        return iterator
