import torch
import torch.jit as jit
import torch.nn.functional as F
from network.mixer import QMIX_Net
from network.net import Q_network_RNN


class DQN(jit.ScriptModule):

    def __init__(self, ob_dim, st_dim, ac_dim, n_agents, h_dim, activation="elu"):
        super().__init__()
        self.eval_Q_net = Q_network_RNN(ob_dim, ac_dim, h_dim)
        self.target_Q_net = Q_network_RNN(ob_dim, ac_dim, h_dim).eval()
        self.target_Q_net.load_state_dict(self.eval_Q_net.state_dict())
        self.eval_mix_net = QMIX_Net(n_agents, st_dim, h_dim, activation)
        self.target_mix_net = QMIX_Net(n_agents, st_dim, h_dim, activation).eval()
        self.target_mix_net.load_state_dict(self.eval_mix_net.state_dict())

    @jit.script_method
    def mode(self, obs, avails):
        with torch.no_grad():
            device = self.eval_Q_net.fc1.weight.device
            q_value = self.eval_Q_net.forward(obs.to(device).unsqueeze(-2)).squeeze(-2)
            q_value = q_value + torch.log(avails.to(device))
            actions = q_value.argmax(-1).cpu()
            return actions
    
    @jit.script_method
    def sample(self, obs, avails):
        return self.mode(obs, avails)
        

class SoftQ(DQN):

    def __init__(self, ob_dim, st_dim, ac_dim, n_agents, h_dim, activation="elu"):
        super().__init__(ob_dim, st_dim, ac_dim, n_agents, h_dim, activation)
        self.alpha = 1e-3   # or 1e-5 if overfitting
    
    @jit.script_method
    def mode(self, obs, avails):
        with torch.no_grad():
            device = self.eval_Q_net.fc1.weight.device
            obs = obs.to(device).unsqueeze(-2)
            avails = avails.to(device).log()
            q_value = self.eval_Q_net.forward(obs).squeeze(-2)
            q_value = q_value + avails
            actions = q_value.argmax(-1).cpu()
            return actions
    
    @jit.script_method
    def sample(self, obs, avails):
        with torch.no_grad():
            device = self.eval_Q_net.fc1.weight.device
            obs = obs.to(device).unsqueeze(-2)
            avails = avails.to(device).log()
            q_value = self.eval_Q_net.forward(obs).squeeze(-2)
            probs = F.softmax(q_value / self.alpha + avails, -1)
            actions = probs.multinomial(1).squeeze(-1).cpu()
            return actions