from k_level_policy_gradients.src.algorithms.value.q_mix import QMIX
import torch.optim as optim
from k_level_policy_gradients.src.utils.replay_memory import EpisodicReplayMemory


def setup_qmix_agent(mdp_info, idx_agent, agent_params, **kwargs):
    """
    Instantiates a QMIX mixing agent.
    """
    batch_size = agent_params["batch_size"]
    max_replay_size = agent_params["max_replay_size"]
    replay_memory = EpisodicReplayMemory(max_replay_size)
    target_update_frequency = agent_params["target_update_frequency"]
    tau = agent_params["tau"]
    warmup_replay_size = agent_params["warmup_replay_size"]
    target_update_mode = agent_params["target_update_mode"]
    assert target_update_mode == "soft" or target_update_mode == "hard"
    mixing_embed_dim = agent_params["mixing_embed_dim"]  # 32
    lr = float(agent_params["lr"])  # 5e-4
    match agent_params["optimizer_class"]:
        case "adam":
            optimizer_class = optim.Adam
        case "rmsprop":
            optimizer_class = optim.RMSprop
    optimizer_params = {"class": optimizer_class, "params": {"lr": lr}}
    scale_loss = agent_params["scale_loss"]
    grad_norm_clip = agent_params["grad_norm_clip"]
    obs_last_action = agent_params["obs_last_action"]
    use_cuda = agent_params["use_cuda"]

    qmix = QMIX(
        mdp_info=mdp_info,
        idx_agent=idx_agent,
        batch_size=batch_size,
        replay_memory=replay_memory,
        target_update_frequency=target_update_frequency,
        tau=tau,
        warmup_replay_size=warmup_replay_size,
        target_update_mode=target_update_mode,
        mixing_embed_dim=mixing_embed_dim,
        optimizer_params=optimizer_params,
        scale_loss=scale_loss,
        grad_norm_clip=grad_norm_clip,
        obs_last_action=obs_last_action,
        host_agents=kwargs["host_agents"],
        use_cuda=use_cuda,
    )
    return qmix
