import distributed


def none_or_str(value):
    if value == "None":
        return None
    return value


def parse_args(base_parser, args, namespace):
    parser = base_parser

    # General training params
    parser.add_argument("--run_prefix", default=None, type=str)
    parser.add_argument("--experiment_name", default=None, type=str)
    parser.add_argument(
        "--ignore_args_more",
        nargs="+",
        type=str,
        default=[],
        help="Additional argument names to ignore when generating experiment name (appended to default ignore_args)"
    )
    parser.add_argument("--seed", default=0, type=int)
    parser.add_argument("--data_seed", default=1337, type=int)
    parser.add_argument("--eval_interval", default=200, type=int)
    parser.add_argument("--full_eval_at", nargs="+", type=int)
    parser.add_argument("--eval_batches", default=64, type=int)
    parser.add_argument("--device", default="cuda:0", type=str)
    parser.add_argument(
        "--distributed_backend",
        default=None,
        type=str,
        required=False,
        choices=distributed.registered_backends(),
    )
    parser.add_argument("--log_interval", default=50, type=int)

    # Checkpointing
    parser.add_argument("--results_base_folder", default="./exps", type=str)
    parser.add_argument("--permanent_ckpt_interval", default=0, type=int)
    parser.add_argument("--latest_ckpt_interval", default=0, type=int)
    parser.add_argument("--resume_from", default=None, type=str)
    parser.add_argument("--resume_from_swa", default=None, type=str)

    parser.add_argument("--auto_resume", default=True)

    # logging params (WandB)
    parser.add_argument("--wandb", action="store_true")  # whether to use wandb or not
    parser.add_argument("--wandb_project", default="my-project", type=str)
    parser.add_argument(
        "--wandb_run_prefix", default="none", type=str
    )  # is added before the autogenerated experiment name
    parser.add_argument(
        "--eval_seq_prefix", default="none", type=str
    )  # prefix used to generate sequences
    parser.add_argument("--log_dynamics", action="store_true")
    parser.add_argument(
        "--dynamics_logger_cfg", default="./src/logger/rotational_logger.yaml", type=str
    )
    parser.add_argument("--wandb_entity", default=None, type=none_or_str)
    parser.add_argument("--log_parameter_norms", action="store_true")
    parser.add_argument("--norm_order", default=2)
    parser.add_argument(
        "--log_step_timing",
        action="store_true",
        help="Log detailed per-step timing breakdown (forward, backward, spectral, optimizer)"
    )

    # Schedule
    parser.add_argument(
        "--scheduler",
        default="cos",
        choices=["linear", "cos", "wsd", "none", "cos_inf"],
    )
    parser.add_argument(
        "--final_div_factor", default=1, type=float
    )  # cosine and linear schedulers
    parser.add_argument("--cos_inf_steps", default=0, type=int)
    # parser.add_argument("--cos-final-lr", default=1e-6, type=float)
    parser.add_argument("--iterations", default=15000, type=int)
    parser.add_argument("--warmup_steps", default=3000, type=int)
    parser.add_argument("--lr", default=1e-3, type=float)
    # wsd
    parser.add_argument("--wsd_final_lr_scale", default=0.0, type=float)
    parser.add_argument("--wsd_fract_decay", default=0.1, type=float)
    # parser.add_argument("--wsd-exponential-decay", action="store_true")
    parser.add_argument(
        "--decay_type",
        default="linear",
        choices=["linear", "cosine", "exp", "miror_cosine", "square", "sqrt"],
    )

    # Optimization
    parser.add_argument(
        "--opt",
        default="adamw",
        choices=[
            "adamw",
            "sgd",
            "sgdm",  # SGD with decoupled weight decay
            "muon",
            "soap",
            "ademamix",
            "lion",
            "sf-adamw",
            "sf-sgd",
            "signsgd",
            "signum",
            "prodigy",
            "sophiag",
            "adopt",
            "mars",
            "adafactor",
            "lamb",
            "scion",
            "scion-light",
            "d-muon",
            "muon-pytorch",  # works only with torch>=2.9
        ],
    )
    parser.add_argument("--batch_size", default=50, type=int)
    parser.add_argument("--acc_steps", default=1, type=int)
    parser.add_argument("--weight_decay", default=1e-1, type=float)
    parser.add_argument("--beta1", default=0.9, type=float)
    parser.add_argument("--beta2", default=0.95, type=float)
    parser.add_argument(
        "--grad_clip", default=1.0, type=float
    )  # default value is 1.0 in NanoGPT
    parser.add_argument(
        "--spectral_grad_clip",
        default="none",
        type=str,
        choices=["none", "clip"],
        help="Pre-clipping mode for raw gradients: 'none' or 'clip' (spectral clipping)",
    )
    parser.add_argument(
        "--spectral_grad_clip_c",
        default=1.0,
        type=float,
        help="Clipping threshold for spectral gradient clipping",
    )
    parser.add_argument("--momentum", default=0.9, type=float)
    parser.add_argument("--shampoo_beta", default=-1.0, type=float)
    parser.add_argument("--precondition_frequency", default=10, type=int)
    parser.add_argument("--max_precond_dim", default=10000, type=int)
    parser.add_argument("--merge_dims", default=False, type=bool)
    parser.add_argument("--precondition_1d", default=False, type=bool)
    parser.add_argument("--normalize_grads", default=False, type=bool)
    parser.add_argument("--soap_data_format", default="channels_first", type=str)
    parser.add_argument("--correct_bias", default=True, type=bool)
    parser.add_argument("--nesterov", default=False, type=bool)
    parser.add_argument("--muon_ns_steps", default=5, type=int)
    parser.add_argument("--muon_lr_factor", default=1.0, type=float)
    parser.add_argument("--adema_beta3", default=0.9, type=float)
    parser.add_argument("--adema_alpha", default=2.0, type=float)
    parser.add_argument("--adema_beta3_warmup", default=None, type=int)
    parser.add_argument("--adema_alpha_warmup", default=None, type=int)
    parser.add_argument("--schedulefree_r", default=0.0, type=float)
    parser.add_argument("--weight_lr_power", default=2.0, type=float)
    parser.add_argument("--dampening", default=0.0, type=float)
    parser.add_argument("--prodigy_beta3", default=None, type=float)
    parser.add_argument("--prodigy_decouple", default=True, type=bool)
    parser.add_argument("--prodigy_use_bias_correction", default=False, type=bool)
    parser.add_argument("--prodigy_safeguard_warmup", default=False, type=bool)
    parser.add_argument("--prodigy_fsdp_in_use", default=False, type=bool)
    parser.add_argument("--sophia_rho", default=0.04, type=float)
    parser.add_argument("--sophia_bs", default=480, type=int)
    parser.add_argument(
        "--clipping_type", default="no", choices=["no", "local", "elementwise"]
    )
    parser.add_argument("--clip_eta", default=1.0, type=float)
    parser.add_argument(
        "--mars_type",
        default="mars-adamw",
        choices=["mars-adamw", "mars-lion", "mars-shampoo"],
    )
    parser.add_argument("--mars_vr_gamma", default=0.025, type=float)
    parser.add_argument("--mars_is_approx", default=True, type=float)
    parser.add_argument("--mars_lr", default=3e-3, type=float)
    parser.add_argument("--mars_beta1", default=0.95, type=float)
    parser.add_argument("--mars_beta2", default=0.99, type=float)
    parser.add_argument("--adafactor_decay_rate", default=-0.8, type=float)
    parser.add_argument("--lamb_use_bias_correction", default=False, type=bool)
    parser.add_argument("--adopt_decouple", default=True, type=bool)
    parser.add_argument("--adopt_eps", default=1e-6, type=float)
    parser.add_argument("--scion_lmh_scale", default=10.0, type=float)
    parser.add_argument("--scion_emb_scale", default=1.0, type=float)
    parser.add_argument("--scion_tr_scale", default=3.0, type=float)
    parser.add_argument(
        "--weight_average", action="store_true"
    )  # uniform weight averaging (or SWA)
    parser.add_argument(
        "--wa_interval",
        default=5,
        type=int,
        help="How often to take the average (every k steps). Must divide wa-horizon.",
    )
    parser.add_argument(
        "--wa_horizon",
        default=500,
        type=int,
        help="How frequently we save uniform model averages. Should divide "
        + "latest-ckpt-interval, otherwise some points may not be saved "
        + "correctly.",
    )
    parser.add_argument(
        "--wa_dtype",
        default="float32",
        type=str,
        choices=["float32", "float64"],
    )
    parser.add_argument("--wa_use_temp_dir", action="store_true")
    parser.add_argument("--wa_sweep_horizon", action="store_true")
    parser.add_argument("--max_num_wa_sweeps", default=5, type=int)
    parser.add_argument(
        "--exponential_weight_average", action="store_true"
    )  # EMA of weights
    parser.add_argument(
        "--ewa_interval",
        default=10,
        type=int,
        help="How often to take the EWA average (every k steps).",
    )
    parser.add_argument(
        "--ewa_decay",
        default=0.95,
        type=float,
        help="EWA decay parameter (between 0.9 and 1).",
    )
    parser.add_argument(
        "--ewa_after_warmup",
        action="store_true",
        help="Start EWA after warmup steps.",
    )

    # Dataset params
    parser.add_argument("--datasets_dir", type=str, default="./src/data/datasets/")
    parser.add_argument(
        "--dataset",
        default="slimpajama",
        choices=[
            "wikitext",
            "shakespeare-char",
            "arxiv",
            "arxiv2000",
            "arxiv+wiki",
            "openwebtext2",
            "redpajama",
            "slimpajama",
            "slimpajama_chunk1",
            "redpajamav2",
            "fineweb",
            "finewebedu",
            "c4",
            "arc_easy",  # benchmark tasks below...
            "arc_challenge",
            "hellaswag",
            "logiqa",
            "piqa",
            "sciq",
            "humaneval",
            "gsm8k",
            "kodcode",
            "mathqa",
            "medqa",
        ],
    )
    parser.add_argument(
        "--tokenizer", default="gpt2", type=str, choices=["gpt2", "mistral"]
    )
    parser.add_argument("--vocab_size", default=50304, type=int)
    parser.add_argument(
        "--data_in_ram", action="store_true"
    )  # force the data to RAM, mostly useless except for openwebtext2
    parser.add_argument(
        "--shared_memory", action="store_true",
        help="Use shared memory for data loading (only rank 0 loads, others attach)"
    )

    # Model params
    parser.add_argument(
        "--model",
        default="llama",
        choices=[
            "base",
            "llama",
            "mup_gpt",
            "mup_llama",
        ],
    )
    parser.add_argument("--parallel_block", action="store_true")
    parser.add_argument(
        "--use_pretrained", default="none", type=str
    )  # 'none', 'gpt-2' or a path to the pretraind model
    parser.add_argument("--from_dense", action="store_true")
    parser.add_argument("--init_std", default=0.02, type=float)
    parser.add_argument("--dropout", default=0.0, type=float)
    parser.add_argument("--n_head", default=12, type=int)
    parser.add_argument("--n_layer", default=24, type=int)  # depths in att + ff blocks
    parser.add_argument("--sequence_length", default=512, type=int)
    parser.add_argument(
        "--n_embd", default=768, type=int  # embedding size / hidden size ...
    )
    parser.add_argument(
        "--multiple_of",  # make SwiGLU hidden layer size multiple of large power of 2
        default=256,
        type=int,
    )
    parser.add_argument("--n_kv_head", default=None, type=int)  # for Adam-mini
    parser.add_argument("--rmsnorm_eps", default=1e-5, type=float)
    parser.add_argument(
        "--dtype",
        default="bfloat16",
        type=str,
        choices=["float32", "float16", "bfloat16"],
    )
    parser.add_argument("--bias", default=False, type=bool)
    parser.add_argument("--compile", action="store_true")
    parser.add_argument(
        "--untied_embeds", action="store_true"
    )  # disables weight tying between lm_head.weight and wte.weight
    parser.add_argument(
        "--mlp_dim_exp_factor", default=1.0, type=float
    )  # moe arguments
    parser.add_argument("--moe", action="store_true")
    parser.add_argument(
        "--moe_routing",
        default="standard_gating",
        type=str,
        choices=["standard_gating", "expert_choice"],
    )
    parser.add_argument("--moe_num_experts", default=8, type=int)
    parser.add_argument(  # only used for expert choice routing
        "--capacity_factor", default=2.0, type=float
    )
    parser.add_argument(  # deepseek routing, experts that are always active
        "--moe_num_shared_experts", default=0, type=int
    )
    parser.add_argument(
        "--moe_router_loss",
        default="load_balancing_z_loss",
        type=str,
        choices=["entropy", "load_balancing_only", "load_balancing_z_loss"],
    )
    parser.add_argument("--moe_num_experts_per_tok", default=2, type=int)
    parser.add_argument("--moe_entropy_loss_factor", default=0.01, type=float)
    parser.add_argument("--moe_aux_loss_factor", default=0.1, type=float)
    parser.add_argument("--moe_z_loss_factor", default=0.01, type=float)
    parser.add_argument(
        "--moe_softmax_order",
        type=str,
        default="topk_softmax",
        choices=["softmax_topk", "topk_softmax"],
    )
    parser.add_argument("--plot_router_logits", action="store_true")
    parser.add_argument(
        "--scale_emb", default=10, type=int
    )  # mup arguments --- the base model width that mup has been configured on
    parser.add_argument("--scale_base_model", default=256, type=int)
    parser.add_argument("--scale_depth", default=1.4, type=float)

    # SVD recording for spectral analysis
    parser.add_argument(
        "--record_svd", action="store_true", help="Enable recording singular values of gradients and updates"
    )
    parser.add_argument(
        "--svd_record_steps",
        nargs="+",
        type=float,
        default=[0.0, 0.05, 0.5, 0.99],
        help="Fractions of total iterations at which to record SVD (e.g., 0.0 0.05 0.5 0.99)"
    )
    parser.add_argument(
        "--svd_layers",
        nargs="+",
        type=str,
        default=["embedding", "early", "middle", "late"],
        help="Which layer positions to record: embedding, early, middle, late"
    )
    parser.add_argument(
        "--svd_save_dir",
        type=str,
        default=None,
        help="Directory to save SVD recordings (defaults to exp_dir/svd_records)"
    )

    # Noise structure analysis (extends SVD recording)
    parser.add_argument(
        "--record_noise_structure",
        action="store_true",
        help="Enable noise structure analysis at SVD recording steps"
    )
    parser.add_argument(
        "--noise_num_samples",
        type=int,
        default=4096,
        help="Number of samples for estimating true gradient G"
    )
    parser.add_argument(
        "--noise_top_k",
        type=int,
        default=5,
        help="Number of top singular vectors to store and analyze"
    )
    parser.add_argument(
        "--noise_num_repeats",
        type=int,
        default=20,
        help="Number of noise samples to analyze per recording step"
    )
    parser.add_argument(
        "--noise_batch_size",
        type=int,
        default=1,
        help="Batch size for stochastic gradient samples (default: 1)"
    )
    parser.add_argument(
        "--noise_data_seed",
        type=int,
        default=9999,
        help="Seed for noise analysis DataReader (should differ from data_seed)"
    )

    # Update noise structure analysis (analyzes optimizer update noise)
    parser.add_argument(
        "--record_update_noise",
        action="store_true",
        help="Enable update noise structure analysis (analyzes optimizer update noise instead of gradient noise)"
    )

    # Spectral post-processing for optimizer updates
    parser.add_argument(
        "--spectral_post_process",
        type=str,
        default="none",
        choices=["none", "clip", "normalize"],
        help="Apply spectral post-processing to optimizer updates: none, clip, or normalize"
    )
    parser.add_argument(
        "--spectral_clip_c",
        type=float,
        default=10.0,
        help="Clipping threshold for spectral clipping (only used when spectral_post_process=clip)"
    )
    parser.add_argument(
        "--spectral_ns_steps",
        type=int,
        default=10,
        help="Number of Newton-Schulz iterations for spectral post-processing"
    )
    parser.add_argument(
        "--spectral_apply_to",
        type=str,
        default="all",
        choices=["2d", "all"],
        help="Which parameters to apply spectral post-processing to: 2d (matrices only) or all"
    )
    parser.add_argument(
        "--disable_dynamic_clip",
        action="store_true",
        help="Disable dynamic clipping threshold during warmup (use constant clip_c)",
    )
    parser.add_argument(
        "--clip_decay_type",
        type=str,
        default="constant",
        choices=["constant", "linear", "sqrt", "cosine", "exp", "square"],
        help="Decay type for clipping threshold during WSD decay phase: constant (default), linear, sqrt, cosine, exp, square"
    )
    parser.add_argument(
        "--clip_decay_fract",
        type=float,
        default=None,
        help="Fraction of iterations for clip decay phase. If None, uses wsd_fract_decay."
    )
    parser.add_argument(
        "--clip_final_scale",
        type=float,
        default=0.0,
        help="Final clipping threshold scale (c decays from clip_c to clip_c * clip_final_scale). Default: 0.0"
    )

    return parser.parse_args(args, namespace)
