
import numpy as np
import torch
import torch.nn as nn
from torch.nn import Parameter
from torch.autograd import Variable
import torch.nn.functional as F
from utils import loss_edges_bpp, loss_mudularity
from gcn_layer import ResidualGatedGCNLayer
from bpp_env import bpp_env, multi_ccbpp_env

class MLP(nn.Module):
    """Multi-layer Perceptron for output prediction.
    """

    def __init__(self, hidden_dim, output_dim, L=2):
        super(MLP, self).__init__()
        self.L = L
        U = []
        for layer in range(self.L - 1):
            U.append(nn.Linear(hidden_dim, hidden_dim, True))
        self.U = nn.ModuleList(U)
        self.V = nn.Linear(hidden_dim, output_dim, True)

    def forward(self, x):
        """
        Args:
            x: Input features (batch_size, hidden_dim)
        Returns:
            y: Output predictions (batch_size, output_dim)
        """
        Ux = x
        for U_i in self.U:
            Ux = U_i(Ux)  # B x H
            Ux = F.relu(Ux)  # B x H
        y = self.V(Ux)  # B x O
        return y

class FC(nn.Module):
    def __init__(self, N, embedding_dim = 64, mlp_layers = 1):
        super(FC, self).__init__()
        self.N = N
        self.embedding_dim = embedding_dim
        self.embedding = nn.Linear(N * N, embedding_dim, bias=False)
        self.mlp_layers = mlp_layers
        self.mlp_edges = MLP(self.embedding_dim, N * N , self.mlp_layers)


    def forward(self, edge_features):
        batch_size = edge_features.size(0)
        x = edge_features.reshape(batch_size, -1)
        x = self.embedding(x)
        x = self.mlp_edges(x)
        return x.reshape(batch_size, self.N, self.N)



    @staticmethod
    def bin_sample(test_data,  B, Q, C, M, heat_map,  config,  seed = 0):

        K = config.param_K
        device = heat_map.device
        num_items = heat_map.shape[1]
        source_heatmap = heat_map[0]
        random_start = torch.randint(0, num_items, size=(K,)).to(device)
        logprob_list = torch.zeros(size=(K, 1)).to(device)
        mask = torch.zeros(K, num_items, dtype=torch.bool).to(device)
        if M==1:
            env =  bpp_env(test_data.cpu().numpy(),  B, Q, C, heat_map.detach().cpu().numpy(), config.fit)
        else:
            env = multi_ccbpp_env(test_data.cpu().numpy(), B, Q, C, M, heat_map.detach().cpu().numpy(), config.fit)
        current = random_start
        env.reset()
        s, r, done= env.step(current.cpu().numpy())
        mask[torch.arange(len(current)), current] = True

        for step in range(num_items - 1):
            # 解析num_ndoes-1步，因为第一步已经初始化
            _, _, pack_orders = s

            score = source_heatmap[pack_orders].sum(axis=1)
            #score = source_heatmap[current,:]

            # score = torch.tensor(pack_vector, requires_grad = True)    # 重要，将源头变成tensor并且require_grad
            ninf_mask = torch.where(mask == True, float('-inf'), 0.)
            masked_score = ninf_mask + score
            probs = torch.softmax(masked_score, dim=1)
            current = torch.argmax(masked_score, dim=1)
            log_probs = probs[torch.arange(K), current].log()  # (K,)

            logprob_list = torch.cat([logprob_list, log_probs.reshape(K, 1)], dim=1)
            s, r, done= env.step(current.cpu().numpy())
            mask[torch.arange(len(current)), current] = True
        bin_num = r

        return bin_num, logprob_list
