import torch
from hydra_zen import store, ZenField
from hydra_zen.typing import Builds
from omegaconf import MISSING

from entquant.compress.backend import nvCOMPBackend
from entquant.compress.manager import CompressionManager
from entquant.entquant.manager import QuantoConfig
from entquant.entquant.optimizer import SymmetricEntropyOptimizer, WrappedAbsmaxOptimizer
from run.hydra_zen import fbuilds, make_config, pbuilds

# -------- Entropy optimizer --------

absmax_optimizer = fbuilds(WrappedAbsmaxOptimizer)
entquant_optimizer_4bit = fbuilds(
    SymmetricEntropyOptimizer,
    lr=1.0,
    reg_param=4.0,
    device_compute="cuda",
)
entquant_optimizer_3bit = entquant_optimizer_4bit(reg_param=16.0)
entquant_optimizer_2bit = entquant_optimizer_4bit(reg_param=64.0, lr=0.25)

store(absmax_optimizer, group="cfg/entquant/config/optimizer", name="absmax")
store(entquant_optimizer_4bit, group="cfg/entquant/config/optimizer", name="symmetric_4bit")
store(entquant_optimizer_3bit, group="cfg/entquant/config/optimizer", name="symmetric_3bit")
store(entquant_optimizer_2bit, group="cfg/entquant/config/optimizer", name="symmetric_2bit")

# -------- Entquant config --------

quanto_config_int8 = fbuilds(
    QuantoConfig,
    include=[
        "*mlp*.up_proj*",
        "*mlp*.down_proj*",
        "*mlp*.gate_proj*",
        "*self_attn*.q_proj*",
        "*self_attn*.k_proj*",
        "*self_attn*.v_proj*",
        "*self_attn*.o_proj*",
    ],
    weight_qtype="qint8",
    optimizer=MISSING,
    optimizer_super_weights=fbuilds(WrappedAbsmaxOptimizer),
)
quanto_config_fp8 = quanto_config_int8(weight_qtype="qfloat8")

store(quanto_config_int8, group="cfg/entquant/config", name="int8")
store(quanto_config_fp8, group="cfg/entquant/config", name="fp8")

# -------- Compression --------

compression_manager = pbuilds(
    CompressionManager,
    target_layer_include=MISSING,
    weight_include=[
        "*mlp*.down_proj.weight",
        "*mlp*.up_proj.weight",
        "*mlp*.gate_proj.weight",
        "*self_attn*.q_proj.weight",
        "*self_attn*.k_proj.weight",
        "*self_attn*.v_proj.weight",
        "*self_attn*.o_proj.weight",
    ],
    device_map="auto",
    backend=fbuilds(nvCOMPBackend),
)

store(compression_manager, group="cfg/entquant/compress", name="default")
store(make_config(), group="cfg/entquant/compress", name="none")
store(
    make_config(dispatch_device_map=ZenField(str | dict[str, str | torch.device | int] | None, {"": "cuda"})),
    group="cfg/entquant/compress",
    name="dispatch",
)

# -------- Overall config --------

entquant = make_config(
    config=ZenField(Builds[QuantoConfig] | None),
    compress=ZenField(Builds[CompressionManager] | None),
    # crucial because otherwise type union hints in QuantoConfig make hydra fail
    # the actual problem is that omegaconf/hydra treat QuantoConfig as a container (before and after instantiation)
    hydra_convert="all",
)

store(entquant, group="cfg/entquant", name="default")
