import torch
import torch.nn as nn
import torch.nn.functional as F


class PolicyNetContinuous(nn.Module):
    def __init__(self, input_dim, hidden_dim, action_dim,
                 bound=1.0, use_orthogonal=True, gain=0.01):
        super().__init__()
        init = torch.nn.init.orthogonal_ if use_orthogonal else torch.nn.init.xavier_uniform_

        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc_mean = nn.Linear(hidden_dim, action_dim)
        self.log_std = nn.Parameter(torch.zeros(action_dim))
        self.bound = bound

        for layer in [self.fc1, self.fc_mean]:
            init(layer.weight, gain=gain)
            nn.init.constant_(layer.bias, 0)

    def forward(self, x):
        x = torch.tanh(self.fc1(x))
        x = torch.tanh(self.fc2(x))
        mean = self.bound * torch.tanh(self.fc_mean(x))
        std = self.log_std.exp()
        return mean, std


class LocalValueNet(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.v = nn.Linear(hidden_dim, 1)

    def forward(self, x):
        x = torch.tanh(self.fc1(x))
        x = torch.tanh(self.fc2(x))
        return self.v(x)


class VDNMixer(nn.Module):
    def forward(self, local_v, state=None):
        # local_v: [batch, n_agents, 1]
        return local_v.sum(dim=1)


class MLPMixer(nn.Module):
    def __init__(self, n_agents, hidden_dim):
        super().__init__()
        self.fc1 = nn.Linear(n_agents, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, 1)

    def forward(self, local_v, state=None):
        x = local_v.squeeze(-1)        # [batch, n_agents]
        x = torch.tanh(self.fc1(x))
        return self.fc2(x)             # [batch, 1]
