import torch
import torch.nn as nn
import  numpy as np
import torch.nn.functional as F

"""
the input x in both networks should be [o, g], where o is the observation and g is the goal.

"""

# define the actor network
class actor(nn.Module):
    def __init__(self, env_params):
        super(actor, self).__init__()
        self.max_action = env_params['action_max']
        self.fc1 = nn.Linear(env_params['obs'] + env_params['goal'], 128)
        self.fc2 = nn.Linear(128, 128)
        self.action_out = nn.Linear(128, env_params['action'])

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        actions = self.max_action * torch.tanh(self.action_out(x))

        return actions


# define the actor network
class actorCounter(nn.Module):
    def __init__(self, env_params):
        super(actorCounter, self).__init__()
        self.max_action = env_params['action_max']
        self.fc1 = nn.Linear(env_params['obs'], 128)
        self.fc2 = nn.Linear(128, 128)
        self.action_out = nn.Linear(128, env_params['action'])

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        actions = self.max_action * torch.tanh(self.action_out(x))

        return actions



class RandomNetwork(nn.Module):
    def __init__(self, env_params, output_dim=32):
        super(RandomNetwork, self).__init__()
        input_dim = env_params['obs']
        self.fc1 = nn.Linear(input_dim, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, output_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x


class PredictorNetwork(nn.Module):
    def __init__(self, env_params, output_dim=32):
        super(PredictorNetwork, self).__init__()
        input_dim = env_params['obs']
        self.fc1 = nn.Linear(input_dim, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, output_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x



# define the actor network
class Discriminator(nn.Module):
    def __init__(self, env_params):
        super(Discriminator, self).__init__()
        self.max_action = env_params['action_max']
        self.fc1 = nn.Linear(env_params['obs'] + env_params['goal'] + env_params['action'], 256)
        self.fc2 = nn.Linear(256, 256)
        self.action_out = nn.Linear(256, 1)

    def forward(self, x):
        x = F.tanh(self.fc1(x))
        x = F.tanh(self.fc2(x))
        x = torch.sigmoid(self.action_out(x))
        return x


class Discriminator2(nn.Module):
    def __init__(self, env_params):
        super(Discriminator2, self).__init__()
        self.max_action = env_params['action_max']
        self.fc1 = nn.Linear(env_params['goal'], 128)
        self.fc2 = nn.Linear(128, 128)
        self.action_out = nn.Linear(128, 1)

    def forward(self, x):
        x = F.tanh(self.fc1(x))
        x = F.tanh(self.fc2(x))
        x = torch.sigmoid(self.action_out(x))
        return x

class actor2(nn.Module):
    def __init__(self, env_params):
        super(actor2, self).__init__()
        self.fc1 = nn.Linear(env_params['obs'] + env_params['goal']+6, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 256)
        self.action_out = nn.Linear(256, env_params['action_dim'])

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        logits = self.action_out(x)

        return logits


class critic(nn.Module):
    def __init__(self, env_params):
        super(critic, self).__init__()
        self.max_action = env_params['action_max']
        self.fc1 = nn.Linear(env_params['obs'] + env_params['goal'] + env_params['action'], 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 256)
        self.q_out = nn.Linear(256, 1)

    def forward(self, x, actions):
        x = torch.cat([x, actions / self.max_action], dim=1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        q_value = self.q_out(x)

        return q_value


class criticCounter(nn.Module):
    def __init__(self, env_params):
        super(criticCounter, self).__init__()
        self.max_action = env_params['action_max']
        self.fc1 = nn.Linear(env_params['obs'] + env_params['action'], 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 256)
        self.q_out = nn.Linear(256, 1)

    def forward(self, x, actions):
        x = torch.cat([x, actions / self.max_action], dim=1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        q_value = self.q_out(x)

        return q_value

class TADDcritic(nn.Module):
    def __init__(self, env_params):
        super(TADDcritic, self).__init__()
        self.max_action = env_params['action_max']
        self.fc1 = nn.Linear(env_params['obs'] + env_params['goal'] + env_params['action'], 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 256)
        self.q_out1 = nn.Linear(256, 1)

        self.fc4 = nn.Linear(env_params['obs'] + env_params['goal'] + env_params['action'], 256)
        self.fc5 = nn.Linear(256, 256)
        self.fc6 = nn.Linear(256, 256)
        self.q_out2 = nn.Linear(256, 1)

    def forward(self, state, actions):
        x = torch.cat([state, actions / self.max_action], dim=1)

        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        q_value1 = self.q_out1(x)

        y = torch.cat([state, actions / self.max_action], dim=1)
        y = F.relu(self.fc4(y))
        y = F.relu(self.fc5(y))
        y = F.relu(self.fc6(y))
        q_value2 = self.q_out2(y)

        return q_value1, q_value2

    def forwardQ(self, x, actions):
        x = torch.cat([x, actions / self.max_action], dim=1)

        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        q_value1 = self.q_out1(x)

        return q_value1

class criticexpert(nn.Module):
    def __init__(self, env_params):
        super(criticexpert, self).__init__()
        self.max_action = env_params['action_max']
        self.fc1 = nn.Linear(env_params['obs'] + env_params['goal'] + env_params['action'], 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 256)
        self.q_out = nn.Linear(256, 1)

    def forward(self, x, actions):
        x = torch.cat([x, actions / self.max_action], dim=1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        q_value = self.q_out(x)

        return q_value

class exp_reward(nn.Module):
    def __init__(self, env_params):
        super(exp_reward, self).__init__()
        self.max_action = env_params['action_max']
        self.fc1 = nn.Linear(env_params['obs'] + env_params['goal'] + env_params['action'], 256)
        self.fc2 = nn.Linear(256, 256)
        self.q_out = nn.Linear(256, 1)

    def forward(self, obs, goal, actions):
        x = torch.cat([obs, goal, actions / self.max_action], dim=1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        q_value = self.q_out(x)

        return q_value

# Vanilla Variational Auto-Encoder
class VAE3(nn.Module):
    def __init__(self, env_params, id):
        super(VAE3, self).__init__()
        self.env_params = env_params
        self.latent_dim = 64
        self.max_action = env_params['action_max']
        if id>=0:
            self.device = torch.device('cuda:%d' % id)
        else:
            self.device = torch.device('cpu')
        dim =  256
        self.e1 = nn.Linear(env_params['action'] + env_params['obs']+ env_params['goal'], dim)
        self.e2 = nn.Linear(dim, dim)
        self.e3 = nn.Linear(dim, dim)

        self.mean = nn.Linear(dim, self.latent_dim)
        self.log_std = nn.Linear(dim, self.latent_dim)

        self.d1 = nn.Linear(env_params['obs'] + env_params['goal'] + self.latent_dim, dim)
        self.d2 = nn.Linear(dim, dim )
        self.d3 = nn.Linear(dim, dim )
        self.d4 = nn.Linear(dim, env_params['action'])

    def forward(self, state, desgoal, action):
        self.z1 = torch.relu(self.e1(torch.cat([state, desgoal, action], 1)))
        self.z2 = torch.relu(self.e2(self.z1))
        self.z3 = torch.relu(self.e3(self.z2))

        mean = self.mean(self.z3)
        log_std = self.log_std(self.z3)
        std = torch.exp(log_std)
        z = mean + std * torch.randn_like(std)

        u = self.decode(state, desgoal, z)
        return u, mean, std

    def decode(self, state, desgoal, z=None):
        if z is None:
            z = torch.randn((state.shape[0], self.latent_dim)).to(self.device).clamp(-1, 1)
        a1 = torch.relu(self.d1(torch.cat([state, desgoal, z], 1)))
        a2 = torch.relu(self.d2(a1))
        a3 = torch.relu(self.d3(a2))
        return self.max_action * torch.tanh(self.d4(a3))

    def decode2(self, state, desgoal, z=None):
        if z is None:
            z = torch.randn((state.shape[0], self.latent_dim)).to(self.device).clamp(-1, 1)
        a1 = torch.relu(self.d1(torch.cat([state, desgoal, z], 1)))
        a2 = torch.relu(self.d2(a1+self.z3))
        a3 = torch.relu(self.d3(a2+self.z2))
        return self.max_action * torch.tanh(self.d4(a3+self.z1))
    
# --- BEGIN RND Network Definition ---
class RNDNetwork(nn.Module):
    def __init__(self, input_dim, output_dim=128, hidden_dim=128):
        super(RNDNetwork, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x



class RewardNormalizer:
    def __init__(self, alpha=0.1, epsilon=1e-8):
        self.mean = 0.0
        self.var = 1.0
        self.alpha = alpha  # EMA更新率
        self.epsilon = epsilon
        self.count = 0

    def update(self, reward):
        # 更新均值和方差
        self.count += 1
        delta = reward - self.mean
        self.mean += self.alpha * delta
        delta2 = reward - self.mean
        self.var = (1 - self.alpha) * (self.var + self.alpha * delta * delta2)

    def normalize(self, reward):
        # 返回归一化奖励
        std = (self.var ** 0.5) + self.epsilon
        return (reward - self.mean) / std

# Vanilla Variational Auto-Encoder
class VAE2(nn.Module):
    def __init__(self, env_params, id):
        super(VAE2, self).__init__()
        self.env_params = env_params
        self.latent_dim = 64
        self.max_action = env_params['action_max']
        if id>=0:
            self.device = torch.device('cuda:%d' % id)
        else:
            self.device = torch.device('cpu')
        dim =  256
        self.e1 = nn.Linear(env_params['action'] + env_params['obs'], dim)
        self.e2 = nn.Linear(dim, dim)
        self.e3 = nn.Linear(dim, dim)

        self.mean = nn.Linear(dim, self.latent_dim)
        self.log_std = nn.Linear(dim, self.latent_dim)

        self.d1 = nn.Linear(env_params['obs']  + self.latent_dim, dim)
        self.d2 = nn.Linear(dim, dim )
        self.d3 = nn.Linear(dim, dim )
        self.d4 = nn.Linear(dim, env_params['action'])

    def forward(self, state, action):
        self.z1 = torch.relu(self.e1(torch.cat([state, action], 1)))
        self.z2 = torch.relu(self.e2(self.z1))
        self.z3 = torch.relu(self.e3(self.z2))

        mean = self.mean(self.z3)
        log_std = self.log_std(self.z3)
        std = torch.exp(log_std)
        z = mean + std * torch.randn_like(std)

        u = self.decode2(state, z)
        return u, mean, std

    def decode(self, state, z=None):
        if z is None:
            z = torch.randn((state.shape[0], self.latent_dim)).to(self.device).clamp(-1, 1)
        a1 = torch.relu(self.d1(torch.cat([state, z], 1)))
        a2 = torch.relu(self.d2(a1))
        a3 = torch.relu(self.d3(a2))
        return self.max_action * torch.tanh(self.d4(a3))

    def decode2(self, state, z=None):
        if z is None:
            z = torch.randn((state.shape[0], self.latent_dim)).to(self.device).clamp(-1, 1)
        a1 = torch.relu(self.d1(torch.cat([state, z], 1)))
        a2 = torch.relu(self.d2(a1))
        a3 = torch.relu(self.d3(a2))
        return self.max_action * torch.tanh(self.d4(a3))


class ExpPolicyFULL(nn.Module):
    def __init__(self, env_params, id):
        super(ExpPolicyFULL, self).__init__()
        self.env_params = env_params
        self.latent_dim = 64
        self.max_action = env_params['action_max']
        if id>=0:
            self.device = torch.device('cuda:%d' % id)
        else:
            self.device = torch.device('cpu')
        dim =  256
        self.e1 = nn.Linear(env_params['obs']+ env_params['goal'], dim)
        self.e2 = nn.Linear(dim, dim)
        self.e3 = nn.Linear(dim, dim)

        self.mean = nn.Linear(dim, env_params['action'])

    def forward(self, state, desgoal):
        self.z1 = torch.relu(self.e1(torch.cat([state, desgoal], 1)))
        self.z2 = torch.relu(self.e2(self.z1))
        self.z3 = torch.relu(self.e3(self.z2))
        mean =  self.mean(self.z3)
        return mean

class ExpPolicy(nn.Module):
    def __init__(self, env_params, id):
        super(ExpPolicy, self).__init__()
        self.env_params = env_params
        self.latent_dim = 64
        self.max_action = env_params['action_max']
        if id>=0:
            self.device = torch.device('cuda:%d' % id)
        else:
            self.device = torch.device('cpu')
        dim =  256
        self.e1 = nn.Linear(env_params['obs']+ env_params['goal'], dim)
        self.e2 = nn.Linear(dim, dim)
        self.e3 = nn.Linear(dim, dim)

        self.mean = nn.Linear(dim, env_params['action'])

    def forward(self, state, desgoal ):
        self.z1 = torch.relu(self.e1(torch.cat([state, desgoal], 1)))
        self.z2 = torch.relu(self.e2(self.z1))
        self.z3 = torch.relu(self.e3(self.z2)+self.z1)
        mean =  self.mean(self.z3)
        return mean

class VAE3GCSL(nn.Module):
    def __init__(self, env_params, id):
        super(VAE3GCSL, self).__init__()
        self.env_params = env_params
        self.latent_dim = 8
        self.max_action = env_params['action_max']
        if id>=0:
            self.device = torch.device('cuda:%d' % id)
        else:
            self.device = torch.device('cpu')
        dim =  64
        self.e1 = nn.Linear(env_params['action'] + env_params['obs']+ env_params['goal'], dim)
        self.e2 = nn.Linear(dim, dim)
        self.e3 = nn.Linear(dim, dim)

        self.mean = nn.Linear(dim, self.latent_dim)
        self.log_std = nn.Linear(dim, self.latent_dim)

        self.d1 = nn.Linear(env_params['obs'] + env_params['goal'] + self.latent_dim, dim)
        self.d2 = nn.Linear(dim, dim )
        self.d3 = nn.Linear(dim, dim )
        self.d4 = nn.Linear(dim, env_params['action'])

    def forward(self, state, desgoal, action):
        self.z1 = torch.relu(self.e1(torch.cat([state, desgoal, action], 1)))
        self.z2 = torch.relu(self.e2(self.z1))
        self.z3 = torch.relu(self.e3(self.z2))

        mean = self.mean(self.z3)
        log_std = self.log_std(self.z3).clamp(-4, 15)
        std = torch.exp(log_std)
        z = mean + std * torch.randn_like(std)
        u = self.decode(state, desgoal, z)
        return u, mean, std

    def decode(self, state, desgoal, z=None):
        # When sampling from the VAE, the latent vector is clipped to [-0.5, 0.5]
        if z is None:
            z = torch.randn((state.shape[0], self.latent_dim)).to(self.device).clamp(-1, 1)
        a1 = torch.relu(self.d1(torch.cat([state, desgoal, z], 1)))
        a2 = torch.relu(self.d2(a1))
        a3 = torch.relu(self.d3(a2))
        return self.max_action * torch.tanh(self.d4(a3))
    

class TransformerForwardModel(nn.Module):
    def __init__(self, env_params, hidden_dim=256, nhead=8, num_layers=3):
        super(TransformerForwardModel, self).__init__()
        self.input_dim = env_params['obs']  + env_params['action']
        self.output_dim = env_params['obs']
        self.max_action = env_params['action_max']
        
        # Input projection
        self.input_proj = nn.Linear(self.input_dim, hidden_dim)
        
        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=nhead,
            dim_feedforward=hidden_dim * 4,
            dropout=0.1,
            activation='relu'
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        
        # Output projection for next state prediction
        self.output_proj = nn.Linear(hidden_dim, self.output_dim)
        
    def forward(self, state, action):
        # Concatenate inputs
        action = action / self.max_action
        x = torch.cat([state, action], dim=-1)
        
        # Project to transformer dimension
        x = self.input_proj(x)
        x = F.relu(x)
        
        # Reshape for transformer (batch_size, seq_len=1, hidden_dim)
        x = x.unsqueeze(1)
        
        # Pass through transformer
        x = self.transformer(x)
        
        # Remove sequence dimension
        x = x.squeeze(1)
        
        # Predict next state with two heads
        next_state = self.output_proj(x)
        
        return next_state