"""Main training script."""

from vmoe import app
import orbax.checkpoint
from .trainer import (
    ml_collections,
    Mesh,
    metric_writers,
    input_pipeline,
    logging,
    pjit_utils,
    create_checkpoint_manager,
    create_flax_model,
    restore_or_create_train_state,
    ThreadPool,
    get_dataset_iterator,
    functools,
    jax,
    multihost_utils,
    os,
    pjit,
    jnp,
    make_create_train_state_fn
)


def load_imbalancing(
    config: ml_collections.ConfigDict,
    workdir: str,
    mesh: Mesh,
    writer: metric_writers.MetricWriter,
):
    """Trains a model and evaluates it periodically."""

    # print(config)
    
    # timestamp = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S%z")
    # # backup source code for replication
    # backup(os.path.join(workdir, timestamp))

    datasets = input_pipeline.get_datasets(config.dataset)
    if "train" not in datasets:
        raise KeyError(
            f'You must have a "train" variant of the dataset. '
            f"Available variants are {sorted(datasets.keys())!r}"
        )

    # Get the global shape of the image array.
    datataset_element_shape_dtype = pjit_utils.get_dataset_shape_dtype_struct(
        datasets["train"]
    )

    print(f'{datataset_element_shape_dtype=}')

    ckpt_manager = create_checkpoint_manager(
        workdir=workdir, **config.get("save_checkpoint", {})
    )

    model = create_flax_model(config=config.model, deterministic=True)

    train_state_initialize_fn = make_create_train_state_fn(
        model=model,
        optimizer_config=config.optimizer,
        input_shape_dtypes=(datataset_element_shape_dtype["image"],),
        train_steps=0,
        extra_rng_keys=tuple(config.get("extra_rng_keys", [])),
        seed=config.get("seed", 0),
    )
    train_state, last_seen_index = restore_or_create_train_state(
        ckpt_manager=ckpt_manager,
        initialize_fn=train_state_initialize_fn,
        axis_resources_regexes=config.params_axis_resources,
        thread_pool=ThreadPool(),
        initialization_kwargs=config.get("initialization"),
    )

    all_steps = ckpt_manager.all_steps()

    # NOTE: CHANGE HERE!!!
    step = all_steps[-1] if all_steps else None
    print(f'\nAll steps: {all_steps}\nStep: {step}\n')
    assert step is not None

    def _array_restore_args_fn(x: jax.ShapeDtypeStruct):
      return orbax.checkpoint.ArrayRestoreArgs(
          dtype=x.dtype, sharding=x.sharding, global_shape=x.shape)
    restore_kwargs = {
        'state': {
            'restore_args': jax.tree_util.tree_map(
                _array_restore_args_fn, train_state),
        },
    }
    items = ckpt_manager.restore(
        step,
        items={
            'state': train_state,
            'dataset_iterator': {'last_seen_index': 0},
        },
        restore_kwargs=restore_kwargs)

    init_step = int(train_state.step)
    logging.info("Selected step = %d", step)
    logging.info("Initial step = %d", init_step)

    apply_fn = pjit.pjit(functools.partial(model.apply, {'params': train_state.params}))

    # tr_iter = get_dataset_iterator(
    #     # dataset=datasets["train"],
    #     # dataset=datasets["val"],
    #     dataset=datasets["test"],
    #     prefetch_size=config.dataset.train.get("prefetch_device", 1),
    #     mesh=mesh,
    #     # last_seen_index=last_seen_index,
    # )

    for dataset_name, tr_iter in datasets.items():
        # if dataset_name == 'train':
        if dataset_name != 'test':
            continue

        for it, batch in enumerate(tr_iter):
            logits, metrics = apply_fn(batch["image"])
            correct = jnp.count_nonzero(jnp.argmax(logits, axis=-1) == jnp.argmax(batch["labels"], axis=-1))
            total = batch["labels"].shape[0]
            
            print(f'Sanity check: accuracy is {correct/total*100:.2f}% ({correct}/{total})')

            for k, v in metrics.items():
                if not isinstance(v, jnp.ndarray):
                    clusters = v['clusters']
                    embeddings = train_state.params['Encoder'][k]['Moe']['Router']['dense']['kernel']
                    num_experts = embeddings.shape[1]
                    token_count = jnp.bincount(jnp.ravel(clusters), minlength=num_experts)
                    dist = token_count / token_count.sum()
                    print(f'{k=}, {dist.tolist()=}')
            
            break
                    

    multihost_utils.sync_devices("training:completed")
    logging.info("Completed.")


if __name__ == "__main__":
    app.run(load_imbalancing)
