from dataclasses import dataclass
import dlimp
from dlimp.dataset import DLataset
import tensorflow as tf
from functools import partial
from typing import Dict, Optional, List, Tuple
import tensorflow_datasets as tfds

from multinav.data.transforms import (
    add_concatenated_actions,
    batch_decode_images,
    correct_action_keys,
    normalize_images,
    normalize_actions,
    normalized_pose,
    subsample_multiple_trajectories_dynamic,
    split_goal_obs,
    remove_strings,
    sample_negatives,
)


WAYPOINT_SPACING = {
    "recon": 0.25,
    "cory_hall": 0.06,
    "seattle": 0.35,
    "tartan_drive": 0.72,
    "sacson": 0.255,
    "go_stanford": 0.12,
    "scand": 0.38,
    "dep_40k": 0.20,
}


def apply_dataset_transforms(
    ds: dlimp.DLataset, action_window_size: int, history_size: int
) -> dlimp.DLataset:
    ds = ds.filter(
        lambda x: tf.shape(x["_len"])[0] > history_size + action_window_size,
        name="filter_short_trajectories",
    )
    assert isinstance(ds, DLataset)

    ds = ds.map(correct_action_keys, name = "correct_action_keys", num_parallel_calls = None)

    # Important to subsample before decoding images, for memory reasons
    # Concatenated actions should be added before subsampling, so actions from outside the subsampled window are included
    ds = ds.map(
        partial(add_concatenated_actions, window_size=action_window_size),
        name="add_action_window",
        num_parallel_calls=None,
    )
    assert isinstance(ds, DLataset)

    ds: DLataset = ds.interleave(
        partial(
            subsample_multiple_trajectories_dynamic,
            size=history_size,
        ),
        name="subsample_trajectories",
        num_parallel_calls=None,
    )
    assert isinstance(ds, DLataset)

    ds = ds.map(batch_decode_images, name="decode_images", num_parallel_calls=None)
    ds = ds.map(
        partial(dlimp.transforms.resize_images, match=["image"], size=(128, 128)),
        num_parallel_calls = None,
    )
    ds = ds.map(normalize_images, name="normalize_images", num_parallel_calls=None)
    ds = ds.map(remove_strings, name="remove_strings", num_parallel_calls=None)
    ds = ds.map(normalize_actions, name="normalize_actions", num_parallel_calls=None)
    ds = ds.map(normalized_pose, name="normalized_pose", num_parallel_calls=None)

    # Remove the goal observation from the trajectory and keep it separate.
    # This takes up additional memory (might want to move to after the shuffle buffer?).
    ds = ds.map(split_goal_obs, name="split_goal_obs", num_parallel_calls=None)

    return ds


def load_single_dataset(
    dataset_name: str,
    sub_dataset_name: str,
    data_dir: Optional[str],
    train: bool,
    end_terminal: bool,
) -> dlimp.DLataset:
    # No reason _not_ to shuffle now, but we'll do most of it later
    builder = tfds.builder(f"{dataset_name}/{sub_dataset_name}", data_dir=data_dir)
    if "val" not in builder.info.splits:
        split = "train[:95%]" if train else "train[95%:]"
    else:
        split = "train" if train else "val"

    ds: dlimp.DLataset = dlimp.DLataset.from_rlds(
        builder,
        split=split,
    )

    def add_dataset_specific_info(frame):
        frame["waypoint_spacing"] = tf.broadcast_to(
            tf.constant(WAYPOINT_SPACING[sub_dataset_name], dtype=tf.float32),
            tf.shape(frame["_len"]),
        )
        frame["end_is_terminal"] = tf.broadcast_to(
            tf.constant(end_terminal), tf.shape(frame["_len"])
        )
        return frame

    ds = ds.map(add_dataset_specific_info, num_parallel_calls=None)
    assert isinstance(ds, dlimp.DLataset)

    if train:
        ds = ds.repeat()
    assert isinstance(ds, dlimp.DLataset)

    return ds


@dataclass
class DatasetConfig:
    dataset_folder: Optional[str] = None
    base_dataset_name: str = "gnm_dataset"
    negative_fraction: float = 0.0
    dataset_weights = {
        "recon": 0.3,
        "cory_hall": 0.15,
        "go_stanford": 0.15,
        "scand": 0.1,
        "sacson": 0.2,
        "tartan_drive": 0.05,
        "seattle": 0.05,
    }


def load_datasets(
    config: DatasetConfig,
    batch_size: int,
    action_window_size: int,
    history_size: int,
    train: bool,
    shuffle_buffer_size: int = 10000,
):
    datasets = [
        apply_dataset_transforms(
            load_single_dataset(
                config.base_dataset_name,
                dataset_name,
                config.dataset_folder,
                train,
                end_terminal=(dataset_name == "recon"),
            ),
            action_window_size,
            history_size,
        )
        for dataset_name in config.dataset_weights.keys()
    ]
    ds = dlimp.DLataset.sample_from_datasets(
        datasets, config.dataset_weights.values(), rerandomize_each_iteration=True
    )

    ds = ds.shuffle(shuffle_buffer_size).batch(
        batch_size,
        drop_remainder=True,
    )

    if config.negative_fraction > 0:
        ds = ds.map(
            partial(sample_negatives, negative_frac=config.negative_fraction),
            num_parallel_calls=None,
        )

    return ds


def make_dataset(
    config: DatasetConfig,
    batch_size: int,
    num_steps_predict: int,
    history_size: int,
) -> Tuple[dlimp.DLataset, Dict[str, dlimp.DLataset]]:
    return load_datasets(
        config=config,
        batch_size=batch_size,
        action_window_size=num_steps_predict,
        history_size=history_size,
        train=True,
    ), {
        k: apply_dataset_transforms(
            load_single_dataset(
                dataset_name=config.base_dataset_name,
                sub_dataset_name=k,
                data_dir=config.dataset_folder,
                train=False,
                end_terminal=(k == "recon"),
            ),
            action_window_size=num_steps_predict,
            history_size=history_size,
        ).batch(batch_size, drop_remainder=False)
        for k in config.dataset_weights.keys()
    }


if __name__ == "__main__":
    import time
    import numpy as np
    import itertools

    train_dataset, eval_datasets = make_dataset(
        "gs://gnm-data-c2",
        256,
        5,
        8,
    )

    iterator = train_dataset.iterator()
    print(tf.nest.map_structure(np.shape, next(iterator)))

    for _ in range(10):
        _ = next(iterator)

    t0 = time.time()
    for _ in range(100):
        _ = next(iterator)
    print(f"Time: {time.time() - t0}")
