"""Main training script."""

from vmoe import app
import orbax.checkpoint
from .trainer import (
    ml_collections,
    Mesh,
    metric_writers,
    input_pipeline,
    logging,
    pjit_utils,
    create_checkpoint_manager,
    create_flax_model,
    restore_or_create_train_state,
    ThreadPool,
    get_dataset_iterator,
    functools,
    jax,
    multihost_utils,
    os,
    pjit,
    jnp,
    make_create_train_state_fn
)
import datetime
from vmoe.patcher.backup import backup


def check_stable(
    config: ml_collections.ConfigDict,
    workdir: str,
    mesh: Mesh,
    writer: metric_writers.MetricWriter,
):
    """Trains a model and evaluates it periodically."""

    print(config)
    
    timestamp = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S%z")
    # backup source code for replication
    backup(os.path.join(workdir, timestamp))

    datasets = input_pipeline.get_datasets(config.dataset)
    if "train" not in datasets:
        raise KeyError(
            f'You must have a "train" variant of the dataset. '
            f"Available variants are {sorted(datasets.keys())!r}"
        )

    # Get the global shape of the image array.
    datataset_element_shape_dtype = pjit_utils.get_dataset_shape_dtype_struct(
        datasets["train"]
    )

    print(f'{datataset_element_shape_dtype=}')

    ckpt_manager = create_checkpoint_manager(
        workdir=workdir, **config.get("save_checkpoint", {})
    )

    model = create_flax_model(config=config.model, deterministic=True)

    train_state_initialize_fn = make_create_train_state_fn(
        model=model,
        optimizer_config=config.optimizer,
        input_shape_dtypes=(datataset_element_shape_dtype["image"],),
        train_steps=0,
        extra_rng_keys=tuple(config.get("extra_rng_keys", [])),
        seed=config.get("seed", 0),
    )
    train_state, last_seen_index = restore_or_create_train_state(
        ckpt_manager=ckpt_manager,
        initialize_fn=train_state_initialize_fn,
        axis_resources_regexes=config.params_axis_resources,
        thread_pool=ThreadPool(),
        initialization_kwargs=config.get("initialization"),
    )

    all_steps = ckpt_manager.all_steps()

    # NOTE: CHANGE HERE!!!
    step = all_steps[0] if all_steps else None
    assert step is not None

    def _array_restore_args_fn(x: jax.ShapeDtypeStruct):
      return orbax.checkpoint.ArrayRestoreArgs(
          dtype=x.dtype, sharding=x.sharding, global_shape=x.shape)
    restore_kwargs = {
        'state': {
            'restore_args': jax.tree_util.tree_map(
                _array_restore_args_fn, train_state),
        },
    }
    items = ckpt_manager.restore(
        step,
        items={
            'state': train_state,
            'dataset_iterator': {'last_seen_index': 0},
        },
        restore_kwargs=restore_kwargs)
    from copy import deepcopy
    old_train_state = deepcopy(items['state'])

    def get_diff(tree1, tree2):
        diff = jax.tree_util.tree_map(lambda x, y: jnp.abs(x - y).max(), tree1, tree2)
        return diff
    
    step = all_steps[-1] if all_steps else None
    assert step is not None

    items = ckpt_manager.restore(
        step,
        items={
            'state': train_state,
            'dataset_iterator': {'last_seen_index': 0},
        },
        restore_kwargs=restore_kwargs)
    new_train_state = items['state']

    print(jax.tree.leaves(get_diff(old_train_state.params, new_train_state.params)))

    train_state = new_train_state

    init_step = int(train_state.step)
    logging.info("Selected step = %d", step)
    logging.info("Initial step = %d", init_step)

    apply_fn = pjit.pjit(functools.partial(model.apply, {'params': train_state.params}))

    # tr_iter = get_dataset_iterator(
    #     # dataset=datasets["train"],
    #     # dataset=datasets["val"],
    #     dataset=datasets["test"],
    #     prefetch_size=config.dataset.train.get("prefetch_device", 1),
    #     mesh=mesh,
    #     # last_seen_index=last_seen_index,
    # )

    tr_iter = datasets["test"]

    for it, batch in enumerate(tr_iter):
        logits, metrics = apply_fn(batch["image"])
        print(jnp.argmax(batch["labels"], axis=-1))
        print(jnp.count_nonzero(jnp.argmax(logits, axis=-1) == jnp.argmax(batch["labels"], axis=-1)))
        for k, v in metrics.items():
            if not isinstance(v, jnp.ndarray):
                clusters = v['clusters']
                print(k, clusters)

        print(f'{it=} -- {logits.shape=}')
        break

    print(logits)


    multihost_utils.sync_devices("training:completed")
    logging.info("Completed.")


if __name__ == "__main__":
    app.run(check_stable)
