from collections import OrderedDict
from copy import deepcopy
from math import sqrt

from run_parallel import run_parallel


family_idx = 47
family_name = "tune_spectral_lr_all"


base_config = {
    "optimizer_params": [{
        "name": "nesgd",
        "lr": [0.001],
        "spectral_scale": 1.0,
        "weight_decay": 0.0,
        "lr_schedule": "warmup-constant-linear",
        "warm_up_fraction": 0.05,
        "cool_down_fraction": 0.5,
    }],
    "training_params": {
        "tokens_processed": 524288, # 2^19
        "val_tokens_processed": 8388608, #2^23
        "batch_size": 64,
        "num_epochs": 1,
        "context_length": 1024,
        "gradnorm": 0.0,
        "tensorcore_precision": "high",   #Can be highest, high, or medium
        "autocast": True,
        "mixed_precision": "bfloat16",
        "compile": True,
    },
    "logging_params": {
        "val_tokens_processed": 8388608,
        "log_step": 256,
        "val_step": 256,
        "save_ckpt_step": 512,
        "load_ckpt_step": 0,
        "keep_last": 2,
        "ckpt_dir": "",
    },
    "gpt_model": {
        "n_embd": 768,
        "n_layer": 12,
        "n_head": 12,
        "vocab_size": 50257,
        "flash_attention": True,
    },
    "dataset": {
        "name": "fineweb1B"
    },
}

# Generate configs for this experiment.
algs = [
    "nesgd-adam_infty-lmo", # muon
    "nesgd-lmo", # scion
    "nesgd-hybrid_prod-adam_2", # muon-gd
    "nesgd-hybrid_prod", # scion-gd
    "nesgd-l2_prod-adam_2", # polargrad
    "nesgd",
    "nesgd-adam_2",
    "nesgd-adam_2-lmo",
    "nesgd-adam_infty",
    "nesgd-l2_prod",
    "nesgd-l2_prod-lmo",
    "nesgd-l2_prod-adam_2-lmo",
    "nesgd-l2_prod-adam_infty",
    "nesgd-l2_prod-adam_infty-lmo",
    "nesgd-hybrid_prod-lmo",
    "nesgd-hybrid_prod-adam_2-lmo",
    "nesgd-hybrid_prod-adam_infty",
    "nesgd-hybrid_prod-adam_infty-lmo",
]
downscale = {
    "nesgd": True,
    "nesgd-lmo": False, # scion
    "nesgd-adam_2": True,
    "nesgd-adam_2-lmo": False,
    "nesgd-adam_infty": True,
    "nesgd-adam_infty-lmo": False, # muon
    "nesgd-l2_prod": True,
    "nesgd-l2_prod-lmo": False,
    "nesgd-l2_prod-adam_2": False, # polargrad
    "nesgd-l2_prod-adam_2-lmo": False,
    "nesgd-l2_prod-adam_infty": True,
    "nesgd-l2_prod-adam_infty-lmo": False,
    "nesgd-hybrid_prod": True, # scion-gd
    "nesgd-hybrid_prod-lmo": False,
    "nesgd-hybrid_prod-adam_2": False, # muon-gd
    "nesgd-hybrid_prod-adam_2-lmo": False,
    "nesgd-hybrid_prod-adam_infty": True,
    "nesgd-hybrid_prod-adam_infty-lmo": False,
}
muon_lrs = {
    "nesgd": [1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1],
    "nesgd-lmo": [1e-3, 1e-2, 1e-1, 1.0], # scion
    "nesgd-adam_2": [1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1],
    "nesgd-adam_2-lmo": [1e-3, 1e-2, 1e-1, 1.0],
    "nesgd-adam_infty": [1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1],
    "nesgd-adam_infty-lmo": [1e-3, 1e-2, 1e-1, 1.0], # muon
    "nesgd-l2_prod": [1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1.0],
    "nesgd-l2_prod-lmo": [1e-3, 1e-2, 1e-1, 1.0],
    "nesgd-l2_prod-adam_2": [1e-3, 1e-2, 1e-1, 1.0], # polargrad
    "nesgd-l2_prod-adam_2-lmo": [1e-3, 1e-2, 1e-1, 1.0],
    "nesgd-l2_prod-adam_infty": [1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1.0],
    "nesgd-l2_prod-adam_infty-lmo": [1e-3, 1e-2, 1e-1, 1.0],
    "nesgd-hybrid_prod": [1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1], # scion-gd
    "nesgd-hybrid_prod-lmo": [1e-3, 1e-2, 1e-1, 1.0],
    "nesgd-hybrid_prod-adam_2": [1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1.0], # muon-gd
    "nesgd-hybrid_prod-adam_2-lmo": [1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1.0],
    "nesgd-hybrid_prod-adam_infty": [1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1],
    "nesgd-hybrid_prod-adam_infty-lmo": [1e-3, 1e-2, 1e-1, 1.0],
}
base_lr = 1e-3
downscale_factor = 1e-3

# Sweep spectral lr.
experiment_configs = OrderedDict()
for alg in algs:
    for muon_lr in muon_lrs[alg]:

        other_lr = float(base_lr)
        if downscale[alg]:
            other_lr *= downscale_factor

        s = muon_lr / other_lr
        spectral_scale = s if "lmo" in alg else sqrt(s)

        opt_settings = {
            "name": alg,
            "lr": [other_lr],
            "spectral_scale": spectral_scale,
        }
        current_config = deepcopy(base_config)
        current_config["optimizer_params"][0].update(dict(opt_settings))
        run_name = f"{family_idx}_{family_name}/{alg}_{other_lr}_{muon_lr}"
        experiment_configs[run_name] = deepcopy(current_config)

# Launch runs in parallel.
run_parallel(experiment_configs)
