import torch.optim as optim
from k_level_policy_gradients.src.utils.parameters import Parameter
from k_level_policy_gradients.src.algorithms.actor_critic.gru_discrete_ddpg import (
    GRUDiscreteDDPG,
)
from k_level_policy_gradients.src.policy.gumbel import GumbelSoftmaxPolicy
from k_level_policy_gradients.src.utils.parameters import DelayedLinearParameter
from k_level_policy_gradients.src.networks.ac_networks import (
    GRUDiscreteActorNetwork,
    DiscreteCriticNetwork,
)
from k_level_policy_gradients.src.utils.replay_memory import EpisodicReplayMemory


def setup_gru_discrete_ddpg_agent(mdp_info, idx_agent, agent_params, **kwargs):
    """
    Instantiates a multiplayer DDPG Discrete agent.
    """
    target_update_frequency = agent_params["target_update_frequency"]
    batch_size = agent_params["batch_size"]
    warmup_replay_size = agent_params["warmup_replay_size"]
    max_replay_size = agent_params["max_replay_size"]
    n_features_actor = agent_params["n_features_actor"]
    lr_actor = float(agent_params["lr_actor"])
    match agent_params["actor_optimizer_class"]:
        case "adam":
            actor_optimizer_class = optim.Adam
        case "rmsprop":
            actor_optimizer_class = optim.RMSprop
    n_features_critic = agent_params["n_features_critic"]
    lr_critic = float(agent_params["lr_critic"])
    match agent_params["critic_optimizer_class"]:
        case "adam":
            critic_optimizer_class = optim.Adam
        case "rmsprop":
            critic_optimizer_class = optim.RMSprop
    centralized_critic = agent_params["centralized_critic"]
    epsilon_decay_start = agent_params["epsilon_decay_start"]
    epsilon_decay_end = agent_params["epsilon_decay_end"]
    epsilon_start_value = agent_params["epsilon_start_value"]
    epsilon_end_value = agent_params["epsilon_end_value"]
    share_agent_params = agent_params["share_agent_params"]
    use_mixer = agent_params.get("use_mixer", False)
    obs_last_action = agent_params["obs_last_action"]
    critic_obs_last_action = agent_params["critic_obs_last_action"]
    critic_agent_encoding = agent_params["critic_agent_encoding"]

    epsilon = DelayedLinearParameter(
        initial_value=epsilon_start_value,
        threshold_value=epsilon_end_value,
        n_start=epsilon_decay_start,
        n_end=epsilon_decay_end,
    )
    pi = GumbelSoftmaxPolicy(epsilon=epsilon, tau=1.0)

    actor_input_dim = mdp_info.observation_space[idx_agent].shape[0]
    critic_input_dim = mdp_info.observation_space[idx_agent].shape[0]

    if share_agent_params:
        actor_input_dim += mdp_info.n_agents
        if critic_agent_encoding:
            critic_input_dim += mdp_info.n_agents
    if obs_last_action:
        actor_input_dim += mdp_info.action_space[idx_agent].n
        if critic_obs_last_action:
            critic_input_dim += mdp_info.action_space[idx_agent].n
    if centralized_critic:
        for i in range(mdp_info.n_agents):
            critic_input_dim += mdp_info.action_space[i].n
    else:
        critic_input_dim += mdp_info.action_space[idx_agent].n

    actor_input_shape = (actor_input_dim,)
    critic_input_shape = (critic_input_dim,)

    actor_params = dict(
        input_shape=actor_input_shape,
        output_shape=(mdp_info.action_space[idx_agent].n,),
        network=GRUDiscreteActorNetwork,
        optimizer={"class": actor_optimizer_class, "params": {"lr": lr_actor}},
        n_features=n_features_actor,
        use_cuda=agent_params["use_cuda"],
    )

    critic_params = dict(
        input_shape=critic_input_shape,
        output_shape=(1,),
        network=DiscreteCriticNetwork,
        optimizer={"class": critic_optimizer_class, "params": {"lr": lr_critic}},
        n_features=n_features_critic,
        use_cuda=agent_params["use_cuda"],
    )

    replay_memory = EpisodicReplayMemory(
        max_replay_size,
    )

    if share_agent_params and idx_agent != 0:
        primary_agent = kwargs["primary_agent"]
    else:
        primary_agent = None

    agent = GRUDiscreteDDPG(
        mdp_info=mdp_info,
        idx_agent=idx_agent,
        policy=pi,
        actor_params=actor_params,
        critic_params=critic_params,
        batch_size=batch_size,
        target_update_frequency=target_update_frequency,
        warmup_replay_size=warmup_replay_size,
        replay_memory=replay_memory,
        epsilon_train=epsilon,
        use_cuda=agent_params["use_cuda"],
        primary_agent=primary_agent,
        use_mixer=use_mixer,
        obs_last_action=obs_last_action,
    )

    return agent
