"""Pretrain a Transformer on language modeling."""

import os
import sys

script_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(script_dir)
sys.path.append(project_root)

from collections import defaultdict

from absl import app, flags

from plainlm.engine import TorchEngine
from plainlm.engine_utils import (
    destroy_ddp,
    init_wandb,
    load_config,
    log,
    maybe_load_checkpoint,
    maybe_make_dir,
    print_master,
    pytorch_setup,
    save_checkpoint,
)

from .construct import construct_model
from .data import get_dataloaders
from .test_utils import evaluation

flags.DEFINE_string("config", "config/config.yaml", "Path to config.yaml file.")
flags.DEFINE_integer("job_idx", None, "Job idx for job-array sweeps. From 0 to n-1.")
FLAGS = flags.FLAGS


def main(_):
    CFG_PATH, JOB_IDX = FLAGS.config, FLAGS.job_idx
    cfg, _ = load_config(CFG_PATH, JOB_IDX)

    run_postfix = (
        "type:"
        + cfg.mixer_type
        + "-pos:"
        + cfg.pos_emb
        + "-gate:"
        + cfg.forget_gate
        + "-kernel:"
        + cfg.kernel
        + "-lr:"
        + str(cfg.lr)
    )

    cfg = cfg._replace(wandb_run_name=run_postfix)
    cfg = cfg._replace(exp_name=run_postfix)

    local_rank, world_size, device, master_process = pytorch_setup(cfg)

    if master_process:
        maybe_make_dir(cfg, JOB_IDX)

    if cfg.use_wandb and master_process:
        init_wandb(cfg)

    # Load checkpoint and starting step
    ckpt, micro_step_start = maybe_load_checkpoint(cfg, device)

    # Dataset
    trainloader, validloader, tokenizer, TO_TOKENS = get_dataloaders(cfg)

    # Model
    model, model_cfg = construct_model(cfg)

    # Engine
    engine = TorchEngine(model, cfg, device, local_rank, ckpt, TO_TOKENS=TO_TOKENS)

    # Training
    print_master("=== Start Training! ===")
    metrics = defaultdict(list)
    train_losses = []

    for micro_step, micro_batch in enumerate(trainloader, micro_step_start + 1):
        step = micro_step // cfg.grad_accumulation_steps
        if step > cfg.steps_budget:
            break

        # Train
        train_loss = engine.step(micro_batch)
        train_losses.append(train_loss)

        # Eval
        valid_loss = None

        # Log
        if step % cfg.log_every_steps == 0:
            if master_process:
                log(
                    cfg,
                    metrics,
                    micro_step,
                    train_losses,
                    valid_loss,
                    engine.optimizer,
                    world_size,
                )
        train_losses = []

    # End of training: log and save checkpoint
    print_master("=== Training Completed! ===")
    if master_process and cfg.save_last_checkpoint:
        save_checkpoint(micro_step - 1, model, engine, cfg, JOB_IDX)

    str_acc_mean_list, str_acc_std_list, char_acc_mean_list = evaluation(
        cfg, model, tokenizer, TO_TOKENS
    )

    # DDP slaughtering
    destroy_ddp()


if __name__ == "__main__":
    app.run(main)
