import torch
import torch.nn as nn
from torch.nn import Parameter
import torch.nn.functional as F
from utils import gat_loss
from bpp_env import bpp_env, multi_ccbpp_env
class MLP(nn.Module):
    """Multi-layer Perceptron for output prediction.
    """

    def __init__(self, hidden_dim, output_dim, L=2):
        super(MLP, self).__init__()
        self.L = L
        U = []
        for layer in range(self.L - 1):
            U.append(nn.Linear(hidden_dim, hidden_dim, True))
        self.U = nn.ModuleList(U)
        self.V = nn.Linear(hidden_dim, output_dim, True)

    def forward(self, x):
        """
        Args:
            x: Input features (batch_size, hidden_dim)
        Returns:
            y: Output predictions (batch_size, output_dim)
        """
        Ux = x
        for U_i in self.U:
            Ux = U_i(Ux)  # B x H
            Ux = F.relu(Ux)  # B x H
        y = self.V(Ux)  # B x O
        return y

class Encoder(nn.Module):
    """
    Encoder class for Pointer-Net
    """

    def __init__(self, embedding_dim,
                 hidden_dim,
                 n_layers,
                 dropout,
                 bidir):
        """
        Initiate Encoder

        :param Tensor embedding_dim: Number of embbeding channels
        :param int hidden_dim: Number of hidden units for the LSTM
        :param int n_layers: Number of layers for LSTMs
        :param float dropout: Float between 0-1
        :param bool bidir: Bidirectional
        """

        super(Encoder, self).__init__()
        self.hidden_dim = hidden_dim//2 if bidir else hidden_dim
        self.n_layers = n_layers*2 if bidir else n_layers
        self.bidir = bidir
        self.lstm = nn.LSTM(embedding_dim,
                            self.hidden_dim,
                            n_layers,
                            dropout=dropout,
                            bidirectional=bidir)

        # Used for propagating .cuda() command
        self.h0 = Parameter(torch.zeros(1), requires_grad=False)
        self.c0 = Parameter(torch.zeros(1), requires_grad=False)

    def forward(self, embedded_inputs,
                hidden):
        """
        Encoder - Forward-pass

        :param Tensor embedded_inputs: Embedded inputs of Pointer-Net
        :param Tensor hidden: Initiated hidden units for the LSTMs (h, c)
        :return: LSTMs outputs and hidden units (h, c)
        """

        embedded_inputs = embedded_inputs.permute(1, 0, 2)

        outputs, hidden = self.lstm(embedded_inputs, hidden)

        return outputs.permute(1, 0, 2), hidden

    def init_hidden(self, embedded_inputs):
        """
        Initiate hidden units

        :param Tensor embedded_inputs: The embedded input of Pointer-NEt
        :return: Initiated hidden units for the LSTMs (h, c)
        """

        batch_size = embedded_inputs.size(0)

        # Reshaping (Expanding)
        h0 = self.h0.unsqueeze(0).unsqueeze(0).repeat(self.n_layers,
                                                      batch_size,
                                                      self.hidden_dim)
        c0 = self.h0.unsqueeze(0).unsqueeze(0).repeat(self.n_layers,
                                                      batch_size,
                                                      self.hidden_dim)

        return h0, c0


class Attention(nn.Module):
    """
    Attention model for Pointer-Net
    """

    def __init__(self, input_dim,
                 hidden_dim):
        """
        Initiate Attention

        :param int input_dim: Input's diamention
        :param int hidden_dim: Number of hidden units in the attention
        """

        super(Attention, self).__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim

        self.input_linear = nn.Linear(input_dim, hidden_dim)
        self.context_linear = nn.Conv1d(input_dim, hidden_dim, 1, 1)
        self.V = Parameter(torch.FloatTensor(hidden_dim), requires_grad=True)
        self._inf = Parameter(torch.FloatTensor([float('-inf')]), requires_grad=False)
        self.tanh = nn.Tanh()
        self.softmax = nn.Softmax()

        # Initialize vector V
        nn.init.uniform(self.V, -1, 1)

    def forward(self, input,
                context,
                mask):
        """
        Attention - Forward-pass

        :param Tensor input: Hidden state h
        :param Tensor context: Attention context
        :param ByteTensor mask: Selection mask
        :return: tuple of - (Attentioned hidden state, Alphas)
        """

        # (batch, hidden_dim, seq_len)
        inp = self.input_linear(input).unsqueeze(2).expand(-1, -1, context.size(1))

        # (batch, hidden_dim, seq_len)
        context = context.permute(0, 2, 1)
        ctx = self.context_linear(context)

        # (batch, 1, hidden_dim)
        V = self.V.unsqueeze(0).expand(context.size(0), -1).unsqueeze(1)

        # (batch, seq_len)
        att = torch.bmm(V, self.tanh(inp + ctx)).squeeze(1)
        if len(att[mask]) > 0:
            att[mask] = self.inf[mask]
        alpha = self.softmax(att)

        hidden_state = torch.bmm(ctx, alpha.unsqueeze(2)).squeeze(2)

        return hidden_state, alpha

    def init_inf(self, mask_size):
        self.inf = self._inf.unsqueeze(1).expand(*mask_size)


class PointerNet(nn.Module):
    """
    Pointer-Net
    """

    def __init__(self, input_dim,
                 embedding_dim,
                 hidden_dim=64,
                 lstm_layers=2,
                 dropout=0.5,
                 mlp_layers=2,
                 bidir=False):
        """
        Initiate Pointer-Net

        :param int embedding_dim: Number of embbeding channels
        :param int hidden_dim: Encoders hidden units
        :param int lstm_layers: Number of layers for LSTMs
        :param float dropout: Float between 0-1
        :param bool bidir: Bidirectional
        """

        super(PointerNet, self).__init__()
        self.input_dim = input_dim
        self.embedding_dim = embedding_dim
        self.bidir = bidir
        self.embedding = nn.Linear(self.input_dim, embedding_dim, bias=False)
        self.encoder = Encoder(embedding_dim,
                               hidden_dim,
                               lstm_layers,
                               dropout,
                               bidir)
        self.mlp_layers = mlp_layers
        self.mlp_edges = MLP(embedding_dim, 2, self.mlp_layers)


    def forward(self, weights, classes, capacity, labels, edge_feat, edge_cw):
        """
        PointerNet - Forward-pass

        :param Tensor inputs: Input sequence
        :return: Pointers probabilities and indices
        """

        batch_size = weights.size(0)
        input_length = weights.size(1)

        rate_weights = (weights / capacity).unsqueeze(2)
        norm_weights = (weights / torch.max(weights)).unsqueeze(2)

        embedded_inputs = self.embedding(torch.cat((rate_weights, norm_weights, classes), dim=2)) \
            .view(batch_size, input_length, -1)

        encoder_hidden0 = self.encoder.init_hidden(embedded_inputs)
        encoder_outputs, encoder_hidden = self.encoder(embedded_inputs,
                                                       encoder_hidden0)

        encoder_outputs = self.mlp_edges(encoder_outputs)
        encoder_outputs = encoder_outputs @ encoder_outputs.transpose(1, 2)  # B x V x V x voc_edges_out
        y_pred_edges = encoder_outputs.transpose(1, 2).sigmoid()


        if self.training:
            loss = gat_loss(y_pred_edges, labels)
            return y_pred_edges, loss
        else:
            return y_pred_edges

    @staticmethod
    def bin_sample(test_data, B, Q, C, M, heat_map, config, seed=0):

        K = config.param_K
        device = heat_map.device
        num_items = heat_map.shape[1]
        source_heatmap = heat_map[0]
        random_start = torch.randint(0, num_items, size=(K,)).to(device)
        logprob_list = torch.zeros(size=(K, 1)).to(device)
        mask = torch.zeros(K, num_items, dtype=torch.bool).to(device)
        if M == 1:
            env = bpp_env(test_data.cpu().numpy(), B, Q, C, heat_map.detach().cpu().numpy(), config.fit)
        else:
            env = multi_ccbpp_env(test_data.cpu().numpy(), B, Q, C, M, heat_map.detach().cpu().numpy(), config.fit)
        current = random_start
        env.reset()
        s, r, done = env.step(current.cpu().numpy())
        mask[torch.arange(len(current)), current] = True

        for step in range(num_items - 1):
            # 解析num_ndoes-1步，因为第一步已经初始化
            _, _, pack_orders = s

            score = source_heatmap[pack_orders].sum(axis=1)
            # score = source_heatmap[current,:]

            # score = torch.tensor(pack_vector, requires_grad = True)    # 重要，将源头变成tensor并且require_grad
            ninf_mask = torch.where(mask == True, float('-inf'), 0.)
            masked_score = ninf_mask + score
            probs = torch.softmax(masked_score, dim=1)
            current = torch.argmax(masked_score, dim=1)
            log_probs = probs[torch.arange(K), current].log()  # (K,)

            logprob_list = torch.cat([logprob_list, log_probs.reshape(K, 1)], dim=1)
            s, r, done = env.step(current.cpu().numpy())
            mask[torch.arange(len(current)), current] = True
        bin_num = r

        return bin_num, logprob_list
