import datetime
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.gumbel import Gumbel

from attention import FlashAttention


def _get_encoding(encoded_nodes, node_index_to_pick):
    # encoded_nodes.shape: (batch, problem, embedding)
    # node_index_to_pick.shape: (batch, pomo)

    batch_size = node_index_to_pick.size(0)
    pomo_size = node_index_to_pick.size(1)
    embedding_dim = encoded_nodes.size(2)

    gathering_index = node_index_to_pick[:, :, None].expand(batch_size, pomo_size, embedding_dim)
    # shape: (batch, pomo, embedding)

    picked_nodes = encoded_nodes.gather(dim=1, index=gathering_index)
    # shape: (batch, pomo, embedding)

    return picked_nodes


def _get_distances(distances, node_index_to_pick):
    # encoded_nodes.shape: (batch, problem, embedding)
    # node_index_to_pick.shape: (batch, pomo)

    batch_size = distances.size(0)
    pomo_size = node_index_to_pick.size(1)
    node_size = distances.size(1)

    gathering_index = node_index_to_pick[:, :, None].expand(batch_size, pomo_size, node_size)
    # shape: (batch, pomo, embedding)

    picked_nodes = distances.gather(dim=1, index=gathering_index)
    # shape: (batch, pomo, embedding)

    return picked_nodes


def _get_encoding2(encoded_nodes, node_index_to_pick):
    # encoded_nodes.shape: (batch, problem, embedding)
    # node_index_to_pick.shape: (batch, pomo)

    batch_size = node_index_to_pick.size(0)
    pomo_size = node_index_to_pick.size(1)
    embedding_dim = encoded_nodes.size(3)

    gathering_index = node_index_to_pick[:, :, None, None].expand(batch_size, pomo_size, 1, embedding_dim)
    # shape: (batch, pomo, embedding)

    picked_nodes = encoded_nodes.gather(dim=2, index=gathering_index)
    # shape: (batch, pomo, embedding)

    return picked_nodes[:, :, 0, :]


def _get_encoding3(heatmap, node_index_to_pick):
    # encoded_nodes.shape: (batch, problem, embedding)
    # node_index_to_pick.shape: (batch, pomo)

    batch_size = node_index_to_pick.size(0)
    pomo_size = node_index_to_pick.size(1)
    embedding_dim = heatmap.size(2)

    gathering_index = node_index_to_pick[:, :, None].expand(batch_size, pomo_size, embedding_dim)
    # shape: (batch, pomo, embedding)

    picked_nodes = heatmap.gather(dim=1, index=gathering_index)
    # shape: (batch, pomo, embedding)

    return picked_nodes


class CVRPModel(nn.Module):

    def __init__(self, **model_params):
        super().__init__()
        self.model_params = model_params
        self.embedding_size = model_params["embedding_dim"]
        self.encoder = CVRP_Encoder(**model_params)
        self.partition = ParNet(50, 4)
        self.partition_training = model_params["train_partition"]
        # self.decoder = CVRP_Decoder(**model_params)
        self.decoder = Decoder(**model_params)
        self.encoded_nodes = None
        self.encoded_edges = None
        self.encoded_nodes2 = None
        self.heatmap = None
        # shape: (batch, problem+1, EMBEDDING_DIM)

    def gen_cos_sim_matrix(self, shift_coors):
        dot_products = torch.matmul(shift_coors, shift_coors.transpose(1, 2))
        magnitudes = torch.sqrt(torch.sum(shift_coors ** 2, dim=1)).unsqueeze(1)
        magnitude_matrix = torch.matmul(magnitudes, magnitudes.transpose(1, 2)) + 1e-10
        cosine_similarity_matrix = dot_products / magnitude_matrix
        return cosine_similarity_matrix

    def get_expand_prob(self, state):

        encoded_last_node = _get_encoding(self.encoded_nodes, state.current_node)
        # shape: (batch, beam_width, embedding)
        heatmap = _get_encoding3(self.heatmap, state.current_node)
        probs = self.decoder(encoded_last_node, state.load, ninf_mask=state.ninf_mask, heatmap=heatmap)
        # shape: (batch, beam_width, problem+1)

        return probs

    def pre_forward(self, reset_state):
        # shape: (batch, 1, 2)
        node_xy = reset_state.node_xy
        # shape: (batch, problem, 2)
        node_demand = reset_state.node_demand
        # shape: (batch, problem)
        self.distances = reset_state.distances
        # shape: (batch, problem, 3)
        x_depot = reset_state.x_depot
        x_node = reset_state.x_node
        buckets = reset_state.buckets
        sorted_indices = reset_state.sorted_indices
        if self.model_params["heatmap"]:
            shift_coors = reset_state.shift_coors
            n_nodes = x_node.size(1)
            k_sparse = 50
            cos_mat = self.gen_cos_sim_matrix(shift_coors)
            cos_mat = (cos_mat + cos_mat.min()) / cos_mat.max()
            euc_mat = self.distances
            euc_aff = 1 + euc_mat
            topk_values, topk_indices = torch.topk(cos_mat,
                                                   k=k_sparse,
                                                   dim=2, largest=True)

            B, N, K = topk_indices.shape
            device = topk_indices.device
            batch_idx = torch.arange(B, device=device).view(B, 1, 1).expand(B, N, K).flatten()
            src_idx = torch.arange(N, device=device).view(1, N, 1).expand(B, N, K).flatten()
            tgt_idx = topk_indices.flatten()
            edge_index = torch.stack([batch_idx, src_idx, tgt_idx], dim=0)
            edge_attr1 = euc_aff[edge_index[0], edge_index[1], edge_index[2]]
            edge_attr2 = cos_mat[edge_index[0], edge_index[1], edge_index[2]]
            edge_attr = torch.cat((edge_attr1.unsqueeze(-1), edge_attr2.unsqueeze(-1)), -1)
        else:
            topk_indices = None
            edge_index = None
            edge_attr= None

        # edge_attr1 = topk_values.reshape(B, -1, 1)
        # edge_attr2 = cos_mat[edge_index[0], edge_index[1], edge_index[2]].reshape(B, k_sparse * n_nodes, 1)
        # edge_attr = torch.cat((edge_attr1, edge_attr2), dim=1)

        self.encoded_nodes, self.encoded_edges = self.encoder(x_depot, x_node, buckets, sorted_indices, edge_attr,
                                                              edge_index)
        # self.encoded_nodes = self.encoder2(x_depot, x_node)
        if self.model_params["heatmap"]:
            # encoded_nodes = self.encoder_partitioner(x_depot, x_node, buckets, sorted_indices)
            self.heatmap = self.partition(self.encoded_edges, self.distances, topk_indices)
        else:
            self.heatmap = None
        # shape: (batch, problem+1, embedding)
        self.decoder.set_kv(self.encoded_nodes)
        # self.decoder.set_kv_cs(self.encoded_nodes[:, -number_of_css:, :])

    def forward(self, state, eval_type="greedy"):
        batch_size = state.BATCH_IDX.size(0)
        pomo_size = state.BATCH_IDX.size(1)

        if state.selected_count == 0:  # First Move, depot
            selected = torch.zeros(size=(batch_size, pomo_size), dtype=torch.long)
            prob = torch.ones(size=(batch_size, pomo_size))


        elif state.selected_count == 1:  # Second Move, POMO
            encoded_last_node = _get_encoding(self.encoded_nodes, state.current_node)
            if self.model_params["heatmap"]:
                heatmap = _get_encoding3(self.heatmap, state.current_node)
            else:
                heatmap = None

            probs = self.decoder(encoded_last_node, state.load, ninf_mask=state.ninf_mask,
                                 heatmap=heatmap)
            topk_value, topk_indices = torch.topk(probs[:, 0, 1:], pomo_size)

            # selected = torch.arange(start=1, end=pomo_size + 1)[None, :].expand(batch_size, pomo_size)
            selected = topk_indices
            prob = torch.ones(size=(batch_size, pomo_size))

        else:
            encoded_last_node = _get_encoding(self.encoded_nodes, state.current_node)
            if self.model_params["heatmap"]:
                heatmap = _get_encoding3(self.heatmap, state.current_node)
            else:
                heatmap = None

            probs = self.decoder(encoded_last_node, state.load, ninf_mask=state.ninf_mask,
                                 heatmap=heatmap)
            # probs = self.decoder(encoded_last_node, state.load, state.charge, ninf_mask=state.ninf_mask,
            #                      xy=state.xy, norm_demand=state.norm_demand, train_mode = train_mode)

            if self.training or self.partition_training or self.model_params['eval_type'] == 'softmax':
                while True:  # to fix pytorch.multinomial bug on selecting 0 probability elements
                    with torch.no_grad():
                        selected = probs.reshape(batch_size * pomo_size, -1).multinomial(1) \
                            .squeeze(dim=1).reshape(batch_size, pomo_size)
                    # shape: (batch, pomo)
                    prob = probs[state.BATCH_IDX, state.POMO_IDX, selected].reshape(batch_size, pomo_size)
                    # print("selected size {}".format(selected.size()))
                    # print("probs size {}".format(probs.size()))
                    # shape: (batch, pomo)
                    while state.ninf_mask.gather(2, selected.unsqueeze(-1)).data.any():
                        print('Sampled bad values, resampling!')
                        selected = probs.reshape(batch_size * pomo_size, -1).multinomial(1) \
                            .squeeze(dim=1).reshape(batch_size, pomo_size)
                        # shape: (batch, pomo)
                        prob = probs[state.BATCH_IDX, state.POMO_IDX, selected].reshape(batch_size, pomo_size)
                    if (prob != 0).all():
                        state.selected = selected
                        break
                    # else:
                    #     print("probs is zero")

            else:
                selected = probs.argmax(dim=2)
                # shape: (batch, pomo)
                prob = None  # value not needed. Can be anything.

        return selected, prob

    def _update_encoding(self, encoded_nodes, top_k, current_nodes, load, charge, mask, charging_station_embedding,
                         charging_station_count, top_cs):

        batch_size, pomo_size, node_size, emb_size = encoded_nodes.size()
        _, pomo_size, k_size = top_k.size()
        encoded_last_node = _get_encoding2(encoded_nodes, current_nodes)
        encoded_last_node = torch.cat((encoded_last_node, load.unsqueeze(-1), charge.unsqueeze(-1)), -1)
        selected_nodes = torch.gather(encoded_nodes, 2,
                                      top_k[..., None].expand(batch_size, pomo_size, k_size, emb_size))
        selected_charging_station_embedding = torch.gather(
            charging_station_embedding, 1,
            top_cs[..., None].expand(batch_size, pomo_size, top_cs.size(-1), emb_size).reshape(batch_size * pomo_size,
                                                                                               top_cs.size(-1),
                                                                                               emb_size))

        selected_charging_station_embeddings = self.partial_encoder(encoded_nodes, selected_nodes, encoded_last_node,
                                                                    mask,
                                                                    selected_charging_station_embedding)
        whole_css = charging_station_embedding.clone()
        xx = top_cs.reshape(batch_size * pomo_size, -1)[:, :, None].expand(batch_size * pomo_size, top_cs.size(-1),
                                                                           whole_css.size(-1))
        whole_css = whole_css.scatter(1, xx, selected_charging_station_embeddings)
        return whole_css


def _get_encoding(encoded_nodes, node_index_to_pick):
    # encoded_nodes.shape: (batch, problem, embedding)
    # node_index_to_pick.shape: (batch, pomo)

    batch_size = node_index_to_pick.size(0)
    pomo_size = node_index_to_pick.size(1)
    embedding_dim = encoded_nodes.size(2)

    gathering_index = node_index_to_pick[:, :, None].expand(batch_size, pomo_size, embedding_dim)
    # shape: (batch, pomo, embedding)

    picked_nodes = encoded_nodes.gather(dim=1, index=gathering_index)
    # shape: (batch, pomo, embedding)

    return picked_nodes


########################################
# ENCODER
########################################

class CVRP_Encoder(nn.Module):
    def __init__(self, partitioner=False, **model_params):
        super().__init__()
        self.model_params = model_params
        embedding_dim = self.model_params['embedding_dim']
        if not partitioner:
            encoder_layer_num = self.model_params['encoder_layer_num']
        else:
            encoder_layer_num = self.model_params['second_layer_num']

        self.embedding_depot = nn.Linear(3, embedding_dim)
        self.embedding_node = nn.Linear(3, embedding_dim)
        self.edge_embedding = nn.Linear(2, 32)
        self.layers = nn.ModuleList([SparseEncoderLayer(**model_params) for _ in range(encoder_layer_num)])

    def forward(self, depot_xy, node_xy_demand, buckets, sorted_indices, edge_attr, edge_index):
        embedded_depot = self.embedding_depot(depot_xy)
        # shape: (batch, 1, embedding)
        embedded_node = self.embedding_node(node_xy_demand)
        if self.model_params["heatmap"]:
            edge_embedding = self.edge_embedding(edge_attr)
        else:
            edge_embedding = None
        embedded_node = torch.cat((embedded_depot, embedded_node), dim=1)

        for layer in self.layers:
            embedded_node, edge_embedding = layer(embedded_node, buckets, sorted_indices, edge_embedding, edge_index)
        return embedded_node, edge_embedding
        # shape: (batch, problem+1, embedding)


class CVRP_Encoder3(nn.Module):
    def __init__(self, **model_params):
        super().__init__()
        self.model_params = model_params
        encoder_layer_num = self.model_params['second_encoder_layer_num']

        self.layers = nn.ModuleList([SparseEncoderLayer(**model_params) for _ in range(encoder_layer_num)])

    def forward(self, embedded_node, embedded_cs):
        for layer in self.layers:
            embedded_node, embedded_cs = layer(embedded_node, embedded_cs)
        total_embedding = torch.cat((embedded_node, embedded_cs), dim=1)
        return total_embedding
        # shape: (batch, problem+1, embedding)


class CVRP_Encoder2(nn.Module):
    def __init__(self, **model_params):
        super().__init__()
        self.model_params = model_params
        embedding_dim = self.model_params['embedding_dim']
        encoder_layer_num = self.model_params['encoder_layer_num']
        self.embedding_depot = nn.Linear(3, embedding_dim)
        self.embedding_cs = nn.Linear(2, embedding_dim)
        self.embedding_node = nn.Linear(3, embedding_dim)
        self.layers = nn.ModuleList([SimpleEncoderLayer(**model_params) for _ in range(encoder_layer_num)])

    def forward(self, depot_xy, node_xy_demand):
        embedded_depot = self.embedding_depot(depot_xy)
        # shape: (batch, 1, embedding)
        embedded_node = self.embedding_node(node_xy_demand)
        embedded_node = torch.cat((embedded_depot, embedded_node), dim=1)
        for layer in self.layers:
            embedded_node = layer(embedded_node)
        return embedded_node
        # shape: (batch, problem+1, embedding)


class SparseEncoderLayer(nn.Module):
    def __init__(self, **model_params):
        super().__init__()
        self.model_params = model_params
        self.node_encoder = EncoderLayer(**model_params)
        self.edge_encoder = EdgeEncoderLayer(**model_params)
        self.add_n_normalization_1 = AddAndInstanceNormalization(**model_params)
        self.add_n_normalization_2 = AddAndInstanceNormalization(**model_params)
        self.feed_forward = FeedForward(**model_params)

    def forward(self, node_embed, buckets, sorted_indices, edge_embedding, edge_index):
        # input1.shape: (batch, problem+1, embedding)
        b, h, u, s = buckets.size()
        _, n, d = node_embed.size()
        my_buckets = buckets.unsqueeze(-1).expand(b, h, u, s, d)
        embedded_node2 = node_embed[:, None, None, :, :].expand(b, h, u, n, d)
        bucket_embedded_node = torch.gather(embedded_node2, 3, my_buckets)
        bucket_embedded_node = bucket_embedded_node.reshape(-1, s, d)

        node_embed1 = self.node_encoder(bucket_embedded_node)
        node_embed1 = node_embed1.reshape(b, h, u, s, d)
        depot = torch.sum(torch.mean(node_embed1[..., 0:1, :], dim=2), dim=1)
        other_nodes = node_embed1[..., 1:, :].reshape(b, h, u * (s - 1), d)
        a_idx = sorted_indices.unsqueeze(-1).expand(-1, -1, -1, 128)
        # gathered = other_nodes.gather(dim=2, index=a_idx)  # → shape (2, 8, 100, 128)
        gathered = torch.zeros_like(other_nodes)
        gathered.scatter_(dim=2, index=a_idx, src=other_nodes)
        # Step 3: Sum over the “8” dimension (dim=1)
        nodes = gathered.sum(dim=1)
        # nodes = gathered[:,0, ...]
        node_embed2 = torch.cat((depot, nodes), dim=1)
        if self.model_params["heatmap"]:
            edge_embedding = self.edge_encoder(node_embed, edge_embedding, edge_index)
        else:
            edge_embedding = None

        out1 = self.add_n_normalization_1(node_embed, node_embed2)
        out2 = self.feed_forward(out1)
        node_embed = self.add_n_normalization_2(out1, out2)

        return node_embed, edge_embedding
        # shape: (batch, problem, embedding)


class EdgeEncoderLayer(nn.Module):
    def __init__(self, **model_params):
        super().__init__()
        self.model_params = model_params
        self.edge_reduce_embed = FeedForward3(128, 64, 32)
        self.edge_update = nn.Linear(32, 32)
        self.W2 = nn.Linear(32, 32)
        self.W3 = nn.Linear(32, 32)
        self.W1 = nn.Linear(32, 32)
        self.e_bns = nn.Linear(32, 32)
        self.act_fn = getattr(F, "silu")

    def forward(self, node_embed, edge_embedding, edge_index):
        reduced_node = self.edge_reduce_embed(node_embed)
        edge_embedding = edge_embedding + self.act_fn(self.e_bns(
            self.W1(edge_embedding) + self.W2(reduced_node[edge_index[0], edge_index[1]]) + self.W3(reduced_node[
                edge_index[0], edge_index[2]])))

        return edge_embedding


class SimpleEncoderLayer(nn.Module):
    def __init__(self, **model_params):
        super().__init__()
        self.model_params = model_params
        embedding_dim = self.model_params['embedding_dim']
        head_num = self.model_params['head_num']
        qkv_dim = self.model_params['qkv_dim']

        self.Wq = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)
        self.Wk = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)
        self.Wv = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)
        self.multi_head_combine = nn.Linear(head_num * qkv_dim, embedding_dim)
        self.attention = FlashAttention(dim=128, heads=8, dim_head=16)

        self.add_n_normalization_1 = AddAndInstanceNormalization(**model_params)
        self.feed_forward = FeedForward(**model_params)
        self.add_n_normalization_2 = AddAndInstanceNormalization(**model_params)

    def forward(self, input1):
        # input1.shape: (batch, problem+1, embedding)
        head_num = self.model_params['head_num']

        q = reshape_by_heads(self.Wq(input1), head_num=head_num)
        k = reshape_by_heads(self.Wk(input1), head_num=head_num)
        v = reshape_by_heads(self.Wv(input1), head_num=head_num)
        # out_concat = self.attention(q, k, v)
        out_concat = multi_head_attention(q, k, v)
        # shape: (batch, problem, head_num*qkv_dim)

        multi_head_out = self.multi_head_combine(out_concat)
        # shape: (batch, problem, embedding)

        out1 = self.add_n_normalization_1(input1, multi_head_out)
        out2 = self.feed_forward(out1)
        out3 = self.add_n_normalization_2(out1, out2)
        return out3


class EncoderLayer(nn.Module):
    def __init__(self, **model_params):
        super().__init__()
        self.model_params = model_params
        embedding_dim = self.model_params['embedding_dim']
        head_num = self.model_params['head_num']
        qkv_dim = self.model_params['qkv_dim']

        self.Wq = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)
        self.Wk = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)
        self.Wv = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)
        self.multi_head_combine = nn.Linear(head_num * qkv_dim, embedding_dim)
        self.attention = FlashAttention(dim=128, heads=8, dim_head=16)

        # self.add_n_normalization_1 = AddAndInstanceNormalization(**model_params)
        # self.feed_forward = FeedForward(**model_params)
        # self.add_n_normalization_2 = AddAndInstanceNormalization(**model_params)

    def batched_cosine_similarity(self, a, b):
        """
        Compute batched cosine similarity between tensors a and b.
        a: (B, N, D)
        b: (B, N, D)
        returns: (B, N, N) similarity matrix
        """
        a_norm = a / a.norm(dim=-1, keepdim=True).clamp(min=1e-8)
        b_norm = b / b.norm(dim=-1, keepdim=True).clamp(min=1e-8)
        return torch.bmm(a_norm, b_norm.transpose(1, 2))  # (B, N, N)

    def batched_random_perm(self, x, top_k=10):
        B, N, D = x.shape
        device = x.device

        # selected = torch.stack([torch.randperm(N)[:top_k] for _ in range(B)])  # (B, K)
        selected = torch.randint(0, N, (B, N, top_k))
        # selected = torch.gather(x, 1, rand_idx)
        mask = torch.full((B, N, N), True, device=device)

        # Create an index mask to select top_k indices without looping
        batch_indices = torch.arange(B, device=device).view(B, 1, 1)
        node_indices = torch.arange(N, device=device).view(1, N, 1)

        # Assign 0 to the top_k closest neighbors
        mask[batch_indices, node_indices, selected] = False
        return selected, mask

    # def batched_mmr_sparse_indices(self, x, top_k=10, lambda_coeff=0):
    #     """
    #     Efficient MMR index selection with diversity penalty.
    #     x: Tensor of shape (B, N, D)
    #     Returns: selected indices (B, N, top_k), and boolean mask (B, N, N)
    #     """
    #     B, N, D = x.shape
    #     device = x.device
    #
    #     # Normalize x
    #     x_norm = x / x.norm(dim=-1, keepdim=True).clamp(min=1e-8)
    #
    #     # Precompute cosine similarity between all pairs (B, N, N)
    #     sim = torch.bmm(x_norm, x_norm.transpose(1, 2))  # (B, N, N)
    #     sim2 = sim[:, None, :, :].expand(B, N, N, N)
    #
    #     selected = torch.zeros(B, N, top_k, dtype=torch.long, device=device)
    #     selected_mask = torch.zeros(B, N, N, dtype=torch.bool, device=device)
    #
    #     for k in range(top_k):
    #         if k == 0:
    #             idx = torch.argmax(sim, dim=-1)  # (B, N)
    #         # elif k == 1:
    #         #     mmr_score = sim
    #         #     mmr_score.masked_fill_(selected_mask, -float("inf"))
    #         #     idx = torch.argmax(mmr_score, dim=-1)
    #         else:
    #             # Gather similarity of candidates to selected nodes
    #             # selected[:, :, :k] => (B, N, k)
    #             # selected_k = selected[:, :, 1:k].unsqueeze(-1).expand(B, N, k-1, N)
    #             selected_k = selected[:, :, :k].unsqueeze(-1).expand(B, N, k, N)
    #
    #             selected_sim = torch.gather(
    #                 sim2, 2, selected_k
    #             ).permute(0, 1, 3, 2)
    #             selected_sim = selected_sim.max(dim=-1).values
    #             # Max similarity to any previously selected node (diversity)
    #             # diversity = selected_sim.min(dim=-1).values  # (B, N)
    #             diversity = selected_sim
    #
    #             # Expand for broadcasting
    #             # diversity = diversity.unsqueeze(-1).expand(-1, -1, N)  # (B, N, N)
    #
    #             # Compute MMR score
    #             mmr_score = lambda_coeff * sim + (1 - lambda_coeff) * (1 - diversity)
    #
    #             # Mask previously selected
    #             mmr_score.masked_fill_(selected_mask, -float("inf"))
    #
    #             idx = torch.argmax(mmr_score, dim=-1)  # (B, N)
    #
    #         # Save and update selection
    #         selected[:, :, k] = idx
    #         selected_mask.scatter_(2, idx.unsqueeze(-1), True)
    #
    #     # Create selection mask
    #     mask = torch.ones(B, N, N, dtype=torch.bool, device=device)
    #     batch_idx = torch.arange(B, device=device).view(B, 1, 1)
    #     node_idx = torch.arange(N, device=device).view(1, N, 1)
    #     mask[batch_idx, node_idx, selected] = False
    #
    #     return selected, mask
    #
    # def batched_mmr_sparse_indices2(self, x, top_k=10, lambda_coeff=0.7):
    #     """
    #     x: Tensor of shape (B, N, D) representing batch of node embeddings
    #     Returns: Tensor of shape (B, N, top_k) containing selected indices for each node
    #     """
    #     B, N, D = x.shape
    #     device = x.device
    #
    #     # Normalize for cosine similarity
    #     x_norm = x / x.norm(dim=-1, keepdim=True).clamp(min=1e-8)
    #     similarity = torch.bmm(x_norm, x_norm.transpose(1, 2))  # (B, N, N)
    #
    #     selected = torch.zeros(B, N, top_k, dtype=torch.long, device=device)
    #     selected_mask = torch.zeros(B, N, N, dtype=torch.bool, device=device)
    #
    #     for k in range(top_k):
    #         if k == 0:
    #             # Select the most relevant for each node (argmax over N)
    #             idx = torch.argmax(similarity, dim=-1)  # (B, N)
    #         else:
    #             prev_sel = selected[:, :, :k]  # (B, N, k)
    #
    #             # Gather vectors of selected nodes
    #             b_idx = torch.arange(B, device=device).view(B, 1, 1).expand(B, N, k)
    #             prev_vecs = x_norm[b_idx, prev_sel, :]  # (B, N, k, D)
    #
    #             # Candidate vectors: shape (B, N, N, D)
    #             cand_vecs = x_norm.unsqueeze(1).expand(B, N, N, D)  # (B, N, N, D)
    #             query_vecs = prev_vecs.unsqueeze(2)  # (B, N, 1, k, D)
    #
    #             # Compute similarity between each candidate and each selected node
    #             sim_to_selected = torch.nn.functional.cosine_similarity(
    #                 cand_vecs.unsqueeze(3), query_vecs, dim=-1
    #             )  # (B, N, N, k)
    #
    #             # Get max similarity to any selected node for each candidate
    #             diversity = torch.max(sim_to_selected, dim=-1).values  # (B, N, N)
    #
    #             # Combine relevance and diversity (diversity from selected)
    #             mmr_scores = (
    #                     lambda_coeff * similarity + (1 - lambda_coeff) * (1 - diversity)
    #             )  # (B, N, N)
    #
    #             # Mask already selected
    #             mmr_scores.masked_fill_(selected_mask, -float('inf'))
    #
    #             idx = torch.argmax(mmr_scores, dim=-1)  # (B, N)
    #
    #         selected[:, :, k] = idx  # save selected index
    #         selected_mask.scatter_(2, idx.unsqueeze(-1), True)  # update mask
    #
    #     mask = torch.full((B, N, N), True, device=device)
    #
    #     # Create an index mask to select top_k indices without looping
    #     batch_indices = torch.arange(B, device=device).view(B, 1, 1)
    #     node_indices = torch.arange(N, device=device).view(1, N, 1)
    #
    #     # Assign 0 to the top_k closest neighbors
    #     mask[batch_indices, node_indices, selected] = False
    #
    #     return selected, mask  # (B, N, top_k)

    def forward(self, input1):
        # input1.shape: (batch, problem+1, embedding)
        head_num = self.model_params['head_num']

        q = reshape_by_heads(self.Wq(input1), head_num=head_num)
        k = reshape_by_heads(self.Wk(input1), head_num=head_num)
        v = reshape_by_heads(self.Wv(input1), head_num=head_num)
        # qkv shape: (batch, head_num, problem, qkv_dim)
        # _, node_mask = self.batched_mmr_sparse_indices(input1, top_k=top_k)
        # new_mask = node_mask & mask
        # float_mask = torch.zeros(new_mask.shape, dtype=torch.float32, device=new_mask.device)
        # float_mask[new_mask] = float('-inf')
        # out_concat = multi_head_attention(q, k, v, rank3_ninf_mask=float_mask)
        out_concat = self.attention(q, k, v)
        # out_concat = multi_head_attention(q, k, v)
        # shape: (batch, problem, head_num*qkv_dim)

        multi_head_out = self.multi_head_combine(out_concat)
        # shape: (batch, problem, embedding)

        # out1 = self.add_n_normalization_1(input1, multi_head_out)
        # out2 = self.feed_forward(out1)
        # out3 = self.add_n_normalization_2(out1, out2)
        out3 = multi_head_out
        return out3
        # shape: (batch, problem, embedding)


class SparseGumbelEncoderLayer(nn.Module):
    def __init__(self, **model_params):
        super().__init__()
        self.model_params = model_params
        embedding_dim = self.model_params['embedding_dim']
        head_num = self.model_params['head_num']
        qkv_dim = self.model_params['qkv_dim']

        self.Wq = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)
        self.Wk = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)
        self.Wv = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)
        self.multi_head_combine = nn.Linear(head_num * qkv_dim, embedding_dim)

        # self.add_n_normalization_1 = AddAndInstanceNormalization(**model_params)
        # self.feed_forward = FeedForward(**model_params)
        # self.add_n_normalization_2 = AddAndInstanceNormalization(**model_params)

    def forward(self, input1, input2):
        # input1.shape: (batch, problem+1, embedding)
        head_num = self.model_params['head_num']
        tau = self.model_params['tau']

        q = reshape_by_heads(self.Wq(input1), head_num=head_num)
        k = reshape_by_heads(self.Wk(input2), head_num=head_num)
        v = reshape_by_heads(self.Wv(input2), head_num=head_num)
        # qkv shape: (batch, head_num, problem, qkv_dim)
        out_concat = multi_head_attention(q, k, v, gumbel=True, tau=tau)
        # shape: (batch, problem, head_num*qkv_dim)

        multi_head_out = self.multi_head_combine(out_concat)
        out3 = multi_head_out
        # shape: (batch, problem, embedding)

        # out1 = self.add_n_normalization_1(input1, multi_head_out)
        # out2 = self.feed_forward(out1)
        # out3 = self.add_n_normalization_2(out1, out2)

        return out3
        # shape: (batch, problem, embedding)


class CVRP_PARTIAL_Encoder(nn.Module):
    def __init__(self, **model_params):
        super().__init__()
        self.model_params = model_params
        self.pomo_size = self.model_params['pomo_size']
        embedding_dim = self.model_params['embedding_dim']
        # encoder_layer_num = self.model_params['partial_encoder_layer_num']

        self.state_encoder = nn.Linear(embedding_dim + 2, embedding_dim)
        # self.layers = nn.ModuleList([PartialEncoderLayer(**model_params) for _ in range(encoder_layer_num)])
        self.partial_encoder_layer = PartialEncoderLayer(**model_params)

    def pre_forward(self, charging_station_embedding):
        self.partial_encoder_layer.pre_forward(charging_station_embedding)

    def forward(self, encoded_embedding, selected_encoded_embedding, current_nodes, visited_mask,
                charging_station_embedding):
        batch_size, pomo_size, node_size, embedding_size = encoded_embedding.size()
        _, pomo_size, k_size, _ = selected_encoded_embedding.size()
        # pomo_size = current_nodes.size(1)
        state_embedding = self.state_encoder(current_nodes)
        selected_encoded_embedding = selected_encoded_embedding.reshape(batch_size * pomo_size, k_size, embedding_size)
        state_embedding = state_embedding[:, :, None, :].reshape(batch_size * pomo_size, 1, embedding_size)
        selected_encoded_embedding = torch.cat((state_embedding, selected_encoded_embedding), dim=1)
        charging_station_embedding = self.partial_encoder_layer(charging_station_embedding, selected_encoded_embedding,
                                                                visited_mask)
        return charging_station_embedding
        # charging_station_embedding = charging_station_embedding.reshape(batch_size, pomo_size, charging_station_count,
        #                                                                 -1)
        # final_encoded_embedding.clone()[:, :, -charging_station_count:, :] = charging_station_embedding
        # return final_encoded_embedding
        # shape: (batch, problem+1, embedding)


class PartialEncoderLayer(nn.Module):
    def __init__(self, **model_params):
        super().__init__()
        self.model_params = model_params
        embedding_dim = self.model_params['embedding_dim']
        head_num = self.model_params['head_num']
        qkv_dim = self.model_params['qkv_dim']
        self.q = None
        self.score = None
        self.v = None
        self.Wq = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)
        self.Wk = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)
        self.Wv = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)
        self.multi_head_combine = nn.Linear(head_num * qkv_dim, embedding_dim)
        self.scale = torch.sqrt(torch.tensor(qkv_dim, dtype=torch.float))
        self.add_n_normalization_1 = AddAndInstanceNormalization(**model_params)
        self.feed_forward = FeedForward2(**model_params)
        self.add_n_normalization_2 = AddAndInstanceNormalization(**model_params)

    def pre_forward(self, input):
        head_num = self.model_params['head_num']
        self.q = reshape_by_heads(self.Wq(input), head_num=head_num)

    def forward(self, input1, input2, visited_mask):
        # input1.shape: (batch, problem+1, embedding)

        # qkv shape: (batch, head_num, problem, qkv_dim)

        out_concat = self.multi_head_attention(input1, input2, visited_mask)
        # out_concat = input1
        # shape: (batch, problem, head_num*qkv_dim)
        multi_head_out = self.multi_head_combine(out_concat)
        # shape: (batch, problem, embedding)
        out1 = input1 + multi_head_out
        # out1 = self.add_n_normalization_1(input1, multi_head_out)
        # out2 = self.feed_forward(out1)
        # out3 = self.add_n_normalization_2(out1, out2)
        out3 = out1
        return out3
        # shape: (batch, problem, embedding)

    def multi_head_attention(self, input1, input2, rank2_ninf_mask=None, rank3_ninf_mask=None):
        head_num = self.model_params['head_num']
        q = reshape_by_heads(self.Wq(input1), head_num=head_num)
        # k = reshape_by_heads(self.Wk(input2), head_num=head_num)
        batch_s = q.size(0)
        n = q.size(2)
        key_dim = q.size(3)

        k = reshape_by_heads(self.Wk(input2), head_num=head_num)
        v = reshape_by_heads(self.Wv(input2), head_num=head_num)
        score = torch.matmul(q, k.transpose(2, 3))

        input_s = v.size(2)
        # score_scaled = score / torch.sqrt(torch.tensor(key_dim, dtype=torch.float))
        score_scaled = score / self.scale
        if rank2_ninf_mask is not None:
            ss = torch.zeros(batch_s, 1, dtype=rank2_ninf_mask.dtype, device=rank2_ninf_mask.device)
            ss = torch.cat((ss, rank2_ninf_mask.reshape(batch_s, -1)), -1)
            score_scaled = score_scaled + ss[:, None, None, :].expand(batch_s, head_num, n, input_s)
        if rank3_ninf_mask is not None:
            score_scaled = score_scaled + rank3_ninf_mask[:, None, :, :].expand(batch_s, head_num, n, input_s)
        weights = nn.Softmax(dim=3)(score_scaled)
        # shape: (batch, head_num, n, problem)
        out = torch.matmul(weights, v)
        # shape: (batch, head_num, n, key_dim)

        out_transposed = out.transpose(1, 2)
        # shape: (batch, n, head_num, key_dim)

        out_concat = out_transposed.reshape(batch_s, n, head_num * key_dim)
        # shape: (batch, n, head_num*key_dim)
        return out_concat


########################################
# DECODER
########################################

class CVRP_Reembedder(nn.Module):
    def __init__(self, **model_params):
        super().__init__()
        self.model_params = model_params
        embedding_dim = self.model_params['embedding_dim']
        head_num = self.model_params['head_num']
        qkv_dim = self.model_params['qkv_dim']
        self.station_count = self.model_params['charging_station_count']
        self.k_local = None
        self.k_global = None
        self.v_local = None
        self.v_global = None
        self.q_local = None
        self.q_global = None
        self.Wq = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)
        self.Wk = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)
        self.Wv = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)
        self.multi_head_combine = nn.Linear(head_num * qkv_dim, embedding_dim)
        self.state_encoder = nn.Linear(embedding_dim + 2, embedding_dim, bias=False)
        self.feed_forward = FeedForward(**model_params)
        self.global_query_node_embeddings = None
        self.global_key_value_node_embeddings = None
        self.local_query_node_embedding = None
        self.local_key_value_node_embedding = None
        self.scale = torch.sqrt(torch.tensor(embedding_dim, dtype=torch.float))

    def pre_forward(self, encoded_whole_nodes, cs_count, local_node_list):
        self.global_query_node_embeddings = torch.cat((encoded_whole_nodes[:, 0:1, :]),
                                                      encoded_whole_nodes[:, -cs_count:, :])
        self.global_key_value_node_embeddings = encoded_whole_nodes
        self.local_query_node_embedding = encoded_whole_nodes[:, 1: -cs_count:, :]
        self.local_key_value_node_embedding = torch.gather(encoded_whole_nodes, 1, local_node_list)

    def reset(self):
        self.k_local = None
        self.v_local = None
        self.q_local = None
        self.local_comp = None
        self.zero_tens = None
        self.current_local_comp = None
        self.current_v_local = None
        self.myinput = None

    def forward(self, encoded_last_node, load, charge, ninf_mask, encoded_whole_nodes, top_k_indices):
        input1 = encoded_whole_nodes
        head_num = self.model_params['head_num']
        total_s = encoded_whole_nodes.size(1)
        dim = encoded_whole_nodes.size(-1)
        batch_s = encoded_whole_nodes.size(0)
        n = encoded_last_node.size(1)

        input_cat = torch.cat((encoded_last_node, load[:, :, None], charge[:, :, None]), dim=2)
        state_emb = self.state_encoder(input_cat)
        k_state = reshape_by_heads3(self.Wk(state_emb), head_num=head_num)
        v_state = reshape_by_heads3(self.Wv(state_emb), head_num=head_num)
        # indices_list = top_k_indices[:, 1:-self.station_count, :]
        if self.local_comp is None:
            self.local_query_node_embedding = encoded_whole_nodes
            self.q_local = reshape_by_heads(self.Wq(self.local_query_node_embedding), head_num=head_num)
            self.k_local = reshape_by_heads(self.Wk(self.local_query_node_embedding), head_num=head_num)
            self.v_local = reshape_by_heads(self.Wv(self.local_query_node_embedding), head_num=head_num)
            self.local_comp = torch.matmul(self.q_local, self.k_local.transpose(-2, -1)) / self.scale
            self.current_local_comp = self.local_comp[..., None, :, :].expand(batch_s, head_num, n,
                                                                              self.local_comp.size(-2),
                                                                              self.local_comp.size(-1))
            self.current_v_local = self.v_local[:, :, None, ...].expand(batch_s, head_num, n, self.v_local.size(-2),
                                                                        self.v_local.size(-1))
            self.zero_tens = torch.zeros(batch_s, n, 1, dtype=ninf_mask.dtype, device=ninf_mask.device)
            self.myinput = input1[:, None, ...].expand(batch_s, n, total_s, dim)

        local_state_comp = torch.matmul(self.q_local, k_state.transpose(-2, -1)) / self.scale
        local_state_comp = local_state_comp.permute(0, 1, 3, 2)

        current_local_comp = torch.cat((self.current_local_comp, local_state_comp.unsqueeze(-1)), -1)
        if ninf_mask is not None:
            ninf_mask_ = torch.cat((ninf_mask, self.zero_tens), -1)
            current_local_comp = current_local_comp + ninf_mask_[:, None, :, None, :].expand(batch_s, head_num, n,
                                                                                             total_s, total_s + 1)

        current_local_comp = torch.softmax(current_local_comp, -1)

        v_state2 = v_state[..., None, :].expand(batch_s, head_num, n, 1, v_state.size(-1))
        current_v_local = torch.cat((self.current_v_local, v_state2), -2)
        out = torch.matmul(current_local_comp, current_v_local)
        out_transposed = out.permute(0, 2, 3, 1, 4)
        out_concat = out_transposed.reshape(batch_s, n, total_s, dim)
        multi_head_out = self.multi_head_combine(out_concat)
        out1 = self.myinput + multi_head_out
        out2 = self.feed_forward(out1)
        out3 = out1 + out2
        return out3


class Decoder(nn.Module):
    def __init__(self, **model_params):
        super().__init__()
        self.model_params = model_params
        # self.decoder = CVRP_Decoder(**model_params)
        self.decoder = CVRP_Sparse_Decoder(**model_params)
        self.logit_clipping = self.model_params['logit_clipping']

    def set_kv(self, encoded_nodes):
        self.decoder.set_kv(encoded_nodes)

    def forward(self, encoded_last_node, load, ninf_mask, heatmap=None):
        if self.model_params["train_partition"]:
            with torch.no_grad():
                score_scaled = self.decoder(encoded_last_node, load, ninf_mask, heatmap=heatmap)
        else:
            score_scaled = self.decoder(encoded_last_node, load, ninf_mask, heatmap=heatmap)

        if self.model_params['heatmap']:
            score_scaled += heatmap
            # score_scaled = heatmap

        score_clipped = self.logit_clipping * torch.tanh(score_scaled)

        score_masked = score_clipped + ninf_mask

        probs = F.softmax(score_masked, dim=2)
        return probs


class CVRP_Decoder(nn.Module):
    def __init__(self, **model_params):
        super().__init__()
        self.model_params = model_params
        embedding_dim = self.model_params['embedding_dim']
        head_num = self.model_params['head_num']
        qkv_dim = self.model_params['qkv_dim']
        self.local = self.model_params['joint_training']
        self.cs_training = self.model_params["cs_training"]

        self.Wq_last = nn.Linear(embedding_dim + 2, head_num * qkv_dim, bias=False)
        self.Wk = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)
        self.Wv = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)

        self.multi_head_combine = nn.Linear(head_num * qkv_dim, embedding_dim)
        self.multi_head_combine2 = nn.Linear(head_num * qkv_dim, embedding_dim)

        self.k = None  # saved key, for multi-head attention
        self.v = None  # saved value, for multi-head_attention
        self.single_head_key = None  # saved, for single-head attention
        self.heatmap = None

        self.k_nodes = None  # saved key, for multi-head attention
        self.v_nodes = None  # saved value, for multi-head_attention
        self.single_head_key_nodes = None  # saved, for single-head attention
        self.my_attention = FlashAttention(dim=128, heads=8, dim_head=16)
        self.k_cs = None  # saved key, for multi-head attention
        self.v_cs = None  # saved value, for multi-head_attention
        self.single_head_key_cs = None  # saved, for single-head attention
        # self.local_policies = nn.ModuleList(
        #     [local_policy_att(self.model_params, idx=i) for i in range(self.model_params['ensemble_size'])])

    def set_kv(self, encoded_nodes):
        # encoded_nodes.shape: (batch, problem+1, embedding)
        head_num = self.model_params['head_num']

        self.k_nodes = reshape_by_heads3(self.Wk(encoded_nodes), head_num=head_num)
        self.v_nodes = reshape_by_heads3(self.Wv(encoded_nodes), head_num=head_num)
        # shape: (batch, head_num, problem+1, qkv_dim)
        self.single_head_key_nodes = encoded_nodes.transpose(-2, -1)
        # shape: (batch, embedding, problem+1)

    def forward(self, encoded_last_node, load, charge, ninf_mask):
        # encoded_last_node.shape: (batch, pomo, embedding)
        # load.shape: (batch, pomo)
        # ninf_mask.shape: (batch, pomo, problem)

        head_num = self.model_params['head_num']
        #  Multi-Head Attention
        #######################################################
        input_cat = torch.cat((encoded_last_node, load[:, :, None], charge[:, :, None]), dim=2)
        q_last = reshape_by_heads(self.Wq_last(input_cat), head_num=head_num)
        q = q_last
        b, h, n, _ = q.size()
        input_s = ninf_mask.size(-1)
        my_mask = ninf_mask[:, None, :, :].expand(b, h, n, input_s)
        my_mask2 = torch.isneginf(my_mask)
        out_concat = self.my_attention(q, self.k_nodes, self.v_nodes, mask=my_mask2)
        # out_concat = self.multi_head_attention4(q, self.k_nodes, self.v_nodes, rank3_ninf_mask=ninf_mask)
        mh_atten_out = self.multi_head_combine(out_concat)
        score = torch.matmul(mh_atten_out, self.single_head_key_nodes)
        sqrt_embedding_dim = self.model_params['sqrt_embedding_dim']
        score_scaled = score / sqrt_embedding_dim
        return score_scaled


class CVRP_Sparse_Decoder(nn.Module):
    def __init__(self, **model_params):
        super().__init__()
        self.model_params = model_params
        embedding_dim = self.model_params['embedding_dim']
        head_num = self.model_params['head_num']
        qkv_dim = self.model_params['qkv_dim']

        self.Wq_last = nn.Linear(embedding_dim + 1, head_num * qkv_dim, bias=False)
        self.Wk = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)
        self.Wv = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)

        self.multi_head_combine = nn.Linear(head_num * qkv_dim, embedding_dim)
        self.multi_head_combine2 = nn.Linear(head_num * qkv_dim, embedding_dim)

        self.k = None  # saved key, for multi-head attention
        self.v = None  # saved value, for multi-head_attention
        self.single_head_key = None  # saved, for single-head attention
        self.heatmap = None

        self.k_nodes = None  # saved key, for multi-head attention
        self.v_nodes = None  # saved value, for multi-head_attention
        self.single_head_key_nodes = None  # saved, for single-head attention
        self.my_attention = FlashAttention(dim=128, heads=8, dim_head=16)
        self.k_cs = None  # saved key, for multi-head attention
        self.v_cs = None  # saved value, for multi-head_attention
        self.single_head_key_cs = None  # saved, for single-head attention
        # self.local_policies = nn.ModuleList(
        #     [local_policy_att(self.model_params, idx=i) for i in range(self.model_params['ensemble_size'])])

    def set_kv(self, encoded_nodes):
        # encoded_nodes.shape: (batch, problem+1, embedding)
        head_num = self.model_params['head_num']

        self.k_nodes = reshape_by_heads3(self.Wk(encoded_nodes), head_num=head_num)
        self.v_nodes = reshape_by_heads3(self.Wv(encoded_nodes), head_num=head_num)
        # shape: (batch, head_num, problem+1, qkv_dim)
        self.single_head_key_nodes = encoded_nodes.transpose(-2, -1)
        # shape: (batch, embedding, problem+1)

    def set_kv_cs(self, encoded_cs):
        # encoded_nodes.shape: (batch, problem+1, embedding)
        head_num = self.model_params['head_num']
        batch, pomo_count, node_size, emb_size = encoded_cs.size()
        self.k_cs = reshape_by_heads2(self.Wk(encoded_cs), head_num=head_num)
        self.v_cs = reshape_by_heads2(self.Wv(encoded_cs), head_num=head_num)
        # shape: (batch, head_num, problem+1, qkv_dim)
        self.single_head_key_cs = encoded_cs.transpose(2, 3)
        self.scale = torch.sqrt(torch.tensor(emb_size / head_num, dtype=torch.float))

        # self.k = torch.cat((self.k_nodes, self.k_cs), dim=-2)
        # self.v = torch.cat((self.v_nodes, self.v_cs), dim=-2)
        # self.single_head_key = torch.cat((self.single_head_key_nodes, self.single_head_key_cs), dim=-1)

    def set_q1(self, encoded_q1):
        # encoded_q.shape: (batch, n, embedding)  # n can be 1 or pomo
        head_num = self.model_params['head_num']
        self.q1 = reshape_by_heads(self.Wq_1(encoded_q1), head_num=head_num)
        # shape: (batch, head_num, n, qkv_dim)

    def set_q2(self, encoded_q2):
        # encoded_q.shape: (batch, n, embedding)  # n can be 1 or pomo
        head_num = self.model_params['head_num']
        self.q2 = reshape_by_heads(self.Wq_2(encoded_q2), head_num=head_num)
        # shape: (batch, head_num, n, qkv_dim)

    def forward(self, encoded_last_node, load, ninf_mask, heatmap):
        # encoded_last_node.shape: (batch, pomo, embedding)
        # load.shape: (batch, pomo)
        head_num = self.model_params['head_num']
        # ninf_mask.shape: (batch, pomo, problem)
        top_k_count = 30
        #######################################################
        input_cat = torch.cat((encoded_last_node, load[:, :, None]), dim=2)
        q_last = reshape_by_heads(self.Wq_last(input_cat), head_num=head_num)

        my_mask = ninf_mask[:, None, :, :].expand(ninf_mask.size(0), head_num, ninf_mask.size(1), ninf_mask.size(-1))
        my_mask2 = torch.isneginf(my_mask)
        out_concat = self.my_attention(q_last, self.k_nodes, self.v_nodes, mask=my_mask2)
        # out_concat = multi_head_attention(q_last, self.k_nodes, self.v_nodes, rank3_ninf_mask=ninf_mask)
        # out_concat = topk_attention(q_last, self.k_nodes, self.v_nodes, rank3_ninf_mask=ninf_mask)
        mh_atten_out = self.multi_head_combine(out_concat)
        score = torch.matmul(mh_atten_out, self.single_head_key_nodes)
        sqrt_embedding_dim = self.model_params['sqrt_embedding_dim']
        score_scaled = score / sqrt_embedding_dim
        return score_scaled


def topk_attention(Q, K, V, rank2_ninf_mask=None, rank3_ninf_mask=None, k=30):
    # Q, K, V: (B, N, D)
    batch_s, head_num, n, key_dim = Q.size()
    input_s = K.size(2)
    scores = torch.matmul(Q, K.transpose(-2, -1))
    score_scaled = scores / torch.sqrt(torch.tensor(Q.size(-1), dtype=torch.float))  # (B, N, N)
    if rank3_ninf_mask is not None:
        score_scaled = score_scaled + rank3_ninf_mask[:, None, :, :].expand(batch_s, head_num, n, input_s)
    topk_scores, topk_indices = score_scaled.topk(k=k, dim=-1)
    mask = torch.full_like(score_scaled, float('-inf'))
    sparse_scores = mask.scatter(dim=-1, index=topk_indices, src=topk_scores)
    attn_weights = torch.softmax(sparse_scores, dim=-1)
    output = torch.matmul(attn_weights, V)
    out_transposed = output.transpose(1, 2)
    # shape: (batch, n, head_num, key_dim)
    out_concat = out_transposed.reshape(batch_s, n, head_num * key_dim)
    return out_concat


def multi_head_attention(q, k, v, rank2_ninf_mask=None, rank3_ninf_mask=None, top_k=30):
    # q shape: (batch, head_num, n, key_dim)   : n can be either 1 or PROBLEM_SIZE
    # k,v shape: (batch, head_num, problem, key_dim)
    # rank2_ninf_mask.shape: (batch, problem)
    # rank3_ninf_mask.shape: (batch, group, problem)

    batch_s = q.size(0)
    head_num = q.size(1)
    n = q.size(2)
    key_dim = q.size(-1)

    input_s = k.size(2)

    if len(k.shape) == 5:
        # q.shape: (batch, head_num, n, 1, key_dim)
        # k.shape: (batch, head_num, n, local, key_dim)
        score = torch.matmul(q, k.transpose(3, 4)).squeeze(-2)
        # shape: (batch, head_num, n, local)
        input_s = k.size(3)
    else:
        score = torch.matmul(q, k.transpose(2, 3))
        # shape: (batch, head_num, n, problem)

    score_scaled = score / torch.sqrt(torch.tensor(key_dim, dtype=torch.float))

    if rank2_ninf_mask is not None:
        score_scaled = score_scaled + rank2_ninf_mask[:, None, None, :].expand(batch_s, head_num, n, input_s)
    if rank3_ninf_mask is not None:
        score_scaled = score_scaled + rank3_ninf_mask[:, None, :, :].expand(batch_s, head_num, n, input_s)

    elif len(k.shape) == 5:
        weights = nn.Softmax(dim=3)(score_scaled).unsqueeze(3)
        # shape: (batch, head_num, n, 1, local)
        out = torch.matmul(weights, v).squeeze(3)
        # shape: (batch, head_num, n, key_dim)

    else:
        weights = nn.Softmax(dim=3)(score_scaled)
        # shape: (batch, head_num, n, problem)

        out = torch.matmul(weights, v)
        # shape: (batch, head_num, n, key_dim)

    out_transposed = out.transpose(1, 2)
    # shape: (batch, n, head_num, key_dim)
    out_concat = out_transposed.reshape(batch_s, n, head_num * key_dim)
    # shape: (batch, n, head_num*key_dim)

    return out_concat


########################################
# NN SUB CLASS / FUNCTIONS
########################################
def reshape_by_heads3(qkv, head_num):
    # q.shape: (batch, n, head_num*key_dim)   : n can be either 1 or PROBLEM_SIZE
    if len(qkv.shape) == 4:
        # q.shape: (batch, n, local, head_num * key_dim)
        batch_s = qkv.size(0)
        n1 = qkv.size(1)
        n2 = qkv.size(2)
        q_reshaped = qkv.reshape(batch_s, n1, n2, head_num, -1)
        # shape: (batch, n, local, head_num, key_dim)

        q_transposed = q_reshaped.transpose(2, 3).transpose(1, 2)
        # shape: (batch, head, n, local, key_dim)
    else:
        batch_s = qkv.size(0)
        n = qkv.size(1)

        q_reshaped = qkv.reshape(batch_s, n, head_num, -1)
        # shape: (batch, n, head_num, key_dim)

        q_transposed = q_reshaped.transpose(1, 2)
        # shape: (batch, head_num, n, key_dim)

    return q_transposed


def reshape_by_heads(qkv, head_num):
    # q.shape: (batch, n, head_num*key_dim)   : n can be either 1 or PROBLEM_SIZE

    batch_s = qkv.size(0)
    n = qkv.size(1)

    q_reshaped = qkv.reshape(batch_s, n, head_num, -1)
    # shape: (batch, n, head_num, key_dim)

    q_transposed = q_reshaped.transpose(1, 2)
    # shape: (batch, head_num, n, key_dim)

    return q_transposed


def reshape_by_heads2(qkv, head_num):
    # q.shape: (batch, n, head_num*key_dim)   : n can be either 1 or PROBLEM_SIZE

    batch_s = qkv.size(0)
    pomo = qkv.size(1)
    n = qkv.size(2)

    q_reshaped = qkv.reshape(batch_s, pomo, n, head_num, -1)
    # shape: (batch, n, head_num, key_dim)

    q_transposed = q_reshaped.permute(0, 3, 1, 2, 4)
    # shape: (batch, head_num, n, key_dim)

    return q_transposed


def multi_head_attention(q, k, v, rank2_ninf_mask=None, rank3_ninf_mask=None, gumbel=False, tau=1):
    # q shape: (batch, head_num, n, key_dim)   : n can be either 1 or PROBLEM_SIZE
    # k,v shape: (batch, head_num, problem, key_dim)
    # rank2_ninf_mask.shape: (batch, problem)
    # rank3_ninf_mask.shape: (batch, group, problem)

    batch_s = q.size(0)
    head_num = q.size(1)
    n = q.size(2)
    key_dim = q.size(-1)

    input_s = k.size(2)

    if len(k.shape) == 5:
        # q.shape: (batch, head_num, n, 1, key_dim)
        # k.shape: (batch, head_num, n, local, key_dim)
        score = torch.matmul(q, k.transpose(3, 4)).squeeze(-2)
        # shape: (batch, head_num, n, local)
        input_s = k.size(3)
    else:
        score = torch.matmul(q, k.transpose(2, 3))
        # shape: (batch, head_num, n, problem)

    score_scaled = score / torch.sqrt(torch.tensor(key_dim, dtype=torch.float))

    if rank2_ninf_mask is not None:
        score_scaled = score_scaled + rank2_ninf_mask[:, None, None, :].expand(batch_s, head_num, n, input_s)
    if rank3_ninf_mask is not None:
        score_scaled = score_scaled + rank3_ninf_mask[:, None, :, :].expand(batch_s, head_num, n, input_s)

    if gumbel:
        g = Gumbel(0, 1).sample((2, *score_scaled.shape))
        weights = torch.sigmoid((score_scaled + g[0] - g[1]) / tau)
        out = torch.matmul(weights, v)
        # eps = 1e-10
        # U = torch.rand_like(score_scaled)
        # gumbel_noise = -torch.log(-torch.log(U + eps) + eps)

        # Apply Gumbel-Softmax trick
        # weights = F.softmax((score_scaled + gumbel_noise) / tau, dim=-1)


    elif len(k.shape) == 5:
        weights = nn.Softmax(dim=3)(score_scaled).unsqueeze(3)
        # shape: (batch, head_num, n, 1, local)
        out = torch.matmul(weights, v).squeeze(3)
        # shape: (batch, head_num, n, key_dim)

    else:
        weights = nn.Softmax(dim=3)(score_scaled)
        # shape: (batch, head_num, n, problem)

        out = torch.matmul(weights, v)
        # shape: (batch, head_num, n, key_dim)

    out_transposed = out.transpose(1, 2)
    # shape: (batch, n, head_num, key_dim)
    out_concat = out_transposed.reshape(batch_s, n, head_num * key_dim)
    # shape: (batch, n, head_num*key_dim)

    return out_concat


# def multi_head_attention(q, k, v, rank2_ninf_mask=None, rank3_ninf_mask=None):
#     # q shape: (batch, head_num, n, key_dim)   : n can be either 1 or PROBLEM_SIZE
#     # k,v shape: (batch, head_num, problem, key_dim)
#     # rank2_ninf_mask.shape: (batch, problem)
#     # rank3_ninf_mask.shape: (batch, group, problem)
#     batch_s = q.size(0)
#     head_num = q.size(1)
#     n = q.size(2)
#     key_dim = q.size(3)
#
#     input_s = k.size(2)
#     score = torch.matmul(q, k.transpose(2, 3))
#     # shape: (batch, head_num, n, problem)
#
#     score_scaled = score / torch.sqrt(torch.tensor(key_dim, dtype=torch.float))
#     if rank2_ninf_mask is not None:
#         score_scaled = score_scaled + rank2_ninf_mask[:, None, None, :].expand(batch_s, head_num, n, input_s)
#     if rank3_ninf_mask is not None:
#         score_scaled = score_scaled + rank3_ninf_mask[:, None, :, :].expand(batch_s, head_num, n, input_s)
#     weights = nn.Softmax(dim=3)(score_scaled)
#     # shape: (batch, head_num, n, problem)
#     out = torch.matmul(weights, v)
#     # shape: (batch, head_num, n, key_dim)
#
#     out_transposed = out.transpose(1, 2)
#     # shape: (batch, n, head_num, key_dim)
#
#     out_concat = out_transposed.reshape(batch_s, n, head_num * key_dim)
#     # shape: (batch, n, head_num*key_dim)
#     return out_concat


class AddAndInstanceNormalization(nn.Module):
    def __init__(self, **model_params):
        super().__init__()
        embedding_dim = model_params['embedding_dim']
        self.norm = nn.InstanceNorm1d(embedding_dim, affine=True, track_running_stats=False)

    def forward(self, input1, input2):
        # input.shape: (batch, problem, embedding)

        added = input1 + input2
        # shape: (batch, problem, embedding)

        transposed = added.transpose(1, 2)
        # shape: (batch, embedding, problem)

        normalized = self.norm(transposed)
        # shape: (batch, embedding, problem)

        back_trans = normalized.transpose(1, 2)
        # shape: (batch, problem, embedding)

        return back_trans


class AddAndBatchNormalization(nn.Module):
    def __init__(self, **model_params):
        super().__init__()
        embedding_dim = model_params['embedding_dim']
        self.norm_by_EMB = nn.BatchNorm1d(embedding_dim, affine=True)
        # 'Funny' Batch_Norm, as it will normalized by EMB dim

    def forward(self, input1, input2):
        # input.shape: (batch, problem, embedding)

        batch_s = input1.size(0)
        problem_s = input1.size(1)
        embedding_dim = input1.size(2)

        added = input1 + input2
        normalized = self.norm_by_EMB(added.reshape(batch_s * problem_s, embedding_dim))
        back_trans = normalized.reshape(batch_s, problem_s, embedding_dim)

        return back_trans


class FeedForward3(nn.Module):
    def __init__(self, embed_dim1, embed_dim2, embed_dim3):
        super().__init__()

        self.W1 = nn.Linear(embed_dim1, embed_dim2)
        self.W2 = nn.Linear(embed_dim2, embed_dim3)

    def forward(self, input1):
        # input.shape: (batch, problem, embedding)

        return self.W2(F.relu(self.W1(input1)))


class FeedForward(nn.Module):
    def __init__(self, **model_params):
        super().__init__()
        embedding_dim = model_params['embedding_dim']
        ff_hidden_dim = model_params['ff_hidden_dim']

        self.W1 = nn.Linear(embedding_dim, ff_hidden_dim)
        self.W2 = nn.Linear(ff_hidden_dim, embedding_dim)

    def forward(self, input1):
        # input.shape: (batch, problem, embedding)

        return self.W2(F.relu(self.W1(input1)))


class FeedForward2(nn.Module):
    def __init__(self, **model_params):
        super().__init__()
        embedding_dim = model_params['embedding_dim']

        self.W1 = nn.Linear(embedding_dim, embedding_dim)

    def forward(self, input1):
        # input.shape: (batch, problem, embedding)

        return self.W1(input1)


# general class for MLP
class MLP(nn.Module):
    @property
    def device(self):
        return self._dummy.device

    def __init__(self, units_list, act_fn):
        super().__init__()
        self._dummy = nn.Parameter(torch.empty(0), requires_grad=False)
        self.units_list = units_list
        self.depth = len(self.units_list) - 1
        self.act_fn = getattr(F, act_fn)
        self.lins = nn.ModuleList([nn.Linear(self.units_list[i], self.units_list[i + 1]) for i in range(self.depth)])

    def forward(self, x, k_sparse):
        for i in range(self.depth):
            x = self.lins[i](x)
            if i < self.depth - 1:
                x = self.act_fn(x)
            else:
                x = x.reshape(x.size(0), x.size(1), -1)
                x = torch.tanh(x)
        return x


# MLP for predicting parameterization theta
class ParNet(MLP):
    def __init__(self, k_sparse, depth=3, units=32, preds=1, act_fn='silu'):
        self.units = units
        self.preds = preds
        self.k_sparse = k_sparse
        unit_list = [self.units] * (depth - 2) + [int(self.units / 2)] + [int(self.units / 4)] + [self.preds]
        super().__init__(unit_list, act_fn)

    def forward(self, x, distances, topk_indices):
        B, N, N = distances.size()
        heatmap = super().forward(x, self.k_sparse).squeeze(dim=-1).reshape(B, N, self.k_sparse)
        full_heatmap = torch.full((B, N, N), -1.0, device=heatmap.device)

        # Fill in the top-k heatmap scores into the correct positions
        full_heatmap.scatter_(dim=2, index=topk_indices, src=heatmap)
        return full_heatmap

    def forward2(self, x, distances):
        x_norm = F.normalize(x, p=2, dim=-1)  # shape: (B, N, D)

        # Compute cosine similarity: (B, N, N)
        cos_sim = torch.matmul(x_norm, x_norm.transpose(1, 2))
        # Mask self-similarity by setting diagonal to -inf
        B, N, D = x.shape
        mask = torch.eye(N, device=x.device).bool().unsqueeze(0)  # shape: (1, N, N)
        cos_sim.masked_fill_(mask, -float('inf'))
        # Get top-50 similar node indices for each node (B, N, 50)
        topk_sim_vals, topk_indices = torch.topk(cos_sim, k=self.k_sparse, dim=-1)

        # Gather top-50 neighbor embeddings: shape (B, N, 50, D)
        neighbors = torch.gather(
            x.unsqueeze(2).expand(-1, -1, N, -1),  # shape: (B, N, N, D)
            dim=2,
            index=topk_indices.unsqueeze(-1).expand(-1, -1, -1, D)  # shape: (B, N, 50, D)
        )
        close_node_distances = - torch.gather(distances,
                                              dim=2,
                                              index=topk_indices).unsqueeze(-1)
        # Expand original node embeddings to shape (B, N, 50, D)
        center_nodes = x.unsqueeze(2).expand(-1, -1, self.k_sparse, -1)

        # Concatenate each node with its 50 most similar neighbors: shape (B, N, 50, 2D)
        pair_features = torch.cat([center_nodes, neighbors, close_node_distances], dim=-1)

        heatmap = super().forward(pair_features, self.k_sparse).squeeze(dim=-1)
        full_heatmap = torch.full((B, N, N), -1.0, device=pair_features.device)

        # Fill in the top-k heatmap scores into the correct positions
        full_heatmap.scatter_(dim=2, index=topk_indices, src=heatmap)
        return full_heatmap
