import math
import torch
import torch.nn as nn
import torch.nn.functional as F


from src.pe.base import TourLayer
from src.pe.base import AddAndInstanceNormalization

from src.pe.ape import TourLayer_APE
from src.pe.rpe import TourLayer_RPE
from src.pe.rope import TourLayer_ROPE
from src.pe.sin import TourLayer_SIN
from src.pe.hades import TourLayer_HADES
from src.pe.hades_no_cpe import TourLayer_HADES_NO_CPE
from src.pe.hades_no_ipe import TourLayer_HADES_NO_IPE
from src.pe.hades_norm import TourLayer_HADES_NORM_SIN


class Model(nn.Module):

    def __init__(self, **model_params):
        super().__init__()
        self.model_params = model_params

        self.problem = model_params['problem']

        # If there is no encoder type, use original
        if "encoder_type" not in model_params:
            model_params['encoder_type'] = "original"

        if model_params['encoder_type'] == "original":
            self.encoder = CVRP_Encoder(**model_params)
        else:
            self.encoder = CVRP_Encoder_PE(**model_params)
            if model_params["encoder_type"] == "ape":
                self.encoder.tour_layer = TourLayer_APE(**model_params)
            elif model_params["encoder_type"] == "rpe":
                self.encoder.tour_layer = TourLayer_RPE(**model_params)
            elif model_params["encoder_type"] == "sin":
                self.encoder.tour_layer = TourLayer_SIN(**model_params)
            elif model_params["encoder_type"] == "rope":
                self.encoder.tour_layer = TourLayer_ROPE(**model_params)
            elif model_params["encoder_type"] == "hades":
                self.encoder.tour_layer = TourLayer_HADES(**model_params)
            elif model_params["encoder_type"] == "hades_no_cpe":
                self.encoder.tour_layer = TourLayer_HADES_NO_CPE(**model_params)
            elif model_params["encoder_type"] == "hades_no_ipe":
                self.encoder.tour_layer = TourLayer_HADES_NO_IPE(**model_params)
            elif model_params["encoder_type"] == "hades_norm_sin":
                self.encoder.tour_layer = TourLayer_HADES_NORM_SIN(**model_params)
            else:
                raise NotImplementedError

        self.encoder_type = model_params['encoder_type']

        self.decoder = CVRP_Decoder(**model_params)
        self.encoded_nodes = None

        self.start_last_node = nn.Parameter(torch.zeros(model_params['embedding_dim']), requires_grad=True)

    def pre_forward(self, reset_state, z):
        depot_xy = reset_state.problem_feat.depot_xy
        # shape: (batch, 1, 2)
        node_xy = reset_state.problem_feat.node_xy
        # shape: (batch, problem, 2)
        node_demand = reset_state.problem_feat.node_demand
        # shape: (batch, problem)
        node_feat = torch.cat((node_xy, node_demand[:, :, None]), dim=2)
        # shape: (batch, problem, 3)
        solution_neighbours = reset_state.neighbours
        # shape: (batch, problem, 2)

        if self.problem == "vrptw":
            node_feat = torch.cat((node_feat, reset_state.problem_feat.node_tw), dim=2)
        elif self.problem == "pcvrp":
            node_feat = torch.cat((node_feat, reset_state.problem_feat.node_prizes[:, :, None]), dim=2)

        # If use original NDS, we don't need the pos_index and cur_dist
        if self.encoder_type == "original":
            self.encoded_nodes = self.encoder(depot_xy, node_feat, solution_neighbours, reset_state.tour_index)
        else:
            self.encoded_nodes = self.encoder(depot_xy, node_feat, solution_neighbours, reset_state.tour_index, reset_state.pos_index, reset_state.cur_dist, reset_state.tour_angle)

        # shape: (batch, problem+1, embedding)
        self.decoder.set_kv(self.encoded_nodes, z)

    def forward(self, state, temperature=1.0):
        batch_size = state.BATCH_IDX.size(0)
        pomo_size = state.BATCH_IDX.size(1)

        if state.current_node is None:
            encoded_last_node = self.start_last_node[None, None].expand(batch_size, pomo_size, -1)
        else:
            encoded_last_node = _get_encoding(self.encoded_nodes, state.current_node)
            # shape: (batch, pomo, embedding)

        probs = self.decoder(encoded_last_node, ninf_mask=state.ninf_mask, temperature=temperature)
        # shape: (batch, pomo, problem+1)

        if self.training or self.model_params['eval_type'] == 'softmax':
            while True:  # to fix pytorch.multinomial bug on selecting 0 probability elements
                with torch.no_grad():
                    selected = probs.reshape(batch_size * pomo_size, -1).multinomial(1) \
                        .squeeze(dim=1).reshape(batch_size, pomo_size)
                # shape: (batch, pomo)
                prob = probs[state.BATCH_IDX, state.ROLLOUT_IDX, selected].reshape(batch_size, pomo_size)
                # shape: (batch, pomo)
                if (prob != 0).all():
                    break

        else:
            selected = probs.argmax(dim=2)
            # shape: (batch, pomo)
            prob = None  # value not needed. Can be anything.

        return selected, prob, probs[:, :, 1:]


def _get_encoding(encoded_nodes, node_index_to_pick):
    # encoded_nodes.shape: (batch, problem, embedding)
    # node_index_to_pick.shape: (batch, pomo)

    batch_size = node_index_to_pick.size(0)
    pomo_size = node_index_to_pick.size(1)
    embedding_dim = encoded_nodes.size(2)

    gathering_index = node_index_to_pick[:, :, None].expand(batch_size, pomo_size, embedding_dim)
    # shape: (batch, pomo, embedding)

    picked_nodes = encoded_nodes.gather(dim=1, index=gathering_index)
    # shape: (batch, pomo, embedding)

    return picked_nodes


########################################
# ENCODER
########################################

class CVRP_Encoder(nn.Module):
    def __init__(self, **model_params):
        super().__init__()
        self.model_params = model_params
        self.problem = model_params['problem']
        self.embedding_dim = self.model_params['embedding_dim']

        self.embedding_depot = nn.Linear(2, self.embedding_dim)

        if self.problem == "cvrp":
            self.embedding_node = nn.Linear(3, self.embedding_dim)
        elif self.problem == "vrptw":
            self.embedding_node = nn.Linear(5, self.embedding_dim)
        elif self.problem == "pcvrp":
            self.embedding_node = nn.Linear(4, self.embedding_dim)
        else:
            raise NotImplementedError

        self.layers = nn.ModuleList(
            [EncoderLayer(**model_params) for _ in range(self.model_params['encoder_layer_num'])])
        self.tour_layer = None
        if self.model_params['tour_layer']:
            self.tour_layer = TourLayer(**model_params)
        self.mp_layers = nn.ModuleList(
            [MessagePassingLayer(**model_params) for _ in range(self.model_params['message_passing_layer_num'])])
        self.layers_2 = nn.ModuleList(
            [EncoderLayer(**model_params) for _ in range(self.model_params['encoder_layer_num_2'])])

    def forward(self, depot_xy, node_feat, solution_neighbours, tour_index):
        # depot_xy.shape: (batch, 1, 2)
        # node_xy_demand.shape: (batch, problem, 3)
        # solution_neighbours.shape: (batch, problem, 2)
        batch = node_feat.shape[0]
        problem = node_feat.shape[1]

        embedded_depot = self.embedding_depot(depot_xy)
        # shape: (batch, 1, embedding)
        embedded_node = self.embedding_node(node_feat)
        # shape: (batch, problem, embedding)

        out = torch.cat((embedded_depot, embedded_node), dim=1)
        # shape: (batch, problem+1, embedding)

        for layer in self.layers:
            out = layer(out)

        if self.tour_layer is not None:
            out = self.tour_layer(batch, problem, tour_index, out)

        for layer in self.mp_layers:
            out = layer(batch, problem, solution_neighbours, out)

        for layer in self.layers_2:
            out = layer(out)

        return out
        # shape: (batch, problem+1, embedding)


class CVRP_Encoder_PE(CVRP_Encoder):
    """ This is the encoder with various positional encodings """
    def __init__(self, **model_params):
        super().__init__(**model_params)

    def forward(self, depot_xy, node_feat, solution_neighbours, tour_index, pos_index, cur_dist, tour_angle):
        # depot_xy.shape: (batch, 1, 2)
        # node_xy_demand.shape: (batch, problem, 3)
        # solution_neighbours.shape: (batch, problem, 2)
        batch = node_feat.shape[0]
        problem = node_feat.shape[1]

        embedded_depot = self.embedding_depot(depot_xy)
        # shape: (batch, 1, embedding)
        embedded_node = self.embedding_node(node_feat)
        # shape: (batch, problem, embedding)

        out = torch.cat((embedded_depot, embedded_node), dim=1)
        # shape: (batch, problem+1, embedding)

        for layer in self.layers:
            out = layer(out)

        if self.tour_layer is not None:
            out = self.tour_layer(batch, problem, tour_index, out, pos_index, cur_dist, tour_angle)

        for layer in self.mp_layers:
            out = layer(batch, problem, solution_neighbours, out)

        for layer in self.layers_2:
            out = layer(out)

        return out


class MessagePassingLayer(nn.Module):
    def __init__(self, **model_params):
        super().__init__()
        self.embedding_dim = embedding_dim = model_params['embedding_dim']

        if model_params['problem'] == "cvrp" or model_params['problem'] == "pcvrp":
            self.directed_graph = False
        elif model_params['problem'] == "vrptw":
            self.directed_graph = True

        # Custom normalization layer
        self.add_and_normalize = AddAndInstanceNormalization(**model_params)

        # Layers to transform neighbour embeddings
        if self.directed_graph:
            self.left_neighbour_projector = nn.Linear(embedding_dim, embedding_dim, bias=False)
            self.right_neighbour_projector = nn.Linear(embedding_dim, embedding_dim, bias=False)
        else:
            self.neighbour_projector = nn.Linear(embedding_dim, embedding_dim, bias=False)

        # Layer to combine customer and neighbour embeddings
        self.neighbour_combiner = nn.Linear(embedding_dim * 2, embedding_dim)

        # Feedforward layer for further processing
        self.feedforward_layer = nn.Linear(embedding_dim, embedding_dim)

    def forward(self, batch_size: int, num_customers: int, solution_neighbours: torch.Tensor,
                out: torch.Tensor) -> torch.Tensor:
        # Sanitize neighbour indices to be within [0, num_customers] (0 is depot, max is N)
        # This avoids invalid indices (-1 for unvisited) causing device-side asserts in gather
        safe_neighbours = solution_neighbours.clamp(min=0, max=num_customers)

        # Get left neighbour embeddings
        left_neighbour_embeddings = torch.gather(out, 1,
                                                 safe_neighbours[:, :, [0]].expand(batch_size, num_customers,
                                                                                   self.embedding_dim))

        # Get right neighbour embeddings
        right_neighbour_embeddings = torch.gather(out, 1,
                                                  safe_neighbours[:, :, [1]].expand(batch_size, num_customers,
                                                                                    self.embedding_dim))

        # Transform neighbour embeddings
        if self.directed_graph:
            left_neighbour_embeddings = self.left_neighbour_projector(left_neighbour_embeddings)
            right_neighbour_embeddings = self.right_neighbour_projector(right_neighbour_embeddings)
        else:
            left_neighbour_embeddings = self.neighbour_projector(left_neighbour_embeddings)
            right_neighbour_embeddings = self.neighbour_projector(right_neighbour_embeddings)

        # Sum transformed neighbour embeddings
        combined_neighbours = left_neighbour_embeddings + right_neighbour_embeddings

        # Concatenate customer embeddings with combined neighbour embeddings
        combined_embeddings = torch.cat((out[:, 1:], combined_neighbours), dim=2)
        combined_embeddings = F.relu(self.neighbour_combiner(combined_embeddings))

        # Apply feedforward layer
        combined_embeddings = self.feedforward_layer(combined_embeddings)

        # Normalize and add embeddings
        normalized_embeddings = self.add_and_normalize(out[:, 1:], combined_embeddings)

        # Re-add the depot to the embeddings
        out = torch.cat((out[:, [0]], normalized_embeddings), dim=1)

        return out


class EncoderLayer(nn.Module):
    def __init__(self, **model_params):
        super().__init__()
        self.model_params = model_params
        embedding_dim = self.model_params['embedding_dim']
        head_num = self.model_params['head_num']
        qkv_dim = self.model_params['qkv_dim']

        self.Wq = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)
        self.Wk = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)
        self.Wv = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)
        self.multi_head_combine = nn.Linear(head_num * qkv_dim, embedding_dim)

        self.add_n_normalization_1 = AddAndInstanceNormalization(**model_params)
        self.feed_forward = FeedForward(**model_params)
        self.add_n_normalization_2 = AddAndInstanceNormalization(**model_params)

    def forward(self, input1):
        # input1.shape: (batch, problem+1, embedding)
        head_num = self.model_params['head_num']

        q = reshape_by_heads(self.Wq(input1), head_num=head_num)
        k = reshape_by_heads(self.Wk(input1), head_num=head_num)
        v = reshape_by_heads(self.Wv(input1), head_num=head_num)
        # qkv shape: (batch, head_num, problem, qkv_dim)

        out_concat = fast_multi_head_attention(q, k, v)
        # shape: (batch, problem, head_num*qkv_dim)

        multi_head_out = self.multi_head_combine(out_concat)
        # shape: (batch, problem, embedding)

        out1 = self.add_n_normalization_1(input1, multi_head_out)
        out2 = self.feed_forward(out1)
        out3 = self.add_n_normalization_2(out1, out2)

        return out3
        # shape: (batch, problem, embedding)


########################################
# DECODER
########################################

class CVRP_Decoder(nn.Module):
    def __init__(self, **model_params):
        super().__init__()
        self.model_params = model_params
        embedding_dim = self.model_params['embedding_dim']
        poly_embedding_dim = self.model_params['poly_embedding_dim']
        head_num = self.model_params['head_num']
        qkv_dim = self.model_params['qkv_dim']
        z_dim = model_params['z_dim']

        # self.Wq_1 = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)
        # self.Wq_2 = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)
        self.Wq_last = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)
        self.Wk = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)
        self.Wv = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)

        self.multi_head_combine = nn.Linear(head_num * qkv_dim, embedding_dim)

        self.k = None  # saved key, for multi-head attention
        self.v = None  # saved value, for multi-head_attention
        self.single_head_key = None  # saved, for single-head attention
        self.z = None

        self.GRU = torch.nn.GRUCell(embedding_dim, embedding_dim)

        self.poly_layer_1 = nn.Linear(embedding_dim + z_dim, poly_embedding_dim)
        self.poly_layer_2 = nn.Linear(poly_embedding_dim, embedding_dim)

    def set_kv(self, encoded_nodes, z):
        # encoded_nodes.shape: (batch, problem+1, embedding)
        head_num = self.model_params['head_num']

        self.k = reshape_by_heads(self.Wk(encoded_nodes), head_num=head_num)
        self.v = reshape_by_heads(self.Wv(encoded_nodes), head_num=head_num)
        # shape: (batch, head_num, problem+1, qkv_dim)
        self.single_head_key = encoded_nodes.transpose(1, 2)
        # shape: (batch, embedding, problem+1)
        self.GRU_hidden = None
        self.z = z

    def set_q1(self, encoded_q1):
        # encoded_q.shape: (batch, n, embedding)  # n can be 1 or pomo
        head_num = self.model_params['head_num']
        self.q1 = reshape_by_heads(self.Wq_1(encoded_q1), head_num=head_num)
        # shape: (batch, head_num, n, qkv_dim)

    def set_q2(self, encoded_q2):
        # encoded_q.shape: (batch, n, embedding)  # n can be 1 or pomo
        head_num = self.model_params['head_num']
        self.q2 = reshape_by_heads(self.Wq_2(encoded_q2), head_num=head_num)
        # shape: (batch, head_num, n, qkv_dim)

    def forward(self, encoded_last_node, ninf_mask, temperature):
        # encoded_last_node.shape: (batch, pomo, embedding)
        # load.shape: (batch, pomo)
        # ninf_mask.shape: (batch, pomo, problem)
        batch_size = encoded_last_node.shape[0]
        rollout_size = encoded_last_node.shape[1]
        embedding_dim = encoded_last_node.shape[2]

        head_num = self.model_params['head_num']

        #  Multi-Head Attention
        #######################################################
        # input_cat = torch.cat((encoded_last_node, load[:, :, None]), dim=2)
        # shape = (batch, group, EMBEDDING_DIM+1)

        self.GRU_hidden = self.GRU(encoded_last_node.reshape(batch_size * rollout_size, embedding_dim), self.GRU_hidden)
        context = self.GRU_hidden.reshape(batch_size, rollout_size, embedding_dim)
        q_last = reshape_by_heads(self.Wq_last(context), head_num=head_num)
        # shape: (batch, head_num, pomo, qkv_dim)

        # # shape: (batch, head_num, pomo, qkv_dim)
        q = q_last
        # shape: (batch, head_num, pomo, qkv_dim)

        out_concat = fast_multi_head_attention(q, self.k, self.v, rank3_ninf_mask=ninf_mask)
        # shape: (batch, pomo, head_num*qkv_dim)

        mh_atten_out = self.multi_head_combine(out_concat)
        # shape: (batch, pomo, embedding)

        poly_out = self.poly_layer_1(torch.cat((mh_atten_out, self.z), dim=2))
        # shape: ?
        poly_out = F.relu(poly_out)
        # shape: ?
        poly_out = self.poly_layer_2(poly_out)
        # shape: ?

        mh_atten_out += poly_out

        #  Single-Head Attention, for probability calculation
        #######################################################
        score = torch.matmul(mh_atten_out, self.single_head_key)
        # shape: (batch, pomo, problem)

        sqrt_embedding_dim = self.model_params['sqrt_embedding_dim']
        logit_clipping = self.model_params['logit_clipping']

        score_scaled = score / sqrt_embedding_dim
        # shape: (batch, pomo, problem)

        score_clipped = logit_clipping * torch.tanh(score_scaled)

        score_masked = score_clipped + ninf_mask

        probs = F.softmax(score_masked / temperature, dim=2)
        # shape: (batch, pomo, problem)

        return probs


########################################
# NN SUB CLASS / FUNCTIONS
########################################

def reshape_by_heads(qkv, head_num):
    # q.shape: (batch, n, head_num*key_dim)   : n can be either 1 or PROBLEM_SIZE

    batch_s = qkv.size(0)
    n = qkv.size(1)

    q_reshaped = qkv.reshape(batch_s, n, head_num, -1)
    # shape: (batch, n, head_num, key_dim)

    q_transposed = q_reshaped.transpose(1, 2)
    # shape: (batch, head_num, n, key_dim)

    return q_transposed


def fast_multi_head_attention(q, k, v, rank3_ninf_mask=None):
    batch_s = q.size(0)
    head_num = q.size(1)
    n = q.size(2)
    key_dim = q.size(3)
    input_s = k.size(2)

    mask = None
    if rank3_ninf_mask is not None:
        mask = rank3_ninf_mask[:, None, :, :]
        mask = mask.expand(batch_s, head_num, n, input_s)

    out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
    out_transposed = out.transpose(1, 2)
    out_concat = out_transposed.reshape(batch_s, n, head_num * key_dim)

    return out_concat


class AddAndReZeroNormalization(nn.Module):
    def __init__(self, **model_params):
        super().__init__()
        self.embedding_dim = model_params['embedding_dim']
        self.alpha = nn.Parameter(torch.zeros(1), requires_grad=True)  # Learnable scalar parameter initialized to zero

    def forward(self, input1, input2):
        # Apply ReZero normalization
        rezero_output = input1 + self.alpha * input2

        return rezero_output

class FeedForward(nn.Module):
    def __init__(self, **model_params):
        super().__init__()
        embedding_dim = model_params['embedding_dim']
        ff_hidden_dim = model_params['ff_hidden_dim']

        self.W1 = nn.Linear(embedding_dim, ff_hidden_dim)
        self.W2 = nn.Linear(ff_hidden_dim, embedding_dim)

    def forward(self, input1):
        # input.shape: (batch, problem, embedding)

        return self.W2(F.relu(self.W1(input1)))
