import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal

LOG_SIG_MAX = 2
LOG_SIG_MIN = -20
epsilon = 1e-6

"""
the input x in both networks should be [o, g], where o is the observation and g is the goal.

"""

def weights_init_(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight, gain=1)
        torch.nn.init.constant_(m.bias, 0)


class QNetwork(nn.Module):
    def __init__(self, env_params):
        super(QNetwork, self).__init__()

        self.max_action = env_params['action_max']
        self.fc1 = nn.Linear(env_params['obs'] + env_params['goal'] + env_params['action'], 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 1)

        self.fc4 = nn.Linear(env_params['obs'] + env_params['goal'] + env_params['action'], 256)
        self.fc5 = nn.Linear(256, 256)
        self.fc6 = nn.Linear(256, 1)

    def forward(self, x, actions):
        data = torch.cat([x, actions / self.max_action], dim=1)
        x = F.relu(self.fc1(data))
        x = F.relu(self.fc2(x))
        q1 = self.fc3(x)

        x = F.relu(self.fc4(data))
        x = F.relu(self.fc5(x))
        q2 = self.fc6(x)

        return q1, q2


class GaussianPolicy(nn.Module):
    def __init__(self, env_params, device):
        super(GaussianPolicy, self).__init__()
        self.device = device
        self.linear1 = nn.Linear(env_params['obs'] + env_params['goal'], 256)
        self.linear2 = nn.Linear(256, 256)

        self.mean_linear = nn.Linear(256, env_params['action'])
        self.log_std_linear = nn.Linear(256, env_params['action'])

        self.apply(weights_init_)

        # action rescaling
        self.action_scale = torch.tensor((env_params['high'] - env_params['low']) / 2.,dtype=torch.float32).to(self.device)
        self.action_bias = torch.tensor((env_params['high'] + env_params['low']) / 2.,dtype=torch.float32).to(self.device)

    def forward(self, state, goal):
        x = F.relu(self.linear1(torch.cat((state, goal),dim=1)))
        x = F.relu(self.linear2(x))
        mean = self.mean_linear(x)
        log_std = self.log_std_linear(x)
        log_std = torch.clamp(log_std, min=LOG_SIG_MIN, max=LOG_SIG_MAX)
        return mean, log_std

    def sample(self, state, goal):
        mean, log_std = self.forward(state, goal)
        std = log_std.exp()
        normal = Normal(mean, std)
        x_t = normal.rsample()  # for reparameterization trick (mean + std * N(0,1))
        y_t = torch.tanh(x_t)
        action = y_t * self.action_scale + self.action_bias
        log_prob = normal.log_prob(x_t)
        # Enforcing Action Bound
        log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + epsilon)
        log_prob = log_prob.sum(1, keepdim=True)
        mean = torch.tanh(mean) * self.action_scale + self.action_bias
        return action, log_prob, mean
