import torch
import torch.nn as nn
class NormalActor(nn.Module):
    def __init__(self, state_dim, action_dim, device="cpu"):
        super(Actor,self).__init__()
        self.linear1 = nn.Linear(state_dim,64)
        self.linear2 = nn.Linear(64,64)
        self.mu = nn.Linear(64,action_dim)
        self.feature_sigma = nn.Linear(64,64)
        self.sigma_param = nn.Parameter(torch.zeros(action_dim, 1))
        # self.z_sigma_param = nn.Parameter(torch.zeros(64, 1))
        self.device = device
        self.N = torch.distributions.Normal(0,1)
        self.N.loc = self.N.loc.to(device)
        self.N.scale = self.N.scale.to(device)

    def forward(self, x):
        x = torch.as_tensor(x,device=self.device,dtype=torch.float32)
        x = torch.tanh(self.linear1(x))
        feature_log_sigma = self.feature_sigma(x)
        x = torch.tanh(self.linear2(x))
        z = x
        # z_sigma = (self.z_sigma_param.view(64)+torch.zeros_like(z)).exp()
        z_sigma = feature_log_sigma.exp()
        feature = z+z_sigma*self.N.sample(z.size())
        logits = self.mu(feature)
        mu = logits
        shape = [1] * len(mu.shape)
        shape[1] = -1
        sigma = (self.sigma_param.view(shape) + torch.zeros_like(mu)).exp()
        return (mu, sigma), feature

class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, device="cpu"):
        super(Actor,self).__init__()
        self.linear1 = nn.Linear(state_dim,64)
        self.linear2 = nn.Linear(64,64)
        self.mu = nn.Linear(64,action_dim)
        self.sigma_param = nn.Parameter(torch.zeros(action_dim, 1))
        self.device = device
        
    def forward(self, x):
        x = torch.as_tensor(x,device=self.device,dtype=torch.float32)
        x = torch.tanh(self.linear1(x))
        x = torch.tanh(self.linear2(x))
        feature = x
        logits = self.mu(x)
        mu = logits
        shape = [1] * len(mu.shape)
        shape[1] = -1
        sigma = (self.sigma_param.view(shape) + torch.zeros_like(mu)).exp()
        return (mu, sigma), feature

class Critic(nn.Module):
    def __init__(self, state_dim, device="cpu"):
        super(Critic,self).__init__()
        self.linear1 = nn.Linear(state_dim,64)
        self.linear2 = nn.Linear(64,64)
        self.linear3 = nn.Linear(64,1)
        self.device = device
        
    def forward(self, x):
        x = torch.as_tensor(x,device=self.device,dtype=torch.float32)
        x = torch.tanh(self.linear1(x))
        x = torch.tanh(self.linear2(x))
        logits = self.linear3(x)
        return logits 

class CriticMmd(nn.Module):
    def __init__(self, state_dim, device="cpu"):
        super(CriticMmd,self).__init__()
        self.linear1 = nn.Linear(state_dim,64)
        self.linear2 = nn.Linear(64,64)
        self.linear3 = nn.Linear(64,1)
        self.device = device
        
    def forward(self, x):
        x = torch.as_tensor(x,device=self.device,dtype=torch.float32)
        x = torch.tanh(self.linear1(x))
        x = torch.tanh(self.linear2(x))
        feature = x
        logits = self.linear3(x)
        return logits, feature    

class ActorCritic(nn.Module):
    def __init__(self, actor, critic):
        super().__init__()
        self.actor = actor
        self.critic = critic  