import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np

def nonzero_averaging(x):
    """
        remove zero vectors and then compute the mean of x
        (The deleted nodes are represented by zero vectors)
    :param x: feature vectors with shape [sz_b, node_num, d]
    :return:  the desired mean value with shape [sz_b, d]
    """
    b = x.sum(dim=-2)
    y = torch.count_nonzero(x, dim=-1)
    z = (y != 0).sum(dim=-1, keepdim=True)
    p = 1 / z
    p[z == 0] = 0
    return torch.mul(p, b)

class MLP(nn.Module):
    def __init__(self, num_layers, input_dim, hidden_dim, output_dim, dropout=0, use_layer_norm=False, new_init=False,
                 activation=nn.ReLU, activation_last_layer=False):
        """
            the implementation of multi layer perceptrons (refer to L2D)
        :param num_layers: number of layers in the neural networks (EXCLUDING the input layer).
                            If num_layers=1, this reduces to linear model.
        :param input_dim: dimensionality of input features
        :param hidden_dim: dimensionality of hidden units at ALL layers
        :param output_dim:  number of classes for prediction
        """

        super(MLP, self).__init__()

        self.linear_or_not = True  # default is linear model
        self.num_layers = num_layers
        self.dropout = nn.Dropout(dropout)
        self.layers = []
        self.use_layer_norm = use_layer_norm
        if num_layers == 1:
            print("num_layers=1")
            if activation_last_layer:
                self.layers.append(nn.Linear(input_dim, hidden_dim))
                self.layers.append(activation())
            else:
                self.layers.append(nn.Linear(input_dim, output_dim))
        else:
            print("num_layers=more")
            self.layers.append(nn.Linear(input_dim, hidden_dim))
            self.layers.append(activation())
            if use_layer_norm:
                self.layers.append(nn.LayerNorm(hidden_dim, elementwise_affine=True))
                # self.layers.append(DyT(hidden_dim))

            if dropout > 0:
                self.layers.append(nn.Dropout(dropout))
            for i in range(num_layers - 2):
                self.layers.append(nn.Linear(hidden_dim, hidden_dim))
                self.layers.append(activation())
                if use_layer_norm:
                    # self.layers.append(DyT(hidden_dim))
                    self.layers.append(nn.LayerNorm(hidden_dim, elementwise_affine=True))
                if dropout > 0:
                    self.layers.append(nn.Dropout(dropout))
            # self.layers.append(nn.Linear(hidden_dim, output_dim))
            if activation_last_layer:
                self.layers.append(nn.Linear(hidden_dim, hidden_dim))
                self.layers.append(activation())
            else:
                self.layers.append(nn.Linear(hidden_dim, output_dim))
        self.net = nn.Sequential(*self.layers)
        # self.init_weights_orthogonal(output_gain=3e-3)
        # self.apply(self._initialize_weights)

        # if new_init:
        #     # for layer in self.net[::2]:
        #     #     torch.nn.init.constant_(layer.bias, 0.1)
        #     torch.nn.init.uniform_(self.net[-1].weight, -3e-3, 3e-3)
        #     torch.nn.init.uniform_(self.net[-1].bias, -3e-3, 3e-3)


class Actor(nn.Module):
    def __init__(self, num_layers, input_dim, hidden_dim, output_dim, dropout=0, use_layer_norm=False, new_init=False,
                 activation=nn.ReLU):
        """
            the implementation of multi layer perceptrons (refer to L2D)
        :param num_layers: number of layers in the neural networks (EXCLUDING the input layer).
                            If num_layers=1, this reduces to linear model.
        :param input_dim: dimensionality of input features
        :param hidden_dim: dimensionality of hidden units at ALL layers
        :param output_dim:  number of classes for prediction
        """

        super(Actor, self).__init__()

        self.linear_or_not = True  # default is linear model
        self.num_layers = num_layers
        self.dropout = nn.Dropout(dropout)
        self.layers = []
        self.use_layer_norm = use_layer_norm
        if num_layers == 1:
            self.layers.append(nn.Linear(input_dim, output_dim))
        else:
            self.layers.append(nn.Linear(input_dim, hidden_dim))
            self.layers.append(activation())

            if dropout > 0:
                self.layers.append(nn.Dropout(dropout))
            for i in range(num_layers - 2):
                self.layers.append(nn.Linear(hidden_dim, hidden_dim))
                self.layers.append(activation())

                if dropout > 0:
                    self.layers.append(nn.Dropout(dropout))
            self.layers.append(nn.Linear(hidden_dim, output_dim))
        self.net = nn.Sequential(*self.layers)
        # self.init_weights_orthogonal(output_gain=3e-3)

# class Actor(nn.Module):
#     def __init__(self, num_layers, input_dim, hidden_dim, output_dim, dropout=0):
#         """
#             the implementation of Actor network (refer to L2D)
#         :param num_layers: number of layers in the neural networks (EXCLUDING the input layer).
#                             If num_layers=1, this reduces to linear model.
#         :param input_dim: dimensionality of input features
#         :param hidden_dim: dimensionality of hidden units at ALL layers
#         :param output_dim:  number of classes for prediction
#         """
#         super(Actor, self).__init__()
#
#         self.linear_or_not = True  # default is linear model
#         self.num_layers = num_layers
#
#         # self.activative = torch.tanh
#         self.activative = torch.relu
#         self.dropout = nn.Dropout(dropout)
#
#         if num_layers < 1:
#             raise ValueError("number of layers should be positive!")
#         elif num_layers == 1:
#             # Linear model
#             self.linear = nn.Linear(input_dim, output_dim)
#         else:
#             # Multi-layer model
#             self.linear_or_not = False
#             self.linears = torch.nn.ModuleList()
#
#             self.linears.append(nn.Linear(input_dim, hidden_dim))
#             for layer in range(num_layers - 2):
#                 self.linears.append(nn.Linear(hidden_dim, hidden_dim))
#             self.linears.append(nn.Linear(hidden_dim, output_dim))
#
#     def forward(self, x):
#         if self.linear_or_not:
#             # If linear model
#             return self.linear(x)
#         else:
#             # If MLP
#             h = x
#             for layer in range(self.num_layers - 1):
#                 h = self.dropout(self.activative((self.linears[layer](h))))
#             return self.linears[self.num_layers - 1](h)




class QRDQN(nn.Module):
    def __init__(self, num_layers, input_dim, hidden_dim, output_dim, dropout=0):
        """
            the implementation of Critic network (refer to L2D)
            :param num_layers: number of layers in the neural networks (EXCLUDING the input layer).
                                    If num_layers=1, this reduces to linear model.
            :param input_dim: dimensionality of input features
            :param hidden_dim: dimensionality of hidden units at ALL layers
            :param output_dim:  number of classes for prediction
        """
        super(QRDQN, self).__init__()
        self.q1 = MLP(num_layers, input_dim, hidden_dim, output_dim, dropout=dropout)
        self.q2 = MLP(num_layers, input_dim, hidden_dim, output_dim, dropout=dropout)

class QRDQN_single(nn.Module):
    def __init__(self, num_layers, input_dim_q, hidden_dim, output_dim, dropout=0, use_layer_norm=True):
        """
            the implementation of Critic network (refer to L2D)
            :param num_layers: number of layers in the neural networks (EXCLUDING the input layer).
                                    If num_layers=1, this reduces to linear model.
            :param input_dim: dimensionality of input features
            :param hidden_dim: dimensionality of hidden units at ALL layers
            :param output_dim:  number of classes for prediction
        """
        super(QRDQN_single, self).__init__()
        self.q = MLP(num_layers, input_dim_q, hidden_dim, output_dim, dropout=dropout, use_layer_norm=use_layer_norm,
                     new_init=True)
        self.output_dim = output_dim
        self.num_quantiles = output_dim


class IQN(nn.Module):
    def __init__(self, num_layers, input_dim, hidden_dim, output_dim, num_cos=64, dropout=0):
        """
            the implementation of Critic network (refer to L2D)
            :param num_layers: number of layers in the neural networks (EXCLUDING the input layer).
                                    If num_layers=1, this reduces to linear model.
            :param input_dim: dimensionality of input features
            :param hidden_dim: dimensionality of hidden units at ALL layers
            :param output_dim:  number of classes for prediction
        """
        super(IQN, self).__init__()
        self.num_cos = num_cos
        self.emb_dim = hidden_dim
        self.cos_net1 = CosineEmbeddingNetwork(num_cos, input_dim)
        self.q1 = MLP(num_layers, input_dim, hidden_dim, output_dim, dropout=dropout)
        self.q2 = MLP(num_layers, input_dim, hidden_dim, output_dim, dropout=dropout)


# class QRDQN_advantage(nn.Module):
#     def __init__(self, num_layers, input_dim_q, input_v, hidden_dim, output_dim, dropout=0):
#         """
#             the implementation of Critic network (refer to L2D)
#             :param num_layers: number of layers in the neural networks (EXCLUDING the input layer).
#                                     If num_layers=1, this reduces to linear model.
#             :param input_dim: dimensionality of input features
#             :param hidden_dim: dimensionality of hidden units at ALL layers
#             :param output_dim:  number of classes for prediction
#         """
#         super(QRDQN_advantage, self).__init__()
#         # self.q1 = MLP(num_layers, input_dim_q, hidden_dim, output_dim, dropout=dropout)
#         self.v1 = MLP(num_layers, input_v, hidden_dim, output_dim * 50, dropout=dropout)
#         # self.q2 = MLP(num_layers, input_dim_q, hidden_dim, output_dim, dropout=dropout)
#         self.v2 = MLP(num_layers, input_v, hidden_dim, output_dim * 50, dropout=dropout)
#         self.output_dim = output_dim
#
#     def forward(self, x, x_global, mask, sz_b):
#         # print(x.shape, x_global.shape)
#         # print(self.q1)
#         # print(self.v1)
#         # exit()
#         # q1 = self.q1(x)
#         v1 = self.v1(x_global)
#         # q2 = self.q2(x)
#         v2 = self.v2(x_global)
#         v1 = v1.view(-1, 50, self.output_dim)
#         v2 = v2.view(-1, 50, self.output_dim)
#
#
#         #
#         # q1 = q1 - v1
#         # q2 = q2 - v2
#
#
#         return v1, v2

class QRDQN_advantage(nn.Module):
    def __init__(self, num_layers, input_dim_q, input_v, hidden_dim, output_dim, dropout=0):
        """
            the implementation of Critic network (refer to L2D)
            :param num_layers: number of layers in the neural networks (EXCLUDING the input layer).
                                    If num_layers=1, this reduces to linear model.
            :param input_dim: dimensionality of input features
            :param hidden_dim: dimensionality of hidden units at ALL layers
            :param output_dim:  number of classes for prediction
        """
        super(QRDQN_advantage, self).__init__()
        self.q1 = MLP(num_layers, input_dim_q, hidden_dim, output_dim, dropout=dropout)
        self.v1 = MLP(num_layers, input_v, hidden_dim, output_dim, dropout=dropout)
        self.q2 = MLP(num_layers, input_dim_q, hidden_dim, output_dim, dropout=dropout)
        self.v2 = MLP(num_layers, input_v, hidden_dim, output_dim, dropout=dropout)
        self.output_dim = output_dim
        self.num_quantiles = output_dim

# class QRDQN_advantage_single(nn.Module):
#     def __init__(self, num_layers, input_dim_q, input_v, hidden_dim, output_dim, dropout=0, use_layer_norm=True):
#         """
#             the implementation of Critic network (refer to L2D)
#             :param num_layers: number of layers in the neural networks (EXCLUDING the input layer).
#                                     If num_layers=1, this reduces to linear model.
#             :param input_dim: dimensionality of input features
#             :param hidden_dim: dimensionality of hidden units at ALL layers
#             :param output_dim:  number of classes for prediction
#         """
#         super(QRDQN_advantage_single, self).__init__()
#         # self.q = MLP(num_layers, input_dim_q, hidden_dim, output_dim, dropout=dropout, use_layer_norm=use_layer_norm,
#         #              new_init=True)
#         # self.v = MLP(num_layers, input_v, hidden_dim, output_dim, dropout=dropout, use_layer_norm=use_layer_norm,
#         #              new_init=True)
#         self.hidden_dim = hidden_dim
#         self.shared = MLP(num_layers - 1, input_dim_q, hidden_dim, output_dim, dropout=dropout,
#                           use_layer_norm=use_layer_norm, new_init=True, activation_last_layer=True)
#         print("shared", self.shared)
#         self.value_head = nn.Linear(hidden_dim, output_dim)
#         self.adv_head = nn.Linear(hidden_dim, output_dim)
#
#         self.output_dim = output_dim
#         self.num_quantiles = output_dim
#
#     def forward(self, x, x_global, mask, sz_b):
#         out = self.shared(x)
#         # out[mask.reshape(sz_b, -1, 1).expand(-1, -1, self.hidden_dim)] = 0
#         out_masked = out.masked_fill(mask.reshape(sz_b, -1, 1).expand(-1, -1, self.hidden_dim), 0)
#         out_mean = nonzero_averaging(out_masked)
#
#         allow_actions = (~mask).sum(dim=-1).sum(dim=-1).view(-1, 1, 1)
#         # out_mean = out.sum(dim=1, keepdim=True) / allow_actions
#         v_out = self.value_head(out_mean).unsqueeze(1)
#         adv_out = self.adv_head(out)
#         adv_out[mask.reshape(sz_b, -1, 1).expand(-1, -1, self.hidden_dim)] = 0
#         adv_mean = adv_out.sum(dim=1, keepdim=True) / allow_actions
#         # adv_mean = nonzero_averaging(adv_out).unsqueeze(1)
#
#         q = v_out + (adv_out - adv_mean)
#
#         q[mask.reshape(sz_b, -1, 1).expand(-1, -1, self.num_quantiles)] = float('-inf')
#
#
#         return q


class QRDQN_advantage_single(nn.Module):
    def __init__(self, num_layers, input_dim_q, input_v, hidden_dim, output_dim, dropout=0, use_layer_norm=True):
        """
            the implementation of Critic network (refer to L2D)
            :param num_layers: number of layers in the neural networks (EXCLUDING the input layer).
                                    If num_layers=1, this reduces to linear model.
            :param input_dim: dimensionality of input features
            :param hidden_dim: dimensionality of hidden units at ALL layers
            :param output_dim:  number of classes for prediction
        """
        super(QRDQN_advantage_single, self).__init__()
        self.q = MLP(num_layers, input_dim_q, hidden_dim, output_dim, dropout=dropout, use_layer_norm=use_layer_norm,
                     new_init=True)
        self.v = MLP(num_layers, input_v, hidden_dim, output_dim, dropout=dropout, use_layer_norm=use_layer_norm,
                     new_init=True)

        self.output_dim = output_dim
        self.num_quantiles = output_dim

class DQN_advantage(nn.Module):
    def __init__(self, num_layers, input_dim_q, input_v, hidden_dim, output_dim, dropout=0):
        """
            the implementation of Critic network (refer to L2D)
            :param num_layers: number of layers in the neural networks (EXCLUDING the input layer).
                                    If num_layers=1, this reduces to linear model.
            :param input_dim: dimensionality of input features
            :param hidden_dim: dimensionality of hidden units at ALL layers
            :param output_dim:  number of classes for prediction
        """
        super(DQN_advantage, self).__init__()
        self.q1 = MLP(num_layers, input_dim_q, hidden_dim, 1, dropout=dropout)
        self.v1 = MLP(num_layers, input_v, hidden_dim, 1, dropout=dropout)
        self.q2 = MLP(num_layers, input_dim_q, hidden_dim, 1, dropout=dropout)
        self.v2 = MLP(num_layers, input_v, hidden_dim, 1, dropout=dropout)
        self.output_dim = output_dim
        self.num_quantiles = output_dim

# class QRDQN_advantage(nn.Module):
#     def __init__(self, num_layers, input_dim_q, input_v, hidden_dim, output_dim, dropout=0):
#         """
#             the implementation of Critic network (refer to L2D)
#             :param num_layers: number of layers in the neural networks (EXCLUDING the input layer).
#                                     If num_layers=1, this reduces to linear model.
#             :param input_dim: dimensionality of input features
#             :param hidden_dim: dimensionality of hidden units at ALL layers
#             :param output_dim:  number of classes for prediction
#         """
#         super(QRDQN_advantage, self).__init__()
#         self.q1 = MLP(num_layers, input_dim_q, hidden_dim, output_dim, dropout=dropout)
#         self.v1 = MLP(num_layers, input_v, hidden_dim, output_dim, dropout=dropout)
#         self.q2 = MLP(num_layers, input_dim_q, hidden_dim, output_dim, dropout=dropout)
#         self.v2 = MLP(num_layers, input_v, hidden_dim, output_dim, dropout=dropout)
#         self.output_dim = output_dim
#
#     def forward(self, x, x_global, mask, sz_b):
#         # print(x.shape, x_global.shape)
#         # print(self.q1)
#         # print(self.v1)
#         # exit()
#         q1 = self.q1(x)
#         v1 = self.v1(x_global).unsqueeze(1)
#         q2 = self.q2(x)
#         v2 = self.v2(x_global).unsqueeze(1)
#         # print(q1.shape, v1.shape, q2.shape, v2.shape)
#         # print((q1 - v1).shape, (q2 - v2).shape)
#         # exit()
#         # q1_masked = q1.clone()
#         # q2_masked = q2.clone()
#         # q1_masked[mask.reshape(sz_b, -1, 1).expand(-1, -1, self.output_dim)] = 0
#         # q2_masked[mask.reshape(sz_b, -1, 1).expand(-1, -1, self.output_dim)] = 0
#         # mask_sum = (~mask.reshape(sz_b, -1)).sum(dim=1, keepdim=True)
#         #
#         #
#         # q1_mean = q1_masked.sum(dim=1, keepdim=True) / mask_sum.unsqueeze(-1)
#         # q2_mean = q2_masked.sum(dim=1, keepdim=True) / mask_sum.unsqueeze(-1)
#
#
#         q1 = q1 - v1
#         q2 = q2 - v2
#
#
#         # print(q1.shape, v1.shape, q2.shape, v2.shape)
#         # exit()
#         return q1, q2


class Critic(nn.Module):
    def __init__(self, num_layers, input_dim, hidden_dim, output_dim, use_layer_norm=False):
        """
            the implementation of Critic network (refer to L2D)
        :param num_layers: number of layers in the neural networks (EXCLUDING the input layer).
                            If num_layers=1, this reduces to linear model.
        :param input_dim: dimensionality of input features
        :param hidden_dim: dimensionality of hidden units at ALL layers
        :param output_dim:  number of classes for prediction
        """
        super(Critic, self).__init__()

        self.linear_or_not = True  # default is linear model
        self.use_layer_norm = use_layer_norm
        self.num_layers = num_layers

        self.activative = torch.relu

        if num_layers < 1:
            raise ValueError("number of layers should be positive!")
        elif num_layers == 1:
            # Linear model
            self.linear = nn.Linear(input_dim, output_dim)
        else:
            # Multi-layer model
            self.linear_or_not = False
            self.linears = torch.nn.ModuleList()


            self.linears.append(nn.Linear(input_dim, hidden_dim))
            for layer in range(num_layers - 2):
                self.linears.append(nn.Linear(hidden_dim, hidden_dim))
            if use_layer_norm:
                self.layer_norms = torch.nn.ModuleList()
                self.layer_norms.append(nn.LayerNorm(hidden_dim))
                for layer in range(num_layers - 2):
                    self.layer_norms.append(nn.LayerNorm(hidden_dim))
            self.linears.append(nn.Linear(hidden_dim, output_dim))


class CosineEmbeddingNetwork(nn.Module):
    def __init__(self, num_cos: int = 64, emb_dim: int = 64):
        super(CosineEmbeddingNetwork, self).__init__()
        self.num_cos = num_cos
        self.emb_dim = emb_dim
        self.net = nn.Linear(num_cos, emb_dim)
