import torch
import torch.nn as nn
import torch.nn.functional as F
import sys

from torch.distributions.categorical import Categorical
from torch_cluster import knn

sys.path.append(f"./utils/")
sys.path.append(f"./scripts/")
from augmentation import Augmentation

import warnings
warnings.filterwarnings("ignore", category=UserWarning)


# ###################
# # Network definition
# # Notation :
# #     bsz : batch size
# #     nb_nodes : number of nodes/cities
# #     dim_emb : embedding/hidden dimension
# #     nb_heads : nb of attention heads
# #     dim_ff : feed-forward dimension
# #     nb_layers : number of encoder/decoder layers
# ###################


class MultiHeadAttention(nn.Module):
    ''' Multi-Head Attention module '''
    def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
        super().__init__()

        self.n_head = n_head
        self.d_k = d_k
        self.d_v = d_v

        self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False)
        self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False)
        self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False)
        self.fc = nn.Linear(n_head * d_v, d_model, bias=False)

        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)

    def forward(self, q, k, v, mask=None):

        d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
        sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)

        residual = q

        # Pass through the pre-attention projection: b x lq x (n*dv)
        # Separate different heads: b x lq x n x dv
        q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
        k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
        v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)

        # Transpose for attention dot product: b x n x lq x dv0
        q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)

        if mask is not None:
            mask = mask.unsqueeze(1)  # For head axis broadcasting.

        attn = torch.matmul(q / self.d_k ** 0.5, k.transpose(2, 3))

        if mask is not None:
            attn = attn.masked_fill(mask == 0, -1e9)

        attn = self.dropout(F.softmax(attn, dim=-1))
        q = torch.matmul(attn, v)
        # q, attn = self.attention(q, k, v, mask=mask)

        # Transpose to move the head dimension back: b x lq x n x dv
        # Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv)
        q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1)
        q = self.dropout(self.fc(q))
        q += residual

        q = self.layer_norm(q)

        return q, attn


class MultiHeadAttentionGraph(nn.Module):
    '''
    This is absolutely the same as MultiHeadAttention above. The reason why this occurs is that we were exploiting
    some features in MHA layers, but we have deleted all of them.
    '''

    def __init__(self, n_head, d_model, d_k, d_v, k, dropout=0.1):
        super().__init__()

        self.n_head = n_head
        self.d_k = d_k
        self.d_v = d_v
        self.k = k

        self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False)
        self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False)
        self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False)
        self.fc = nn.Linear(n_head * d_v, d_model, bias=False)

        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)

    def forward(self, q, k, v, edge_index=None, mask=None, if_global=True, if_local_opt=False):

        d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
        sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)

        residual = q

        # Pass through the pre-attention projection: b x lq x (n*dv)
        # Separate different heads: b x lq x n x dv
        q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
        k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
        v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)

        # Transpose for attention dot product: b x n x lq x dv0
        q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)

        if mask is not None:
            mask = mask.unsqueeze(1)  # For head axis broadcasting.

        attn = torch.matmul(q / self.d_k ** 0.5, k.transpose(2, 3))

        if mask is not None:
            attn = attn.masked_fill(mask == 0, -1e9)

        attn = self.dropout(F.softmax(attn, dim=-1))
        q = torch.matmul(attn, v)
        # q, attn = self.attention(q, k, v, mask=mask)

        # Transpose to move the head dimension back: b x lq x n x dv
        # Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv)
        q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1)
        q = self.dropout(self.fc(q))
        q += residual

        q = self.layer_norm(q)

        return q, attn


class Transformer_encoder_net(nn.Module):
    """
    Encoder network based on self-attention transformer
    Inputs :
      h of size      (bsz, nb_nodes+1, dim_emb)    batch of input cities
    Outputs :
      h of size      (bsz, nb_nodes+1, dim_emb)    batch of encoded cities
      score of size  (bsz, nb_nodes+1, nb_nodes+1) batch of attention scores
    """

    def __init__(self, nb_layers, dim_emb, nb_heads, dim_ff, batchnorm, k, if_global=True, if_local_opt=False):
        super(Transformer_encoder_net, self).__init__()
        assert dim_emb == nb_heads * (dim_emb // nb_heads)  # check if dim_emb is divisible by nb_heads
        self.MHA_layers = nn.ModuleList(
            [MultiHeadAttentionGraph(nb_heads, dim_emb, dim_emb, dim_emb, k) for _ in range(nb_layers)])
        self.linear1_layers = nn.ModuleList([nn.Linear(dim_emb, dim_ff) for _ in range(nb_layers)])
        self.linear2_layers = nn.ModuleList([nn.Linear(dim_ff, dim_emb) for _ in range(nb_layers)])
        if batchnorm:
            self.norm1_layers = nn.ModuleList([nn.BatchNorm1d(dim_emb) for _ in range(nb_layers)])
            self.norm2_layers = nn.ModuleList([nn.BatchNorm1d(dim_emb) for _ in range(nb_layers)])
        else:
            self.norm1_layers = nn.ModuleList([nn.LayerNorm(dim_emb) for _ in range(nb_layers)])
            self.norm2_layers = nn.ModuleList([nn.LayerNorm(dim_emb) for _ in range(nb_layers)])
        self.nb_layers = nb_layers
        self.nb_heads = nb_heads
        self.batchnorm = batchnorm
        self.if_global = if_global
        self.if_local_opt = if_local_opt

    def forward(self, h, edge_index=None):
        # PyTorch nn.MultiheadAttention requires input size (seq_len, bsz, dim_emb)
        h = h.transpose(0, 1)  # size(h)=(nb_nodes, bsz, dim_emb)
        # L layers
        for i in range(self.nb_layers):
            h_rc = h  # residual connection, size(h_rc)=(nb_nodes, bsz, dim_emb)
            h = h.transpose(0, 1)
            h, score = self.MHA_layers[i](h, h, h, edge_index=edge_index)  # size(h)=(nb_nodes, bsz, dim_emb), size(score)=(bsz, nb_nodes, nb_nodes)
            h = h.transpose(0, 1)
            # add residual connection
            h = h_rc + h  # size(h)=(nb_nodes, bsz, dim_emb)
            if self.batchnorm:
                # Pytorch nn.BatchNorm1d requires input size (bsz, dim, seq_len)
                h = h.permute(1, 2, 0).contiguous()  # size(h)=(bsz, dim_emb, nb_nodes)
                h = self.norm1_layers[i](h)  # size(h)=(bsz, dim_emb, nb_nodes)
                h = h.permute(2, 0, 1).contiguous()  # size(h)=(nb_nodes, bsz, dim_emb)
            else:
                h = self.norm1_layers[i](h)  # size(h)=(nb_nodes, bsz, dim_emb)
            # feedforward
            h_rc = h  # residual connection
            h = self.linear2_layers[i](torch.relu(self.linear1_layers[i](h)))
            h = h_rc + h  # size(h)=(nb_nodes, bsz, dim_emb)
            if self.batchnorm:
                h = h.permute(1, 2, 0).contiguous()  # size(h)=(bsz, dim_emb, nb_nodes)
                h = self.norm2_layers[i](h)  # size(h)=(bsz, dim_emb, nb_nodes)
                h = h.permute(2, 0, 1).contiguous()  # size(h)=(nb_nodes, bsz, dim_emb)
            else:
                h = self.norm2_layers[i](h)  # size(h)=(nb_nodes, bsz, dim_emb)
        # Transpose h
        h = h.transpose(0, 1)  # size(h)=(bsz, nb_nodes, dim_emb)
        return h, score


def myMHA(Q, K, V, nb_heads, mask=None, clip_value=None):
    """
    Compute multi-head attention (MHA) given a query Q, key K, value V and attention mask :
      h = Concat_{k=1}^nb_heads softmax(Q_k^T.K_k).V_k
    Note : We did not use nn.MultiheadAttention to avoid re-computing all linear transformations at each call.
    Inputs : Q of size (bsz, dim_emb, 1)                batch of queries
             K of size (bsz, dim_emb, nb_nodes+1)       batch of keys
             V of size (bsz, dim_emb, nb_nodes+1)       batch of values
             mask of size (bsz, nb_nodes+1)             batch of masks of visited cities
             clip_value is a scalar
    Outputs : attn_output of size (bsz, 1, dim_emb)     batch of attention vectors
              attn_weights of size (bsz, 1, nb_nodes+1) batch of attention weights
    """
    bsz, nb_nodes, emd_dim = K.size()  # dim_emb must be divisable by nb_heads
    if nb_heads > 1:
        # PyTorch view requires contiguous dimensions for correct reshaping
        Q = Q.transpose(1, 2).contiguous()  # size(Q)=(bsz, dim_emb, 1)
        Q = Q.view(bsz * nb_heads, emd_dim // nb_heads, 1)  # size(Q)=(bsz*nb_heads, dim_emb//nb_heads, 1)
        Q = Q.transpose(1, 2).contiguous()  # size(Q)=(bsz*nb_heads, 1, dim_emb//nb_heads)
        K = K.transpose(1, 2).contiguous()  # size(K)=(bsz, dim_emb, nb_nodes+1)
        K = K.view(bsz * nb_heads, emd_dim // nb_heads,
                   nb_nodes)  # size(K)=(bsz*nb_heads, dim_emb//nb_heads, nb_nodes+1)
        K = K.transpose(1, 2).contiguous()  # size(K)=(bsz*nb_heads, nb_nodes+1, dim_emb//nb_heads)
        V = V.transpose(1, 2).contiguous()  # size(V)=(bsz, dim_emb, nb_nodes+1)
        V = V.view(bsz * nb_heads, emd_dim // nb_heads,
                   nb_nodes)  # size(V)=(bsz*nb_heads, dim_emb//nb_heads, nb_nodes+1)
        V = V.transpose(1, 2).contiguous()  # size(V)=(bsz*nb_heads, nb_nodes+1, dim_emb//nb_heads)
    attn_weights = torch.bmm(Q, K.transpose(1, 2)) / Q.size(
        -1) ** 0.5  # size(attn_weights)=(bsz*nb_heads, 1, nb_nodes+1)
    if clip_value is not None:
        attn_weights = clip_value * torch.tanh(attn_weights)
    if mask is not None:
        if nb_heads > 1:
            mask = torch.repeat_interleave(mask, repeats=nb_heads, dim=0)  # size(mask)=(bsz*nb_heads, nb_nodes+1)
        # attn_weights = attn_weights.masked_fill(mask.unsqueeze(1), float('-inf')) # size(attn_weights)=(bsz*nb_heads, 1, nb_nodes+1)
        attn_weights = attn_weights.masked_fill(mask.unsqueeze(1),
                                                float('-1e9'))  # size(attn_weights)=(bsz*nb_heads, 1, nb_nodes+1)

    attn_weights = torch.softmax(attn_weights, dim=-1)  # size(attn_weights)=(bsz*nb_heads, 1, nb_nodes+1)
    attn_output = torch.bmm(attn_weights, V)  # size(attn_output)=(bsz*nb_heads, 1, dim_emb//nb_heads)
    if nb_heads > 1:
        attn_output = attn_output.transpose(1, 2).contiguous()  # size(attn_output)=(bsz*nb_heads, dim_emb//nb_heads, 1)
        attn_output = attn_output.view(bsz, emd_dim, 1)  # size(attn_output)=(bsz, dim_emb, 1)
        attn_output = attn_output.transpose(1, 2).contiguous()  # size(attn_output)=(bsz, 1, dim_emb)
        attn_weights = attn_weights.view(bsz, nb_heads, 1,
                                         nb_nodes)  # size(attn_weights)=(bsz, nb_heads, 1, nb_nodes+1)
        attn_weights = attn_weights.mean(dim=1)  # mean over the heads, size(attn_weights)=(bsz, 1, nb_nodes+1)
    return attn_output, attn_weights


class Transformer_global_encoder_net(nn.Module):
    """
    Encoder network based on self-attention transformer
    Inputs :
      h of size      (bsz, nb_nodes+1, dim_emb)    batch of input cities
    Outputs :
      h of size      (bsz, nb_nodes+1, dim_emb)    batch of encoded cities
      score of size  (bsz, nb_nodes+1, nb_nodes+1) batch of attention scores
    """

    def __init__(self, nb_layers, dim_emb, nb_heads, dim_ff, batchnorm):
        super(Transformer_global_encoder_net, self).__init__()
        assert dim_emb == nb_heads * (dim_emb // nb_heads)  # check if dim_emb is divisible by nb_heads
        self.MHA_layers = nn.ModuleList(
            [MultiHeadAttention(nb_heads, dim_emb, dim_emb, dim_emb) for _ in range(nb_layers)])
        self.linear1_layers = nn.ModuleList([nn.Linear(dim_emb, dim_ff) for _ in range(nb_layers)])
        self.linear2_layers = nn.ModuleList([nn.Linear(dim_ff, dim_emb) for _ in range(nb_layers)])
        if batchnorm:
            self.norm1_layers = nn.ModuleList([nn.BatchNorm1d(dim_emb) for _ in range(nb_layers)])
            self.norm2_layers = nn.ModuleList([nn.BatchNorm1d(dim_emb) for _ in range(nb_layers)])
        else:
            self.norm1_layers = nn.ModuleList([nn.LayerNorm(dim_emb) for _ in range(nb_layers)])
            self.norm2_layers = nn.ModuleList([nn.LayerNorm(dim_emb) for _ in range(nb_layers)])
        self.nb_layers = nb_layers
        self.nb_heads = nb_heads
        self.batchnorm = batchnorm

    def forward(self, local_emb, global_emb):
        # PyTorch nn.MultiheadAttention requires input size (seq_len, bsz, dim_emb)
        local_emb = local_emb.transpose(0, 1)  # size(h)=(nb_nodes, bsz, dim_emb)
        # L layers
        for i in range(self.nb_layers):
            local_emb_rc = local_emb  # residual connection, size(h_rc)=(nb_nodes, bsz, dim_emb)
            local_emb = local_emb.transpose(0, 1)
            # size(h)=(nb_nodes, bsz, dim_emb), size(score)=(bsz, nb_nodes, nb_nodes)
            local_emb, score = self.MHA_layers[i](local_emb, global_emb, global_emb)
            local_emb = local_emb.transpose(0, 1)
            # add residual connection
            local_emb = local_emb_rc + local_emb  # size(h)=(nb_nodes, bsz, dim_emb)
            if self.batchnorm:
                # Pytorch nn.BatchNorm1d requires input size (bsz, dim, seq_len)
                local_emb = local_emb.permute(1, 2, 0).contiguous()  # size(h)=(bsz, dim_emb, nb_nodes)
                local_emb = self.norm1_layers[i](local_emb)  # size(h)=(bsz, dim_emb, nb_nodes)
                local_emb = local_emb.permute(2, 0, 1).contiguous()  # size(h)=(nb_nodes, bsz, dim_emb)
            else:
                local_emb = self.norm1_layers[i](local_emb)  # size(h)=(nb_nodes, bsz, dim_emb)
            # feedforward
            local_emb_rc = local_emb  # residual connection
            local_emb = self.linear2_layers[i](torch.relu(self.linear1_layers[i](local_emb)))
            local_emb = local_emb_rc + local_emb  # size(h)=(nb_nodes, bsz, dim_emb)
            if self.batchnorm:
                local_emb = local_emb.permute(1, 2, 0).contiguous()  # size(h)=(bsz, dim_emb, nb_nodes)
                local_emb = self.norm2_layers[i](local_emb)  # size(h)=(bsz, dim_emb, nb_nodes)
                local_emb = local_emb.permute(2, 0, 1).contiguous()  # size(h)=(nb_nodes, bsz, dim_emb)
            else:
                local_emb = self.norm2_layers[i](local_emb)  # size(h)=(nb_nodes, bsz, dim_emb)
        # Transpose h
        local_emb = local_emb.transpose(0, 1)  # size(h)=(bsz, nb_nodes, dim_emb)
        return local_emb, score


class AutoRegressiveDecoderLayerGeneralTest(nn.Module):
    """
    Single decoder layer based on self-attention and query-attention
    Inputs :
      h_t of size      (bsz, 1, dim_emb)          batch of input queries
      K_att of size    (bsz, nb_nodes+1, dim_emb) batch of query-attention keys
      V_att of size    (bsz, nb_nodes+1, dim_emb) batch of query-attention values
      mask of size     (bsz, nb_nodes+1)          batch of masks of visited cities
    Output :
      h_t of size (bsz, nb_nodes+1)               batch of transformed queries
    """

    def __init__(self, dim_emb, nb_heads):
        super(AutoRegressiveDecoderLayerGeneralTest, self).__init__()
        self.dim_emb = dim_emb
        self.nb_heads = nb_heads
        self.W0_att = nn.Linear(dim_emb, dim_emb)
        self.Wq_att = nn.Linear(dim_emb, dim_emb)
        self.W1_MLP = nn.Linear(dim_emb, dim_emb)
        self.W2_MLP = nn.Linear(dim_emb, dim_emb)
        self.BN_selfatt = nn.LayerNorm(dim_emb)
        self.BN_att = nn.LayerNorm(dim_emb)
        self.BN_MLP = nn.LayerNorm(dim_emb)
        self.K_sa = None
        self.V_sa = None

    def reset_selfatt_keys_values(self):
        self.K_sa = None
        self.V_sa = None

    def forward(self, h_t, K_att, V_att, mask=None):
        bsz = h_t.size(0)
        h_t = h_t.view(bsz, 1, self.dim_emb)  # size(h_t)=(bsz, 1, dim_emb)

        # compute attention between self-attention nodes and encoding nodes in the partial tour (translation process)
        q_a = self.Wq_att(h_t)  # size(q_a)=(bsz, 1, dim_emb)
        h_t = h_t + self.W0_att(myMHA(q_a, K_att, V_att, self.nb_heads, mask)[0])  # size(h_t)=(bsz, 1, dim_emb)
        h_t = self.BN_att(h_t.squeeze())  # size(h_t)=(bsz, dim_emb)
        h_t = h_t.view(bsz, 1, self.dim_emb)  # size(h_t)=(bsz, 1, dim_emb)
        # MLP
        h_t = h_t + self.W2_MLP(torch.relu(self.W1_MLP(h_t)))
        h_t = self.BN_MLP(h_t.squeeze(1))  # size(h_t)=(bsz, dim_emb)
        return h_t


class Transformer_decoder_net_general_test(nn.Module):
    """
    Decoder network based on self-attention and query-attention transformers
    Inputs :
      h_t of size      (bsz, 1, dim_emb)                            batch of input queries
      K_att of size    (bsz, nb_nodes+1, dim_emb*nb_layers_decoder) batch of query-attention keys for all decoding layers
      V_att of size    (bsz, nb_nodes+1, dim_emb*nb_layers_decoder) batch of query-attention values for all decoding layers
      mask of size     (bsz, nb_nodes+1)                            batch of masks of visited cities
    Output :
      prob_next_node of size (bsz, nb_nodes+1)                      batch of probabilities of next node
    """

    def __init__(self, dim_emb, nb_heads, nb_layers_decoder):
        super(Transformer_decoder_net_general_test, self).__init__()
        self.dim_emb = dim_emb
        self.nb_heads = nb_heads
        self.nb_layers_decoder = nb_layers_decoder
        self.decoder_layers = nn.ModuleList(
            [AutoRegressiveDecoderLayerGeneralTest(dim_emb, nb_heads) for _ in range(nb_layers_decoder - 1)])
        self.Wq_final = nn.Linear(dim_emb, dim_emb)

    # Reset to None self-attention keys and values when decoding starts
    def reset_selfatt_keys_values(self):
        for l in range(self.nb_layers_decoder - 1):
            self.decoder_layers[l].reset_selfatt_keys_values()

    def forward(self, h_t, K_att, V_att, mask=None):
        for l in range(self.nb_layers_decoder):
            K_att_l = K_att[:, :,
                      l * self.dim_emb:(l + 1) * self.dim_emb].contiguous()  # size(K_att_l)=(bsz, nb_nodes+1, dim_emb)
            V_att_l = V_att[:, :,
                      l * self.dim_emb:(l + 1) * self.dim_emb].contiguous()  # size(V_att_l)=(bsz, nb_nodes+1, dim_emb)
            if l < self.nb_layers_decoder - 1:  # decoder layers with multiple heads (intermediate layers)
                h_t = self.decoder_layers[l](h_t, K_att_l, V_att_l, mask)
            else:  # decoder layers with single head (final layer)
                q_final = self.Wq_final(h_t)
                bsz = h_t.size(0)
                q_final = q_final.view(bsz, 1, self.dim_emb)
                attn_weights = myMHA(q_final, K_att_l, V_att_l, 1, mask, 10)[1]
        prob_next_node = attn_weights.squeeze(1)
        return prob_next_node


class TSP_net_general_test_version(nn.Module):
    """
    The TSP network is composed of two steps :
      Step 1. Encoder step : Take a set of 2D points representing a fully connected graph
                             and encode the set with self-transformer.
      Step 2. Decoder step : Build the TSP tour recursively/autoregressively,
                             i.e. one node at a time, with a self-transformer and query-transformer.
    Inputs :
      x, of shape (bsz, nb_nodes, dim_emb) Euclidian coordinates of the nodes/cities
      deterministic is a boolean : If True the salesman will choose the city with highest probability.
                                   If False the salesman will choose the city with Categorical sampling.
    Outputs :
      tours, of shape (bsz, nb_nodes) : batch of tours, i.e. sequences of ordered cities
                                      tours[b,t] contains the idx of the city visited at step t in batch b
      nabla_py, of shape (bsz,) : batch of sum_t log prob( pi_t | pi_(t-1),...,pi_0 )
    """

    def __init__(self, dim_input_nodes, dim_emb, dim_ff, nb_layers_global_encoder, nb_layers_local_encoder,
                 nb_layers_decoder, nb_heads, local_k, batchnorm=True):
        super(TSP_net_general_test_version, self).__init__()

        self.dim_input = dim_input_nodes
        self.dim_emb = dim_emb

        # input embedding layer
        self.global_input_emb = nn.Linear(dim_input_nodes, dim_emb)
        self.local_input_emb = nn.Linear(dim_input_nodes, dim_emb)

        # encoder layer
        self.local_encoder = Transformer_encoder_net(nb_layers_local_encoder, dim_emb, nb_heads,
                                                     dim_ff, batchnorm, local_k)
        self.global_encoder = Transformer_global_encoder_net(nb_layers_global_encoder, dim_emb,
                                                             nb_heads, dim_ff, batchnorm)

        # vector to start decoding
        self.start_placeholder = nn.Parameter(torch.randn(dim_input_nodes))

        # decoder layer
        self.decoder = Transformer_decoder_net_general_test(dim_emb, nb_heads, nb_layers_decoder)
        self.WK_att_decoder = nn.Linear(2 * dim_emb, nb_layers_decoder * dim_emb)
        self.WV_att_decoder = nn.Linear(2 * dim_emb, nb_layers_decoder * dim_emb)
        self.query_mlp = nn.Linear(3 * dim_emb, dim_emb)

        self.local_k = local_k
        self.aug_module = Augmentation()

    def forward(self, x, deterministic=False):
        # some parameters
        bsz = x.shape[0]
        nb_nodes = x.shape[1]
        zero_to_bsz = torch.arange(bsz, device=x.device)  # [0,1,...,bsz-1]

        # concat the nodes and the input placeholder that starts the decoding
        start_vec = self.start_placeholder.repeat(bsz, 1, 1)

        tours = []
        nabla_pi = []

        b_select = torch.arange(0, bsz).to(x.device)

        first_visited_node = start_vec.squeeze()
        last_visited_node = start_vec.squeeze()

        mask_global = torch.ones((bsz, nb_nodes)).bool().to(x.device)
        all_idx = torch.arange(0, nb_nodes).repeat((bsz, 1)).to(x.device)

        for t in range(nb_nodes):
            unvisited_matrix = torch.reshape(all_idx[mask_global], (bsz, -1))
            num_nodes = unvisited_matrix.size(1)

            # knn process
            b_local = torch.arange(0, bsz).repeat(num_nodes).sort()[0].to(x.device)
            unvisited_matrix_idx = unvisited_matrix.view((-1,))
            graph = x[b_local, unvisited_matrix_idx]
            k = min(self.local_k, num_nodes)
            next_node_idx = knn(graph, last_visited_node, k, b_local, b_select).flip([0]).detach()

            # Node gather
            next_node_idx = next_node_idx[0, :] % num_nodes
            next_node_bsz = torch.arange(0, bsz).repeat(k).sort()[0].to(x.device)
            true_next_idx = unvisited_matrix[next_node_bsz, next_node_idx]
            next_node_cluster = x[next_node_bsz, true_next_idx, :]
            next_node_cluster = next_node_cluster.view((bsz, k, self.dim_input))
            next_node_cluster = torch.cat((next_node_cluster, last_visited_node.unsqueeze(dim=1)), dim=1)

            # Local embedding
            next_node_cluster_scaled, _ = self.aug_module.scale(next_node_cluster)
            h_local_emb = self.local_input_emb(next_node_cluster_scaled)
            h_local_encoder, _ = self.local_encoder(h_local_emb)

            # Global embedding
            graph = graph.view((bsz, -1, self.dim_input))
            graph = torch.cat((graph, last_visited_node.unsqueeze(dim=1)), dim=1)
            graph = torch.cat((graph, first_visited_node.unsqueeze(dim=1)), dim=1)
            scaled_graph, _ = self.aug_module.scale(graph)
            h_global_emb = self.global_input_emb(scaled_graph)

            # Gather global embedding for next node group
            next_node_global_emb = h_global_emb[next_node_bsz, next_node_idx]
            next_node_global_emb = next_node_global_emb.view((bsz, -1, self.dim_emb))

            # Global embedding of last and first visited nodes
            last_and_first_node_emb = h_global_emb[:, num_nodes:, :]

            # Concat the global embedding
            next_node_global_emb = torch.cat((next_node_global_emb, last_and_first_node_emb), dim=1)

            # MHA for global embedding
            h_global_encoder, _ = self.global_encoder(next_node_global_emb, h_global_emb)

            # Concat local embedding and global embedding
            h_local_encoder_group = h_local_encoder[:, :k, :]
            h_global_encoder_group = h_global_encoder[:, :k, :]
            h_all_group = torch.cat((h_local_encoder_group, h_global_encoder_group), dim=2)
            h_local_t = h_local_encoder[:, -1, :]
            h_global_t = h_global_encoder[:, -1, :]
            h_global_0 = h_global_encoder[:, k, :]
            h_t = torch.cat((h_local_t, h_global_t), dim=1)
            h_t = torch.cat((h_t, h_global_0), dim=1)

            # q_att, K_att and V_att
            h_t = self.query_mlp(h_t)
            # shape(K_att) = shape(V_att) = (bsz, nb_nodes+1, dim_emb * nb_layers_decoder)
            K_att_decoder = self.WK_att_decoder(h_all_group)
            V_att_decoder = self.WV_att_decoder(h_all_group)

            # Decode
            prob_next_node = self.decoder(h_t, K_att_decoder, V_att_decoder)

            # Next node choosing, shape(idx) = (bsz,)
            if deterministic:
                idx = torch.argmax(prob_next_node, dim=1)
            else:
                idx = Categorical(prob_next_node).sample()

            # Next node info
            last_visited_node = next_node_cluster[zero_to_bsz, idx]
            true_next_idx = true_next_idx.view(bsz, -1)
            last_visited_idx = true_next_idx[zero_to_bsz, idx]

            if t == 0:
                first_visited_node = last_visited_node

            # Other info updated
            mask_global[zero_to_bsz, last_visited_idx] = False

            # Update the current tour
            probs_choices = prob_next_node[zero_to_bsz, idx]
            nabla_pi.append(torch.log(probs_choices))
            tours.append(last_visited_idx)

        # logprob_of_choices = sum_t log prob( pi_t | pi_(t-1),...,pi_0 )
        nabla_pi = torch.stack(nabla_pi, dim=1).sum(dim=1)

        # convert the list of nodes into a tensor of shape (bsz,num_cities)
        tours = torch.stack(tours, dim=1)

        return tours, nabla_pi