MAX_BIN = 70
class bpp_env:
    def __init__(self, data, capacity, Q, C, heatmap= None, fit = "NF"):
        # data : batch_s x num_items

        # seq.shape = (batch, TSP_SIZE, 2)

        self.data = np.swapaxes(data, 1, 2).astype(int)
        self.batch_s = self.data.shape[0]

        #self.group_s = None
        self.cur_idx = np.zeros(self.batch_s, dtype = np.int32)
        self.group_state = None
        self.num_items = self.data.shape[1]
        self.capacity = capacity
        self.Q = Q
        self.C = C
        self.fit = fit
        self.heatmap = heatmap
        self.packed_orders = [[[] for _ in range(MAX_BIN)] for _ in range(self.batch_s)]

    def reset(self):

        self.cur_idx = np.zeros(self.batch_s, dtype = np.int32)
        self.bins = np.zeros([self.batch_s, MAX_BIN])
        self.classes = np.zeros([self.batch_s, MAX_BIN, self.Q])
        #self.mask_array = self.data[:,:,1] < 0
        self.mask_array = np.zeros([self.batch_s, self.num_items], dtype = bool)
        self.packed_orders = [[[] for _ in range(MAX_BIN)] for _ in range(self.batch_s)]
        self.group_state = (self.bins, self.mask_array)
        reward = None
        done = False
        return self.observe(), reward, done

    def step(self, selected_idx):
        # selected_idx_mat.shape = (batch, group)

        # move state
        #self.group_state.move_to(selected_idx_mat)
        selected_idx = selected_idx
        assert self.mask_array[np.arange(self.batch_s), selected_idx].any() == False, 'illegal action'
        item = self.data[np.arange(self.batch_s),selected_idx,:]

        if self.fit == "FF":
            max_idx = np.max(self.cur_idx)
            for j in range(max_idx + 1):
                check_array = self._check_allocability(item, self.bins[:,j], self.classes[:,j,:])
                if check_array.all():
                    self.bins[:, j] +=  item[:,0]
                    not_empty_batch =  np.where(np.array(item[:,1]) >= 0)[0]

                    self.classes[not_empty_batch,j, item[not_empty_batch,1]] = 1
                    self.mask_array[np.arange(self.batch_s), selected_idx] = True
                    if self.heatmap is not None:
                        for i in range(self.batch_s):
                           self.packed_orders[i][j].append(selected_idx[i])
                    break
                else:
                    valid_batch = np.argwhere(check_array == True).reshape(-1)
                    #if len(valid_batch) > 0:
                    not_empty_batch = np.where(np.array(item[valid_batch, 1]) >= 0)[0]
                    self.bins[valid_batch, j] += item[valid_batch, 0]
                    self.classes[not_empty_batch, j, item[not_empty_batch, 1]] = 1
                    self.mask_array[valid_batch, selected_idx[valid_batch]] = True
                    if self.heatmap is not None:
                        for i in valid_batch:
                           self.packed_orders[i][j].append(selected_idx[i])

                    unvalid_batch = np.argwhere(check_array == False).reshape(-1)
                    for b in unvalid_batch:
                        for k in range(j+1, self.cur_idx[b] + 1):
                            check_array = self._check_allocability(item[b],self.bins[b, k], self.classes[b,k,:], batch_flag = False)
                            if check_array:
                                self.bins[b, k] += item[b, 0]
                                if item[b, 1] >= 0:
                                    self.classes[b, k, item[b, 1]] = 1
                                self.mask_array[b, selected_idx[b]] = True
                                if self.heatmap is not None:
                                    self.packed_orders[b][k].append(selected_idx[b])
                                break
                        if not self.mask_array[b, selected_idx[b]]:
                            # 现在打开的箱子装不了， 新开一个箱子
                            self.cur_idx[b] += 1
                            self.bins[b, self.cur_idx[b]] += item[b,0]
                            if item[b, 1] >= 0:
                                self.classes[b, self.cur_idx[b], item[b,1]] = 1
                            self.mask_array[b, selected_idx[b]] = True
                            if self.heatmap is not None:
                                self.packed_orders[b][self.cur_idx[b]].append(selected_idx[b])
                    break

        elif self.fit == "NF":
            check_array = self._check_allocability(item, self.bins[np.arange(self.batch_s), self.cur_idx],
                                                   self.classes[np.arange(self.batch_s), self.cur_idx,:])

            unvalid_batch = np.argwhere(check_array == False).reshape(-1)

            self.cur_idx[unvalid_batch] += 1
            self.bins[np.arange(self.batch_s), self.cur_idx] += item[np.arange(self.batch_s),0]
            self.classes[np.arange(self.batch_s), self.cur_idx, item[np.arange(self.batch_s),1]] = 1
            self.mask_array[np.arange(self.batch_s), selected_idx] = True
            if self.heatmap is not None:
                for b in range(self.batch_s):
                    self.packed_orders[b][self.cur_idx[b]].append(selected_idx[b])


        done = self.mask_array.all()
        if not done:
            reward = np.zeros(self.batch_s)
        else:
            reward = -(self.cur_idx + 1)
        return self.observe(), reward, done



    def _check_allocability(self, item, cur_bin, cur_class,  batch_flag = True):
        #try_combine = np.where(item > cur_bin, item, cur_bin)
        if batch_flag:
            weights, classes = item[:,0], item[:,1]
            class_cp = copy.deepcopy(cur_class)
            not_empty_batch = np.where(np.array(classes) >= 0)[0]
            class_cp[not_empty_batch,classes[not_empty_batch]] = 1
            #class_cp[np.arange(len(item)),classes] = 1
            flag1 = cur_bin + weights <= self.capacity
            flag2 = np.sum(class_cp, axis= 1) <= self.C
            return np.array([flag1[i] and flag2[i] for i in range(len(item))])
        else:
            weights, classes = item[0], item[1]
            class_cp = copy.deepcopy(cur_class)
            if classes >= 0:
                class_cp[classes] = 1
            return  (cur_bin + weights <= self.capacity) and (sum(class_cp) <= self.C)
            #if batch_flag else try_combine <= self.capacity

    def observe(self):
        latest_orders = []
        max_idx = max([len(self.packed_orders[i][self.cur_idx[i]]) for i in range(self.batch_s)])
        for i in range(len(self.cur_idx)):
            new_size = len(self.packed_orders[i][self.cur_idx[i]])
            latest_orders.append(copy.deepcopy(self.packed_orders[i][self.cur_idx[i]]))
            for j in range(max_idx - new_size):
                new_input = self.packed_orders[i][self.cur_idx[i]][j % new_size]
                latest_orders[-1].append(new_input)

        current_bin = self.bins[np.arange(self.batch_s),self.cur_idx]
        current_class = self.classes[np.arange(self.batch_s),self.cur_idx,:]

        return np.concatenate((current_bin[:,None,None],current_class[:,None,:]),2), self.mask_array, latest_orders


class multi_ccbpp_env:
    def __init__(self, data, capacity, Q, C, M, heatmap= None, fit = "NF"):
        # data : batch_s x num_items

        # seq.shape = (batch, TSP_SIZE, 2)

        self.data = np.swapaxes(data, 1, 2).astype(int)
        self.batch_s = self.data.shape[0]

        #self.group_s = None
        self.cur_idx = np.zeros(self.batch_s, dtype = np.int32)
        self.group_state = None
        self.num_items = self.data.shape[1]
        self.capacity = capacity
        self.fit = fit
        self.heatmap = heatmap
        self.Q = Q
        self.C = C
        self.M = M
        self.packed_orders = [[[] for _ in range(MAX_BIN)] for _ in range(self.batch_s)]

    def reset(self):

        self.cur_idx = np.zeros(self.batch_s, dtype = np.int32)

        self.bins = np.zeros([self.batch_s, MAX_BIN])
        self.classes = np.zeros([self.batch_s, MAX_BIN, self.Q])
        self.mask_array = np.zeros([self.batch_s, self.num_items], dtype = bool)
        self.packed_orders = [[[] for _ in range(MAX_BIN)] for _ in range(self.batch_s)]
        self.group_state = (self.bins, self.mask_array)
        reward = None
        done = False
        return self.observe(), reward, done

    def step(self, selected_idx):
        # selected_idx_mat.shape = (batch, group)

        # move state
        #self.group_state.move_to(selected_idx_mat)
        selected_idx = selected_idx
        assert self.mask_array[np.arange(self.batch_s), selected_idx].any() == False, 'illegal action'
        item = self.data[np.arange(self.batch_s),selected_idx,:]

        if self.fit == "FF":
            max_idx = np.max(self.cur_idx)
            for j in range(max_idx + 1):
                check_array = self._check_allocability(item, self.bins[:,j], self.classes[:,j,:])
                if check_array.all():
                    not_empty_batch = np.where(np.array(item[:, 1]) >= 0)[0]
                    self.bins[:, j] +=  item[:,0]
                    for i in range(self.M):
                        self.classes[not_empty_batch,j, item[not_empty_batch,1+i]] = 1
                    self.mask_array[np.arange(self.batch_s), selected_idx] = True
                    if self.heatmap is not None:
                        for i in range(self.batch_s):
                           self.packed_orders[i][j].append(selected_idx[i])
                    break
                else:
                    valid_batch = np.argwhere(check_array == True).reshape(-1)
                    #if len(valid_batch) > 0:
                    not_empty_batch = np.where(np.array(item[valid_batch, 1]) >= 0)[0]
                    self.bins[valid_batch, j] += item[valid_batch, 0]
                    for i in range(self.M):
                        self.classes[not_empty_batch, j, item[not_empty_batch, 1+i]] = 1
                    self.mask_array[valid_batch, selected_idx[valid_batch]] = True
                    if self.heatmap is not None:
                        for i in valid_batch:
                           self.packed_orders[i][j].append(selected_idx[i])

                    unvalid_batch = np.argwhere(check_array == False).reshape(-1)
                    for b in unvalid_batch:
                        for k in range(j+1, self.cur_idx[b] + 1):
                            check_array = self._check_allocability(item[b],self.bins[b, k], self.classes[b,k,:], batch_flag = False)
                            if check_array:
                                self.bins[b, k] += item[b, 0]
                                for i in range(self.M):
                                    if item[b, 1+i] >= 0:
                                        self.classes[b, k, item[b, 1+i]] = 1

                                self.mask_array[b, selected_idx[b]] = True
                                if self.heatmap is not None:
                                    self.packed_orders[b][k].append(selected_idx[b])
                                break
                        if not self.mask_array[b, selected_idx[b]]:
                            # 现在打开的箱子装不了， 新开一个箱子
                            self.cur_idx[b] += 1
                            self.bins[b, self.cur_idx[b]] += item[b,0]
                            for i in range(self.M):
                                if item[b,1+i] >= 0:
                                    self.classes[b, self.cur_idx[b], item[b,1+i]] = 1
                            self.mask_array[b, selected_idx[b]] = True
                            if self.heatmap is not None:
                                self.packed_orders[b][self.cur_idx[b]].append(selected_idx[b])
                    break

        elif self.fit == "NF":
            check_array = self._check_allocability(item, self.bins[np.arange(self.batch_s), self.cur_idx],
                                                   self.classes[np.arange(self.batch_s), self.cur_idx,:])

            unvalid_batch = np.argwhere(check_array == False).reshape(-1)

            self.cur_idx[unvalid_batch] += 1
            self.bins[np.arange(self.batch_s), self.cur_idx] += item[np.arange(self.batch_s),0]
            for i in range(self.M):
                self.classes[np.arange(self.batch_s), self.cur_idx, item[np.arange(self.batch_s),1+i]] = 1
            self.mask_array[np.arange(self.batch_s), selected_idx] = True
            if self.heatmap is not None:
                for b in range(self.batch_s):
                    self.packed_orders[b][self.cur_idx[b]].append(selected_idx[b])


        done = self.mask_array.all()
        if not done:
            reward = np.zeros(self.batch_s)
        else:
            reward = -(self.cur_idx + 1)
        return self.observe(), reward, done



    def _check_allocability(self, item, cur_bin, cur_class,  batch_flag = True):
        #try_combine = np.where(item > cur_bin, item, cur_bin)
        if batch_flag:
            weights, classes = item[:,0], item[:,1:]
            class_cp = copy.deepcopy(cur_class)
            not_empty_batch = np.where(np.array(classes[:,0]) >= 0)[0]
            for i in range(self.M):
                class_cp[not_empty_batch,classes[not_empty_batch,i]] = 1
            flag1 = cur_bin + weights <= self.capacity
            flag2 = np.sum(class_cp, axis= 1) <= self.C
            return np.array([flag1[i] and flag2[i] for i in range(len(item))])
        else:
            weights, classes = item[0], item[1:]
            class_cp = copy.deepcopy(cur_class)
            if classes[0] >=0:
                for c in classes:
                    class_cp[c] = 1
            return  (cur_bin + weights <= self.capacity) and (sum(class_cp) <= self.C)
            #if batch_flag else try_combine <= self.capacity

    def observe(self):
        latest_orders = []
        max_idx = max([len(self.packed_orders[i][self.cur_idx[i]]) for i in range(self.batch_s)])
        for i in range(len(self.cur_idx)):
            new_size = len(self.packed_orders[i][self.cur_idx[i]])
            latest_orders.append(copy.deepcopy(self.packed_orders[i][self.cur_idx[i]]))
            for j in range(max_idx - new_size):
                new_input = self.packed_orders[i][self.cur_idx[i]][j % new_size]
                latest_orders[-1].append(new_input)
        current_bin = self.bins[np.arange(self.batch_s), self.cur_idx]
        current_class = self.classes[np.arange(self.batch_s), self.cur_idx, :]

        return np.concatenate((current_bin[:, None, None], current_class[:, None, :]),
                              2), self.mask_array, latest_orders
