import torch
import torch.nn as nn
import torch.nn.functional as F



class TourLayer(nn.Module):
    def __init__(self, **model_params):
        super().__init__()
        self.embedding_dim = embedding_dim = model_params['embedding_dim']

        # Layer to combine customer and tour embeddings
        self.tour_combiner = nn.Linear(embedding_dim * 2, embedding_dim)

        # Layer to refine the combined embeddings
        self.feedforward_layer = nn.Linear(embedding_dim, embedding_dim)

        # Custom normalization layer
        self.add_and_normalize = AddAndInstanceNormalization(**model_params)

    def forward(self, batch_size: int, num_customers: int, tour_index: torch.Tensor, out: torch.Tensor) -> torch.Tensor:
        # Exclude the depot from customer embeddings
        customer_embeddings = out[:, 1:]

        # Calculate the maximum tour index
        max_tour_index = tour_index.max()

        # Find out if there exist unvisited customers (possible for example in the PCVRP). If yes,
        # create a dummy tour for them with index max_tour_index + 1
        all_nodes_visited = True
        if tour_index.min() == -1:
            tour_index[tour_index == -1] = max_tour_index + 1
            max_tour_index += 1
            all_nodes_visited = False

        # Initialize tour embeddings
        tour_embeddings = torch.zeros(batch_size, max_tour_index + 1, customer_embeddings.shape[2],
                                      dtype=customer_embeddings.dtype)


        # Accumulate customer embeddings into tour embeddings
        tour_embeddings.scatter_add_(1, tour_index[:, :, None].expand(-1, -1, customer_embeddings.shape[2]),
                                     customer_embeddings)

        # If a dummy tour was used for unvisited customers, set the tour embeddings for the dummy to all zero
        if not all_nodes_visited:
            tour_embeddings[:, -1] = 0

        # Gather the customer tour embeddings based on tour_index
        customer_tour_embeddings = torch.gather(tour_embeddings, 1,
                                                tour_index[:, :, None].expand(-1, -1, customer_embeddings.shape[2]))

        # Concatenate customer embeddings with their respective tour embeddings
        combined_embeddings = torch.cat((customer_embeddings, customer_tour_embeddings), dim=2)
        combined_embeddings = combined_embeddings.view(batch_size, num_customers, self.embedding_dim * 2)

        # Apply the tour combiner layer and activation
        combined_embeddings = F.relu(self.tour_combiner(combined_embeddings))

        # Apply feedforward layer
        combined_embeddings = self.feedforward_layer(combined_embeddings)

        # Normalize and add embeddings
        normalized_embeddings = self.add_and_normalize(customer_embeddings, combined_embeddings)

        # Re-add the depot to the embeddings
        out = torch.cat((out[:, [0]], normalized_embeddings), dim=1)

        return out


class AddAndInstanceNormalization(nn.Module):
    def __init__(self, **model_params):
        super().__init__()
        embedding_dim = model_params['embedding_dim']
        self.norm = nn.InstanceNorm1d(embedding_dim, affine=True, track_running_stats=False)

    def forward(self, input1, input2):
        # input.shape: (batch, problem, embedding)

        added = input1 + input2
        # shape: (batch, problem, embedding)

        transposed = added.transpose(1, 2)
        # shape: (batch, embedding, problem)

        normalized = self.norm(transposed)
        # shape: (batch, embedding, problem)

        back_trans = normalized.transpose(1, 2)
        # shape: (batch, problem, embedding)

        return back_trans