import torch
import torch.nn as nn
import torch.nn.functional as F

"""
the input x in both networks should be [o, g], where o is the observation and g is the goal.

"""

# define the actor network
class expert(nn.Module):
    def __init__(self, env_params):
        super(expert, self).__init__()
        self.max_action = env_params['action_max']
        self.fc1 = nn.Linear(env_params['obs'] + env_params['goal'], 512)
        self.fc2 = nn.Linear(512, 512)
        self.action_out = nn.Linear(512, env_params['action'])

    def forward(self, x,y):
        x = torch.cat((x,y),1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        actions = self.max_action * torch.tanh(self.action_out(x))

        return actions

class Policyactor(nn.Module):
    def __init__(self, env_params):
        super(Policyactor, self).__init__()
        self.max_action = env_params['action_max']
        self.fc1 = nn.Linear(env_params['obs'], 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 256)
        self.action_out = nn.Linear(256, env_params['action'])

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        actions = self.max_action * torch.tanh(self.action_out(x))

        return actions

# define the actor network
class actor(nn.Module):
    def __init__(self, env_params):
        super(actor, self).__init__()
        self.max_action = env_params['action_max']
        self.fc1 = nn.Linear(env_params['obs'] + env_params['goal'], 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 256)
        self.action_out = nn.Linear(256, env_params['action'])

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        actions = self.max_action * torch.tanh(self.action_out(x))

        return actions
    
    
class actor_nogoal(nn.Module):
    def __init__(self, env_params):
        super(actor_nogoal, self).__init__()
        self.max_action = env_params['action_max']
        self.fc1 = nn.Linear(env_params['obs'], 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 256)
        self.action_out = nn.Linear(256, env_params['action'])

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        actions = self.max_action * torch.tanh(self.action_out(x))

        return actions

class expactor(nn.Module):
    def __init__(self, env_params):
        super(expactor, self).__init__()
        self.max_action = env_params['action_max']
        self.fc1 = nn.Linear(env_params['obs'] + env_params['goal'], 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 256)
        self.action_out = nn.Linear(256, env_params['action'])

    def forward(self, x, y):
        x =torch.cat((x,y),dim=1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        actions = self.action_out(x)
        return actions

class critic(nn.Module):
    def __init__(self, env_params):
        super(critic, self).__init__()
        self.max_action = env_params['action_max']
        self.fc1 = nn.Linear(env_params['obs'] + env_params['goal'] + env_params['action'], 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 256)
        self.q_out = nn.Linear(256, 1)

    def forward(self, x, actions):
        x = torch.cat([x, actions / self.max_action], dim=-1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        q_value = self.q_out(x)

        return q_value


class critic_nogoal(nn.Module):
    def __init__(self, env_params):
        super(critic_nogoal, self).__init__()
        self.max_action = env_params['action_max']
        self.fc1 = nn.Linear(env_params['obs']+ env_params['action'], 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 256)
        self.q_out = nn.Linear(256, 1)

    def forward(self, x, actions):
        x = torch.cat([x, actions / self.max_action], dim=1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        q_value = self.q_out(x)

        return q_value


class critic2(nn.Module):
    def __init__(self, env_params):
        super(critic2, self).__init__()
        self.max_action = env_params['action_max']
        self.fc1 = nn.Linear(env_params['obs'] + env_params['goal'] + env_params['action'], 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 256)
        self.q_out = nn.Linear(256, 1)

        self.fc11 = nn.Linear(env_params['obs'] + env_params['goal'] + env_params['action'], 256)
        self.fc21 = nn.Linear(256, 256)
        self.fc31 = nn.Linear(256, 256)
        self.q_out1 = nn.Linear(256, 1)

    def forward(self, x, actions):
        xin = torch.cat([x, actions / self.max_action], dim=1)
        x = F.relu(self.fc1(xin))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        q_value1 = self.q_out(x)

        x = F.relu(self.fc11(xin))
        x = F.relu(self.fc21(x))
        x = F.relu(self.fc31(x))
        q_value2 = self.q_out(x)

        return q_value1, q_value2

    def Q1(self, x, actions):
        xin = torch.cat([x, actions / self.max_action], dim=1)
        x = F.relu(self.fc1(xin))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        q_value1 = self.q_out(x)

        return q_value1



class Subactor(nn.Module):
    def __init__(self, env_params):
        super(Subactor, self).__init__()
        self.max_action = env_params['action_max']
        self.fc1 = nn.Linear(env_params['obs'] + 2*env_params['goal'], 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 256)
        self.action_out = nn.Linear(256, env_params['action'])

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        actions = self.max_action * torch.tanh(self.action_out(x))

        return actions


# Vanilla Variational Auto-Encoder
class VAERNN(nn.Module):
    def __init__(self, env_params, id):
        super(VAERNN, self).__init__()
        self.env_params = env_params
        self.latent_dim = 32
        self.batchsize = 256
        if id >= 0:
            self.device = torch.device('cuda:%d' % id)
        else:
            self.device = torch.device('cpu')
        dim = 256

        self.lstm = torch.nn.LSTM(input_size=env_params['obs'] + 2 * env_params['goal'],
                                  hidden_size=dim, num_layers=2)
        self.lstm.to(self.device)
        self.h0 = torch.randn(2, self.batchsize, dim).to(self.device)  # num_layers * num_directions, batch, hidden_size
        self.c0 = torch.randn(2, self.batchsize, dim).to(self.device)

        self.mean = nn.Linear(dim, self.latent_dim)
        self.log_std = nn.Linear(dim, self.latent_dim)

        self.lstm2 = torch.nn.LSTM(
            input_size=env_params['goal'] + env_params['obs'] + self.latent_dim,
            hidden_size=dim, num_layers=2)
        self.lstm2.to(self.device)
        self.h1 = torch.randn(2, self.batchsize, dim).to(self.device)  # num_layers * num_directions, batch, hidden_size
        self.c1 = torch.randn(2, self.batchsize, dim).to(self.device)

        self.d = nn.Linear(dim, env_params['goal'])

    def forward(self, state, desgoal, hergoal):
        # context seq_len, batch, input_size
        lenght = state.shape[1]
        context = torch.cat((state, desgoal, hergoal), 2)
        # print(lenght, self.batchsize,2 * self.env_params['goal'] + self.env_params['obs'])
        context = torch.reshape(context, (lenght, self.batchsize,
                                          2 * self.env_params['goal'] + self.env_params['obs']))
        output, (self.hn, self.cn) = self.lstm(context, (self.h0, self.c0))

        mean = self.mean(output.transpose(1,0))
        log_std = self.log_std(output.transpose(1,0))
        std = torch.exp(log_std)

        z = mean + std * torch.randn_like(std).to(self.device)

        u = self.decode(state, desgoal, z)
        return u, mean, std

    def decode(self, state, desgoal, z=None):
        lenght = state.shape[1]
        if z is None:
            dis = torch.distributions.Normal(torch.zeros(self.latent_dim), torch.ones(self.latent_dim))
            z = dis.sample((state.shape[0],)).clamp(-0.5, 0.5).to(self.device)
            z = torch.tile(torch.unsqueeze(z, 1), (1, lenght, 1))
        batchsize = state.size(0)
        context = torch.cat((state, desgoal, z), -1)
        context = torch.reshape(context, (lenght, batchsize,
                                          self.latent_dim + self.env_params['goal'] +
                                          self.env_params['obs']))
        output, (hn, cn) = self.lstm2(context, (self.h1, self.c1))
        output = torch.reshape(output, (batchsize, lenght, -1))

        d = self.d(output)
        return d


class VAE(nn.Module):
    def __init__(self, env_params, id):
        super(VAE, self).__init__()
        self.env_params = env_params
        self.latent_dim = 64
        if id >= 0:
            self.device = torch.device('cuda:%d' % id)
        else:
            self.device = torch.device('cpu')
        dim = 512
        self.e1 = nn.Linear(env_params['obs'] + 2*env_params['goal'], dim)
        self.e2 = nn.Linear(dim, dim)
        self.e3 = nn.Linear(dim, dim)

        self.mean = nn.Linear(dim, self.latent_dim)
        self.log_std = nn.Linear(dim, self.latent_dim)

        self.d1 = nn.Linear(env_params['obs'] + env_params['goal']+ self.latent_dim, dim)
        self.d2 = nn.Linear(dim, dim)
        self.d3 = nn.Linear(dim, env_params['goal'])

    def forward(self, aggoal, desgoal, sampledgoal):
        self.z1 = torch.relu(self.e1(torch.cat([aggoal, desgoal, sampledgoal], 1)))
        self.z2 = torch.relu(self.e2(self.z1))
        self.z3 = torch.relu(self.e3(self.z2))

        mean = self.mean(self.z3)
        log_std = self.log_std(self.z3)
        std = torch.exp(log_std)
        z = mean + std * torch.randn_like(std).to(self.device)
        u = self.decode(aggoal, desgoal, z)

        return u, mean, std

    def decode(self, aggoal, desgoal,z=None):
        if z is None:
            dis = torch.distributions.Normal(torch.zeros(self.latent_dim), torch.ones(self.latent_dim))
            z = dis.sample((aggoal.shape[0],)).clamp(-0.5, 0.5).to(self.device)
        a1 = torch.relu(self.d1(torch.cat([aggoal, desgoal,z], 1)))
        a2 = torch.relu(self.d2(a1))
        return self.d3(a2)

    def decode2(self, aggoal, desgoal,z=None):
        if z is None:
            dis = torch.distributions.Normal(torch.zeros(self.latent_dim), torch.ones(self.latent_dim))
            z = dis.sample((aggoal.shape[0],)).clamp(-0.5, 0.5).to(self.device)
        a1 = torch.relu(self.d1(torch.cat([aggoal, desgoal,z], 1))+self.z3)
        a2 = torch.relu(self.d2(a1)+self.z2)
        return self.d3(a2+self.z1)

class VAE_NoGoal(nn.Module):
    def __init__(self, env_params, id):
        super(VAE_NoGoal, self).__init__()
        self.env_params = env_params
        self.latent_dim = 64
        if id >= 0:
            self.device = torch.device('cuda:%d' % id)
        else:
            self.device = torch.device('cpu')
        dim = 256
        self.e1 = nn.Linear(env_params['obs'] + env_params['goal'], dim)
        self.e2 = nn.Linear(dim, dim)
        self.e3 = nn.Linear(dim, dim)

        self.mean = nn.Linear(dim, self.latent_dim)
        self.log_std = nn.Linear(dim, self.latent_dim)

        self.d1 = nn.Linear(env_params['obs'] + self.latent_dim, dim)
        self.d2 = nn.Linear(dim, dim)
        self.d3 = nn.Linear(dim, env_params['goal'])

    def forward(self, aggoal, sampledgoal):
        self.z1 = torch.relu(self.e1(torch.cat([aggoal, sampledgoal], 1)))
        self.z2 = torch.relu(self.e2(self.z1))
        self.z3 = torch.relu(self.e3(self.z2))

        mean = self.mean(self.z3)
        log_std = self.log_std(self.z3)
        std = torch.exp(log_std)
        z = mean + std * torch.randn_like(std).to(self.device)
        u = self.decode(aggoal, z)

        return u, mean, std

    def decode(self, aggoal,z=None):
        if z is None:
            dis = torch.distributions.Normal(torch.zeros(self.latent_dim), torch.ones(self.latent_dim))
            z = dis.sample((aggoal.shape[0],)).clamp(-0.5, 0.5).to(self.device)
        a1 = torch.relu(self.d1(torch.cat([aggoal,z], 1))+self.z3)
        a2 = torch.relu(self.d2(a1)+self.z2)
        return self.d3(a2+self.z1)


class GeneratorFull(nn.Module):
    def __init__(self, env_params, id):
        super(GeneratorFull, self).__init__()
        self.env_params = env_params
        self.latent_dim = 8
        if id >= 0:
            self.device = torch.device('cuda:%d' % id)
        else:
            self.device = torch.device('cpu')
        dim = 128
        self.e1 = nn.Linear(env_params['obs'] + env_params['goal'], dim)
        self.e2 = nn.Linear(dim, dim)
        self.e3 = nn.Linear(dim, dim)
        self.e4 = nn.Linear(dim, env_params['goal'])

    def forward(self, aggoal, desgoal):
        self.z1 = torch.relu(self.e1(torch.cat([aggoal, desgoal], 1)))
        self.z2 = torch.relu(self.e2(self.z1))
        self.z3 = torch.relu(self.e3(self.z2))
        goal = self.e4(self.z3)
        return goal


class VAERIS(nn.Module):
    def __init__(self, env_params, id):
        super(VAERIS, self).__init__()
        self.env_params = env_params
        self.latent_dim = 16
        if id >= 0:
            self.device = torch.device('cuda:%d' % id)
        else:
            self.device = torch.device('cpu')
        dim = 128
        self.e1 = nn.Linear(env_params['obs'] + 2 * env_params['goal'], dim)
        self.e2 = nn.Linear(dim, dim)
        self.e3 = nn.Linear(dim, dim)

        self.mean = nn.Linear(dim, self.latent_dim)
        self.log_std = nn.Linear(dim, self.latent_dim)

        self.d1 = nn.Linear(env_params['obs'] + env_params['goal'] + self.latent_dim, dim)
        self.d2 = nn.Linear(dim, dim)
        self.d3 = nn.Linear(dim, env_params['goal'])

    def forward(self, aggoal, desgoal, sampledgoal):
        self.z1 = torch.relu(self.e1(torch.cat([aggoal, desgoal, sampledgoal], 1)))
        self.z2 = torch.relu(self.e2(self.z1))
        self.z3 = torch.relu(self.e3(self.z2))

        mean = self.mean(self.z3)
        log_std = self.log_std(self.z3)
        std = torch.exp(log_std)
        z = mean + std * torch.randn_like(std).to(self.device)
        u = self.decode(aggoal, desgoal, z)

        return u, mean, std

    def decode(self, aggoal, desgoal, z=None):
        if z is None:
            dis = torch.distributions.Normal(torch.zeros(self.latent_dim), torch.ones(self.latent_dim))
            z = dis.sample((aggoal.shape[0],)).clamp(-0.5, 0.5).to(self.device)
        a1 = torch.relu(self.d1(torch.cat([aggoal, desgoal, z], 1)))
        a2 = torch.relu(self.d2(a1))
        return self.d3(a2)


class RIGVAE(nn.Module):
    def __init__(self, env_params, id):
        super(RIGVAE, self).__init__()
        self.env_params = env_params
        self.latent_dim = 16
        if id >= 0:
            self.device = torch.device('cuda:%d' % id)
        else:
            self.device = torch.device('cpu')
        dim = 128
        self.e1 = nn.Linear(env_params['goal'], dim)
        self.e2 = nn.Linear(dim, dim)
        self.e3 = nn.Linear(dim, dim)

        self.mean = nn.Linear(dim, self.latent_dim)
        self.log_std = nn.Linear(dim, self.latent_dim)

        self.d1 = nn.Linear(self.latent_dim, dim)
        self.d2 = nn.Linear(dim, dim)
        self.d3 = nn.Linear(dim, env_params['goal'])

    def forward(self, state):
        self.z1 = torch.relu(self.e1(state))
        self.z2 = torch.relu(self.e2(self.z1))
        self.z3 = torch.relu(self.e3(self.z2))

        mean = self.mean(self.z3)
        log_std = self.log_std(self.z3)
        std = torch.exp(log_std)
        z = mean + std * torch.randn_like(std)

        u = self.decode(z)

        return u, mean, std

    def decode(self, z=None):
        # When sampling from the VAE, the latent vector is clipped to [-0.5, 0.5]
        if z is None:
            dis = torch.distributions.Normal(torch.zeros(self.latent_dim), torch.ones(self.latent_dim))
            z = dis.sample((256,)).clamp(-0.5, 0.5).to(self.device)
        a1 = torch.relu(self.d1(z))
        a2 = torch.relu(self.d2(a1))
        return self.d3(a2)

    def decode2(self, z=None):
        # When sampling from the VAE, the latent vector is clipped to [-0.5, 0.5]
        if z is None:
            dis = torch.distributions.Normal(torch.zeros(self.latent_dim), torch.ones(self.latent_dim))
            z = dis.sample((30,)).clamp(-0.5, 0.5).to(self.device)
        a1 = torch.relu(self.d1(z))
        a2 = torch.relu(self.d2(a1))
        return self.d3(a2)


class VAELEAP(nn.Module):
    def __init__(self, env_params, id):
        super(VAELEAP, self).__init__()
        self.env_params = env_params
        self.latent_dim = 16
        if id >= 0:
            self.device = torch.device('cuda:%d' % id)
        else:
            self.device = torch.device('cpu')
        dim = 128
        self.e1 = nn.Linear(env_params['goal'], dim)
        self.e2 = nn.Linear(dim, dim)
        self.e3 = nn.Linear(dim, dim)

        self.mean = nn.Linear(dim, self.latent_dim)
        self.log_std = nn.Linear(dim, self.latent_dim)

        self.d1 = nn.Linear(self.latent_dim, dim)
        self.d2 = nn.Linear(dim, dim)
        self.d3 = nn.Linear(dim, env_params['goal'])

    def forward(self, state):
        self.z1 = torch.relu(self.e1(state))
        self.z2 = torch.relu(self.e2(self.z1))
        self.z3 = torch.relu(self.e3(self.z2))

        mean = self.mean(self.z3)
        log_std = self.log_std(self.z3)
        std = torch.exp(log_std)
        z = mean + std * torch.randn_like(std)

        u = self.decode(z)

        return u, mean, std

    def decode(self, z=None):
        # When sampling from the VAE, the latent vector is clipped to [-0.5, 0.5]
        if z is None:
            dis = torch.distributions.Normal(torch.zeros(self.latent_dim), torch.ones(self.latent_dim))
            # z = dis.sample((256,)).clamp(-0.5, 0.5).to(self.device)
            z = dis.sample((256,)).to(self.device)
        a1 = torch.relu(self.d1(z))
        a2 = torch.relu(self.d2(a1))
        return self.d3(a2)

    def decode2(self, z=None):
        # When sampling from the VAE, the latent vector is clipped to [-0.5, 0.5]
        if z is None:
            dis = torch.distributions.Normal(torch.zeros(self.latent_dim), torch.ones(self.latent_dim))
            z = dis.sample((30,)).clamp(-0.5, 0.5).to(self.device)
        a1 = torch.relu(self.d1(z))
        a2 = torch.relu(self.d2(a1))
        return self.d3(a2)