import sys, re, argparse
from pathlib import Path
import wandb
import pytorch_lightning as pl

DIRECTORY = Path(__file__).parent
ROOT_DIR = DIRECTORY.parent
sys.path.insert(1, str(ROOT_DIR))
from _utils import reduce_checkpoints
from _abstract_task.run import get_trainer
from GLUE.configuration import glue_config, NUM_EPOCHS
from GLUE.data import DATA_MODULES
from GLUE.training import MODEL_MODULES


NUM_RUNS = 10
ALL_TASKS = list(DATA_MODULES.keys())
TRAINABLE_TASKS = [task for task in ALL_TASKS if not task.startswith("a")]


def run(args, tasks: list[str], run_range: tuple[int, int]):

    # Run different tasks
    for task in tasks:

        # Get configuration
        task_overwrites = [f"task={task}", f"num_epochs={NUM_EPOCHS[task]}"]
        config = glue_config(cli_overwrites=task_overwrites + args.config_overwrites)

        # Base of random seeds
        basic_seed = config.get("seed", None)
        if basic_seed is None and config.load_ckpt_path:
            basic_seed = int(re.findall(r"_seed=([0-9]+)", config.load_ckpt_path)[-1])

        # Run the same task with different random seeds
        for th in range(run_range[0], run_range[1]):

            if config.scale == "large" and task in ["qqp", "mnli"] and th >= 5:
                continue  # save time from low-varaiance high-computation tasks when at Large level

            # Data/Training module
            random_seed = basic_seed * NUM_RUNS + th
            pl.seed_everything(random_seed)
            dm = DATA_MODULES[task](config)
            glue_model = MODEL_MODULES[task].from_pretrained(config, except_classifier=args.force_new_classifier)

            # Trainer
            metrics_str = "_".join(f"{{{m}:.5f}}" for m in glue_model.METRICS)
            default_ckpt_name = f"{task}_seed={random_seed}_{metrics_str}"
            # e.g. ..._{accuracy:.2f}_{f1:.2f} -> ..._accuracy=0.89_f1=0.91.ckpt
            trainer = get_trainer(
                config,
                logging_project_name="SuperGLUE",
                suggested_checkpoint_name=default_ckpt_name,
                num_sanity_val_steps=0,
            )

            # Fit
            pl.seed_everything(random_seed)
            if args.test:
                trainer.test(glue_model, datamodule=dm)
            else:
                trainer.fit(glue_model, datamodule=dm)
                if config.logger == "wandb":
                    # End the current logging run to start a new logging run
                    wandb.join()

        # Remove and left only top k score checkpoints
        if not args.test:
            num_runs = NUM_RUNS
            if config.scale == "large" and task in ["qqp", "mnli"]:
                num_runs = 5
            reduce_checkpoints(
                checkpoint_directory=trainer.checkpoint_callback.dirpath,
                task_name=task,
                expected_num_checkpoints=num_runs,
            )


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "tasks",
        type=str,
        help="Sub-tasks to train, specified in lowercase identifiers delimited by comma or 'all' to train all tasks",
    )
    parser.add_argument(
        "--range",
        type=str,
        default=f"0-{NUM_RUNS}",
        help="Execute [start,end)th runs, specified in 'start-end'. This option is useful when you want to multi-process against repeated runs.",
    )
    parser.add_argument(
        "--test",
        action="store_true",
        help="Do testing if true, otherwise do training and validation.",
    )
    parser.add_argument(
        "--force_new_classifier",
        action="store_true",
        help="Force model to use a newly initialized classifier. This is for finetuning from intermediate task checkpoint.",
    )
    parser.add_argument("config_overwrites", type=str, nargs="+")
    args = parser.parse_args()

    # Infer which tasks to do
    if args.tasks == "all":
        tasks = ALL_TASKS if args.test else TRAINABLE_TASKS
    else:
        tasks = re.split(r",\s?", args.tasks)
        assert all(task in ALL_TASKS for task in tasks)

    # Infer how much runs to do
    run_range = tuple(int(i) for i in args.range.split("-"))
    assert 0 <= run_range[0] and run_range[1] <= NUM_RUNS
    if args.test:
        run_range = (0, 1)

    run(args, tasks=tasks, run_range=run_range)