"""Main training script."""

from vmoe import app
from .trainer import (
    ml_collections,
    Mesh,
    metric_writers,
    input_pipeline,
    get_train_steps_and_epochs,
    logging,
    pjit_utils,
    create_checkpoint_manager,
    create_flax_model,
    restore_or_create_train_state,
    ThreadPool,
    get_dataset_iterator,
    get_loss_fn,
    create_tree_summarizer,
    train_step,
    wrap_train_step_with_adversarial_attack,
    wrap_train_step_with_mixup,
    create_profile_hook,
    create_progress_hook,
    create_evaluation_hook,
    create_fewshot_hook,
    make_train_cost_fn,
    utils,
    multihost_utils,
    time,
    os,
    jax,
    tf,
    functools,
    pjit,
    make_create_train_state_fn
)
import datetime
import wandb
from functools import wraps
from vmoe.patcher.backup import backup


# patch the logger to use wandb
def wandb_wrap(writer):
    @wraps(writer)
    def wrapped(*args, **kwargs):
        wandb.log(args[1], step=args[0])
        return writer(*args, **kwargs)
    return wrapped


def train_and_evaluate(
    config: ml_collections.ConfigDict,
    workdir: str,
    mesh: Mesh,
    writer: metric_writers.MetricWriter,
):

    print(config)
    
    """Trains a model and evaluates it periodically."""
    timestamp = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S%z")

    if config.use_wandb:
        # initialize wandb
        wandb.init()
        # override the writer with wandb
        writer.write_scalars = wandb_wrap(writer.write_scalars)

    # backup source code for replication
    backup(os.path.join(workdir, timestamp))

    datasets = input_pipeline.get_datasets(config.dataset)
    if "train" not in datasets:
        raise KeyError(
            f'You must have a "train" variant of the dataset. '
            f"Available variants are {sorted(datasets.keys())!r}"
        )
    train_examples = input_pipeline.get_data_num_examples(config.dataset.train)
    train_batch_size = config.dataset.train.batch_size
    train_steps, train_epochs = get_train_steps_and_epochs(
        train_steps=config.get("train_steps"),
        train_epochs=config.get("train_epochs"),
        train_batch_size=train_batch_size,
        train_examples=train_examples,
    )
    logging.info(
        "Training for %d steps (%g epochs) over %d examples, with a "
        "batch size of %d",
        train_steps,
        train_epochs,
        train_examples,
        train_batch_size,
    )

    # Get the global shape of the image array.
    datataset_element_shape_dtype = pjit_utils.get_dataset_shape_dtype_struct(
        datasets["train"]
    )

    ckpt_manager = create_checkpoint_manager(
        workdir=workdir, **config.get("save_checkpoint", {})
    )
    train_state_initialize_fn = make_create_train_state_fn(
        model=create_flax_model(config=config.model, deterministic=False),
        optimizer_config=config.optimizer,
        input_shape_dtypes=(datataset_element_shape_dtype["image"],),
        train_steps=train_steps,
        extra_rng_keys=tuple(config.get("extra_rng_keys", [])),
        seed=config.get("seed", 0),
    )
    train_state, last_seen_index = restore_or_create_train_state(
        ckpt_manager=ckpt_manager,
        initialize_fn=train_state_initialize_fn,
        axis_resources_regexes=config.params_axis_resources,
        thread_pool=ThreadPool(),
        initialization_kwargs=config.get("initialization"),
    )
    init_step = int(train_state.step)
    logging.info("Initial step = %d", init_step)
    tr_iter = get_dataset_iterator(
        dataset=datasets["train"],
        prefetch_size=config.dataset.train.get("prefetch_device", 1),
        mesh=mesh,
        last_seen_index=last_seen_index,
    )
    train_loss_fn, eval_loss_fn, label_pred_fn = get_loss_fn(**config.loss)
    summarizer = create_tree_summarizer(config.get("summarize_arrays"))
    train_step_fn = functools.partial(
        train_step,
        loss_fn=train_loss_fn,
        microsteps=config.get("microsteps"),
        summarizer=summarizer,
    )
    if config.get("adversarial", {}):
        adversarial_config = config.adversarial.to_dict()
        train_step_fn = wrap_train_step_with_adversarial_attack(
            train_step_fn, train_loss_fn, **adversarial_config
        )
    # If mixup options are defined, wrap the train_step_fn with mixup.
    if config.get("mixup", {}):
        mixup_config = config.mixup.to_dict()
        train_step_fn = wrap_train_step_with_mixup(
            train_step_fn,
            partition_spec=jax.sharding.PartitionSpec(
                mesh.axis_names,
            ),
            **mixup_config,
        )

    train_step_pjit = pjit.pjit(
        fun=train_step_fn,
        out_shardings=(
            jax.tree_util.tree_map(lambda x: x.sharding, train_state),
            None,
        ),
        donate_argnums=(0, 1, 2),
    )

    # Setup hooks.
    profile_hook = create_profile_hook(workdir=workdir, **config.get("profile", {}))
    progress_hook = create_progress_hook(
        writer=writer,
        first_step=init_step + 1,
        train_steps=train_steps,
        **config.get("report_progress", {}),
    )
    evaluation_hook, config_model_eval = create_evaluation_hook(
        base_model_config=config.model.copy_and_resolve_references(),
        writer=writer,
        progress_hook=progress_hook,
        datasets={name: ds for name, ds in datasets.items() if name != "train"},
        loss_fn=eval_loss_fn,
        label_pred_fn=label_pred_fn,
        first_step=init_step + 1,
        train_steps=train_steps,
        extra_rng_keys=config.get("extra_rng_keys", []),
        **config.get("evaluate", {}),
    )
    fewshot_hook, _ = create_fewshot_hook(
        base_model_config=config_model_eval,
        writer=writer,
        progress_hook=progress_hook,
        first_step=init_step + 1,
        train_steps=train_steps,
        extra_rng_keys=config.get("extra_rng_keys", []),
        **config.get("fewshot", {}),
    )

    # Run checkpoint hook just before starting the loop. This will save the train
    # state at initialization.
    def _save_checkpoint(step, ts, it, force=False):
        last_seen_index = step * train_batch_size
        with progress_hook.timed("ckpt", wait_jax_async_dispatch=False):
            ckpt_manager.save(
                step,
                items={
                    "state": ts,
                    "dataset_iterator": {"last_seen_index": last_seen_index},
                },
                force=force,
            )

    if init_step == 0 and not tf.io.gfile.exists(os.path.join(workdir, "ckpt/0")):
        multihost_utils.sync_devices("training:ckpt-first")
        _save_checkpoint(init_step, train_state, tr_iter, force=True)
    # Explicitly compile train_step here.
    t0 = time.time()
    train_step_pjit = train_step_pjit.lower(
        train_state,
        datataset_element_shape_dtype["image"],
        datataset_element_shape_dtype["labels"],
    ).compile()
    t1 = time.time()
    # Report compilation time, and flops and optimal seconds per step and device.
    writer.write_scalars(init_step + 1, {"train/compile_secs": t1 - t0})
    train_step_flops_per_device, train_step_seconds_per_device = (
        utils.get_flops_and_seconds_per_device(train_step_pjit)
    )
    if train_step_flops_per_device:
        writer.write_scalars(
            init_step + 1, {"train/step_flops_per_device": train_step_flops_per_device}
        )
    if train_step_seconds_per_device:
        writer.write_scalars(
            init_step + 1,
            {"train/step_seconds_per_device": train_step_seconds_per_device},
        )
    train_cost_fn = make_train_cost_fn(train_step_pjit)
    for step, batch in zip(range(init_step + 1, train_steps + 1), tr_iter):
        profile_hook(step)
        with jax.profiler.StepTraceAnnotation("train", step_num=step):
            train_state, metrics = train_step_pjit(
                train_state, batch["image"], batch["labels"]
            )
        progress_hook(
            step,
            scalar_metrics=(
                train_cost_fn(step) | {f"train/{k}": v for k, v in metrics.items()}
            ),
        )
        _save_checkpoint(step, train_state, tr_iter)
        evaluation_hook(step, params=train_state.params, **train_cost_fn(step))
        fewshot_hook(
            step, variables={"params": train_state.params}, **train_cost_fn(step)
        )
    ckpt_manager.wait_until_finished()
    if not tf.io.gfile.exists(os.path.join(workdir, f"ckpt/{train_steps}")):
        multihost_utils.sync_devices("training:ckpt-last")
        _save_checkpoint(train_steps, train_state, tr_iter, force=True)
        ckpt_manager.wait_until_finished()
    multihost_utils.sync_devices("training:completed")
    logging.info("Training completed.")


if __name__ == "__main__":
    app.run(train_and_evaluate)
