import torch
import torch.nn as nn
import torch.nn.functional as F

from Encoding.MatNet import MatNetEncoder
from Decoding.MP_decoder import MPDecoder

from torch_geometric.utils import to_dense_batch


class TSPModel(nn.Module):

    def __init__(self, encoder_params, decoder_params):
        super().__init__()
        self.encoder_params = encoder_params
        self.decoder_params = decoder_params
        
        self.encoder = MatNetEncoder(**encoder_params)
        self.encoded_row = None
        self.encoded_col = None # shape: (batch, problem, EMBEDDING_DIM)
        
        self.decoder = MPDecoder(**decoder_params)

    def pre_forward(self, reset_state, P, pref=None, encode=True):
        problems = reset_state.problems
        tw_start = reset_state.tw_start
        tw_end = reset_state.tw_end
        # NOTE: We ignore service time as it is always zero anyway
        batch_size, self.n_nodes, _, _ = problems.shape

        tw_start_expanded = tw_start.unsqueeze(1).expand(batch_size, self.n_nodes, self.n_nodes).unsqueeze(3)
        tw_end_expanded = tw_end.unsqueeze(1).expand(batch_size, self.n_nodes, self.n_nodes).unsqueeze(3)
        problems_aug = torch.cat((problems, tw_start_expanded, tw_end_expanded), dim=3)

        self.encoded_row, self.encoded_col = self.encoder(problems, tw_start, tw_end)
        # encoded_nodes.shape: (batch, node, embedding)

        self.decoder.reset(problems, self.encoded_col, P)

    def forward(self, state):
        batch_size = state.BATCH_IDX.size(0)
        pomo_size = state.BATCH_IDX.size(1)
        
        if state.selected_count == 0:  # First Move, depot
            selected = torch.zeros(size=(batch_size, pomo_size), dtype=torch.long)
            prob = torch.ones(size=(batch_size, pomo_size))

            encoded_first_row = _get_encoding(self.encoded_row, selected)
            # shape: (batch, pomo, embedding)
            self.decoder.set_q1(encoded_first_row)

        elif state.selected_count == 1: # Second move, POMO
            selected = torch.arange(start=1, end=pomo_size + 1)[None, :].expand(batch_size, -1)
            prob = torch.ones(size=(batch_size, pomo_size))

        else:
            probs = self.decoder(state.current_node, state.ninf_mask, state.current_time)
            # shape: (batch, pomo, job)

            if self.training or self.decoder_params['eval_type'] == 'softmax':
                while True:  # to fix pytorch.multinomial bug on selecting 0 probability elements
                    with torch.no_grad():
                        selected = probs.reshape(batch_size * pomo_size, -1).multinomial(1) \
                            .squeeze(dim=1).reshape(batch_size, pomo_size)
                        # shape: (batch, pomo)

                    prob = probs[state.BATCH_IDX, state.POMO_IDX, selected] \
                        .reshape(batch_size, pomo_size)
                    # shape: (batch, pomo)

                    if (prob != 0).all():
                        break
            else:
                selected = probs.argmax(dim=2)
                # shape: (batch, pomo)
                prob = None
                
        return selected, prob
    
### UTILS ###

def _get_encoding(encoded_nodes, node_index_to_pick):
    # encoded_nodes.shape: (batch, problem, embedding)
    # node_index_to_pick.shape: (batch, pomo)

    batch_size = node_index_to_pick.size(0)
    pomo_size = node_index_to_pick.size(1)
    embedding_dim = encoded_nodes.size(2)

    gathering_index = node_index_to_pick[:, :, None].expand(batch_size, pomo_size, embedding_dim)
    # shape: (batch, pomo, embedding)

    picked_nodes = encoded_nodes.gather(dim=1, index=gathering_index)
    # shape: (batch, pomo, embedding)

    return picked_nodes
