import jax

from jax.experimental.compilation_cache import compilation_cache as cc

cc.initialize_cache(<CACHE DIR>)

import flax
import tensorflow as tf

from multinav.training.args import register_args

from multinav.training.eval_loop import do_validation_loop
from multinav.training.setup import make_model_and_dataset
from multinav.training.train_loop import do_train_loop
from multinav.utils.jax_utils import get_devices, split_and_prefetch, split_to_devices
from multinav.model.model_base import MultiNavModel
from multinav.cli_config import config_dict_to_dataclass

from ml_collections.config_dict import ConfigDict
from absl import flags, app, logging as absl_logging


def main(_):
    tf.get_logger().setLevel("WARNING")
    absl_logging.set_verbosity("WARNING")

    flax.linen.enable_named_call()

    args = flags.FLAGS

    # Set up devices
    device_list = jax.local_devices()
    if args.device is not None:
        device_list = get_devices(device_list, args.device)
    num_devices = len(device_list)

    rng = jax.random.PRNGKey(args.seed)

    batch_size = args.batch_size_per_device * num_devices

    # Set up model
    model_config_dict = args.model_config.to_dict()
    model_config: MultiNavModel.Config = config_dict_to_dataclass(
        model_config_dict, MultiNavModel.Config
    )

    model, train_dataset, _ = make_model_and_dataset(
        model_config=model_config,
        data_path=args.data_path,
        batch_size=batch_size,
        learning_rate=args.learning_rate,
        weight_decay=args.weight_decay,
        warmup_steps=args.warmup_steps,
        epochs=args.epochs,
        eval_interval=args.eval_interval,
        device_list=device_list,
        rng=rng,
    )

    # Initialize W&B now that we have a sanity check on the model
    config_dict = args.flag_values_dict()
    for key, value in config_dict.items():
        if isinstance(value, ConfigDict):
            config_dict[key] = value.to_dict()

    sharded_rngs = jax.random.split(rng, num_devices)
    sharded_rngs = jax.device_put_sharded(tuple(sharded_rngs), devices=device_list)
    train_data = split_and_prefetch(train_dataset, device_list)

    step = 0
    # Warmup
    model, step, sharded_rngs = do_train_loop(
        step=step,
        num_steps=10,
        checkpoint_manager=None,
        model=model,
        sharded_rngs=sharded_rngs,
        train_data=train_data,
        device_list=device_list,
        epoch=0,
        total_epochs=0,
        log_interval=None,
        save_interval=None,
    )

    # Profiling
    with jax.profiler.trace("/tmp/jax-trace-multinav"):
        model, step, sharded_rngs = do_train_loop(
            step=step,
            num_steps=10,
            checkpoint_manager=None,
            model=model,
            sharded_rngs=sharded_rngs,
            train_data=train_data,
            device_list=device_list,
            epoch=0,
            total_epochs=0,
            log_interval=None,
            save_interval=None,
        )


if __name__ == "__main__":
    register_args()
    app.run(main)
