from typing import cast

import flax.linen as nn

from .adapters import FFAdapter
from .protocol import QFixProtocol
from .qfix import AdditiveQFix, QFix
from .qfix_lin import AdditiveQFixLin, QFixLin
from .qmix import QMIX
from .qmix_overcooked import QMIX_Overcooked
from .vdn import VDN
from .cnn import CNN


def make_fixee(config, *, is_overcooked: bool = False) -> nn.Module:
    fixee = config["QFIX"]["FIXEE"]

    if fixee == "vdn":
        return VDN()

    if fixee == "qmix":
        QMIX_class = QMIX_Overcooked if is_overcooked else QMIX
        state_module = CNN() if is_overcooked else None
        return QMIX_Overcooked(
            config["MIXER_EMBEDDING_DIM"],
            config["MIXER_HYPERNET_HIDDEN_DIM"],
            config["MIXER_INIT_SCALE"],
            state_module = state_module,
        )

    raise ValueError(f"Invalid {fixee=}")


def make_fixer(
    config,
    num_agents: int,
    *,
    wrap_ff_adapter: bool = False,
    is_overcooked: bool = False,
) -> nn.Module:
    if wrap_ff_adapter:
        fixer = cast(
            QFixProtocol,
            make_fixer(
                config,
                num_agents,
                is_overcooked=is_overcooked,
            ),
        )
        return FFAdapter(fixer)

    config_qfix = config["QFIX"]

    fixer = config_qfix["FIXER"]

    detach_advantages = config_qfix.get("DETACH_ADVANTAGES", True)
    debug_recover_fixee = config_qfix.get("DEBUG_RECOVER_FIXEE", False)
    debug_recover_fixee_w = config_qfix.get("DEBUG_RECOVER_FIXEE_W", False)
    debug_recover_fixee_b = config_qfix.get("DEBUG_RECOVER_FIXEE_B", False)

    if fixer == "qfix-lin":
        return QFixLin(
            hidden_size=config["HIDDEN_SIZE"],
            num_agents=num_agents,
            debug_recover_fixee_w=debug_recover_fixee or debug_recover_fixee_w,
            debug_recover_fixee_b=debug_recover_fixee or debug_recover_fixee_b,
            is_overcooked=is_overcooked,
        )

    if fixer == "q+fix-lin":
        return AdditiveQFixLin(
            hidden_size=config["HIDDEN_SIZE"],
            num_agents=num_agents,
            detach_advantages=detach_advantages,
            debug_recover_fixee_w=debug_recover_fixee or debug_recover_fixee_w,
            debug_recover_fixee_b=debug_recover_fixee or debug_recover_fixee_b,
            is_overcooked=is_overcooked,
        )

    fixee = make_fixee(config, is_overcooked=is_overcooked)

    if fixer == "qfix":
        return QFix(
            hidden_size=config["HIDDEN_SIZE"],
            fixee=fixee,
            debug_recover_fixee_w=debug_recover_fixee or debug_recover_fixee_w,
            debug_recover_fixee_b=debug_recover_fixee or debug_recover_fixee_b,
            is_overcooked=is_overcooked,
        )

    if fixer == "q+fix":
        return AdditiveQFix(
            hidden_size=config["HIDDEN_SIZE"],
            fixee=fixee,
            detach_advantages=detach_advantages,
            debug_recover_fixee_w=debug_recover_fixee or debug_recover_fixee_w,
            debug_recover_fixee_b=debug_recover_fixee or debug_recover_fixee_b,
            is_overcooked=is_overcooked,
        )

    raise ValueError(f"Invalid fixer name {fixer}")
