import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
import numpy as np

LOG_SIG_MAX = 2
LOG_SIG_MIN = -20
epsilon = 1e-6

# Initialize Policy weights
def weights_init_(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        torch.nn.init.zeros_(m.bias)

def weights_init_ones(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.ones_(m.weight)
        torch.nn.init.zeros_(m.bias)


class EqvarModuleMean(nn.Module):

    def __init__(self, num_inputs, num_outputs):
        super(EqvarModuleMean, self).__init__()
        self.num_inputs = num_inputs
        self.num_outputs = num_outputs

        self.gamma = nn.Linear(num_inputs, num_outputs)

    def forward(self, x):
        xm = x.mean(1, keepdim=True)
        x = self.gamma(x - xm)

        return x


class EqvarModuleMax(nn.Module):

    def __init__(self, num_inputs, num_outputs):
        super(EqvarModuleMax, self).__init__()
        self.num_inputs = num_inputs
        self.num_outputs = num_outputs

        self.gamma = nn.Linear(num_inputs, num_outputs)

    def forward(self, x):
        xm, _ = x.max(1, keepdim=True)
        x = self.gamma(x - xm)

        return x


class EmbeddingNetwork(nn.Module):

    def __init__(self, num_inputs, hidden_dim, num_outputs, pool='mean'):
        super(EmbeddingNetwork, self).__init__()
        self.num_inputs = num_inputs
        self.hidden_dim = hidden_dim
        self.num_outputs = num_outputs

        if pool == 'mean':
            self.phi = nn.Sequential(
                EqvarModuleMean(self.num_inputs, self.hidden_dim),
                nn.Tanh(),
                EqvarModuleMean(self.hidden_dim, self.hidden_dim),
                nn.Tanh(),
                EqvarModuleMean(self.hidden_dim, self.hidden_dim),
                nn.Tanh()
            )
        else:
            self.phi = nn.Sequential(
                EqvarModuleMax(self.num_inputs, self.hidden_dim),
                nn.Tanh(),
                EqvarModuleMax(self.hidden_dim, self.hidden_dim),
                nn.Tanh(),
                EqvarModuleMax(self.hidden_dim, self.hidden_dim),
                nn.Tanh()
            )

        self.rho = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(self.hidden_dim, self.hidden_dim),
            nn.Tanh(),
            nn.Dropout(p=0.5),
            nn.Linear(self.hidden_dim, self.num_outputs)
        )

    def forward(self, x):
        phi_out = self.phi(x)
        sum_output, _ = phi_out.max(1)
        rho_out = self.rho(sum_output)
        return rho_out


class ValueNetwork(nn.Module):
    def __init__(self, num_inputs, hidden_dim):
        super(ValueNetwork, self).__init__()

        self.linear1 = nn.Linear(num_inputs, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, hidden_dim)
        self.linear3 = nn.Linear(hidden_dim, 1)

        self.apply(weights_init_)

    def forward(self, state):
        x = F.relu(self.linear1(state))
        x = F.relu(self.linear2(x))
        x = self.linear3(x)
        return x


class QNetwork(nn.Module):
    def __init__(self, num_inputs, num_actions, hidden_dim):
        super(QNetwork, self).__init__()

        # Q1 architecture
        self.linear1 = nn.Linear(num_inputs + num_actions, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, hidden_dim)
        self.linear3 = nn.Linear(hidden_dim, 1)

        # Q2 architecture
        self.linear4 = nn.Linear(num_inputs + num_actions, hidden_dim)
        self.linear5 = nn.Linear(hidden_dim, hidden_dim)
        self.linear6 = nn.Linear(hidden_dim, 1)

        self.apply(weights_init_)

    def forward(self, state, action):
        xu = torch.cat([state, action], 1)
        
        x1 = F.relu(self.linear1(xu))
        x1 = F.relu(self.linear2(x1))
        x1 = self.linear3(x1)
        x1 = torch.sigmoid(x1)

        x2 = F.relu(self.linear4(xu))
        x2 = F.relu(self.linear5(x2))
        x2 = self.linear6(x2)
        x2 = torch.sigmoid(x2)

        return x1, x2

    def Q1(self, state, action):
        xu = torch.cat([state, action], 1)

        q1 = F.relu(self.linear1(xu))
        q1 = F.relu(self.linear2(q1))
        q1 = self.linear3(q1)
        q1 = torch.sigmoid(q1)
        return q1


# class GaussianPolicy(nn.Module):
class PolicyNetworkNorm(nn.Module):
    def __init__(self, num_inputs, num_actions, hidden_dim):
        super(PolicyNetworkNorm, self).__init__()

        # self.noise = noise
        self.num_actions = num_actions
        
        self.linear1 = nn.Linear(num_inputs, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, hidden_dim)

        self.mean_linear = nn.Linear(hidden_dim, num_actions)
        self.log_std_linear = nn.Linear(hidden_dim, num_actions)

        self.apply(weights_init_)


    def forward(self, state):
        x = F.relu(self.linear1(state))
        x = F.relu(self.linear2(x))
        mean = self.mean_linear(x)
        K = torch.tensor(self.num_actions)
        abs_mean = torch.abs(mean)
        Gs = torch.sum(abs_mean, dim=-1).view(-1, 1)
        Gs = Gs / K
        ones = torch.ones(Gs.size())
        Gs = torch.where(Gs > 1, Gs, ones)
        mean = mean / Gs
        mean = torch.tanh(mean)
        return mean

class PolicyNetwork(nn.Module):
    def __init__(self, num_inputs, num_actions, hidden_dim):
        super(PolicyNetwork, self).__init__()

        self.num_actions = num_actions

        self.linear1 = nn.Linear(num_inputs, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, hidden_dim)

        self.mean_linear = nn.Linear(hidden_dim, num_actions)
        self.log_std_linear = nn.Linear(hidden_dim, num_actions)

        self.apply(weights_init_)


    def forward(self, state):
        x = F.relu(self.linear1(state))
        x = F.relu(self.linear2(x))
        mean = self.mean_linear(x)

        mean = torch.tanh(mean)
        return mean
