import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as tds


class DDPG_critic_network(nn.Module):
    
    def __init__(self, state_dim, action_dim, action_low, action_high):
        
        super(DDPG_critic_network, self).__init__()
        hidden_size = 128
        
        num_inputs = state_dim
        num_outputs = action_dim
        
        self.linear1 = nn.Linear(num_inputs, hidden_size)
        self.ln1 = nn.BatchNorm1d(hidden_size)

        self.linear2 = nn.Linear(hidden_size+num_outputs, hidden_size)
        self.ln2 = nn.BatchNorm1d(hidden_size)

        self.V = nn.Linear(hidden_size, 1)
        self.V.weight.data.mul_(0.1)
        self.V.bias.data.mul_(0.1)
        
    def forward(self, s, a):
        x = inputs
        x = self.linear1(x)
        x = self.ln1(x)
        x = F.relu(x)

        x = torch.cat((x, actions), 1)
        x = self.linear2(x)
        x = self.ln2(x)
        x = F.relu(x)
        V = self.V(x)
        
        return V
    
class DDPG_actor_network(nn.Module):
    def __init__(self, state_dim, action_dim, action_low, action_high):
        
        super(DDPG_actor_network, self).__init__()
        
        layer_size = 128
        
        self.fc1 = nn.Linear(state_dim, layer_size)
        self.bn1 = nn.BatchNorm1d(layer_size)
        self.fc2 = nn.Linear(layer_size, layer_size)
        self.bn2 = nn.BatchNorm1d(layer_size)
        
        self.mu = nn.Linear(layer_size,action_dim)
        self.mu.weight.data.mul_(0.1)
        self.mu.bias.data.mul_(0.1)
        
        self.action_low, self.action_high = action_low, action_high
        
    def forward(self, s):
        
#         s = F.relu(self.fc1(s))
#         a = self.fc2(s)

        s = self.fc1(s)
        s = self.bn1(s)
        s = F.relu(s)
        s = self.fc2(s)
        s = self.bn2(s)
        s = F.relu(s)
        
        mu = self.mu(s)
        
        miu_iv,miu_vaso = torch.split(mu,1,dim=1)
            
        miu_iv = torch.clamp(miu_iv,self.action_low[0],self.action_high[0])
        miu_vaso = torch.clamp(miu_vaso,self.action_low[1],self.action_high[1])
        
#         a = a.clamp(self.action_low, self.action_high)
        a2 = torch.cat([miu_iv,miu_vaso],1)
        
        return a2


class NAF_network(nn.Module):
        def __init__(self, state_dim, action_dim, action_low, action_high):
            super(NAF_network, self).__init__()
            
            layer_size = 128
            
            self.sharefc1 = nn.Linear(state_dim, layer_size)
            self.bn1 = nn.BatchNorm1d(layer_size)
            
            
            self.sharefc2 = nn.Linear(layer_size, layer_size)
            self.bn2 = nn.BatchNorm1d(layer_size)
            
            self.v_fc1 = nn.Linear(layer_size, 1)
            
            self.miu_fc1 = nn.Linear(layer_size, action_dim)
            
            self.L_fc1 = nn.Linear(layer_size, action_dim ** 2)
            
            self.action_dim = action_dim
            self.action_low, self.action_high = action_low, action_high

            
        def forward(self, s, a = None):
            
            s = F.relu(self.sharefc1(s))
            s = self.bn1(s)
            s = F.relu(self.sharefc2(s))
            s = self.bn2(s)
            
            v = self.v_fc1(s)
            
            miu = self.miu_fc1(s)
            
            # currently could only clip according to the same one single value.
            # but different dimensions may mave different high and low bounds
            # modify to clip along different action dimension
            
            self.miu_iv,self.miu_vaso = torch.split(miu,1,dim=1)
            
            self.miu_iv = torch.clamp(self.miu_iv,self.action_low[0],self.action_high[0])
            self.miu_vaso = torch.clamp(self.miu_vaso,self.action_low[1],self.action_high[1])
            
#             miu = torch.clamp(miu, self.action_low, self.action_high)
            
            if a is None:
                return v, miu,self.miu_iv,self.miu_vaso
            
            L = torch.tanh(self.L_fc1(s))
            L = L.view(-1, self.action_dim, self.action_dim)
            
            tril_mask = torch.tril(torch.ones(
             self.action_dim, self.action_dim), diagonal=-1).unsqueeze(0)
            diag_mask = torch.diag(torch.diag(
             torch.ones(self.action_dim, self.action_dim))).unsqueeze(0)
                
            L = L * tril_mask.expand_as(L) + torch.exp(L) * diag_mask.expand_as(L)
            
            P = torch.bmm(L, L.transpose(2, 1))

            u_mu = (a - miu).unsqueeze(2)
            A = -0.5 * \
                torch.bmm(torch.bmm(u_mu.transpose(2, 1), P), u_mu)[:, :, 0]
            
            q = A + v
            
#             q = torch.clamp(q,-15,15)
            
            return q
        
        
class DQN_fc_network(nn.Module):
        def __init__(self, input_dim, output_dim, hidden_layers):
            super(DQN_fc_network, self).__init__()
            
            self.fc_in = nn.Linear(input_dim, 32)
            self.fc_hiddens = [nn.Linear(32,32) for i in range(hidden_layers)]
            self.fc_out = nn.Linear(32, output_dim)
            
        def forward(self, x):
            x = F.relu(self.fc_in(x))
            for layer in self.fc_hiddens:
                x = F.relu(layer(x))
            x = self.fc_out(x)
            return x
        
class DQN_dueling_network(nn.Module):
        def __init__(self, input_dim, output_dim, hidden_layers):
            super(DQN_dueling_network, self).__init__()
            self.fc_in = nn.Linear(input_dim, 32)
            self.fc_hiddens = [nn.Linear(32,32) for i in range(hidden_layers - 1)]
            
            self.fca_before = nn.Linear(32, 16)
            self.fcv_before = nn.Linear(32, 16)
            self.fca = nn.Linear(16, output_dim)
            self.fcv = nn.Linear(16, 1)
            
        def forward(self, x):
            x = F.relu(self.fc_in(x))
            
            for layer in self.fc_hiddens:
                x = F.relu(layer(x))
            
            a = F.relu(self.fca_before(x))
            a = self.fca(a)
            a -= a.mean()
            v = F.relu(self.fcv_before(x))
            v = self.fcv(v)
            q = a + v
            return q        


    
class AC_v_fc_network(nn.Module):
    
    def __init__(self, state_dim):
        super(AC_v_fc_network, self).__init__()
        
        self.fc1 = nn.Linear(state_dim, 30)
        self.fc2 = nn.Linear(30, 30)
        self.fc3 = nn.Linear(30,1)
        
    def forward(self, s):
        s = F.relu(self.fc1(s))
        v = F.relu(self.fc2(s))
        v = self.fc3(v)
        
        return v
    
class AC_a_fc_network(nn.Module):
        def __init__(self, input_dim, output_dim):
            super(AC_a_fc_network, self).__init__()
            self.fc1 = nn.Linear(input_dim, 30)
            self.fc2 = nn.Linear(30, 30)
            self.fc3 = nn.Linear(30, output_dim)
            
        def forward(self, x):
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = self.fc3(x)
            
            return F.softmax(x, dim = 1)
        
class CAC_a_fc_network(nn.Module):
    def __init__(self, input_dim, output_dim, action_low, action_high):
        super(CAC_a_fc_network, self).__init__()
        self.fc1 = nn.Linear(input_dim, 32)
        self.fc2 = nn.Linear(32, 32)
        self.fc3 = nn.Linear(32, output_dim)
        
        self.sigma = torch.ones((output_dim))
        self.action_low, self.action_high = action_low, action_high
    
    def forward(self, s):
        s = F.relu(self.fc1(s))
        s = F.relu(self.fc2(s))
        mu = self.fc3(s)
        mu = torch.clamp(mu, self.action_low, self.action_high)
        
        m = tds.normal.Normal(loc = mu, scale = self.sigma)
        
        return m