import argparse
import os
from pathlib import Path
import json

import torch.distributed as dist
from open_lm.distributed import world_info_from_env

from training.hyperparameters import available_scales
from training.file_utils import download_val_data, load_ppl_yaml, tok_mult_paths, de_en_paths


def parse_scrub_args():
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--scale",
        type=str,
        required=False,
        default=None,
        choices=available_scales(),
        help="Competition scale.",
    )
    parser.add_argument(
        "--data-config",
        type=str,
        required=False,
        default=None,
        help="Competition scale.",
    )
    parser.add_argument(
        "--logs",
        type=str,
        required=False,
        default=None,
        help="Place to write logs and checkpoints.",
    )
    parser.add_argument(
        "--remote-sync",
        type=str,
        required=False,
        help="If set, experiments will be sync'd to s3.",
        default=None,
    )
    parser.add_argument(
        "--clean-exp",
        default=False,
        action="store_true",
        help="If set, local exp dir will be cleared, only supported if --remote-sync specified.",
    )
    parser.add_argument("--git-db", default="exp_data/models", type=str, help="place to save output model.json")

    parser.add_argument("--workers", type=int, default=2, help="Number of workers for open_lm.")
    parser.add_argument(
        "--precision",
        type=str,
        choices=["amp", "amp_bf16", "amp_bfloat16", "bf16", "fp16", "fp32"],
        default="amp_bfloat16",
        help="Floating point precision.",
    )
    parser.add_argument(
        "--num-checkpoints",
        type=int,
        default=5,
        help="Number of times we save checkpoints during training.",
    )
    parser.add_argument("--seed", type=int, default=124, help="Random seed.")
    parser.add_argument(
        "--report-to-wandb",
        default=False,
        action="store_true",
        help="If True, report to wandb.",
    )
    parser.add_argument(
        "--val-frequency",
        type=int,
        default=5,
        help="How often to run evaluation with val-data (in epochs). Last epoch validated if val-data provided.",
    )
    parser.add_argument(
        "--manifest-prefix-override",
        type=str,
        required=False,
        default=None,
        help="Overide the manifest prefix for the target dataset.json",
    )
    parser.add_argument(
        "--remote-sync-override",
        type=str,
        required=False,
        default=None,
        help="Overide the manifest prefix for the target dataset.json",
    )
    parser.add_argument(
        "--chinchilla-multiplier",
        type=float,
        required=False,
        help="Support token multiplier.",
    )
    parser.add_argument(
        "--re-evaluate",
        required=False,
        type=str,
        default=None,
        help="pass a model json",
    )
    parser.add_argument(
        "--downstream-eval",
        required=False,
        action="store_true",
        help="Eval on llm-foundry as a loss evaluation.",
    )
    parser.add_argument(
        "--tokmult-eval",
        required=False,
        action="store_true",
        help="Eval on eval sets for tokmult paper.",
    )
    parser.add_argument(
        "--tokmult-eval-ge",
        required=False,
        action="store_true",
        help="Eval on eval sets for tokmult paper.",
    )

    # override hparams in the hparam config provided in --scale (good for grid searches)
    parser.add_argument(
        "--warmup",
        required=False,
        type=int,
        default=None,
        help="number of warmup steps",
    )
    parser.add_argument(
        "--lr",
        required=False,
        type=float,
        default=None,
        help="max lr for training",
    )
    parser.add_argument(
        "--wd",
        required=False,
        type=float,
        default=None,
        help="weight-decay for training",
    )
    parser.add_argument(
        "--cd",
        required=False,
        type=float,
        default=None,
        help="cool down ending lr",
    )
    parser.add_argument(
        "--global-bs",
        required=False,
        type=int,
        default=None,
        help="global batch size",
    )
    parser.add_argument(
        "--acc",
        required=False,
        type=int,
        default=None,
        help="gradient acc steps",
    )

    args = parser.parse_args()

    if args.re_evaluate is not None:
        # in the case of re-evaluation set based on the input model.json
        model_json = None
        with open(args.re_evaluate, "r") as f:
            model_json = json.load(f)

        args.scale = Path(model_json["hyperparameters"]["model"]).stem
        args.data_config = f"exp_data/datasets/tokenized/{model_json['dataset_name']}.json"

        for i in range(len(model_json["open_lm_args"])):
            if model_json["open_lm_args"][i] == "--logs":
                args.logs = model_json["open_lm_args"][i + 1]
            if model_json["open_lm_args"][i] == "--remote-sync":
                if args.remote_sync_override:
                    args.remote_sync = args.remote_sync_override
                else:
                    args.remote_sync = model_json["open_lm_args"][i + 1]

    assert args.scale is not None and args.data_config is not None and args.logs is not None
    assert not (args.tokmult_eval and args.downstream_eval)

    return args


def get_open_lm_args(args, hparams, dr):
    if args.manifest_prefix_override is not None:
        manifest_name = Path(dr.manifest_url).name
        dr.manifest_url = os.path.join(args.manifest_prefix_override, f"{manifest_name}")

    local_rank, _, _ = world_info_from_env()

    open_lm_args = [
        "--workers",
        f"{args.workers}",
        "--precision",
        args.precision,
        "--global-batch-size",
        f"{hparams.global_bs}",
        # "32",
        "--log-every-n-steps",
        "20",
        "--grad-clip-norm",
        "1",
        "--lr",
        f"{hparams.lr}",
        "--warmup",
        f"{hparams.warmup}",
        "--model",
        f"{hparams.model}",
        "--wd",
        f"{hparams.wd}",
        "--beta2",
        "0.95",
        "--epochs",
        f"{args.num_checkpoints}",
        "--resume",
        "latest",
        "--seed",
        f"{args.seed}",
        "--accum-freq",
        f"{hparams.acc}",
        "--model-norm",
        "gain_only_lp_layer_norm",
        "--delete-previous-checkpoint",
        "--lr-cooldown-end",
        f"{hparams.cd}",
        "--logs",
        f"{args.logs}",
    ]

    name = None
    if args.re_evaluate is None:
        # case where we are training
        name = hparams.get_friendly_name(dr)

        open_lm_args.extend(
            [
                "--train-num-samples",
                f"{hparams.tokens // args.num_checkpoints}",
                "--dataset-manifest",
                dr.manifest_url,
                "--data-key",
                dr.data_key,
                "--name",
                name,
            ]
        )
        # add fsdp flags that are different for "smaller" and "larger" configs

        open_lm_args.extend(hparams.fsdp_flags)
    else:
        # get name from the passed model.json
        model_json = None
        with open(args.re_evaluate, "r") as f:
            model_json = json.load(f)

        name = model_json["name"]
        open_lm_args.extend(
            [
                "--name",
                name,
            ]
        )

    openlm_val_data = download_val_data("open_lm_val", skip_download=local_rank != 0)
    c4_val_data = download_val_data("c4_val", skip_download=local_rank != 0)
    paloma_val_data = download_val_data("paloma_val", skip_download=local_rank != 0)

    if args.downstream_eval:
        tasks = load_ppl_yaml()
        downstream_datas = [download_val_data(task_name, skip_download=local_rank != 0) for task_name in tasks]

        open_lm_args.extend(
            [
                "--val-data",
                openlm_val_data,
                c4_val_data,
                # paloma_val_data,
                *downstream_datas,
                "--val-frequency",
                f"{args.val_frequency}",
                "--val-data-key",
                "json",
                "txt",
                # "json.gz",
                *(["txt"] * len(downstream_datas)),
                "--squash-mask-left",
                "--target-mask-individual",
                "50400",
                "--target-mask-left",
                "50300",
                "--val-tok-ci",
                "--val-seq-ci",
                "--val-max-pop-ci",
                "300000",
            ]
        )
    elif args.tokmult_eval:
        tasks = load_ppl_yaml()
        downstream_datas = [download_val_data(task_name, skip_download=local_rank != 0) for task_name in tasks]

        open_lm_args.extend(
            [
                "--val-data",
                *tok_mult_paths,
                *downstream_datas,
                "--val-frequency",
                f"{args.val_frequency}",
                "--val-data-key",
                "json",  # double open_lm, c4_val are the two entry in tok_mult_paths
                "txt",
                # "json.gz",
                *(["json.gz"] * (len(tok_mult_paths) - 2)),
                *(["txt"] * len(downstream_datas)),
                "--squash-mask-left",
                "--target-mask-individual",
                "50400",
                "--target-mask-left",
                "50300",
                "--val-tok-ci",
                "--val-seq-ci",
                "--val-max-pop-ci",
                "300000",
            ]
        )

    if args.report_to_wandb:
        open_lm_args.extend(["--report-to", "wandb", "--wandb-project-name", "overtraining"])

    if hparams.qk_norm:
        open_lm_args.append("--qk-norm")
    if hparams.grad_checkpointing:
        open_lm_args.append("--grad-checkpointing")
    if hparams.z_loss > 0:
        open_lm_args.extend(["--z-loss", f"{hparams.z_loss}"])
    if args.remote_sync:
        open_lm_args.extend(
            [
                "--remote-sync",
                f"{args.remote_sync}",
            ]
        )

    return open_lm_args, name
