


import torch
import torch.nn as nn
import torch.nn.functional as F
import math


import torch
import torch.nn as nn


class sacactor(nn.Module):
    def __init__(
        self,
        env_params,
        network_width=256,
        network_depth=4,
        skip_connections=4,
        use_relu=True,
        use_ln=True,
        log_std_min=-5,
        log_std_max=2
    ):
        super(sacactor, self).__init__()
        self.max_action = env_params['action_max']
        self.action_dim = env_params['action']
        self.input_dim = env_params['obs'] + env_params['goal']
        self.network_width = network_width
        self.network_depth = network_depth
        self.skip_connections = skip_connections
        self.use_relu = use_relu
        self.use_ln = use_ln
        self.log_std_min = log_std_min
        self.log_std_max = log_std_max

        # Define layers
        self.fc1 = nn.Linear(self.input_dim, network_width)
        self.fc2 = nn.Linear(network_width, network_width)
        self.fc3 = nn.Linear(network_width, network_width)
        self.fc4 = nn.Linear(network_width, network_width)

        if self.use_ln:
            self.ln1 = nn.LayerNorm(network_width)
            self.ln2 = nn.LayerNorm(network_width)
            self.ln3 = nn.LayerNorm(network_width)
            self.ln4 = nn.LayerNorm(network_width)
        else:
            self.ln1 = nn.Identity()
            self.ln2 = nn.Identity()
            self.ln3 = nn.Identity()
            self.ln4 = nn.Identity()

        self.mean_out = nn.Linear(network_width, self.action_dim)
        self.log_std_out = nn.Linear(network_width, self.action_dim)

        self._initialize_weights()

    def _initialize_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                fan_in = module.weight.size(1)
                std = (1.0 / (3 * fan_in)) ** 0.5
                nn.init.uniform_(module.weight, -std, std)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)

    def forward(self, x):
        h = self.fc1(x)
        h = self.ln1(h)
        h = F.silu(h) if not self.use_relu else F.relu(h)
        skip = h if self.skip_connections else None

        h = self.fc2(h)
        h = self.ln2(h)
        h = F.silu(h) if not self.use_relu else F.relu(h)
        if self.skip_connections and skip is not None:
            h = h + skip
            skip = h

        h = self.fc3(h)
        h = self.ln3(h)
        h = F.silu(h) if not self.use_relu else F.relu(h)
        if self.skip_connections and skip is not None:
            h = h + skip
            skip = h

        h = self.fc4(h)
        h = self.ln4(h)
        h = F.silu(h) if not self.use_relu else F.relu(h)
        if self.skip_connections and skip is not None:
            h = h + skip

        mean = self.max_action * torch.tanh(self.mean_out(h))
        log_std = self.log_std_out(h)
        log_std = torch.tanh(log_std)
        log_std = self.log_std_min + 0.5 * (self.log_std_max - self.log_std_min) * (log_std + 1)

        return mean, log_std





class SAEncoder(nn.Module):
    def __init__(self, env_params, repr_dim=64):
        super(SAEncoder, self).__init__()
        # 输入维度：state + goal + action
        input_dim = env_params['obs'] + env_params['action']
        network_width = 256
        network_depth = 2

        # LeCun 均匀初始化
        def lecun_uniform_init(m):
            if isinstance(m, nn.Linear):
                fan_in = m.in_features
                limit = math.sqrt(3.0 / fan_in)
                nn.init.uniform_(m.weight, -limit, limit)
                nn.init.zeros_(m.bias)

        # 定义网络层
        self.layers = nn.ModuleList()
        self.layers.append(nn.Linear(input_dim, network_width))
        for _ in range(network_depth - 1):
            self.layers.append(nn.Linear(network_width, network_width))
        self.output_layer = nn.Linear(network_width, repr_dim)

        # 初始化权重和偏置
        for layer in self.layers:
            lecun_uniform_init(layer)
        lecun_uniform_init(self.output_layer)

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
            x = F.silu(x)  # 使用 swish 激活函数 (PyTorch 中为 silu)
        x = self.output_layer(x)
        return x

class GEncoder(nn.Module):
    def __init__(self, env_params, repr_dim=64):
        super(GEncoder, self).__init__()
        # 输入维度：goal
        input_dim = env_params['goal']
        network_width = 256
        network_depth = 2

        # LeCun 均匀初始化
        def lecun_uniform_init(m):
            if isinstance(m, nn.Linear):
                fan_in = m.in_features
                limit = math.sqrt(3.0 / fan_in)
                nn.init.uniform_(m.weight, -limit, limit)
                nn.init.zeros_(m.bias)

        # 定义网络层
        self.layers = nn.ModuleList()
        self.layers.append(nn.Linear(input_dim, network_width))
        for _ in range(network_depth - 1):
            self.layers.append(nn.Linear(network_width, network_width))
        self.output_layer = nn.Linear(network_width, repr_dim)

        # 初始化权重和偏置
        for layer in self.layers:
            lecun_uniform_init(layer)
        lecun_uniform_init(self.output_layer)

    def forward(self, goals):
        x = goals
        for layer in self.layers:
            x = layer(x)
            x = F.silu(x)  # 使用 swish 激活函数
        x = self.output_layer(x)
        return x

# define the actor network
class expert(nn.Module):
    def __init__(self, env_params):
        super(expert, self).__init__()
        self.max_action = env_params['action_max']
        self.fc1 = nn.Linear(env_params['obs'] + env_params['goal'], 512)
        self.fc2 = nn.Linear(512, 512)
        self.action_out = nn.Linear(512, env_params['action'])

    def forward(self, x,y):
        x = torch.cat((x,y),1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        actions = self.max_action * torch.tanh(self.action_out(x))

        return actions

class Policyactor(nn.Module):
    def __init__(self, env_params):
        super(Policyactor, self).__init__()
        self.max_action = env_params['action_max']
        self.fc1 = nn.Linear(env_params['obs'], 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 256)
        self.action_out = nn.Linear(256, env_params['action'])

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        actions = self.max_action * torch.tanh(self.action_out(x))

        return actions

# define the actor network
class actor(nn.Module):
    def __init__(self, env_params):
        super(actor, self).__init__()
        self.max_action = env_params['action_max']
        self.fc1 = nn.Linear(env_params['obs'] + env_params['goal'], 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 256)
        self.action_out = nn.Linear(256, env_params['action'])

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        actions = self.max_action * torch.tanh(self.action_out(x))

        return actions
    
    
class actor_nogoal(nn.Module):
    def __init__(self, env_params):
        super(actor_nogoal, self).__init__()
        self.max_action = env_params['action_max']
        self.fc1 = nn.Linear(env_params['obs'], 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 256)
        self.action_out = nn.Linear(256, env_params['action'])

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        actions = self.max_action * torch.tanh(self.action_out(x))

        return actions

class expactor(nn.Module):
    def __init__(self, env_params):
        super(expactor, self).__init__()
        self.max_action = env_params['action_max']
        self.fc1 = nn.Linear(env_params['obs'] + env_params['goal'], 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 256)
        self.action_out = nn.Linear(256, env_params['action'])

    def forward(self, x, y):
        x =torch.cat((x,y),dim=1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        actions = self.action_out(x)
        return actions

class critic(nn.Module):
    def __init__(self, env_params):
        super(critic, self).__init__()
        self.max_action = env_params['action_max']
        self.fc1 = nn.Linear(env_params['obs'] + env_params['goal'] + env_params['action'], 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 256)
        self.q_out = nn.Linear(256, 1)

    def forward(self, x, actions):
        x = torch.cat([x, actions / self.max_action], dim=-1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        q_value = self.q_out(x)

        return q_value


class critic_nogoal(nn.Module):
    def __init__(self, env_params):
        super(critic_nogoal, self).__init__()
        self.max_action = env_params['action_max']
        self.fc1 = nn.Linear(env_params['obs']+ env_params['action'], 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 256)
        self.q_out = nn.Linear(256, 1)

    def forward(self, x, actions):
        x = torch.cat([x, actions / self.max_action], dim=1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        q_value = self.q_out(x)

        return q_value


class critic2(nn.Module):
    def __init__(self, env_params):
        super(critic2, self).__init__()
        self.max_action = env_params['action_max']
        self.fc1 = nn.Linear(env_params['obs'] + env_params['goal'] + env_params['action'], 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 256)
        self.q_out = nn.Linear(256, 1)

        self.fc11 = nn.Linear(env_params['obs'] + env_params['goal'] + env_params['action'], 256)
        self.fc21 = nn.Linear(256, 256)
        self.fc31 = nn.Linear(256, 256)
        self.q_out1 = nn.Linear(256, 1)

    def forward(self, x, actions):
        xin = torch.cat([x, actions / self.max_action], dim=1)
        x = F.relu(self.fc1(xin))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        q_value1 = self.q_out(x)

        x = F.relu(self.fc11(xin))
        x = F.relu(self.fc21(x))
        x = F.relu(self.fc31(x))
        q_value2 = self.q_out(x)

        return q_value1, q_value2

    def Q1(self, x, actions):
        xin = torch.cat([x, actions / self.max_action], dim=1)
        x = F.relu(self.fc1(xin))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        q_value1 = self.q_out(x)

        return q_value1



class Subactor(nn.Module):
    def __init__(self, env_params):
        super(Subactor, self).__init__()
        self.max_action = env_params['action_max']
        self.fc1 = nn.Linear(env_params['obs'] + 2*env_params['goal'], 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 256)
        self.action_out = nn.Linear(256, env_params['action'])

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        actions = self.max_action * torch.tanh(self.action_out(x))

        return actions


# Vanilla Variational Auto-Encoder
class VAERNN(nn.Module):
    def __init__(self, env_params, id):
        super(VAERNN, self).__init__()
        self.env_params = env_params
        self.latent_dim = 32
        self.batchsize = 256
        if id >= 0:
            self.device = torch.device('cuda:%d' % id)
        else:
            self.device = torch.device('cpu')
        dim = 256

        self.lstm = torch.nn.LSTM(input_size=env_params['obs'] + 2 * env_params['goal'],
                                  hidden_size=dim, num_layers=2)
        self.lstm.to(self.device)
        self.h0 = torch.randn(2, self.batchsize, dim).to(self.device)  # num_layers * num_directions, batch, hidden_size
        self.c0 = torch.randn(2, self.batchsize, dim).to(self.device)

        self.mean = nn.Linear(dim, self.latent_dim)
        self.log_std = nn.Linear(dim, self.latent_dim)

        self.lstm2 = torch.nn.LSTM(
            input_size=env_params['goal'] + env_params['obs'] + self.latent_dim,
            hidden_size=dim, num_layers=2)
        self.lstm2.to(self.device)
        self.h1 = torch.randn(2, self.batchsize, dim).to(self.device)  # num_layers * num_directions, batch, hidden_size
        self.c1 = torch.randn(2, self.batchsize, dim).to(self.device)

        self.d = nn.Linear(dim, env_params['goal'])

    def forward(self, state, desgoal, hergoal):
        # context seq_len, batch, input_size
        lenght = state.shape[1]
        context = torch.cat((state, desgoal, hergoal), 2)
        # print(lenght, self.batchsize,2 * self.env_params['goal'] + self.env_params['obs'])
        context = torch.reshape(context, (lenght, self.batchsize,
                                          2 * self.env_params['goal'] + self.env_params['obs']))
        output, (self.hn, self.cn) = self.lstm(context, (self.h0, self.c0))

        mean = self.mean(output.transpose(1,0))
        log_std = self.log_std(output.transpose(1,0))
        std = torch.exp(log_std)

        z = mean + std * torch.randn_like(std).to(self.device)

        u = self.decode(state, desgoal, z)
        return u, mean, std

    def decode(self, state, desgoal, z=None):
        lenght = state.shape[1]
        if z is None:
            dis = torch.distributions.Normal(torch.zeros(self.latent_dim), torch.ones(self.latent_dim))
            z = dis.sample((state.shape[0],)).clamp(-0.5, 0.5).to(self.device)
            z = torch.tile(torch.unsqueeze(z, 1), (1, lenght, 1))
        batchsize = state.size(0)
        context = torch.cat((state, desgoal, z), -1)
        context = torch.reshape(context, (lenght, batchsize,
                                          self.latent_dim + self.env_params['goal'] +
                                          self.env_params['obs']))
        output, (hn, cn) = self.lstm2(context, (self.h1, self.c1))
        output = torch.reshape(output, (batchsize, lenght, -1))

        d = self.d(output)
        return d


class VAE(nn.Module):
    def __init__(self, env_params, id):
        super(VAE, self).__init__()
        self.env_params = env_params
        self.latent_dim = 64
        if id >= 0:
            self.device = torch.device('cuda:%d' % id)
        else:
            self.device = torch.device('cpu')
        dim = 512
        self.e1 = nn.Linear(env_params['obs'] + 2*env_params['goal'], dim)
        self.e2 = nn.Linear(dim, dim)
        self.e3 = nn.Linear(dim, dim)

        self.mean = nn.Linear(dim, self.latent_dim)
        self.log_std = nn.Linear(dim, self.latent_dim)

        self.d1 = nn.Linear(env_params['obs'] + env_params['goal']+ self.latent_dim, dim)
        self.d2 = nn.Linear(dim, dim)
        self.d3 = nn.Linear(dim, env_params['goal'])

    def forward(self, aggoal, desgoal, sampledgoal):
        self.z1 = torch.relu(self.e1(torch.cat([aggoal, desgoal, sampledgoal], 1)))
        self.z2 = torch.relu(self.e2(self.z1))
        self.z3 = torch.relu(self.e3(self.z2))

        mean = self.mean(self.z3)
        log_std = self.log_std(self.z3)
        std = torch.exp(log_std)
        z = mean + std * torch.randn_like(std).to(self.device)
        u = self.decode(aggoal, desgoal, z)

        return u, mean, std

    def decode(self, aggoal, desgoal,z=None):
        if z is None:
            dis = torch.distributions.Normal(torch.zeros(self.latent_dim), torch.ones(self.latent_dim))
            z = dis.sample((aggoal.shape[0],)).clamp(-0.5, 0.5).to(self.device)
        a1 = torch.relu(self.d1(torch.cat([aggoal, desgoal,z], 1)))
        a2 = torch.relu(self.d2(a1))
        return self.d3(a2)

    def decode2(self, aggoal, desgoal,z=None):
        if z is None:
            dis = torch.distributions.Normal(torch.zeros(self.latent_dim), torch.ones(self.latent_dim))
            z = dis.sample((aggoal.shape[0],)).clamp(-0.5, 0.5).to(self.device)
        a1 = torch.relu(self.d1(torch.cat([aggoal, desgoal,z], 1))+self.z3)
        a2 = torch.relu(self.d2(a1)+self.z2)
        return self.d3(a2+self.z1)

class VAE_NoGoal(nn.Module):
    def __init__(self, env_params, id):
        super(VAE_NoGoal, self).__init__()
        self.env_params = env_params
        self.latent_dim = 64
        if id >= 0:
            self.device = torch.device('cuda:%d' % id)
        else:
            self.device = torch.device('cpu')
        dim = 256
        self.e1 = nn.Linear(env_params['obs'] + env_params['goal'], dim)
        self.e2 = nn.Linear(dim, dim)
        self.e3 = nn.Linear(dim, dim)

        self.mean = nn.Linear(dim, self.latent_dim)
        self.log_std = nn.Linear(dim, self.latent_dim)

        self.d1 = nn.Linear(env_params['obs'] + self.latent_dim, dim)
        self.d2 = nn.Linear(dim, dim)
        self.d3 = nn.Linear(dim, env_params['goal'])

    def forward(self, aggoal, sampledgoal):
        self.z1 = torch.relu(self.e1(torch.cat([aggoal, sampledgoal], 1)))
        self.z2 = torch.relu(self.e2(self.z1))
        self.z3 = torch.relu(self.e3(self.z2))

        mean = self.mean(self.z3)
        log_std = self.log_std(self.z3)
        std = torch.exp(log_std)
        z = mean + std * torch.randn_like(std).to(self.device)
        u = self.decode(aggoal, z)

        return u, mean, std

    def decode(self, aggoal,z=None):
        if z is None:
            dis = torch.distributions.Normal(torch.zeros(self.latent_dim), torch.ones(self.latent_dim))
            z = dis.sample((aggoal.shape[0],)).clamp(-0.5, 0.5).to(self.device)
        a1 = torch.relu(self.d1(torch.cat([aggoal,z], 1))+self.z3)
        a2 = torch.relu(self.d2(a1)+self.z2)
        return self.d3(a2+self.z1)


class GeneratorFull(nn.Module):
    def __init__(self, env_params, id):
        super(GeneratorFull, self).__init__()
        self.env_params = env_params
        self.latent_dim = 8
        if id >= 0:
            self.device = torch.device('cuda:%d' % id)
        else:
            self.device = torch.device('cpu')
        dim = 128
        self.e1 = nn.Linear(env_params['obs'] + env_params['goal'], dim)
        self.e2 = nn.Linear(dim, dim)
        self.e3 = nn.Linear(dim, dim)
        self.e4 = nn.Linear(dim, env_params['goal'])

    def forward(self, aggoal, desgoal):
        self.z1 = torch.relu(self.e1(torch.cat([aggoal, desgoal], 1)))
        self.z2 = torch.relu(self.e2(self.z1))
        self.z3 = torch.relu(self.e3(self.z2))
        goal = self.e4(self.z3)
        return goal


class VAERIS(nn.Module):
    def __init__(self, env_params, id):
        super(VAERIS, self).__init__()
        self.env_params = env_params
        self.latent_dim = 16
        if id >= 0:
            self.device = torch.device('cuda:%d' % id)
        else:
            self.device = torch.device('cpu')
        dim = 128
        self.e1 = nn.Linear(env_params['obs'] + 2 * env_params['goal'], dim)
        self.e2 = nn.Linear(dim, dim)
        self.e3 = nn.Linear(dim, dim)

        self.mean = nn.Linear(dim, self.latent_dim)
        self.log_std = nn.Linear(dim, self.latent_dim)

        self.d1 = nn.Linear(env_params['obs'] + env_params['goal'] + self.latent_dim, dim)
        self.d2 = nn.Linear(dim, dim)
        self.d3 = nn.Linear(dim, env_params['goal'])

    def forward(self, aggoal, desgoal, sampledgoal):
        self.z1 = torch.relu(self.e1(torch.cat([aggoal, desgoal, sampledgoal], 1)))
        self.z2 = torch.relu(self.e2(self.z1))
        self.z3 = torch.relu(self.e3(self.z2))

        mean = self.mean(self.z3)
        log_std = self.log_std(self.z3)
        std = torch.exp(log_std)
        z = mean + std * torch.randn_like(std).to(self.device)
        u = self.decode(aggoal, desgoal, z)

        return u, mean, std

    def decode(self, aggoal, desgoal, z=None):
        if z is None:
            dis = torch.distributions.Normal(torch.zeros(self.latent_dim), torch.ones(self.latent_dim))
            z = dis.sample((aggoal.shape[0],)).clamp(-0.5, 0.5).to(self.device)
        a1 = torch.relu(self.d1(torch.cat([aggoal, desgoal, z], 1)))
        a2 = torch.relu(self.d2(a1))
        return self.d3(a2)


class RIGVAE(nn.Module):
    def __init__(self, env_params, id):
        super(RIGVAE, self).__init__()
        self.env_params = env_params
        self.latent_dim = 16
        if id >= 0:
            self.device = torch.device('cuda:%d' % id)
        else:
            self.device = torch.device('cpu')
        dim = 128
        self.e1 = nn.Linear(env_params['goal'], dim)
        self.e2 = nn.Linear(dim, dim)
        self.e3 = nn.Linear(dim, dim)

        self.mean = nn.Linear(dim, self.latent_dim)
        self.log_std = nn.Linear(dim, self.latent_dim)

        self.d1 = nn.Linear(self.latent_dim, dim)
        self.d2 = nn.Linear(dim, dim)
        self.d3 = nn.Linear(dim, env_params['goal'])

    def forward(self, state):
        self.z1 = torch.relu(self.e1(state))
        self.z2 = torch.relu(self.e2(self.z1))
        self.z3 = torch.relu(self.e3(self.z2))

        mean = self.mean(self.z3)
        log_std = self.log_std(self.z3)
        std = torch.exp(log_std)
        z = mean + std * torch.randn_like(std)

        u = self.decode(z)

        return u, mean, std

    def decode(self, z=None):
        # When sampling from the VAE, the latent vector is clipped to [-0.5, 0.5]
        if z is None:
            dis = torch.distributions.Normal(torch.zeros(self.latent_dim), torch.ones(self.latent_dim))
            z = dis.sample((256,)).clamp(-0.5, 0.5).to(self.device)
        a1 = torch.relu(self.d1(z))
        a2 = torch.relu(self.d2(a1))
        return self.d3(a2)

    def decode2(self, z=None):
        # When sampling from the VAE, the latent vector is clipped to [-0.5, 0.5]
        if z is None:
            dis = torch.distributions.Normal(torch.zeros(self.latent_dim), torch.ones(self.latent_dim))
            z = dis.sample((30,)).clamp(-0.5, 0.5).to(self.device)
        a1 = torch.relu(self.d1(z))
        a2 = torch.relu(self.d2(a1))
        return self.d3(a2)


class VAELEAP(nn.Module):
    def __init__(self, env_params, id):
        super(VAELEAP, self).__init__()
        self.env_params = env_params
        self.latent_dim = 16
        if id >= 0:
            self.device = torch.device('cuda:%d' % id)
        else:
            self.device = torch.device('cpu')
        dim = 128
        self.e1 = nn.Linear(env_params['goal'], dim)
        self.e2 = nn.Linear(dim, dim)
        self.e3 = nn.Linear(dim, dim)

        self.mean = nn.Linear(dim, self.latent_dim)
        self.log_std = nn.Linear(dim, self.latent_dim)

        self.d1 = nn.Linear(self.latent_dim, dim)
        self.d2 = nn.Linear(dim, dim)
        self.d3 = nn.Linear(dim, env_params['goal'])

    def forward(self, state):
        self.z1 = torch.relu(self.e1(state))
        self.z2 = torch.relu(self.e2(self.z1))
        self.z3 = torch.relu(self.e3(self.z2))

        mean = self.mean(self.z3)
        log_std = self.log_std(self.z3)
        std = torch.exp(log_std)
        z = mean + std * torch.randn_like(std)

        u = self.decode(z)

        return u, mean, std

    def decode(self, z=None):
        # When sampling from the VAE, the latent vector is clipped to [-0.5, 0.5]
        if z is None:
            dis = torch.distributions.Normal(torch.zeros(self.latent_dim), torch.ones(self.latent_dim))
            # z = dis.sample((256,)).clamp(-0.5, 0.5).to(self.device)
            z = dis.sample((256,)).to(self.device)
        a1 = torch.relu(self.d1(z))
        a2 = torch.relu(self.d2(a1))
        return self.d3(a2)

    def decode2(self, z=None):
        # When sampling from the VAE, the latent vector is clipped to [-0.5, 0.5]
        if z is None:
            dis = torch.distributions.Normal(torch.zeros(self.latent_dim), torch.ones(self.latent_dim))
            z = dis.sample((30,)).clamp(-0.5, 0.5).to(self.device)
        a1 = torch.relu(self.d1(z))
        a2 = torch.relu(self.d2(a1))
        return self.d3(a2)