import numpy as np
import torch as th
from torch import nn
import math

from torch.distributions import Normal

class StochasticActor(nn.Module):
    def __init__(self, dim, learning_rate = 1e-6):
        super().__init__()
        self.dim = dim
        input_dim = 2*dim
        self.seq = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
        )        
        self.optimizer = th.optim.Adam(self.parameters(), lr = learning_rate)
        self.mu = nn.Linear(64, dim)
        self.log_std = nn.Linear(64, dim)

    def forward(self, states_0, states_1, deterministic:bool = False,):
        inputs = th.cat((states_0, states_1),1)
        latent_pi = self.seq(inputs)
        mean = self.mu(latent_pi)
        if deterministic:
            return mean
        distribution = Normal(th.zeros(self.dim, device = states_0.device), th.ones(self.dim, device = states_0.device))
        log_std = self.log_std(latent_pi)
        batch_size = mean.shape[0]
        samples = mean +  log_std.exp() * distribution.sample((batch_size,))
        return samples
