import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as D


LOG_STD_MIN = -5
LOG_STD_MAX = 2
epsilon = 1e-6


class DeterministicActor(nn.Module):
    def __init__(self, observation_dim, action_dim, hidden_dim=256, hidden_layer=1, max_action=1.):
        super(DeterministicActor, self).__init__()
        self.net = nn.Sequential()
        self.net.append(nn.Linear(observation_dim, hidden_dim))
        self.net.append(nn.ReLU(inplace=True))
        for _ in range(hidden_layer):
            self.net.append(nn.Linear(hidden_dim, hidden_dim))
            self.net.append(nn.ReLU(inplace=True))
        self.net.append(nn.Linear(hidden_dim, action_dim))
        self.net.append(nn.Tanh())
        
        self.max_action = max_action

    def forward(self, state):
        action = self.max_action * self.net(state)
        return action

    def act(self, state):
        action = self.forward(state).detach().cpu().numpy().flatten()
        return action
    
    def loss(self, state, action):
        pred = self(state)
        loss = F.mse_loss(action, pred)
        return loss
        

class StochasticActor(nn.Module):
    def __init__(self, observation_dim, action_dim, hidden_dim=256, hidden_layer=1, max_action=1.):
        super(StochasticActor, self).__init__()

        self.net = nn.Sequential()
        self.net.append(nn.Linear(observation_dim, hidden_dim))
        self.net.append(nn.ReLU(inplace=True))
        for _ in range(hidden_layer):
            self.net.append(nn.Linear(hidden_dim, hidden_dim))
            self.net.append(nn.ReLU(inplace=True))
            self.net.append(nn.Dropout(p=0.1))
        self.net.append(nn.Linear(hidden_dim, action_dim * 2))
        self.net.append(nn.Tanh())
        
        self.max_action = max_action
        
        self._init_net()
        
    def _init_net(self):
        for layer in self.modules():
            if isinstance(layer, torch.nn.Linear):
                nn.init.orthogonal_(layer.weight)

    def forward(self, x):
        x = self.net(x) * self.max_action

        mu, log_std = torch.split(x, split_size_or_sections=int(x.shape[-1] / 2), dim=-1)
        log_std = LOG_STD_MIN + 0.5 * (LOG_STD_MAX - LOG_STD_MIN) * (log_std + 1)
        std = torch.exp(log_std)

        return mu, std

    def act(self, state):
        mu, std = self.forward(state)
        # action = mu.detach().cpu().numpy().flatten()
        return mu.detach().cpu().numpy().flatten()

    def log_prob(self, state, action):
        mu, std = self.forward(state)

        dist = torch.distributions.Normal(mu, std)
        log_prob = torch.mean(dist.log_prob(action), dim=1, keepdim=True)

        return log_prob

    def orthogonal_regularization(self, reg_coef=1e-5):
        reg = 0
        for layer in self.modules():
            if isinstance(layer, torch.nn.Linear):
                prod = torch.matmul(torch.transpose(layer.weight, 0, 1), layer.weight)
                reg += torch.sum(torch.square(prod * (1 - torch.eye(prod.shape[0]).to(prod))))

        return reg * reg_coef

    def loss(self, state, action, coef=1e-4):
        reg = self.orthogonal_regularization(coef)
        return -self.log_prob(state, action).mean() + reg


class GMMActor(nn.Module):
    def __init__(self, 
                 observation_dim, 
                 action_dim, 
                 hidden_dim=256, 
                 hidden_layer=1, 
                 num_modes=5, 
                 max_action=1.,
                 min_std=1e-4):
        super(GMMActor, self).__init__()
        
        self.observation_dim = observation_dim
        self.action_dim = action_dim

        self.net = nn.Sequential()
        self.net.append(nn.Linear(observation_dim, hidden_dim))
        self.net.append(nn.ReLU(inplace=True))
        for _ in range(hidden_layer):
            self.net.append(nn.Linear(hidden_dim, hidden_dim))
            self.net.append(nn.ReLU(inplace=True))

        self.decoder_mu = nn.Linear(hidden_dim, action_dim*num_modes, bias=True)
        self.decoder_scale = nn.Linear(hidden_dim, action_dim*num_modes, bias=True)
        self.decoder_logits = nn.Linear(hidden_dim, num_modes, bias=True)

        self.num_modes = num_modes
        self.max_action = max_action
        self.min_std = min_std

    def forward(self, state, training=False):
        embed = self.net(state)
        
        means = torch.tanh(self.decoder_mu(embed)).reshape(-1, self.num_modes, self.action_dim)
        scales = self.decoder_scale(embed).reshape(-1, self.num_modes, self.action_dim)
        logits = self.decoder_logits(embed).reshape(-1, self.num_modes)

        if not training:
            # low-noise for all Gaussian dists
            scales = torch.ones_like(means) * 1e-8
        else:
            # post-process the scale accordingly
            scales = F.softplus(scales) + self.min_std

        # mixture components - make sure that `batch_shape` for the distribution is equal
        component_distribution = D.Normal(loc=means, scale=scales)
        component_distribution = D.Independent(component_distribution, 1) 
        mixture_distribution = D.Categorical(logits=logits)

        dists = D.MixtureSameFamily(
            mixture_distribution=mixture_distribution,
            component_distribution=component_distribution,
        )
        
        return dists

    def act(self, state):
        dists = self.forward(state, training=False)
        action = dists.sample().detach().cpu().numpy().flatten()
        return action
    
    def loss(self, state, action):
        dists = self.forward(state, training=True)
        log_probs = dists.log_prob(action)
        loss = -log_probs.mean()
        return loss
        

class RNNGMMActor(nn.Module):
    # most refer to the robomimic, obs encoder, mlp, rnn, decoeder
    def __init__(self, observation_dim, 
                 action_dim, 
                 mlp_layer_dims=(),
                 rnn_hidden_dim=400,
                 rnn_num_layers=2,
                 num_modes=5,
                 seq_len=10,
                 min_std=0.0001):
        super(RNNGMMActor, self).__init__()
        self.min_std=min_std
        self.act_dim=action_dim
        self.num_modes=num_modes
        self.seq_len=seq_len
        self.rnn_hidden_dim=rnn_hidden_dim
        self.rnn_num_layers=rnn_num_layers

        self.nets = nn.ModuleDict()
        
        # mlp after observation encoder
        nets=[nn.Identity()]
        if len(mlp_layer_dims)>0 and mlp_layer_dims[-1]!=rnn_hidden_dim:
            print("\nWarning: the last layer of mlp should be the same as rnn_hidden_dim. Add automatically!\n")
            mlp_layer_dims=mlp_layer_dims+(rnn_hidden_dim,)

        for i in range(len(mlp_layer_dims)):
            if i==0:
                nets.append(nn.Linear(observation_dim, mlp_layer_dims[i]))
            else:
                nets.append(nn.Linear(mlp_layer_dims[i-1], mlp_layer_dims[i]))
            nets.append(nn.ReLU())
        self.mlp=nn.Sequential(*nets)

        # RNN
        self.rnn_network = nn.LSTM(
                input_size=observation_dim if len(mlp_layer_dims)==0 else mlp_layer_dims[-1], 
                hidden_size=rnn_hidden_dim, 
                num_layers=rnn_num_layers, 
                batch_first=True, 
                bidirectional=False)

        # decoder
        self.decoder_mu = nn.Linear(rnn_hidden_dim, action_dim*num_modes, bias=True)
        self.decoder_scale = nn.Linear(rnn_hidden_dim, action_dim*num_modes, bias=True)
        self.decoder_logits = nn.Linear(rnn_hidden_dim, num_modes, bias=True)

    def get_rnn_init_state(self, batch_size, device):
        h_0 = torch.zeros(self.rnn_num_layers, batch_size, self.rnn_hidden_dim).to(device)
        c_0 = torch.zeros(self.rnn_num_layers, batch_size, self.rnn_hidden_dim).to(device)
        return (h_0, c_0)
        
    def forward(self, state, training=True, rnn_hidden_state=None, return_state=False):
        """
        Forward pass through the network.
        state: [B, T, D]
        training (bool): if True, the wide dist. is used for exploration, otherwise the narrow dist. is used for testing
        """
        B, T, _ = state.shape
        rnn_input=self.mlp(state)

        if rnn_hidden_state is None:
            rnn_hidden_state=self.get_rnn_init_state(state.shape[0], state.device)
        rnn_outputs, rnn_new_state=self.rnn_network(rnn_input,rnn_hidden_state) # [B, T, D],  (h, c)

        means = self.decoder_mu(rnn_outputs).reshape(B, T, self.num_modes, self.act_dim) # [B, T, A, M]
        scales = self.decoder_scale(rnn_outputs).reshape(B, T, self.num_modes, self.act_dim) # [B, T, A, M]
        logits = self.decoder_logits(rnn_outputs).reshape(B, T, self.num_modes) # [B, T, M]

        # apply tanh squashing to mean if not using tanh-GMM to ensure means are in [-1, 1]
        means = torch.tanh(means)

        if not training:
            # low-noise for all Gaussian dists
            scales = torch.ones_like(means) * 1e-8
        else:
            # post-process the scale accordingly
            scales = F.softplus(scales) + self.min_std

        # mixture components - make sure that `batch_shape` for the distribution is equal
        # to (batch_size, timesteps, num_modes) since MixtureSameFamily expects this shape
        component_distribution = D.Normal(loc=means, scale=scales)
        component_distribution = D.Independent(component_distribution, 1) # shift action dim to event shape

        # unnormalized logits to categorical distribution for mixing the modes
        mixture_distribution = D.Categorical(logits=logits)

        dists = D.MixtureSameFamily(
            mixture_distribution=mixture_distribution,
            component_distribution=component_distribution,
        )

        if return_state:
            return dists, rnn_new_state
        else:
            return dists

    @torch.no_grad()
    def act(self, state, rnn_hidden_state=None, return_state=False):
        '''
        state: [B, T, D]   shape as (1, T, state_dim)
        rnn_hidden_state: (h_lstm, c_lstm)

        return: 
        action: np.ndarray, (action_dim, )
        rnn_new_state: optional tuple(h_lstm, c_lstm)
        '''
        assert len(state.shape) == 3 # [B, T, state_dim] (1, T, state_dim)
        assert state.shape[1]<=self.seq_len

        outputs = self.forward(state, training=False, rnn_hidden_state=rnn_hidden_state, return_state=return_state)

        if return_state:
            dists, rnn_new_state = outputs
        else:
            dists = outputs

        action = dists.sample()[:, -1, :].detach().cpu().numpy().flatten()
        if return_state:
            return action, rnn_new_state
        else:
            return action

    def loss(self, state, action):
        dists = self.forward(state, training=True, rnn_hidden_state=None, return_state=False)
        assert len(dists.batch_shape) == 2 # [B, T]
        log_probs = dists.log_prob(action)
        loss = -log_probs.mean()
        return loss
