from k_level_policy_gradients.src.algorithms.actor_critic.kmaddpg import (
    KMADDPG,
)
import torch.optim as optim
from k_level_policy_gradients.src.utils.replay_memory import ReplayMemoryObsMultiAgent
from k_level_policy_gradients.src.networks.ac_networks import MADDPGCriticNetwork


def setup_kmaddpg_agent(mdp_info, idx_agent, agent_params, **kwargs):
    """
    Instantiates a KMADDPG continuous mixing agent.
    """
    k_level = agent_params["k_level"]
    batch_size = agent_params["batch_size"]
    max_replay_size = agent_params["max_replay_size"]
    if agent_params["share_agent_params"]:
        obs_space = [
            mdp_info.observation_space[idx_agent].shape[0] + mdp_info.n_agents
            for idx_agent in range(mdp_info.n_agents)
        ]
    else:
        obs_space = [
            mdp_info.observation_space[idx_agent].shape[0]
            for idx_agent in range(mdp_info.n_agents)
        ]
    replay_memory = ReplayMemoryObsMultiAgent(
        max_size=max_replay_size,
        state_dim=mdp_info.state_space.shape[0],
        obs_dim=obs_space,
        action_dim=[
            mdp_info.action_space[idx_agent].shape[0]
            for idx_agent in range(mdp_info.n_agents)
        ],
        n_agents=mdp_info.n_agents,
        discrete_actions=False,
    )
    target_update_frequency = agent_params["target_update_frequency"]
    tau = agent_params["tau"]
    warmup_replay_size = agent_params["warmup_replay_size"]
    target_update_mode = agent_params["target_update_mode"]
    assert target_update_mode == "soft" or target_update_mode == "hard"
    n_features_critic = agent_params["n_features_critic"]
    lr_actors = float(agent_params["lr_actors"])
    match agent_params["actor_optimizer_class"]:
        case "adam":
            actor_optimizer_class = optim.Adam
        case "rmsprop":
            actor_optimizer_class = optim.RMSprop
    actor_optimizer_params = {
        "class": actor_optimizer_class,
        "params": {"lr": lr_actors},
    }

    # Centralised critic
    critic_input_shape = (
        mdp_info.state_space.shape[0]
        + sum([action_space.shape[0] for action_space in mdp_info.action_space]),
    )
    lr_critics = float(agent_params["lr_critics"])
    match agent_params["critic_optimizer_class"]:
        case "adam":
            critic_optimizer_class = optim.Adam
        case "rmsprop":
            critic_optimizer_class = optim.RMSprop
    critic_params = dict(
        input_shape=critic_input_shape,
        output_shape=(1,),
        network=MADDPGCriticNetwork,
        optimizer={"class": critic_optimizer_class, "params": {"lr": lr_critics}},
        n_features=n_features_critic,
        use_cuda=agent_params["use_cuda"],
    )
    scale_critic_loss = agent_params["scale_critic_loss"]
    scale_actor_loss = agent_params["scale_actor_loss"]
    grad_norm_clip = agent_params["grad_norm_clip"]
    obs_last_action = agent_params["obs_last_action"]
    use_cuda = agent_params["use_cuda"]

    maddpg = KMADDPG(
        k_level=k_level,
        mdp_info=mdp_info,
        idx_agent=idx_agent,
        batch_size=batch_size,
        replay_memory=replay_memory,
        target_update_frequency=target_update_frequency,
        tau=tau,
        warmup_replay_size=warmup_replay_size,
        target_update_mode=target_update_mode,
        actor_optimizer_params=actor_optimizer_params,
        critic_params=critic_params,
        scale_critic_loss=scale_critic_loss,
        scale_actor_loss=scale_actor_loss,
        grad_norm_clip=grad_norm_clip,
        obs_last_action=obs_last_action,
        host_agents=kwargs["host_agents"],
        use_cuda=use_cuda,
    )
    return maddpg
