import itertools
import time

import tqdm
from multinav.cli_config import TrainingConfig

# import jax

# from jax.experimental.compilation_cache import compilation_cache as cc

from multinav.data.dataset import (
    apply_dataset_transforms,
    load_single_dataset,
    make_dataset,
)

import flax
import tensorflow as tf

import tyro

# from multinav.utils.jax_utils import get_devices, split_and_prefetch, split_to_devices
from absl import flags, app, logging as absl_logging

from multinav.model.model_base import MultiNavModel


def main():
    flax.linen.enable_named_call()

    args = tyro.cli(TrainingConfig)
    model_config: MultiNavModel.Config = args.config

    batch_size = args.batch_size_per_device * 4  # num_devices

    train_dataset, _ = make_dataset(
        args.data_path,
        batch_size=batch_size,
        num_steps_predict=model_config.actor_head_config.required_batch_predict_horizon,
        history_size=model_config.backbone_config.max_seq_len
        + model_config.actor_head_config.required_extra_batch_seq_len,
    )

    ds_iter = train_dataset.as_numpy_iterator()

    # Warmup
    for _ in tqdm.trange(10, desc="Warmup"):
        _ = next(ds_iter)

    # Enable TF profiler trace
    tf.profiler.experimental.start("/tmp/profiler")

    t0 = time.time()
    N = 10
    for _ in tqdm.trange(N):
        batch = next(ds_iter)
        print(batch["action_chunked"].shape)
        for i in range(4):
            print(
                f"a{i}",
                batch["action_chunked"][..., i].mean(),
                batch["action_chunked"][..., i].min(),
                batch["action_chunked"][..., i].max()
            )
    print(f"Time: {(time.time() - t0) / N:.2f}s/sample")

    tf.profiler.experimental.stop()


if __name__ == "__main__":
    main()
