import torch.nn.functional as F
import torch.optim as optim
from k_level_policy_gradients.src.algorithms.value.dqn import DQN
from k_level_policy_gradients.src.policy.td_policy import EpsGreedy
from k_level_policy_gradients.src.utils.parameters import DelayedLinearParameter
from k_level_policy_gradients.src.networks.q_network import QNetwork
from k_level_policy_gradients.src.utils.replay_memory import ReplayMemoryObsMasks


def setup_dqn_agent(mdp_info, idx_agent, agent_params, **kwargs):
    """
    Instantiates a multiplayer DQN agent.
    """
    batch_size = agent_params["batch_size"]
    max_replay_size = agent_params["max_replay_size"]
    target_update_frequency = agent_params["target_update_frequency"]
    tau = agent_params["tau"]
    warmup_replay_size = agent_params["warmup_replay_size"]
    target_update_mode = agent_params["target_update_mode"]
    assert target_update_mode == "soft" or target_update_mode == "hard"
    n_features = agent_params["n_features"]
    lr = float(agent_params["lr"])
    epsilon_decay_start = agent_params["epsilon_decay_start"]
    epsilon_decay_end = agent_params["epsilon_decay_end"]
    epsilon_start_value = agent_params["epsilon_start_value"]
    epsilon_end_value = agent_params["epsilon_end_value"]
    epsilon = DelayedLinearParameter(
        initial_value=epsilon_start_value,
        threshold_value=epsilon_end_value,
        n_start=epsilon_decay_start,
        n_end=epsilon_decay_end,
    )
    scale_loss = agent_params["scale_loss"]
    grad_norm_clip = agent_params["grad_norm_clip"]
    obs_last_action = agent_params["obs_last_action"]
    share_agent_params = agent_params["share_agent_params"]
    use_mixer = agent_params.get("use_mixer", False)
    use_cuda = agent_params["use_cuda"]

    pi = EpsGreedy(epsilon=epsilon)

    input_dim = mdp_info.observation_space[idx_agent].shape[0]
    if obs_last_action:
        input_dim += mdp_info.action_space[idx_agent].n
    if share_agent_params:
        input_dim += mdp_info.n_agents
    input_shape = (input_dim,)

    approximator_params = dict(
        input_shape=input_shape,
        output_shape=(mdp_info.action_space[idx_agent].n,),
        network=QNetwork,
        optimizer={"class": optim.Adam, "params": {"lr": lr}},
        loss=F.smooth_l1_loss,
        n_features=n_features,
        use_cuda=agent_params["use_cuda"],
    )

    replay_memory = ReplayMemoryObsMasks(
        max_replay_size,
        input_shape[0],
        mdp_info.action_space[idx_agent].n,
        discrete_actions=True,
    )

    if share_agent_params and idx_agent != 0:
        primary_agent = kwargs["primary_agent"]
    else:
        primary_agent = None

    agent = DQN(
        mdp_info=mdp_info,
        idx_agent=idx_agent,
        policy=pi,
        batch_size=batch_size,
        replay_memory=replay_memory,
        target_update_frequency=target_update_frequency,
        tau=tau,
        warmup_replay_size=warmup_replay_size,
        target_update_mode=target_update_mode,
        approximator_params=approximator_params,
        scale_loss=scale_loss,
        grad_norm_clip=grad_norm_clip,
        obs_last_action=obs_last_action,
        primary_agent=primary_agent,
        use_mixer=use_mixer,
        use_cuda=use_cuda,
    )

    return agent
