from hydra_zen import store

from run.hydra_zen import make_config

# =============================================================================
# Hydra Experiment Configs
# =============================================================================

experiment_store = store(group="experiment", package="_global_")
exec_workflow_config = store.get_entry(group=None, name="exec_workflow")["node"]

# ------ Base Model Experiment Config ------

base_model_hydra_defaults = [
    {"override /workflow": "default"},
    {"override /cfg": "default"},
    {"override /run": "default"},
    {"override /run/save_model": "no"},
    {"override /cfg/eval": "none"},
    "_self_",
]

experiment_store(
    make_config(
        run=dict(identifier="${cfg.model.identifier}__base_model"),
        hydra_defaults=base_model_hydra_defaults,
        bases=(exec_workflow_config,),
    ),
    name="base_model",
)

base_model_int8_hydra_defaults = [
    {"override /workflow": "quantization"},
    {"override /cfg": "quantization"},
    {"override /run": "default"},
    {"override /run/save_model": "no"},
    {"override /cfg/eval": "none"},
    {"override /cfg/quantization/config": "bnb_int8"},
    "_self_",
]

experiment_store(
    make_config(
        run=dict(identifier="${cfg.model.identifier}__base_model_int8"),
        cfg=dict(super_weights=None),  # type: ignore
        hydra_defaults=base_model_int8_hydra_defaults,
        bases=(exec_workflow_config,),
    ),
    name="base_model_int8",
)

# ------ EntQuant Experiment Config ------

entquant_hydra_defaults = [
    {"override /workflow": "entquant"},
    {"override /cfg": "entquant"},
    {"override /run": "default"},
    {"override /run/save_model": "no"},
    {"override /cfg/eval": "none"},
    "_self_",
]

experiment_store(
    make_config(
        run=dict(
            identifier="${cfg.model.identifier}__entquant_int8__rp${cfg.entquant.config.optimizer.reg_param}_lr${cfg.entquant.config.optimizer.lr}__sw_${cfg.super_weights.spike_threshold}"
        ),
        # create base model to be compressed on CPU
        cfg=dict(model=dict(device_map="cpu"), entquant=dict(compress=dict(device_map={"": "cuda"}))),  # type: ignore
        hydra_defaults=entquant_hydra_defaults
        + [
            {"override /cfg/entquant/config": "int8"},
            {"override /cfg/entquant/config/optimizer": "symmetric_4bit"},
        ],
        bases=(exec_workflow_config,),
    ),
    name="entquant_int8",
)

experiment_store(
    make_config(
        run=dict(
            identifier="${cfg.model.identifier}__entquant_fp8_rp${cfg.entquant.config.optimizer.reg_param}_lr${cfg.entquant.config.optimizer.lr}__sw_${cfg.super_weights.spike_threshold}"
        ),
        # create base model to be compressed on CPU
        cfg=dict(model=dict(device_map="cpu"), entquant=dict(compress=dict(device_map={"": "cuda"}))),  # type: ignore
        hydra_defaults=entquant_hydra_defaults
        + [
            {"override /cfg/entquant/config": "fp8"},
            {"override /cfg/entquant/config/optimizer": "symmetric_4bit"},
        ],
        bases=(exec_workflow_config,),
    ),
    name="entquant_fp8",
)

experiment_store(
    make_config(
        run=dict(identifier="${cfg.model.identifier}__int8"),
        cfg=dict(model=dict(device_map="cpu")),  # type: ignore
        hydra_defaults=entquant_hydra_defaults
        + [
            {"override /cfg/entquant/config": "int8"},
            {"override /cfg/entquant/config/optimizer": "absmax"},
            {"override /cfg/entquant/compress": "dispatch"},
        ],
        bases=(exec_workflow_config,),
    ),
    name="int8",
)

experiment_store(
    make_config(
        run=dict(identifier="${cfg.model.identifier}__fp8"),
        cfg=dict(model=dict(device_map="cpu")),  # type: ignore
        hydra_defaults=entquant_hydra_defaults
        + [
            {"override /cfg/entquant/config": "fp8"},
            {"override /cfg/entquant/config/optimizer": "absmax"},
            {"override /cfg/entquant/compress": "dispatch"},
        ],
        bases=(exec_workflow_config,),
    ),
    name="fp8",
)

# ------ BnB NF4 Experiment Config ------

bnb_hydra_defaults = [
    {"override /workflow": "quantization"},
    {"override /cfg": "quantization"},
    {"override /run": "default"},
    {"override /run/save_model": "no"},
    {"override /cfg/eval": "none"},
    {"override /cfg/quantization/config": "bnb_nf4"},
    "_self_",
]

experiment_store(
    make_config(
        run=dict(identifier="${cfg.model.identifier}__bnb_nf4"),
        cfg=dict(super_weights=None, model=dict(device_map="cuda")),  # type: ignore
        hydra_defaults=bnb_hydra_defaults,
        bases=(exec_workflow_config,),
    ),
    name="bnb_nf4",
)

experiment_store(
    make_config(
        run=dict(identifier="${cfg.model.identifier}__bnb_nf4__sw"),
        cfg=dict(model=dict(device_map="cuda")),  # type: ignore
        hydra_defaults=bnb_hydra_defaults,
        bases=(exec_workflow_config,),
    ),
    name="bnb_nf4_sw",
)

# ------ HQQ Experiment Config ------

hqq_hydra_defaults = [
    {"override /workflow": "quantization"},
    {"override /cfg": "quantization"},
    {"override /run": "default"},
    {"override /run/save_model": "no"},
    {"override /cfg/eval": "none"},
    {"override /cfg/quantization/config": "hqq"},
    "_self_",
]

experiment_store(
    make_config(
        run=dict(
            identifier="${cfg.model.identifier}__hqq_${cfg.quantization.config.nbits}bit_g${cfg.quantization.config.group_size}"
        ),
        cfg=dict(super_weights=None, model=dict(device_map="cuda")),  # type: ignore
        hydra_defaults=hqq_hydra_defaults,
        bases=(exec_workflow_config,),
    ),
    name="hqq",
)

experiment_store(
    make_config(
        run=dict(
            identifier="${cfg.model.identifier}__hqq_${cfg.quantization.config.nbits}bit_g${cfg.quantization.config.group_size}__sw"
        ),
        cfg=dict(model=dict(device_map="cuda")),  # type: ignore
        hydra_defaults=hqq_hydra_defaults,
        bases=(exec_workflow_config,),
    ),
    name="hqq_sw",
)
