import distributed
import json

def parse_args(base_parser, args, namespace):
    parser = base_parser
    # General training params
    parser.add_argument("--experiment_name", default=None, type=str)
    parser.add_argument("--seed", default=0, type=int)
    parser.add_argument("--data_seed", default=1337, type=int)
    parser.add_argument("--eval_interval", default=200, type=int)
    parser.add_argument("--full_eval_at", nargs="+", type=int)
    parser.add_argument("--eval_batches", default=32, type=int)
    parser.add_argument("--device", default="cuda:0", type=str)
    parser.add_argument(
        "--distributed_backend",
        default=None,
        type=str,
        required=False,
        choices=distributed.registered_backends(),
    )
    parser.add_argument("--log_interval", default=50, type=int)

    # Checkpointing
    parser.add_argument("--results_base_folder", default="./exps", type=str)
    parser.add_argument("--permanent_ckpt_interval", default=0, type=int)
    parser.add_argument("--latest_ckpt_interval", default=0, type=int)
    parser.add_argument("--resume_from", default=None, type=str)
    parser.add_argument("--resume_from_swa", default=None, type=str)

    parser.add_argument("--auto_resume", default=False)

    # logging params (WandB)
    parser.add_argument("--wandb", action="store_true")  # whether to use wandb or not
    parser.add_argument("--wandb_project", default="my-project", type=str)
    parser.add_argument("--wandb_group", required=True, type=str)
    parser.add_argument("--wandb_job_type", required=True, type=str)
    parser.add_argument("--wandb_run_prefix", default="none", type=str)  # is added before the autogenerated experiment name
    parser.add_argument("--eval_seq_prefix", default="none", type=str)  # prefix used to generate sequences
    parser.add_argument("--log_dynamics", action="store_true")
    parser.add_argument("--dynamics_logger_cfg", default="./src/logger/rotational_logger.yaml", type=str)

    # Schedule
    parser.add_argument("--scheduler", default="cos", choices=["linear", "cos", "wsd", "none", "cos_inf"])
    parser.add_argument("--cos_inf_steps", default=0, type=int)
    # parser.add_argument("--cos_final_lr", default=1e-6, type=float)
    parser.add_argument("--iterations", default=15000, type=int)
    parser.add_argument("--warmup_steps", default=300, type=int)
    parser.add_argument("--lr", default=1e-3, type=float)
    # wsd
    parser.add_argument("--wsd_final_lr_scale", default=0.0, type=float)
    parser.add_argument("--wsd_fract_decay", default=0.1, type=float)
    # parser.add_argument("--wsd_exponential_decay", action="store_true")
    parser.add_argument("--decay_type", default="linear", choices=["linear", "cosine", "exp", "miror_cosine", "square", "sqrt"])
    # Optimization
    parser.add_argument("--opt", default="adamw", choices=[
        "adamw", "sgd", "SFAdamW", "dct-adamw", "ldadamw", "galoreadamw",
        "frugal", "fira", "apollo", "trion",
    ])
    # parser.add_argument("--use_sparse_grad", default=0, type=int, help="whether to sparsify gradients")

    parser.add_argument("--batch_size", default=50, type=int)
    parser.add_argument("--acc_steps", default=4, type=int)
    parser.add_argument("--weight_decay", default=1e-1, type=float)
    parser.add_argument("--beta1", default=0.9, type=float)
    parser.add_argument("--beta2", default=0.95, type=float)
    parser.add_argument("--grad_clip", default=1.0, type=float)  # default value is 1.0 in NanoGPT

    # Weight Averaging
    parser.add_argument("--weight_average", action="store_true")
    parser.add_argument("--wa_interval",
        default=5,
        type=int,
        help="How often to take the average (every k steps). Must divide wa-horizon.",
    )
    parser.add_argument(
        "--wa_horizon",
        default=500,
        type=int,
        help="How frequently we save uniform model averages. Should divide "
        + "latest-ckpt-interval, otherwise some points may not be saved "
        + "correctly.",
    )
    parser.add_argument(
        "--wa_dtype",
        default="float32",
        type=str,
        choices=["float32", "float64"],
    )

    parser.add_argument("--wa_use_temp_dir", action="store_true")
    parser.add_argument("--wa_sweep_horizon", action="store_true")
    parser.add_argument("--max_num_wa_sweeps", default=5, type=int)

    parser.add_argument("--exponential_moving_average", action="store_true")
    parser.add_argument("--ema_interval", default=10, type=int, help="How often to take the EMA average (every k steps).")
    parser.add_argument("--ema_decay", default=0.95, type=float, help="EMA decay parameter (between 0.9 and 1).", )
    parser.add_argument("--ema_after_warmup", action="store_true", help="Start EMA after warmup steps.", )

    # Dataset params
    parser.add_argument("--datasets_dir", type=str, default="./datasets/")
    parser.add_argument(
        "--dataset",
        default="slimpajama",
        choices=[
            "wikitext",
            "shakespeare-char",
            "arxiv",
            "arxiv2000",
            "arxiv+wiki",
            "openwebtext2",
            "redpajama",
            "slimpajama",
            "slimpajama_chunk1",
            "redpajamav2",
            "c4",
        ],
    )
    parser.add_argument("--tokenizer", default="gpt2", type=str, choices=["gpt2", "mistral"])
    parser.add_argument("--vocab_size", default=50304, type=int)
    parser.add_argument("--data_in_ram", action="store_true")  # force the data to RAM, mostly useless except for openwebtext2

    # Model params
    parser.add_argument("--model", default="llama", choices=["base", "llama",])
    parser.add_argument("--parallel_block", action="store_true")
    parser.add_argument("--use_pretrained", default="none", type=str)  # 'none', 'gpt-2' or a path to the pretraind model
    parser.add_argument("--from_dense", action="store_true")
    parser.add_argument("--init_std", default=0.02, type=float)
    parser.add_argument("--dropout", default=0.0, type=float)
    parser.add_argument("--n_head", default=12, type=int)
    parser.add_argument("--n_layer", default=24, type=int)  # depths in att + ff blocks
    parser.add_argument("--sequence_length", default=512, type=int)
    parser.add_argument("--n_embd", default=768, type=int) # embedding size / hidden size ...
    parser.add_argument("--multiple_of", default=256, type=int) # make SwiGLU hidden layer size multiple of large power of 2
    parser.add_argument("--rmsnorm_eps", default=1e-5, type=float)
    parser.add_argument("--dtype", default="bfloat16", type=str, choices=["float32", "float16", "bfloat16"])
    parser.add_argument("--bias", default=False, type=bool)
    parser.add_argument("--compile", action="store_true")
    parser.add_argument("--mlp_dim_exp_factor", default=1.0, type=float)

    # LowRank
    parser.add_argument("--lowrank_rank", type=int, default=128)
    parser.add_argument("--lowrank_proj", type=str, default='dct', choices=['dct', 'hdm', "svd", "random", "randperm", "randn-qr"])
    parser.add_argument("--lowrank_use_ef", type=int, choices=[0, 1], default=0)
    parser.add_argument("--lowrank_q_ef", type=int, choices=[0, 4, 8], default=0)
    parser.add_argument("--lowrank_max_shape", type=int, default=32_000)
    parser.add_argument("--lowrank_distributed", type=int, choices=[0, 1], default=0)
    parser.add_argument("--lowrank_rotate_states", type=int, choices=[0, 1], default=0)
    parser.add_argument("--lowrank_upd_gap", type=int, default=200)

    # ### Muon
    parser.add_argument("--muon_ns_type", type=str, default='torch', choices=['torch', 'triton'])
    parser.add_argument("--scaling_type", type=str, default='kj', choices=['kj', 'none', 'kimi', 'dion'])
    parser.add_argument("--use_makhoul", type=int, default=0, choices=[0, 1])

    return parser.parse_args(args, namespace)
