import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

# from https://github.com/shariqiqbal2810/maddpg-pytorch/
def make_env(scenario_name, benchmark=False):
    from multiagent.environment import MultiAgentEnv
    import multiagent.scenarios as scenarios

    scenario = scenarios.load(scenario_name + ".py").Scenario()
    world = scenario.make_world()
    if benchmark:
        env = MultiAgentEnv(world, scenario.reset_world, scenario.reward,
                            scenario.observation, scenario.benchmark_data)
    else:
        env = MultiAgentEnv(world, scenario.reset_world, scenario.reward,
                            scenario.observation)
    return env


def soft_update(t_net, s_net, tau):
    for t_param, s_param in zip(t_net.parameters(), s_net.parameters()):
        t_param.data.copy_(t_param.data * (1.0 - tau) + s_param.data * tau)


def hard_update(t_net, s_net):
    soft_update(t_net, s_net, 1)


def onehot_from_logits(action):
    return (action == action.max(1, keepdim=True)[0]).float()


def gumbel_softmax(action, device, temperature=1.0, hard=False):
    U = Variable(torch.FloatTensor(*action.shape).uniform_(), requires_grad=False).to(device)
    eps = 1e-20
    y = action - torch.log(-torch.log(U + eps) + eps)
    y = F.softmax(y / temperature, dim=1)
    if hard:
        y_hard = onehot_from_logits(y)
        y = (y_hard - y).detach() + y
    return y


class MLPNetworks(nn.Module):
    def __init__(self, input_dim, out_dim, hidden_dim=64, nonlin=F.relu,
                 constrain_out=False, norm_in=True, discrete_action=True):
        super(MLPNetworks, self).__init__()

        if norm_in:
            self.in_fn = nn.BatchNorm1d(input_dim)
            self.in_fn.weight.data.fill_(1)
            self.in_fn.bias.data.fill_(0)
        else:
            self.in_fn = lambda x: x
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, out_dim)
        self.nonlin = nonlin
        if constrain_out and not discrete_action:
            self.fc3.weight.data.uniform_(-3e-3, 3e-3)
            self.out_fn = F.tanh
        else:
            self.out_fn = lambda x: x

    def forward(self, x):
        h1 = self.nonlin(self.fc1(self.in_fn(x)))
        h2 = self.nonlin(self.fc2(h1))
        out = self.out_fn(self.fc3(h2))
        return out
