import sys, argparse, os
from datetime import datetime
from pathlib import Path
import pytorch_lightning as pl

DIRECTORY = Path(__file__).parent
ROOT_DIR = DIRECTORY.parent
sys.path.insert(1, str(ROOT_DIR))
from _abstract_task.run import get_trainer
from Pretraining.configuration import pretraining_config
from Pretraining.data import PretrainingHFDataModule
from Pretraining.training import (
    Pretraining,
    convert_zero_checkpoint_to_lightning_checkpoint,
)


def run(config, args):

    pl.seed_everything(config.seed)

    # Data
    dm = PretrainingHFDataModule(config)
    if config.datasets_num_proc > 1 and config.accelerator == "cpu":
        # Using multiprocess (num_proc > 1) to preprocess data is way more faster and recommended.
        # But it results in num_proc splits of data, thus we want merge the splits into a single arrow file.
        dm.prepare_data()  # preprocess
        dm.setup(None)  # automatically load and concatenate all splits
        data_dir = os.path.abspath(
            "~/.cache/huggingface/datasets/openwebtext/plain_text/1.0.0/85b3ae7051d2d72e7c5fdf6dfb462603aaa26e9ed506202bf3a24d261c6c40a1"
        )
        save_dir = f"{data_dir}/preprocessed_{config.max_sequence_length}"
        dm.datasets["train"].dataset.save_to_disk(save_dir)
        os.rename(f"{save_dir}/dataset.arrow", f"{data_dir}/preprocessed.arrow")
        return

    # Training Module
    if config.load_ckpt_path:
        model = Pretraining.from_pretrained(config)
    else:
        model = Pretraining(config)

    # Trainer
    default_ckpt_name = f"{args.setting}_{config.scale}_seed={config.seed}"
    trainer = get_trainer(
        config,
        logging_project_name="Pretraining",
        log_learning_rate=True,
        suggested_checkpoint_name=default_ckpt_name,
        save_every_epoch=(config.scale == "large"),
    )

    # Fit
    ## fix data order from sampler which is initiated when training starts
    pl.seed_everything(config.seed)
    start_time = datetime.now()
    print(f"Training process {os.getpid()} started at {start_time}.")
    trainer.fit(model, datamodule=dm, ckpt_path=args.resume_ckpt_path)
    print(f"Total run time: {start_time - datetime.now()}")

    # Convert all unconverted deepspeed style checkpoints into lightning checkpoints, if using deepspeed
    if config.strategy and "deepspeed" in config.strategy and trainer.global_rank == 0:
        ckpt_dir = Path(trainer.checkpoint_callback.dirpath)
        for ds_ckpt_dir in ckpt_dir.glob("*.deepspeed"):
            if ds_ckpt_dir.with_suffix(".ckpt").exists():
                continue  # already converted
            convert_zero_checkpoint_to_lightning_checkpoint(ds_ckpt_dir)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "setting",
        default="electra_tsp2",
        type=str,
        help="The name of configuration setting function in Pretraining/configuration.py",
    )
    parser.add_argument(
        "--dev",
        action="store_true",
        help="Automatically setting the basic value for required options for convenience of developing.",
    )
    parser.add_argument(
        "--resume_ckpt_path", type=str, default=None,
    )
    parser.add_argument(
        "config_overwrites",
        nargs="*",
        help="Overwrites for configuration which will be passed to Omegaconf.from_cli",
    )
    args = parser.parse_args()
    config_overwrites = args.config_overwrites
    if args.dev:
        config_overwrites = ["scale=small", "seed=1", "devices=1"] + config_overwrites
    config = pretraining_config(
        setting_fn=args.setting, cli_overwrites=config_overwrites,
    )
    run(config, args)
