import torch.nn.functional as F
import numpy as np
import torch
from numpy.random import gumbel
from bpp_env import bpp_env, multi_ccbpp_env

def greedy_search(score):
    return np.argsort(-score)

def update_mask(mask, new_nodes):
    mask[np.arange(len(new_nodes)), new_nodes] = True
    return mask

def beam_search(heat_map, K, seed=0):
    """
    按照每一次搜索的概率序列来排序
    :param heat_map: np.float, n*n matrix
    :param K: int, beam size
    :return:
    """
    np.random.seed(seed)
    num_orders = len(heat_map)
    mask = np.zeros((K, num_orders), dtype=bool)

    random_start = np.random.choice(num_orders, size=K)
    #random_start = np.arange(K)
    next_nodes = random_start.reshape(K, -1)
    mask = update_mask(mask, random_start)
    #mask[0, next_nodes[0]] = True

    next_ind = random_start
    score = np.ones(K)  # 维持K个最大概率分数，前t-1步
    prev_Ks = []
    #beam_lk = heat_map[next_ind] #（N，K）
    for step in range(num_orders-1):
        # 解析num_ndoes-1步，因为第一步已经初始化
        # print(f"Decoding {step+1} step")

        beam_lk  = heat_map[next_ind] * score.reshape(K, 1)

        beam_lk[mask] = float('-inf')
        after_mask = beam_lk.reshape(-1)  # 把K*N的矩阵铺平
        next_ind_large = np.argsort(after_mask)[-K:]
        score = np.sort(after_mask)[-K:]
        prev_k = next_ind_large // num_orders
        next_nodes = next_nodes[prev_k]  # 交换序列，包括指针前面的部分
        prev_Ks.append(prev_k)
        next_ind = next_ind_large - prev_k*num_orders
        next_nodes = np.concatenate((next_nodes, next_ind.reshape(K, 1)), axis=1)

        # 更新mask
        mask = update_mask(mask[prev_k], next_ind)

    return next_nodes


def sample_search(heat_map, K, seed=0):
    """
    :param heat_map: np.float, n*n matrix
    :param K: int, beam size
    :return:
    """
    # heat_map = remove_padding(heat_map, True)


    num_orders = len(heat_map)

    mask = np.zeros((K, num_orders), dtype=bool)
    #mask = torch.zeros(K, num_orders, dtype=torch.bool).to(device)

    random_start = np.random.choice(num_orders, size=K)

    #random_start = torch.randint(0, num_orders, size=(K,)).to(device)

    next_nodes = random_start.reshape(K, -1)
    mask = update_mask(mask, random_start)

    next_ind = random_start

    for step in range(num_orders-1):
        # 解析num_ndoes-1步，因为第一步已经初始化
        masked_score = heat_map[next_ind]
        masked_score[mask] = float('-inf')

        next_ind = np.argmax(masked_score, axis=1)

        next_nodes = np.concatenate((next_nodes, next_ind.reshape(K, 1)), axis=1)

        # 更新mask
        mask = update_mask(mask, next_ind)

    # check
    # for path in next_nodes:
    #     if len(set(path)) < num_orders:
    #         print('False')
    return next_nodes

def sample_bin_search(weights, heatmap, fit, K = 64):

    num_orders = len(heatmap)

    max_reward = float('inf')
    # mask = torch.zeros(K, num_orders, dtype=torch.bool).to(device)

    random_start = np.random.choice(num_orders, size=K)
    env = bpp_env(weights[None,:], heatmap[None,:,:], fit)

    for i in range(K):
        current = random_start[i]
        env.reset()
        for _ in range(num_orders):
            s, r, done = env.step([current])
            _, mask, pack_orders = s
            #score = heatmap[tuple(pack_orders)].sum(axis=0)
            score = heatmap[current]
            ninf_mask = np.where(mask == True, float('-inf'), 0.)
            masked_score = score + ninf_mask  # 此处要避免inplace的操作
            current = np.argmax(masked_score)

        max_reward = min(max_reward, -r)
    return max_reward


def sample_bin_batch_search(weights,  B, Q, C,M, heatmap, fit, K = 32):

    num_orders = len(heatmap)
    random_start = np.random.choice(num_orders, size=K)
    if M==1:
        env =bpp_env(weights[None,:,:].repeat(K, axis=0), B, Q, C, heatmap[None,:,:].repeat(K, axis=0), fit)
    else:
        env = multi_ccbpp_env(weights[None,:,:].repeat(K, axis=0), B, Q, C, M, heatmap[None,:,:].repeat(K, axis=0), fit)
    current = random_start

    env.reset()
    for _ in range(num_orders):
        if K == 1:
            s, r, done = env.step(current)
        else:
            s, r, done = env.step(current)
        _, mask, pack_orders = s
        if K != 1:
            score = heatmap[pack_orders].sum(axis=1)
        else:
            score = heatmap[tuple(pack_orders)].sum(axis=0)
        #score = heatmap[current]
        ninf_mask = np.where(mask == True, float('-inf'), 0.)
        masked_score = score + ninf_mask  # 此处要避免inplace的操作
        current = np.argmax(masked_score, axis=1)

    max_reward = min(-r)
    return max_reward