from typing import Optional
import jax
import flax
from flax.struct import dataclass
import dataclasses
import tensorflow as tf
import tyro
import wandb
import os
from orbax.checkpoint import (
    CheckpointManager,
    CheckpointManagerOptions,
    StandardCheckpointer,
)
from multinav import cli_config

from multinav.training.eval_loop import do_validation_loop
from multinav.training.setup import make_model_and_dataset
from multinav.training.train_loop import do_train_loop
from multinav.utils.jax_utils import get_devices, split_and_prefetch

from absl import logging as absl_logging

from multinav.utils.sys_utils import input_with_timeout

import jax.experimental.compilation_cache.compilation_cache as cc
cc.initialize_cache(<CACHE DIR)

@dataclass
class TrainingConfig(cli_config.TrainingConfig):
    notes: Optional[str] = None


def main():
    tf.get_logger().setLevel("WARNING")
    absl_logging.set_verbosity("WARNING")

    args = tyro.cli(TrainingConfig)

    if args.notes is None:
        # Get notes from user, with timeout after 1 minute
        notes = input_with_timeout("Notes: ", 60)
    else:
        notes = args.notes

    # Set up devices
    device_list = jax.local_devices()
    if args.devices is not None:
        device_list = get_devices(device_list, args.devices)
    num_devices = len(device_list)

    # Set up model
    model_config = args.config
    rng = jax.random.PRNGKey(args.seed)
    batch_size = args.batch_size_per_device * num_devices
    model, train_dataset, val_datasets = make_model_and_dataset(
        model_config=model_config,
        data_config=args.data_config,
        batch_size=batch_size,
        epochs=args.num_epochs,
        eval_interval=args.eval_interval,
        device_list=device_list,
        rng=rng,
    )

    wandb.init(
        project="multinav_spatiotemporal",
        config=dataclasses.asdict(args),
        notes=notes,
    )

    checkpoint_manager = None
    if args.model_path is not None:
        checkpoint_dir = os.path.join(args.model_path, wandb.run.name)
        checkpoint_manager = CheckpointManager(
            directory=checkpoint_dir,
            checkpointers=StandardCheckpointer(),
            options=CheckpointManagerOptions(
                save_interval_steps=args.save_interval,
                max_to_keep=args.models_to_keep,
            ),
        )

    sharded_rngs = jax.random.split(rng, num_devices)
    sharded_rngs = jax.device_put_sharded(tuple(sharded_rngs), devices=device_list)
    train_data = split_and_prefetch(train_dataset, device_list)

    step = 0
    for epoch in range(args.num_epochs):
        eval_model = flax.jax_utils.unreplicate(model)
        for val_dataset_name, val_dataset in val_datasets.items():
            do_validation_loop(
                val_dataset_name=val_dataset_name,
                val_data_iter=val_dataset,
                model=eval_model,
                model_config=model_config,
                device=device_list[0],
                step=step,
            )

        model, step, sharded_rngs = do_train_loop(
            step=step,
            num_steps=args.eval_interval,
            checkpoint_manager=checkpoint_manager,
            model=model,
            sharded_rngs=sharded_rngs,
            train_data=train_data,
            device_list=device_list,
            epoch=epoch,
            total_epochs=args.num_epochs,
            log_interval=args.log_interval,
            save_interval=args.save_interval,
        )

    wandb.finish()


if __name__ == "__main__":
    main()
