import torch
import torch.nn.functional as F


class PolicyNetContinuous(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, action_dim, bound=1.0, use_orthogonal=True, gain=0.01):
        super(PolicyNetContinuous, self).__init__()

        init_method = torch.nn.init.orthogonal_ if use_orthogonal else torch.nn.init.xavier_uniform_

        self.fc1 = torch.nn.Linear(input_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, hidden_dim)
        self.fc_mean = torch.nn.Linear(hidden_dim, action_dim)
        self.log_std = torch.nn.Parameter(torch.zeros(action_dim))
        # self.log_std = torch.nn.Parameter(torch.ones(action_dim))
        self.bound = bound
        for layer in [self.fc1, self.fc_mean]:
            init_method(layer.weight, gain=gain)
            torch.nn.init.constant_(layer.bias, 0)

    def forward(self, x):
        if torch.isnan(x).any() or torch.isinf(x).any():
            raise ValueError("Input contains NaN or Inf values")

        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        action_mean = self.bound * torch.tanh(self.fc_mean(x))
        action_std = self.log_std.exp()

        return action_mean, action_std


class ValueNet(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(ValueNet, self).__init__()
        self.fc1 = torch.nn.Linear(input_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = torch.nn.Linear(hidden_dim, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)