import torch
import torch.nn as nn
import  numpy as np
import torch.nn.functional as F

"""
the input x in both networks should be [o, g], where o is the observation and g is the goal.

"""


    
    
    
    
# define the actor network
class actor(nn.Module):
    def __init__(self, env_params):
        super(actor, self).__init__()
        self.max_action = env_params['action_max']
        self.fc1 = nn.Linear(env_params['obs'] + env_params['goal'], 128)
        self.fc2 = nn.Linear(128, 128)
        self.action_out = nn.Linear(128, env_params['action'])

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        actions = self.max_action * torch.tanh(self.action_out(x))

        return actions


# define the actor network
class actorCounter(nn.Module):
    def __init__(self, env_params):
        super(actorCounter, self).__init__()
        self.max_action = env_params['action_max']
        self.fc1 = nn.Linear(env_params['obs'], 128)
        self.fc2 = nn.Linear(128, 128)
        self.action_out = nn.Linear(128, env_params['action'])

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        actions = self.max_action * torch.tanh(self.action_out(x))

        return actions



class RandomNetwork(nn.Module):
    def __init__(self, env_params, output_dim=32):
        super(RandomNetwork, self).__init__()
        input_dim = env_params['obs']
        self.fc1 = nn.Linear(input_dim, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, output_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x


class PredictorNetwork(nn.Module):
    def __init__(self, env_params, output_dim=32):
        super(PredictorNetwork, self).__init__()
        input_dim = env_params['obs']
        self.fc1 = nn.Linear(input_dim, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, output_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x



# define the actor network
class Discriminator(nn.Module):
    def __init__(self, env_params):
        super(Discriminator, self).__init__()
        self.max_action = env_params['action_max']
        self.fc1 = nn.Linear(env_params['obs'] + env_params['goal'] + env_params['action'], 256)
        self.fc2 = nn.Linear(256, 256)
        self.action_out = nn.Linear(256, 1)

    def forward(self, x):
        x = F.tanh(self.fc1(x))
        x = F.tanh(self.fc2(x))
        x = torch.sigmoid(self.action_out(x))
        return x


class Discriminator2(nn.Module):
    def __init__(self, env_params):
        super(Discriminator2, self).__init__()
        self.max_action = env_params['action_max']
        self.fc1 = nn.Linear(env_params['goal'], 128)
        self.fc2 = nn.Linear(128, 128)
        self.action_out = nn.Linear(128, 1)

    def forward(self, x):
        x = F.tanh(self.fc1(x))
        x = F.tanh(self.fc2(x))
        x = torch.sigmoid(self.action_out(x))
        return x

class actor2(nn.Module):
    def __init__(self, env_params):
        super(actor2, self).__init__()
        self.fc1 = nn.Linear(env_params['obs'] + env_params['goal']+6, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 256)
        self.action_out = nn.Linear(256, env_params['action_dim'])

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        logits = self.action_out(x)

        return logits


class critic(nn.Module):
    def __init__(self, env_params):
        super(critic, self).__init__()
        self.max_action = env_params['action_max']
        self.fc1 = nn.Linear(env_params['obs'] + env_params['goal'] + env_params['action'], 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 256)
        self.q_out = nn.Linear(256, 1)

    def forward(self, x, actions):
        x = torch.cat([x, actions / self.max_action], dim=1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        q_value = self.q_out(x)

        return q_value


class criticCounter(nn.Module):
    def __init__(self, env_params):
        super(criticCounter, self).__init__()
        self.max_action = env_params['action_max']
        self.fc1 = nn.Linear(env_params['obs'] + env_params['action'], 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 256)
        self.q_out = nn.Linear(256, 1)

    def forward(self, x, actions):
        x = torch.cat([x, actions / self.max_action], dim=1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        q_value = self.q_out(x)

        return q_value

class TADDcritic(nn.Module):
    def __init__(self, env_params):
        super(TADDcritic, self).__init__()
        self.max_action = env_params['action_max']
        self.fc1 = nn.Linear(env_params['obs'] + env_params['goal'] + env_params['action'], 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 256)
        self.q_out1 = nn.Linear(256, 1)

        self.fc4 = nn.Linear(env_params['obs'] + env_params['goal'] + env_params['action'], 256)
        self.fc5 = nn.Linear(256, 256)
        self.fc6 = nn.Linear(256, 256)
        self.q_out2 = nn.Linear(256, 1)

    def forward(self, state, actions):
        x = torch.cat([state, actions / self.max_action], dim=1)

        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        q_value1 = self.q_out1(x)

        y = torch.cat([state, actions / self.max_action], dim=1)
        y = F.relu(self.fc4(y))
        y = F.relu(self.fc5(y))
        y = F.relu(self.fc6(y))
        q_value2 = self.q_out2(y)

        return q_value1, q_value2

    def forwardQ(self, x, actions):
        x = torch.cat([x, actions / self.max_action], dim=1)

        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        q_value1 = self.q_out1(x)

        return q_value1

class criticexpert(nn.Module):
    def __init__(self, env_params):
        super(criticexpert, self).__init__()
        self.max_action = env_params['action_max']
        self.fc1 = nn.Linear(env_params['obs'] + env_params['goal'] + env_params['action'], 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 256)
        self.q_out = nn.Linear(256, 1)

    def forward(self, x, actions):
        x = torch.cat([x, actions / self.max_action], dim=1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        q_value = self.q_out(x)

        return q_value

class exp_reward(nn.Module):
    def __init__(self, env_params):
        super(exp_reward, self).__init__()
        self.max_action = env_params['action_max']
        self.fc1 = nn.Linear(env_params['obs'] + env_params['goal'] + env_params['action'], 256)
        self.fc2 = nn.Linear(256, 256)
        self.q_out = nn.Linear(256, 1)

    def forward(self, obs, goal, actions):
        x = torch.cat([obs, goal, actions / self.max_action], dim=1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        q_value = self.q_out(x)

        return q_value

# Vanilla Variational Auto-Encoder
class VAE3(nn.Module):
    def __init__(self, env_params, id):
        super(VAE3, self).__init__()
        self.env_params = env_params
        self.latent_dim = 64
        self.max_action = env_params['action_max']
        if id>=0:
            self.device = torch.device('cuda:%d' % id)
        else:
            self.device = torch.device('cpu')
        dim =  256
        self.e1 = nn.Linear(env_params['action'] + env_params['obs']+ env_params['goal'], dim)
        self.e2 = nn.Linear(dim, dim)
        self.e3 = nn.Linear(dim, dim)

        self.mean = nn.Linear(dim, self.latent_dim)
        self.log_std = nn.Linear(dim, self.latent_dim)

        self.d1 = nn.Linear(env_params['obs'] + env_params['goal'] + self.latent_dim, dim)
        self.d2 = nn.Linear(dim, dim )
        self.d3 = nn.Linear(dim, dim )
        self.d4 = nn.Linear(dim, env_params['action'])

    def forward(self, state, desgoal, action):
        self.z1 = torch.relu(self.e1(torch.cat([state, desgoal, action], 1)))
        self.z2 = torch.relu(self.e2(self.z1))
        self.z3 = torch.relu(self.e3(self.z2))

        mean = self.mean(self.z3)
        log_std = self.log_std(self.z3)
        std = torch.exp(log_std)
        z = mean + std * torch.randn_like(std)

        u = self.decode2(state, desgoal, z)
        return u, mean, std

    def decode(self, state, desgoal, z=None):
        if z is None:
            z = torch.randn((state.shape[0], self.latent_dim)).to(self.device).clamp(-1, 1)
        a1 = torch.relu(self.d1(torch.cat([state, desgoal, z], 1)))
        a2 = torch.relu(self.d2(a1))
        a3 = torch.relu(self.d3(a2))
        return self.max_action * torch.tanh(self.d4(a3))

    def decode2(self, state, desgoal, z=None):
        if z is None:
            z = torch.randn((state.shape[0], self.latent_dim)).to(self.device).clamp(-1, 1)
        a1 = torch.relu(self.d1(torch.cat([state, desgoal, z], 1)))
        a2 = torch.relu(self.d2(a1+self.z3))
        a3 = torch.relu(self.d3(a2+self.z2))
        return self.max_action * torch.tanh(self.d4(a3+self.z1))
    
    
# Vanilla Variational Auto-Encoder
class VAE2(nn.Module):
    def __init__(self, env_params, id):
        super(VAE2, self).__init__()
        self.env_params = env_params
        self.latent_dim = 64
        self.max_action = env_params['action_max']
        if id>=0:
            self.device = torch.device('cuda:%d' % id)
        else:
            self.device = torch.device('cpu')
        dim =  256
        self.e1 = nn.Linear(env_params['action'] + env_params['obs'], dim)
        self.e2 = nn.Linear(dim, dim)
        self.e3 = nn.Linear(dim, dim)

        self.mean = nn.Linear(dim, self.latent_dim)
        self.log_std = nn.Linear(dim, self.latent_dim)

        self.d1 = nn.Linear(env_params['obs']  + self.latent_dim, dim)
        self.d2 = nn.Linear(dim, dim )
        self.d3 = nn.Linear(dim, dim )
        self.d4 = nn.Linear(dim, env_params['action'])

    def forward(self, state, action):
        self.z1 = torch.relu(self.e1(torch.cat([state, action], 1)))
        self.z2 = torch.relu(self.e2(self.z1))
        self.z3 = torch.relu(self.e3(self.z2))

        mean = self.mean(self.z3)
        log_std = self.log_std(self.z3)
        std = torch.exp(log_std)
        z = mean + std * torch.randn_like(std)


        u = self.decode2(state, z)
        return u, mean, std

    def decode(self, state, z=None):
        if z is None:
            z = torch.randn((state.shape[0], self.latent_dim)).to(self.device).clamp(-1, 1)
        a1 = torch.relu(self.d1(torch.cat([state, z], 1)))
        a2 = torch.relu(self.d2(a1))
        a3 = torch.relu(self.d3(a2))
        return self.max_action * torch.tanh(self.d4(a3))

    def decode2(self, state, z=None):
        if z is None:
            z = torch.randn((state.shape[0], self.latent_dim)).to(self.device).clamp(-1, 1)
        a1 = torch.relu(self.d1(torch.cat([state, z], 1)))
        a2 = torch.relu(self.d2(a1+self.z3))
        a3 = torch.relu(self.d3(a2+self.z2))
        return self.max_action * torch.tanh(self.d4(a3+self.z1))


class ExpPolicyFULL(nn.Module):
    def __init__(self, env_params, id):
        super(ExpPolicyFULL, self).__init__()
        self.env_params = env_params
        self.latent_dim = 64
        self.max_action = env_params['action_max']
        if id>=0:
            self.device = torch.device('cuda:%d' % id)
        else:
            self.device = torch.device('cpu')
        dim =  256
        self.e1 = nn.Linear(env_params['obs']+ env_params['goal'], dim)
        self.e2 = nn.Linear(dim, dim)
        self.e3 = nn.Linear(dim, dim)

        self.mean = nn.Linear(dim, env_params['action'])

    def forward(self, state, desgoal):
        self.z1 = torch.relu(self.e1(torch.cat([state, desgoal], 1)))
        self.z2 = torch.relu(self.e2(self.z1))
        self.z3 = torch.relu(self.e3(self.z2))
        mean =  self.mean(self.z3)
        return mean

class ExpPolicy(nn.Module):
    def __init__(self, env_params, id):
        super(ExpPolicy, self).__init__()
        self.env_params = env_params
        self.latent_dim = 64
        self.max_action = env_params['action_max']
        if id>=0:
            self.device = torch.device('cuda:%d' % id)
        else:
            self.device = torch.device('cpu')
        dim =  256
        self.e1 = nn.Linear(env_params['obs']+ env_params['goal'], dim)
        self.e2 = nn.Linear(dim, dim)
        self.e3 = nn.Linear(dim, dim)

        self.mean = nn.Linear(dim, env_params['action'])

    def forward(self, state, desgoal ):
        self.z1 = torch.relu(self.e1(torch.cat([state, desgoal], 1)))
        self.z2 = torch.relu(self.e2(self.z1))
        self.z3 = torch.relu(self.e3(self.z2)+self.z1)
        mean =  self.mean(self.z3)
        return mean

class VAE3GCSL(nn.Module):
    def __init__(self, env_params, id):
        super(VAE3GCSL, self).__init__()
        self.env_params = env_params
        self.latent_dim = 8
        self.max_action = env_params['action_max']
        if id>=0:
            self.device = torch.device('cuda:%d' % id)
        else:
            self.device = torch.device('cpu')
        dim =  64
        self.e1 = nn.Linear(env_params['action'] + env_params['obs']+ env_params['goal'], dim)
        self.e2 = nn.Linear(dim, dim)
        self.e3 = nn.Linear(dim, dim)

        self.mean = nn.Linear(dim, self.latent_dim)
        self.log_std = nn.Linear(dim, self.latent_dim)

        self.d1 = nn.Linear(env_params['obs'] + env_params['goal'] + self.latent_dim, dim)
        self.d2 = nn.Linear(dim, dim )
        self.d3 = nn.Linear(dim, dim )
        self.d4 = nn.Linear(dim, env_params['action'])

    def forward(self, state, desgoal, action):
        self.z1 = torch.relu(self.e1(torch.cat([state, desgoal, action], 1)))
        self.z2 = torch.relu(self.e2(self.z1))
        self.z3 = torch.relu(self.e3(self.z2))

        mean = self.mean(self.z3)
        log_std = self.log_std(self.z3).clamp(-4, 15)
        std = torch.exp(log_std)
        z = mean + std * torch.randn_like(std)
        u = self.decode(state, desgoal, z)
        return u, mean, std

    def decode(self, state, desgoal, z=None):
        # When sampling from the VAE, the latent vector is clipped to [-0.5, 0.5]
        if z is None:
            z = torch.randn((state.shape[0], self.latent_dim)).to(self.device).clamp(-1, 1)
        a1 = torch.relu(self.d1(torch.cat([state, desgoal, z], 1)))
        a2 = torch.relu(self.d2(a1))
        a3 = torch.relu(self.d3(a2))
        return self.max_action * torch.tanh(self.d4(a3))