from collections import OrderedDict
from copy import deepcopy
from math import sqrt

from run_parallel import run_parallel


family_idx = 62
family_name = "slim_pajama1B_double_grid"


base_config = {
    "optimizer_params": [{
        "name": "nesgd-adam_infty-lmo",
        "lr": [0.001],
        "weight_decay": 0.0,
        "lr_schedule": "warmup-constant-linear",
        "warm_up_fraction": 0.05,
        "cool_down_fraction": 0.5,
    }],
    "training_params": {
        "tokens_processed": 524288, # 2^19
        "val_tokens_processed": 8388608, #2^23
        "batch_size": 16,
        "num_epochs": 1,
        "context_length": 1024,
        "gradnorm": 0.0,
        "tensorcore_precision": "high",   #Can be highest, high, or medium
        "autocast": True,
        "mixed_precision": "bfloat16",
        "compile": True,
    },
    "logging_params": {
        "val_tokens_processed": 8388608,
        "log_step": 256,
        "val_step": 256,
        "save_ckpt_step": 512,
        "load_ckpt_step": 0,
        "keep_last": 2,
        "ckpt_dir": "",
    },
    "gpt_model": {
        "n_embd": 1280,
        "n_layer": 36,
        "n_head": 20,
        "vocab_size": 50257,
        "flash_attention": True,
    },
    "dataset": {
        "name": "slim_pajama1B"
    },
    "seed": 42,
}

# Generate configs for this experiment.
alg_settings = {
    "muon": {
        "optimizer": "nesgd-adam_infty-lmo",
        "truncate_loss": None,
    },
    "scion": {
        "optimizer": "nesgd-lmo",
        "truncate_loss": None,
    },
    "muon-momo-stale": {
        "optimizer": "nesgd-adam_infty-lmo-momo-stale",
        "truncate_loss": 2.8,
    },
    "muonmax-momo-stale": {
        "optimizer": "nesgd-hybrid_prod-adam_2-momo-stale",
        "truncate_loss": 2.8,
    },
    "muonmax-lmo-momo-stale": {
        "optimizer": "nesgd-hybrid_prod-adam_2-lmo-momo-stale",
        "truncate_loss": 2.8,
    },
}
muon_lrs = [0.0001, 0.001, 0.01, 0.1]
other_lrs = [0.00001, 0.0001, 0.001, 0.01]

rerun = [
    "muonmax-lmo-momo-stale_0.001_0.0001",
    "muonmax-lmo-momo-stale_0.001_0.01",
    "muonmax-lmo-momo-stale_0.01_1e-05",
    "muonmax-lmo-momo-stale_0.1_0.0001",
]

# Sweep lr scale.
experiment_configs = OrderedDict()
for alg, settings in alg_settings.items():
    for muon_lr in muon_lrs:
        for other_lr in other_lrs:

            s = muon_lr / other_lr
            spectral_scale = s if "lmo" in settings["optimizer"] else sqrt(s)
            opt_settings = {
                "name": settings["optimizer"],
                "lr": [other_lr],
                "spectral_scale": spectral_scale,
                "truncate_loss": settings["truncate_loss"],
            }
            current_config = deepcopy(base_config)
            current_config["optimizer_params"][0].update(dict(opt_settings))
            run_name = f"{family_idx}_{family_name}/{alg}_{muon_lr}_{other_lr}"

            # temp: avoid rerunning
            import os
            basename = os.path.basename(run_name)
            if basename not in rerun:
                continue

            experiment_configs[run_name] = deepcopy(current_config)

# Launch runs in parallel.
run_parallel(experiment_configs)
