import numpy as np
import torch
import torch.nn as nn
from torch.nn import Parameter
from torch.autograd import Variable
import torch.nn.functional as F
from utils import loss_edges_bpp, loss_mudularity
from gcn_layer import ResidualGatedGCNLayer
from bpp_env import bpp_env, multi_ccbpp_env

class MLP(nn.Module):
    """Multi-layer Perceptron for output prediction.
    """

    def __init__(self, hidden_dim, output_dim, L=2):
        super(MLP, self).__init__()
        self.L = L
        U = []
        for layer in range(self.L - 1):
            U.append(nn.Linear(hidden_dim, hidden_dim, True))
        self.U = nn.ModuleList(U)
        self.V = nn.Linear(hidden_dim, output_dim, True)

    def forward(self, x):
        """
        Args:
            x: Input features (batch_size, hidden_dim)
        Returns:
            y: Output predictions (batch_size, output_dim)
        """
        Ux = x
        for U_i in self.U:
            Ux = U_i(Ux)  # B x H
            Ux = F.relu(Ux)  # B x H
        y = self.V(Ux)  # B x O
        return y

class GCN(nn.Module):
    """
    Pointer-Net
    """
    def __init__(self,
                 input_dim,
                 hidden_dim,
                 gcn_layers=2,
                 mlp_layers = 3,
                 aggregation="mean"
):
        """
        Initiate Pointer-Net
        :param int input_dim: dimension of weight
        :param int embedding_dim: Number of embbeding channels
        :param int hidden_dim: Encoders hidden units
        :param int lstm_layers: Number of layers for LSTMs
        :param float dropout: Float between 0-1
        :param bool bidir: Bidirectional
        """

        super(GCN, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.aggregation = aggregation
        self.gcn_layers = gcn_layers
        self.mlp_layers = mlp_layers

        self.item_embedding = nn.Linear(input_dim, self.hidden_dim, bias=False)
        self.edges_embedding = nn.Linear(1, self.hidden_dim, bias=False)
        self.mlp_edges = MLP(self.hidden_dim, 2, self.mlp_layers)


        self.layers = nn.ModuleList()
        for layer in range(self.gcn_layers):
            self.layers.append(ResidualGatedGCNLayer(self.hidden_dim, self.aggregation))

        # Initialize decoder_input0
        #nn.init.uniform(self.decoder_input0, -1, 1)

    def forward(self, weights, classes, capacity, labels, edge_feat, edge_cw):
        """
        PointerNet - Forward-pass
        :param Tensor weights: Input sequence, batch_size *  num_items * 1
        :param Tensor edge_feat: Input sequence, batch_size * num_items * num_items * 3

        :return: Pointers probabilities and indices
        """

        #batch_size, nof_items, dim = weights.shape

        # decoder_input0 = self.decoder_input0.unsqueeze(0).expand(batch_size, -1)

        rate_weights = (weights/capacity).unsqueeze(2)
        norm_weights = (weights/torch.max(weights)).unsqueeze(2)

        x = self.item_embedding(torch.cat((rate_weights, norm_weights, classes), dim=2))
        #x = self.item_embedding(torch.cat(( norm_weights, classes), dim=2))
        edge_feat = edge_feat.unsqueeze(3)
        e = self.edges_embedding(edge_feat)

        # [ batch_size x embedding_dim x  num_items ]

        #adj = Variable(torch.from_numpy(np.ones((batch_size, nof_items, nof_items)))) # 全连接子图
        #inputs_knapsacks = inputs_knapsacks.view(batch_size, -1)

        for layer in range(self.gcn_layers):
            x, e = self.layers[layer](x, e)

        y_pred_edges = self.mlp_edges(e)  # B x V x V x voc_edges_out

        if self.training:
            loss = loss_edges_bpp(y_pred_edges, labels, edge_cw)\
                   + 0.3 *  loss_mudularity(y_pred_edges, labels)
            return  y_pred_edges, loss
        else:
            return y_pred_edges



    @staticmethod
    def sample(heat_map, K,  seed = 0):
        """
        :param heat_map: torch.float, (bs, n, n)
        :param K: int, beam size
        :return:
        """
        # heat_map = torch.FloatTensor(heat_map)
        np.random.seed(seed)

        def map_index(nex_ind, K, num_items):
            index_list = []

            for i, value in enumerate(nex_ind):
                data_index = i // K
                index_list.append(num_items * data_index + value.item())
            return np.array(index_list)

        device = heat_map.device

        batch_size, num_items = heat_map.shape[0], heat_map.shape[1]
        heat_map = heat_map.reshape(-1, num_items)
        mask = torch.zeros(batch_size * K, num_items, dtype=torch.bool).to(device)
        random_start = torch.randint(0, num_items, size=(batch_size, K)).to(device)
        logprob_list = torch.zeros(size=(batch_size, K, 1)).to(device)


        next_ind = random_start.reshape(batch_size*K, )

        next_nodes = random_start.unsqueeze(2)

        mask[np.arange(batch_size * K), next_ind] = True


        for step in range(num_items - 1):

            # 解析num_ndoes-1步，因为第一步已经初始化
            #index = np.vstack([np.arange(batch_size * K), next_ind.cpu().numpy()]).transpose()

            index = map_index(next_ind,  K, num_items)
            score = heat_map[index,:]

            inf_index = (mask.cpu().numpy() == True).nonzero()
            ninf_masks = torch.zeros([batch_size * K, num_items], dtype=torch.float).to(device)
            ninf_masks[inf_index] = float('-inf')
            masked_score = score + ninf_masks  # 此处要避免inplace的操作

            score_flatten = masked_score.reshape(-1, num_items)
            probs = torch.softmax(score_flatten, dim=1)  # batch_size x K x num_items

            next_ind = torch.argmax(masked_score.reshape(-1, num_items), dim=1)
            log_probs = probs[torch.arange(batch_size*K), next_ind].log()\
                .reshape(batch_size, K, 1)

            next_nodes = torch.cat([next_nodes, next_ind.reshape(batch_size, K, 1)], dim=2)
            logprob_list = torch.cat([logprob_list, log_probs], dim=2)
            # 更新mask
            #mask[:, torch.arange(K), next_ind] = True
            mask[np.arange(batch_size * K), next_ind] = True
            #mask = mask.reshape(batch_size, K, num_items)

        return next_nodes.detach().cpu().numpy(), logprob_list

    @staticmethod
    def bin_sample(test_data,  B, Q, C, M, heat_map,  config,  type = "cluster", seed = 0):

        K = config.param_K
        device = heat_map.device
        num_items = heat_map.shape[1]
        source_heatmap = heat_map[0]
        #valid_number = torch.sum(test_data[0,1,:]>=0)
        valid_number = num_items
        random_start = torch.randint(0, valid_number, size=(K,)).to(device)
        logprob_list = torch.zeros(size=(K, 1)).to(device)

        mask = torch.zeros(K, num_items, dtype=torch.bool).to(device)
        #mask[torch.arange(K),valid_number:] = True
        #mask = (test_data[:, 1, :] < 0).to(device)

        if M==1:
            env =  bpp_env(test_data.cpu().numpy(),  B, Q, C, heat_map.detach().cpu().numpy(), config.fit)
        else:
            env = multi_ccbpp_env(test_data.cpu().numpy(), B, Q, C, M, heat_map.detach().cpu().numpy(), config.fit)
        current = random_start
        env.reset()
        s, r, done= env.step(current.cpu().numpy())
        mask[torch.arange(len(current)), current] = True

        for step in range(valid_number - 1):
            # 解析num_ndoes-1步，因为第一步已经初始化
            _, _, pack_orders = s
            if type == "cluster":
                score = source_heatmap[pack_orders].sum(axis=1)
                #score = source_heatmap[current,:]

                # score = torch.tensor(pack_vector, requires_grad = True)    # 重要，将源头变成tensor并且require_grad
                ninf_mask = torch.where(mask == True, float('-inf'), 0.)
                masked_score = ninf_mask + score
                probs = torch.softmax(masked_score, dim=1)
                current = torch.argmax(masked_score, dim=1)
                log_probs = probs[torch.arange(K), current].log()  # (K,)
            else:
                score = source_heatmap[current, :]
                ninf_mask = torch.where(mask == True, float('-inf'), 0.)
                masked_score = score + ninf_mask  # 此处要避免inplace的操作
                probs = torch.softmax(masked_score, dim=1)
                if type == "argmax":
                    current = torch.argmax(masked_score, dim=1)
                    log_probs = probs[torch.arange(K), current].log()  # (K,)
                elif type == "sample":
                    item_sampler = torch.distributions.Categorical(probs)
                    current = item_sampler.sample()  # (K,)
                    log_probs = item_sampler.log_prob(current)  # (K,)

            logprob_list = torch.cat([logprob_list, log_probs.reshape(K, 1)], dim=1)
            s, r, done= env.step(current.cpu().numpy())
            mask[torch.arange(len(current)), current] = True
        bin_num = r

        return bin_num, logprob_list
