import numpy as np
import numpy.random as rd
import torch
import torch.nn as nn


Tensor = torch.Tensor

'''DQN (Q network)'''


class QNet(nn.Module):  # `nn.Module` is a PyTorch module for neural network
    """
    Critic class for **Q-network**.

    :param mid_dim[int]: the middle dimension of networks
    :param num_layer[int]: the number of MLP network layer
    :param state_dim[int]: the dimension of state (the number of state vector)
    :param action_dim[int]: the dimension of action (the number of discrete action)
    """

    def __init__(self, mid_dim: int, num_layer: int, state_dim: int, action_dim: int):
        super().__init__()
        self.net = build_mlp(mid_dim, num_layer, input_dim=state_dim, output_dim=action_dim)

        self.explore_rate = None  # float ∈ [0, 1]
        self.action_dim = action_dim

    def forward(self, state: Tensor) -> Tensor:
        """
        The forward function for **Dueling Q-network**.

        :param state: [tensor] the input state. state.shape == (batch_size, state_dim)
        :return: Q values for multiple actions [tensor]. q_values.shape == (batch_size, action_dim)
        """
        q_values = self.net(state)
        return q_values

    def get_action(self, state: Tensor) -> Tensor:  # return [int], which is the index of discrete action
        """
        return the action for exploration with the epsilon-greedy.

        :param state: [tensor] the input state. state.shape == (batch_size, state_dim)
        :return: action [tensor.int]. action.shape == (batch_size, 1)
        """
        if self.explore_rate < rd.rand():
            action = self.net(state).argmax(dim=1, keepdim=True)
        else:
            action = torch.randint(self.action_dim, size=(state.shape[0], 1))
        return action





'''Actor (policy network)'''


class Actor(nn.Module):
    def __init__(self, mid_dim: int, num_layer: int, state_dim: int, action_dim: int):
        super().__init__()
        self.net = build_mlp(mid_dim, num_layer, input_dim=state_dim, output_dim=action_dim)
        self.explore_noise_std = 0.1  # standard deviation of exploration action noise
        self.log_sqrt_2pi = np.log(np.sqrt(2 * np.pi))

    def forward(self, state: Tensor) -> Tensor:
        return self.net(state).tanh()  # action

    def get_action(self, state: Tensor) -> Tensor:  # for exploration
        action = self.net(state).tanh()
        noise = (torch.randn_like(action) * self.explore_noise_std).clamp(-0.5, 0.5)
        return (action + noise).clamp(-1.0, 1.0)

    def get_action_noise(self, state: Tensor, action_std: float) -> Tensor:
        action = self.net(state).tanh()
        noise = (torch.randn_like(action) * action_std).clamp(-0.5, 0.5)
        return (action + noise).clamp(-1.0, 1.0)

    def get_logprob(self, state: Tensor, action: Tensor) -> Tensor:
        action_avg = self.net(state)
        action_std = torch.ones_like(action_avg) * self.explore_noise_std
        action_std_log = action_std.log()

        delta = ((action_avg - action) / action_std).pow(2).__mul__(0.5)
        logprob = -(action_std_log + self.log_sqrt_2pi + delta)  # new_logprob
        return logprob

    def get_logprob_fixed(self, state: Tensor, action: Tensor) -> Tensor:
        action_avg = self.net(state)  # NOTICE! `action_avg` is a tensor without .tanh()
        action_std = self.explore_noise_std
        action_std_log = np.log(action_std)  # assert isinstance(action_std, float)

        action_tanh = action_avg.tanh()  # action.tanh()

        logprob = action_std_log + self.log_sqrt_2pi + (action_tanh - action).pow(2).__mul__(0.5)
        logprob += (-action_tanh.pow(2) + 1.000001).log()  # fix logprob using the derivative of action.tanh()
        return logprob




class ActorPPO(nn.Module):
    def __init__(self, mid_dim: int, num_layer: int, state_dim: int, action_dim: int):
        super().__init__()
        self.net = build_mlp(mid_dim, num_layer, input_dim=state_dim, output_dim=action_dim)

        # the logarithm (log) of standard deviation (std) of action, it is a trainable parameter
        self.action_std_log = nn.Parameter(torch.zeros((1, action_dim)) - 0.5, requires_grad=True)
        self.log_sqrt_2pi = np.log(np.sqrt(2 * np.pi))

    def forward(self, state: Tensor) -> Tensor:
        return self.net(state).tanh()  # action.tanh()

    def get_action(self, state: Tensor) -> (Tensor, Tensor):
        action_avg = self.net(state)
        action_std = self.action_std_log.exp()

        noise = torch.randn_like(action_avg)
        action = action_avg + noise * action_std
        return action, noise

    def get_logprob(self, state: Tensor, action: Tensor) -> Tensor:
        action_avg = self.net(state)
        action_std = self.action_std_log.exp()

        delta = ((action_avg - action) / action_std).pow(2).__mul__(0.5)
        logprob = -(self.action_std_log + self.log_sqrt_2pi + delta)  # new_logprob
        return logprob

    def get_logprob_entropy(self, state: Tensor, action: Tensor) -> (Tensor, Tensor):
        action_avg = self.net(state)
        action_std = self.action_std_log.exp()

        delta = ((action_avg - action) / action_std).pow(2) * 0.5
        logprob = -(self.action_std_log + self.log_sqrt_2pi + delta).sum(1)  # new_logprob

        dist_entropy = (logprob.exp() * logprob).mean()  # policy entropy
        return logprob, dist_entropy

    def get_old_logprob(self, _action: Tensor, noise: Tensor) -> Tensor:  # noise = action - a_noise
        delta = noise.pow(2).__mul__(0.5)
        return -(self.action_std_log + self.log_sqrt_2pi + delta).sum(1)  # old_logprob

    @staticmethod
    def convert_action_for_env(action: Tensor) -> Tensor:
        return action.tanh()





class Critic(nn.Module):
    def __init__(self, mid_dim: int, num_layer: int, state_dim: int, action_dim: int):
        super().__init__()
        self.net = build_mlp(mid_dim, num_layer, input_dim=state_dim + action_dim, output_dim=1)

    def forward(self, state: Tensor, action: Tensor) -> Tensor:
        return self.net(torch.cat((state, action), dim=1))  # q value


class CriticPPO(nn.Module):
    def __init__(self, mid_dim: int, num_layer: int, state_dim: int, _action_dim: int):
        super().__init__()
        self.net = build_mlp(mid_dim, num_layer, input_dim=state_dim, output_dim=1)

    def forward(self, state: Tensor) -> Tensor:
        return self.net(state)  # advantage value





def build_mlp(mid_dim: int, num_layer: int, input_dim: int, output_dim: int):  # MLP (MultiLayer Perceptron)
    assert num_layer >= 1
    net_list = list()
    if num_layer == 1:
        net_list.extend([nn.Linear(input_dim, output_dim), ])
    else:  # elif num_layer >= 2:
        net_list.extend([nn.Linear(input_dim, mid_dim), nn.ReLU()])
        for _ in range(num_layer - 2):
            net_list.extend([nn.Linear(mid_dim, mid_dim), nn.ReLU()])
        net_list.extend([nn.Linear(mid_dim, output_dim), ])
    return nn.Sequential(*net_list)
