import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class RFQIActor(nn.Module):
    """
    Perturbation Actor: Takes state and action (from VAE), outputs perturbed action.
    """
    def __init__(self, state_dim, action_dim, max_action, phi=0.05):
        super(RFQIActor, self).__init__()
        self.l1 = nn.Linear(state_dim + action_dim, 400)
        self.l2 = nn.Linear(400, 300)
        self.l3 = nn.Linear(300, action_dim)

        self.max_action = max_action
        self.phi = phi

    def forward(self, state, action):
        a = F.relu(self.l1(torch.cat([state, action], 1)))
        a = F.relu(self.l2(a))
        # Perturbation is limited by phi * max_action
        a = self.phi * self.max_action * torch.tanh(self.l3(a))
        return (a + action).clamp(-self.max_action, self.max_action)

class RFQICritic(nn.Module):
    """
    Twin Critic Network
    """
    def __init__(self, state_dim, action_dim):
        super(RFQICritic, self).__init__()
        # Q1 architecture
        self.l1 = nn.Linear(state_dim + action_dim, 400)
        self.l2 = nn.Linear(400, 300)
        self.l3 = nn.Linear(300, 1)

        # Q2 architecture
        self.l4 = nn.Linear(state_dim + action_dim, 400)
        self.l5 = nn.Linear(400, 300)
        self.l6 = nn.Linear(300, 1)

    def forward(self, state, action):
        sa = torch.cat([state, action], 1)
        
        q1 = F.relu(self.l1(sa))
        q1 = F.relu(self.l2(q1))
        q1 = self.l3(q1)

        q2 = F.relu(self.l4(sa))
        q2 = F.relu(self.l5(q2))
        q2 = self.l6(q2)
        return q1, q2

    def q1(self, state, action):
        sa = torch.cat([state, action], 1)
        q1 = F.relu(self.l1(sa))
        q1 = F.relu(self.l2(q1))
        q1 = self.l3(q1)
        return q1

class RFQIVAE(nn.Module):
    """
    Conditional VAE: Encodes (state, action) -> latent, Decodes (state, latent) -> action
    """
    def __init__(self, state_dim, action_dim, latent_dim, max_action, device):
        super(RFQIVAE, self).__init__()
        self.e1 = nn.Linear(state_dim + action_dim, 750)
        self.e2 = nn.Linear(750, 750)

        self.mean = nn.Linear(750, latent_dim)
        self.log_std = nn.Linear(750, latent_dim)

        self.d1 = nn.Linear(state_dim + latent_dim, 750)
        self.d2 = nn.Linear(750, 750)
        self.d3 = nn.Linear(750, action_dim)

        self.max_action = max_action
        self.latent_dim = latent_dim
        self.device = device

    def forward(self, state, action):
        z = F.relu(self.e1(torch.cat([state, action], 1)))
        z = F.relu(self.e2(z))

        mean = self.mean(z)
        # Clamped for numerical stability
        log_std = self.log_std(z).clamp(-4, 15)
        std = torch.exp(log_std)
        z = mean + std * torch.randn_like(std)

        u = self.decode(state, z)

        return u, mean, std

    def decode(self, state, z=None):
        # When sampling from the VAE, the latent vector is clipped to [-0.5, 0.5]
        if z is None:
            z = torch.randn((state.shape[0], self.latent_dim)).to(self.device).clamp(-0.5, 0.5)

        a = F.relu(self.d1(torch.cat([state, z], 1)))
        a = F.relu(self.d2(a))
        return self.max_action * torch.tanh(self.d3(a))

class RFQIETA(nn.Module):
    """
    Auxiliary network to estimate the dual variable eta for the robust update.
    """
    def __init__(self, state_dim, action_dim):
        super(RFQIETA, self).__init__()
        self.fc_1 = nn.Linear(state_dim + action_dim, 64)
        self.fc_2 = nn.Linear(64, 64)
        self.fc_out = nn.Linear(64, 1)
        
    def forward(self, s, a):
        eta = F.relu(self.fc_1(torch.cat([s, a], 1)))
        eta = F.relu(self.fc_2(eta))
        eta = self.fc_out(eta)
        return eta.squeeze(dim=1)