import torch
import torch.nn as nn
import torch.nn.functional as F

# Q value network
class q_network(nn.Module):
    def __init__(self, state_dim, hidden_dim, num_action):
        super().__init__()
        self.state_dim = state_dim
        self.hidden_dim = hidden_dim
        self.linear1 = nn.Linear(state_dim, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, hidden_dim)
        self.linear3 = nn.Linear(hidden_dim, num_action)
    
    def forward(self, state):
        out = F.relu(self.linear1(state))
        out = F.relu(self.linear2(out))
        out = self.linear3(out)
        return out
    
# Expected Free Energy
class EFE_network(nn.Module):
    def __init__(self, state_dim, hidden_dim):
        super().__init__()
        self.state_dim = state_dim
        self.hidden_dim = hidden_dim
        self.linear1 = nn.Linear(state_dim, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, hidden_dim)
        self.linear3 = nn.Linear(hidden_dim, 1)
        
    def forward(self, state):
        out = F.relu(self.linear1(state))
        out = F.relu(self.linear2(out))
        out = self.linear3(out)
        return out
    
# Policy Network
class Policy_network(nn.Module):
    def __init__(self, state_dim, hidden_dim, action_num):
        super().__init__()
        self.state_dim = state_dim
        self.action_num = action_num
        self.linear1 = nn.Linear(state_dim, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, hidden_dim)
        self.linear3 = nn.Linear(hidden_dim, action_num)
        
    def forward(self, state):
        out = F.relu(self.linear1(state))
        out = F.relu(self.linear2(out))
        out = F.softmax(self.linear3(out), dim=-1)
        return out

# p(o|S) : decoder
class Decoder(nn.Module):
    def __init__(self, obs_dim, hidden_dim, state_dim):
        super().__init__()
        self.obs_dim = obs_dim
        self.hidden_dim = hidden_dim
        self.linear1 = nn.Linear(state_dim, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, hidden_dim)
        self.mu = nn.Linear(hidden_dim, obs_dim)
        self.rho = nn.Sequential(nn.Linear(hidden_dim, obs_dim), 
                                 nn.Softplus())
        
    def forward(self, x_state):
        out = self.linear1(x_state)
        out = F.relu(out)
        out = self.linear2(out)
        out = F.relu(out)
        mu = self.mu(out)
        rho = self.rho(out)
        return mu, rho
    
# p(s|o) : encoder
class Encoder(nn.Module):
    def __init__(self, obs_dim, hidden_dim, state_dim):
        super().__init__()
        self.obs_dim = obs_dim
        self.hidden_dim = hidden_dim
        self.linear1 = nn.Linear(obs_dim, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, hidden_dim)
        self.mu = nn.Linear(hidden_dim, state_dim)
        self.rho = nn.Sequential(nn.Linear(hidden_dim, state_dim), 
                                 nn.Softplus())
        
    def forward(self, x_obs):
        out = self.linear1(x_obs)
        out = F.relu(out)
        out = self.linear2(out)
        out = F.relu(out)
        mu = self.mu(out)
        rho = self.rho(out)
        return mu, rho
    
# p(s_{t+!} | s_t, a_t) : transition
class Transition(nn.Module):
    def __init__(self, hidden_dim, state_dim, action_num):
        super().__init__()
        self.state_dim = state_dim
        self.hidden_dim = hidden_dim
        self.action_num = action_num
        self.linear1 = nn.Linear(state_dim, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, hidden_dim)
        self.mu = nn.Linear(hidden_dim, state_dim*action_num)
        self.rho = nn.Sequential(nn.Linear(hidden_dim, state_dim*action_num), 
                                 nn.Softplus())
        
    def forward(self, s_prev):
        out = self.linear1(s_prev)
        out = F.relu(out)
        out = self.linear2(out)
        out = F.relu(out)
        mu = self.mu(out).view(s_prev.size(0), self.state_dim, self.action_num)
        rho = self.rho(out).view(s_prev.size(0), self.state_dim, self.action_num)
        return mu, rho
    
# p(o_{t+1} | o_t)
class Priorobservation(nn.Module):
    """
    It predicts mean and std of normal distribution
    """
    def __init__(self, obs_dim, hidden_dim):
        super().__init__()
        self.obs_dim = obs_dim
        self.hidden_dim = hidden_dim
        self.linear1 = nn.Linear(obs_dim, hidden_dim)
        self.drop1 = nn.Dropout(0.5)
        self.linear2 = nn.Linear(hidden_dim, hidden_dim)
        self.drop2 = nn.Dropout(0.5)
        self.linear3 = nn.Linear(hidden_dim, obs_dim*2)
        
        
    def forward(self, o_prev):
        out = self.drop1(self.linear1(o_prev))
        out = F.relu(out)
        out = self.drop2(self.linear2(out))
        out = F.relu(out)
        out = self.linear3(out)
        return out[:, :self.obs_dim], F.softplus(out[:,self.obs_dim:])

# Behavior Cloning    
class BC(nn.Module):
    def __init__(self, obs_dim, hidden_dim, action_dim):
        super().__init__()
        self.obs_dim = obs_dim
        self.hidden_dim = hidden_dim
        self.linear1 = nn.Linear(obs_dim, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, hidden_dim)
        self.linear3 = nn.Linear(hidden_dim, action_dim)
        
        
    def forward(self, o_prev):
        out = self.linear1(o_prev)
        out = F.relu(out)
        out = self.linear2(out)
        out = F.relu(out)
        out = self.linear3(out)
        return out   
    
    
# Generate models
def generate_triple(obs_dim, hidden_dim, state_dim, action_num, device):
    decoder = Decoder(obs_dim, hidden_dim, state_dim).to(device)
    encoder = Encoder(obs_dim, hidden_dim, state_dim).to(device)
    transition = Transition(hidden_dim, state_dim, action_num).to(device)
    return decoder, encoder, transition
    