from abc import ABC, abstractmethod
import numpy as np
import copy
from collections import defaultdict
from itertools import product, combinations
from typing import List, Union, Sequence, Any, Tuple, Dict, Optional
from numpy import ndarray
from abl.utils import (
    hamming_dist,
    to_hashable,
    hashable_to_list,
    flatten,
    confidence_dist,
    calculate_revision_num,
    print_log,
)
from abl.bridge import BaseBridge
from abl.learning import ABLModel
from abl.reasoning import ReasonerBase
from abl.learning import ABLModel
from abl.reasoning import ReasonerBase
from abl.evaluation import BaseMetric
import os
from functools import lru_cache
import heapq
import pickle
from torch.utils.data import Dataset, DataLoader
from itertools import chain
from sklearn.metrics import confusion_matrix


class Heap:
    def __init__(self):
        self._heap = []

    def push(self, item):
        heapq.heappush(self._heap, item)

    def pop(self):
        return heapq.heappop(self._heap)

    def is_empty(self):
        return len(self._heap) == 0


class UnsupKBBase(ABC):
    def __init__(
        self,
        pseudo_label_list,
        prebuild_GKB=False,
        GKB_len_list=None,
        max_err=0,
        use_cache=False,
        kb_file_path=None,
        num=1,
        max_times=10000,
        ind=False,
        num_digits=1,
    ):
        """
        Initialize the KBBase instance.

        Args:
        pseudo_label_list (list): List of pseudo labels.
        prebuild_GKB (bool): Whether to prebuild the General Knowledge Base (GKB).
        GKB_len_list (list): List of lengths for the GKB.
        max_err (int): Maximum error threshold.
        use_cache (bool): Whether to use caching.
        kb_file_path (str, optional): Path to the file from which to load the pre-built knowledge base. If None, build a new knowledge base.
        """
        self.num = num
        self.max_times = max_times
        self.pseudo_label_list = pseudo_label_list
        self.prebuild_GKB = prebuild_GKB
        self.GKB_len_list = GKB_len_list
        self.max_err = max_err
        self.use_cache = use_cache
        self.base = {}
        self.ind = ind
        self.num_digits = num_digits
        self.kb_file_path = kb_file_path
        self.count = 0

    def prebuild_kb(self):
        if self.kb_file_path and os.path.exists(self.kb_file_path):
            self.load_kb(self.kb_file_path)
        elif self.prebuild_GKB:
            X, Y = self._get_GKB()
            for x, y in zip(X, Y):
                # print('len(x):',x)
                self.base.setdefault(len(x), defaultdict(list))[y].append(x)
            if self.kb_file_path:
                self.save_kb(self.kb_file_path)

    def sort_mask_ind(self, mask_probability):
        selected_dict = {}
        max_hard_label = [np.argmax(probability) for probability in mask_probability]
        max_score = 1
        # prob_list=[]
        for i in range(len(max_hard_label)):
            max_score *= mask_probability[i][max_hard_label[i]]
        sorted_V_tuple = []
        sorted_probs = []
        sorted_indices = []
        max_heap_cls = Heap()
        origin_root_state = []
        mask = [0 for _ in range(len(max_hard_label))]
        for i in range(len(max_hard_label)):
            sorted_prob_with_indices = sorted(
                enumerate(mask_probability[i]), key=lambda x: (-x[1], -x[0])
            )
            sorted_prob = [x[1] for x in sorted_prob_with_indices]
            sorted_indice = [x[0] for x in sorted_prob_with_indices]
            sorted_probs.append(sorted_prob)
            sorted_indices.append(sorted_indice)
            # for i in range(len(max_hard_label)-1,-1,-1):
            if mask[i] + 1 < len(mask_probability[i]):
                origin_root_state.append(
                    (i, 0, sorted_probs[i][1] / sorted_probs[i][0])
                )
        origin_root_state = sorted(
            origin_root_state, key=lambda x: (-x[2], -x[1], -x[0])
        )
        origin_state = [copy.deepcopy(origin_root_state)]
        state = [copy.deepcopy(origin_root_state)]
        suc_prob = (
            sorted_probs[state[0][0][0]][1]
            / sorted_probs[state[0][0][0]][0]
            * max_score
        )
        suc_mask = copy.deepcopy(mask)
        suc_mask[state[0][0][0]] = suc_mask[state[0][0][0]] + 1
        max_heap = Heap()
        max_heap.push((-suc_prob, 1))
        labels = []
        masks = []
        probs = []
        label = [sorted_indices[_][mask[_]] for _ in range(len(mask))]
        masks.append(mask)
        yield label, max_score, [
            mask_probability[i][label[i]] for i in range(len(label))
        ]
        labels.append(label)
        probs.append(max_score)
        selected_dict[tuple(mask)] = max_score
        selected_dict[tuple(suc_mask)] = suc_prob
        cur_budget = 1
        sum_conflict = 0
        max_conflict = 0
        max_son_conflict = 0
        max_father_conflict = 0
        while not max_heap.is_empty():
            heaptop = max_heap.pop()
            # print('heaptop:',heaptop)
            suc_prob = -heaptop[0]
            father_idx = heaptop[1]
            origin_state_f = origin_state[father_idx - 1]
            state_f = state[father_idx - 1]
            mask_f = masks[father_idx - 1]
            # print('mask_f:',mask_f)
            prob_f = probs[father_idx - 1]
            suc_mask = copy.deepcopy(mask_f)
            # print('state_f:',state_f)
            i, j, prob = state_f.pop(0)
            # print(i,j,prob)
            suc_mask[i] = j + 1
            # print('suc_mask:',suc_mask)
            origin_suc_state = []
            for _ in range(len(max_hard_label)):
                if suc_mask[_] + 1 < len(mask_probability[_]):
                    denominator = sorted_probs[_][suc_mask[_]]
                    if denominator > 0:
                        ratio = sorted_probs[_][suc_mask[_] + 1] / denominator
                    else:
                        ratio = float('inf')
                    origin_suc_state.append(
                        (
                            _,
                            suc_mask[_],
                            ratio,
                        )
                    )
            origin_suc_state = sorted(
                origin_suc_state, key=lambda x: (-x[2], -x[1], -x[0])
            )
            origin_state.append(origin_suc_state)
            suc_state = copy.deepcopy(origin_suc_state)
            suc_label = [sorted_indices[_][suc_mask[_]] for _ in range(len(suc_mask))]
            yield suc_label, suc_prob, [
                mask_probability[i][suc_label[i]] for i in range(len(suc_label))
            ]
            labels.append(suc_label)
            masks.append(suc_mask)
            probs.append(suc_prob)
            # conflict=0
            while len(suc_state) != 0:
                suc_suc_mask = copy.deepcopy(suc_mask)
                _i, _j, _prob = suc_state[0]
                suc_suc_mask[_i] = _j + 1
                if np.isinf(_prob) or np.isinf(suc_prob):
                    suc_suc_prob = float('inf')
                else:
                    suc_suc_prob = suc_prob * _prob
                if tuple(suc_suc_mask) in selected_dict:
                    suc_state.pop(0)
                    continue
                else:
                    max_heap.push((-suc_suc_prob, cur_budget + 1))
                    selected_dict[tuple(suc_suc_mask)] = suc_suc_prob
                    # suc_state.pop(0)
                    break
            state.append(suc_state)
            while len(state_f) != 0:
                suc_f_mask = copy.deepcopy(mask_f)
                _i, _j, _prob = state_f[0]
                suc_f_mask[_i] = _j + 1
                suc_f_prob = prob_f * _prob
                if tuple(suc_f_mask) in selected_dict:
                    state_f.pop(0)
                    continue
                else:
                    max_heap.push((-suc_f_prob, father_idx))
                    selected_dict[tuple(suc_f_mask)] = suc_f_prob
                    break
            state[father_idx - 1] = state_f
            cur_budget += 1

    def sort_mask(self, mask_probability):
        selected_dict = {}
        max_hard_label = [np.argmax(probability) for probability in mask_probability]
        max_score = 1
        # prob_list=[]
        for i in range(len(max_hard_label)):
            max_score *= mask_probability[i][max_hard_label[i]]
        sorted_V_tuple = []
        sorted_probs = []
        sorted_indices = []
        max_heap_cls = Heap()
        origin_root_state = []
        mask = [0 for _ in range(len(max_hard_label))]
        for i in range(len(max_hard_label)):
            sorted_prob_with_indices = sorted(
                enumerate(mask_probability[i]), key=lambda x: (-x[1], -x[0])
            )
            sorted_prob = [x[1] for x in sorted_prob_with_indices]
            sorted_indice = [x[0] for x in sorted_prob_with_indices]
            sorted_probs.append(sorted_prob)
            sorted_indices.append(sorted_indice)
            # for i in range(len(max_hard_label)-1,-1,-1):
            if mask[i] + 1 < len(mask_probability[i]):
                origin_root_state.append(
                    (i, 0, sorted_probs[i][1] / sorted_probs[i][0], True)
                )
        origin_root_state = sorted(
            origin_root_state, key=lambda x: (x[3], -x[2], -x[1], -x[0])
        )
        origin_state = [copy.deepcopy(origin_root_state)]
        state = [copy.deepcopy(origin_root_state)]
        suc_prob = (
            sorted_probs[state[0][0][0]][1]
            / sorted_probs[state[0][0][0]][0]
            * max_score
        )
        suc_mask = copy.deepcopy(mask)
        suc_mask[state[0][0][0]] = suc_mask[state[0][0][0]] + 1
        max_heap = Heap()
        max_heap.push((1, -suc_prob, 1))
        labels = []
        masks = []
        probs = []
        label = [sorted_indices[_][mask[_]] for _ in range(len(mask))]
        masks.append(mask)
        yield label, max_score, [
            mask_probability[i][label[i]] for i in range(len(label))
        ]
        labels.append(label)
        probs.append(max_score)
        selected_dict[tuple(mask)] = max_score
        selected_dict[tuple(suc_mask)] = suc_prob
        cur_budget = 1
        sum_conflict = 0
        max_conflict = 0
        max_son_conflict = 0
        max_father_conflict = 0
        while not max_heap.is_empty():
            heaptop = max_heap.pop()
            # print('heaptop:',heaptop)
            suc_revision = heaptop[0]
            suc_prob = -heaptop[1]
            father_idx = heaptop[2]
            origin_state_f = origin_state[father_idx - 1]
            state_f = state[father_idx - 1]
            mask_f = masks[father_idx - 1]
            # print('mask_f:',mask_f)
            prob_f = probs[father_idx - 1]
            suc_mask = copy.deepcopy(mask_f)
            # print('state_f:',state_f)
            i, j, prob, change = state_f.pop(0)
            # print(i,j,prob)
            suc_mask[i] = j + 1
            father_revision = suc_revision - 1 if change else suc_revision
            # print('suc_mask:',suc_mask)
            origin_suc_state = []
            for _ in range(len(max_hard_label)):
                if suc_mask[_] + 1 < len(mask_probability[_]):
                    denominator = sorted_probs[_][suc_mask[_]]
                    if denominator > 0:
                        ratio = sorted_probs[_][suc_mask[_] + 1] / denominator
                    else:
                        ratio = float('inf')
                    origin_suc_state.append(
                        (
                            _,
                            suc_mask[_],
                            ratio,
                            suc_mask[_] == 0,
                        )
                    )
            origin_suc_state = sorted(
                origin_suc_state, key=lambda x: (x[3], -x[2], -x[1], -x[0])
            )
            origin_state.append(origin_suc_state)
            suc_state = copy.deepcopy(origin_suc_state)
            suc_label = [sorted_indices[_][suc_mask[_]] for _ in range(len(suc_mask))]
            yield suc_label, suc_prob, [
                mask_probability[i][suc_label[i]] for i in range(len(suc_label))
            ]
            labels.append(suc_label)
            masks.append(suc_mask)
            probs.append(suc_prob)
            while len(suc_state) != 0:
                suc_suc_mask = copy.deepcopy(suc_mask)
                _i, _j, _prob, change = suc_state[0]
                suc_suc_mask[_i] = _j + 1
                suc_suc_prob = suc_prob * _prob
                if tuple(suc_suc_mask) in selected_dict:
                    suc_state.pop(0)
                    continue
                else:
                    if change is True:
                        suc_suc_revision = suc_revision + 1
                    else:
                        suc_suc_revision = suc_revision
                    max_heap.push((suc_suc_revision, -suc_suc_prob, cur_budget + 1))
                    selected_dict[tuple(suc_suc_mask)] = suc_suc_prob
                    # suc_state.pop(0)
                    break
            state.append(suc_state)
            while len(state_f) != 0:
                suc_f_mask = copy.deepcopy(mask_f)
                _i, _j, _prob, change = state_f[0]
                suc_f_mask[_i] = _j + 1
                suc_f_prob = prob_f * _prob
                if tuple(suc_f_mask) in selected_dict:
                    state_f.pop(0)
                    continue
                else:
                    if change is True:
                        suc_revision = father_revision + 1
                    else:
                        suc_revision = father_revision
                    max_heap.push((suc_revision, -suc_f_prob, father_idx))
                    selected_dict[tuple(suc_f_mask)] = suc_f_prob
                    break
            state[father_idx - 1] = state_f
            cur_budget += 1

    def val(self, probability):
        # probability = np.clip(probability, 1e-9, 1)
        if self.ind:
            generator = self.sort_mask_ind(probability)
        else:
            generator = self.sort_mask(probability)
        candidates = []
        candidates_prob = []
        candidates_weight = []
        n = 0
        times = 0
        if self.num == 0 or self.max_times == 0:
            return candidates, candidates_prob
        # print('self.max_times:',self.max_times)
        # print('probability:',probability)
        for label, prob, weight in generator:
            res = self.logic_forward(label)
            # print('label:',label)
            # print('res:',res)
            if res:
                candidates.append(label)
                candidates_prob.append(prob)
                candidates_weight.append(weight)
                n += 1
                if n == self.num:
                    break
            times += 1
            if times == self.max_times:
                break
        # origin_label,origin_prob,origin_weight=next(generator)
        # if n == 0:
        #     candidates.append(origin_label)
        #     candidates_prob.append(origin_prob)
        #     candidates_weight.append(origin_weight)
        # candidates=candidates*(self.num//n)+candidates[:self.num%n]
        # candidates_prob=candidates_prob*(self.num//n)+candidates[:self.num%n]
        # candidates_weight=candidates_weight*(self.num//n)+candidates_weight[:self.num%n]
        return candidates, candidates_prob, candidates_weight

    def val_chess(self, probability, pos):
        # probability = np.clip(probability, 1e-9, 1)
        if self.ind:
            generator = self.sort_mask_ind(probability)
        else:
            generator = self.sort_mask(probability)
        candidates = []
        candidates_prob = []
        candidates_weight = []
        n = 0
        times = 0
        if self.num == 0 or self.max_times == 0:
            return candidates, candidates_prob
        # print('self.max_times:',self.max_times)
        # print('probability:',probability)
        for label, prob, weight in generator:
            res = self.logic_forward(label, pos)
            # print('label:',label)
            # print('res:',res)
            if res:
                candidates.append(label)
                candidates_prob.append(prob)
                candidates_weight.append(weight)
                n += 1
                if n == self.num:
                    break
            times += 1
            if times == self.max_times:
                break
        # origin_label,origin_prob,origin_weight=next(generator)
        # if n == 0:
        #     candidates.append(origin_label)
        #     candidates_prob.append(origin_prob)
        #     candidates_weight.append(origin_weight)
        # candidates=candidates*(self.num//n)+candidates[:self.num%n]
        # candidates_prob=candidates_prob*(self.num//n)+candidates[:self.num%n]
        # candidates_weight=candidates_weight*(self.num//n)+candidates_weight[:self.num%n]
        return candidates, candidates_prob, candidates_weight

    # For parallel version of _get_GKB
    # def _get_XY_list(self, args):
    #     pre_x, post_x_it = args[0], args[1]
    #     XY_list = []
    #     for post_x in post_x_it:
    #         x = (pre_x,) + post_x
    #         y = self.logic_forward(x)
    #         if y not in [None, np.inf]:
    #             XY_list.append((x, y))
    #     return XY_list

    # Parallel _get_GKB
    # def _get_GKB(self):
    #     X=[]
    #     Y=[]
    #     X_True=[]
    #     for length in self.GKB_len_list:
    #         X.extend(list(product(self.pseudo_label_list, repeat=length - 1)))
    #     for l in X:
    #         if self.logic_forward(l):
    #             Y.append(True)
    #             X_True.append(l)
    #         else:
    #             Y.append(False)
    #             # Y.append(True)
    #     return X_True#, Y

    def _get_GKB(self):
        X = []
        Y = []
        X_True = []
        for length in self.GKB_len_list:
            X.extend(list(product(self.pseudo_label_list, repeat=length)))
        for l in X:
            if self.logic_forward(l):
                Y.append(True)
                X_True.append(l)
            else:
                Y.append(False)
                # Y.append(True)
        return X, Y

    def save_kb(self, file_path):
        """
        Save the knowledge base to a file.

        Args:
        file_path (str): The path to the file where the knowledge base will be saved.
        """
        os.makedirs(os.path.dirname(file_path), exist_ok=True)
        with open(file_path, "wb") as f:
            pickle.dump(self.base, f)

    def load_kb(self, file_path):
        """
        Load the knowledge base from a file.

        Args:
        file_path (str): The path to the file from which the knowledge base will be loaded.
        """
        if os.path.exists(file_path):
            with open(file_path, "rb") as f:
                self.base = pickle.load(f)
        else:
            print(f"File {file_path} not found. Starting with an empty knowledge base.")
            self.base = {}

    @abstractmethod
    def logic_forward(self, pseudo_labels):
        pass

    def abduce_candidates(self, pred_res, max_revision_num, require_more_revision=0):
        if self.prebuild_GKB:
            return self._abduce_by_GKB(
                pred_res, max_revision_num, require_more_revision
            )
        else:
            if not self.use_cache:
                return self._abduce_by_search(
                    pred_res, max_revision_num, require_more_revision
                )
            else:
                # print('pred_res:',pred_res)
                return self._abduce_by_search_cache(
                    to_hashable(pred_res), max_revision_num, require_more_revision
                )

    def abduce_candidates_chess(
        self, pred_res, pos, max_revision_num, require_more_revision=0
    ):
        if self.prebuild_GKB:
            return self._abduce_by_GKB_chess(
                pred_res, pos, max_revision_num, require_more_revision
            )
        else:
            if not self.use_cache:
                return self._abduce_by_search_chess(
                    pred_res, pos, max_revision_num, require_more_revision
                )
            else:
                # print('pred_res:',pred_res)
                return self._abduce_by_search_cache_chess(
                    to_hashable(pred_res),
                    to_hashable(pos),
                    max_revision_num,
                    require_more_revision,
                )

    def _find_candidate_GKB(self, pred_res):
        return self.base[len(pred_res)][True]

    def _abduce_by_GKB(self, pred_res, max_revision_num, require_more_revision):
        # if self.base == {} or len(pred_res) not in self.GKB_len_list:
        #     return []

        all_candidates = self._find_candidate_GKB(pred_res)
        if len(all_candidates) == 0:
            return []
        # if mask:
        #     cost_list = hamming_dist(pred_res, all_candidates)
        # else:
        cost_list = hamming_dist(pred_res, all_candidates)
        min_revision_num = np.min(cost_list)
        revision_num = min(max_revision_num, min_revision_num + require_more_revision)
        idxs = np.where(cost_list <= revision_num)[0]
        candidates = [all_candidates[idx] for idx in idxs]
        return candidates

    def _abduce_by_GKB_chess(
        self, pred_res, pos, max_revision_num, require_more_revision
    ):
        # if self.base == {} or len(pred_res) not in self.GKB_len_list:
        #     return []

        all_candidates = self._find_candidate_GKB(pred_res)
        if len(all_candidates) == 0:
            return []
        # if mask:
        #     cost_list = hamming_dist(pred_res, all_candidates)
        # else:
        cost_list = hamming_dist(pred_res, all_candidates)
        min_revision_num = np.min(cost_list)
        revision_num = min(max_revision_num, min_revision_num + require_more_revision)
        idxs = np.where(cost_list <= revision_num)[0]
        candidates = [all_candidates[idx] for idx in idxs]
        return candidates

    def revise_by_idx(self, pred_res, revision_idx):
        candidates = []
        abduce_c = product(self.pseudo_label_list, repeat=len(revision_idx))
        for c in abduce_c:
            candidate = pred_res.copy()
            for i, idx in enumerate(revision_idx):
                candidate[idx] = c[i]
            if self.logic_forward(candidate):
                candidates.append(candidate)
        return candidates

    def revise_by_idx_chess(self, pred_res, pos, revision_idx):
        candidates = []
        abduce_c = product(self.pseudo_label_list, repeat=len(revision_idx))
        for c in abduce_c:
            candidate = pred_res.copy()
            for i, idx in enumerate(revision_idx):
                candidate[idx] = c[i]
            if self.logic_forward(candidate, pos):
                candidates.append(candidate)
        return candidates

    def _revision(self, revision_num, pred_res):
        new_candidates = []
        # if mask:
        # revision_idx_list = combinations(range(len(pred_res))[mask], revision_num)
        # else:
        revision_idx_list = combinations(range(len(pred_res)), revision_num)
        for revision_idx in revision_idx_list:
            candidates = self.revise_by_idx(pred_res, revision_idx)
            new_candidates.extend(candidates)
        return new_candidates

    def _revision_chess(self, revision_num, pred_res, pos):
        new_candidates = []
        # if mask:
        # revision_idx_list = combinations(range(len(pred_res))[mask], revision_num)
        # else:
        revision_idx_list = combinations(range(len(pred_res)), revision_num)
        for revision_idx in revision_idx_list:
            candidates = self.revise_by_idx_chess(pred_res, pos, revision_idx)
            new_candidates.extend(candidates)
        return new_candidates

    def _abduce_by_search(self, pred_res, max_revision_num, require_more_revision):
        candidates = []
        for revision_num in range(len(pred_res) + 1):
            if revision_num == 0 and self.logic_forward(pred_res):
                candidates.append(pred_res)
            elif revision_num > 0:
                candidates.extend(self._revision(revision_num, pred_res))
            if len(candidates) > 0:
                min_revision_num = revision_num
                break
            if revision_num >= max_revision_num:
                return []

        for revision_num in range(
            min_revision_num + 1, min_revision_num + require_more_revision + 1
        ):
            if revision_num > max_revision_num:
                return candidates
            candidates.extend(self._revision(revision_num, pred_res))
        return candidates

    def _abduce_by_search_chess(
        self, pred_res, pos, max_revision_num, require_more_revision
    ):
        candidates = []
        for revision_num in range(len(pred_res) + 1):
            # print('pred_res:',pred_res)
            # print('pos:',pos)
            if revision_num == 0 and self.logic_forward(pred_res, pos):
                candidates.append(pred_res)
            elif revision_num > 0:
                candidates.extend(self._revision_chess(revision_num, pred_res, pos))
            if len(candidates) > 0:
                min_revision_num = revision_num
                break
            if revision_num >= max_revision_num:
                return []

        for revision_num in range(
            min_revision_num + 1, min_revision_num + require_more_revision + 1
        ):
            if revision_num > max_revision_num:
                return candidates
            candidates.extend(self._revision_chess(revision_num, pred_res, pos))
        return candidates

    @lru_cache(maxsize=None)
    def _abduce_by_search_cache(
        self, pred_res, max_revision_num, require_more_revision
    ):
        pred_res = hashable_to_list(pred_res)
        # y = hashable_to_list(y)
        return self._abduce_by_search(pred_res, max_revision_num, require_more_revision)

    @lru_cache(maxsize=None)
    def _abduce_by_search_cache_chess(
        self, pred_res, pos, max_revision_num, require_more_revision
    ):
        pred_res = hashable_to_list(pred_res)
        # y = hashable_to_list(y)
        return self._abduce_by_search_chess(
            pred_res, pos, max_revision_num, require_more_revision
        )

    def _dict_len(self, dic):
        if not self.GKB_flag:
            return 0
        else:
            return sum(len(c) for c in dic.values())

    def __len__(self):
        if not self.GKB_flag:
            return 0
        else:
            return sum(self._dict_len(v) for v in self.base.values())


class UnsupReasonerBase:
    def __init__(self, kb, dist_func="hamming", mapping=None, use_zoopt=False):
        """
        Root class for all reasoner in the ABL system.

        Parameters
        ----------
        kb : KBBase
            The knowledge base to be used for reasoning.
        dist_func : str, optional
            The distance function to be used. Can be "hamming" or "confidence". Default is "hamming".
        mapping : dict, optional
            A mapping of indices to labels. If None, a default mapping is generated.
        use_zoopt : bool, optional
            Whether to use the Zoopt library for optimization. Default is False.

        Raises
        ------
        NotImplementedError
            If the specified distance function is neither "hamming" nor "confidence".
        """

        if not (dist_func == "hamming" or dist_func == "confidence"):
            raise NotImplementedError  # Only hamming or confidence distance is available.

        self.kb = kb
        self.dist_func = dist_func
        self.use_zoopt = use_zoopt
        if mapping is None:
            self.mapping = {
                index: label for index, label in enumerate(self.kb.pseudo_label_list)
            }
        else:
            self.mapping = mapping
        self.remapping = dict(zip(self.mapping.values(), self.mapping.keys()))

    def _get_cost_list(self, pred_pseudo_label, pred_prob, candidates):
        """
        Get the list of costs between pseudo label and each candidate.

        Parameters
        ----------
        pred_pseudo_label : list
            The pseudo label to be used for computing costs of candidates.
        pred_prob : list
            Probabilities of the predictions. Used when distance function is "confidence".
        candidates : list
            List of candidate abduction result.

        Returns
        -------
        numpy.ndarray
            Array of computed costs for each candidate.
        """
        if self.dist_func == "hamming":
            return hamming_dist(pred_pseudo_label, candidates)

        elif self.dist_func == "confidence":
            candidates = [[self.remapping[x] for x in c] for c in candidates]
            return confidence_dist(pred_prob, candidates)

    def _get_one_candidate(self, pred_pseudo_label, pred_prob, candidates):
        """
        Get one candidate. If multiple candidates exist, return the one with minimum cost.

        Parameters
        ----------
        pred_pseudo_label : list
            The pseudo label to be used for selecting a candidate.
        pred_prob : list
            Probabilities of the predictions.
        candidates : list
            List of candidate abduction result.

        Returns
        -------
        list
            The chosen candidate based on minimum cost.
            If no candidates, an empty list is returned.
        """
        if len(candidates) == 0:
            return []
        elif len(candidates) == 1:
            return candidates[0]
        else:
            cost_array = self._get_cost_list(
                pred_pseudo_label, pred_prob.astype(np.float64), candidates
            )
            np.set_printoptions(precision=15)
            # print('cost_array:',cost_array)
            candidate = candidates[np.argmin(cost_array)]
            # print('candidates:',candidates)
            # print('condidata:',candidate)
            return candidate

    def zoopt_revision_score(self, symbol_num, pred_pseudo_label, pred_prob, sol):
        """
        Get the revision score for a single solution.

        Parameters
        ----------
        symbol_num : int
            Number of total symbols.
        pred_pseudo_label : list
            List of predicted pseudo labels.
        pred_prob : list
            List of probabilities for predicted results.
        y : any
            Ground truth for the predicted results.
        sol : array-like
            Solution to evaluate.

        Returns
        -------
        float
            The revision score for the given solution.
        """
        revision_idx = np.where(sol.get_x() != 0)[0]
        candidates = self.revise_by_idx(pred_pseudo_label, revision_idx)
        if len(candidates) > 0:
            return np.min(self._get_cost_list(pred_pseudo_label, pred_prob, candidates))
        else:
            return symbol_num

    def _constrain_revision_num(self, solution, max_revision_num):
        x = solution.get_x()
        return max_revision_num - x.sum()

    def revise_by_idx(self, pred_pseudo_label, revision_idx):
        """
        Revise the pseudo label according to the given indices.

        Parameters
        ----------
        pred_pseudo_label : list
            List of predicted pseudo labels.
        y : any
            Ground truth for the predicted results.
        revision_idx : array-like
            Indices of the revisions to retrieve.

        Returns
        -------
        list
            The revisions according to the given indices.
        """
        return self.kb.revise_by_idx(pred_pseudo_label, revision_idx)

    def abduce(
        self, pred_prob, pred_pseudo_label, max_revision=-1, require_more_revision=0
    ):
        symbol_num = len(flatten(pred_pseudo_label))
        max_revision_num = calculate_revision_num(max_revision, symbol_num)

        if self.use_zoopt:
            solution = self.zoopt_get_solution(
                symbol_num, pred_pseudo_label, pred_prob, max_revision_num
            )
            revision_idx = np.where(solution != 0)[0]
            candidates = self.revise_by_idx(pred_pseudo_label, revision_idx)
        else:
            candidates = self.kb.abduce_candidates(
                pred_pseudo_label, max_revision_num, require_more_revision
            )

        candidate = self._get_one_candidate(pred_pseudo_label, pred_prob, candidates)
        return candidate

    def batch_abduce(
        self, pred_prob, pred_pseudo_label, max_revision=-1, require_more_revision=0
    ):

        return [
            self.abduce(
                _pred_prob, _pred_pseudo_label, max_revision, require_more_revision
            )
            for _pred_prob, _pred_pseudo_label in zip(pred_prob, pred_pseudo_label)
        ]

    def abduce_chess(
        self,
        pred_prob,
        pred_pseudo_label,
        pos=None,
        max_revision=-1,
        require_more_revision=0,
    ):
        symbol_num = len(flatten(pred_pseudo_label))
        max_revision_num = calculate_revision_num(max_revision, symbol_num)

        if self.use_zoopt:
            solution = self.zoopt_get_solution(
                symbol_num, pred_pseudo_label, pred_prob, max_revision_num
            )
            revision_idx = np.where(solution != 0)[0]
            candidates = self.revise_by_idx(pred_pseudo_label, revision_idx)
        else:
            candidates = self.kb.abduce_candidates_chess(
                pred_pseudo_label, pos, max_revision_num, require_more_revision
            )

        candidate = self._get_one_candidate(pred_pseudo_label, pred_prob, candidates)
        return candidate

    def batch_abduce_chess(
        self,
        pred_prob,
        pred_pseudo_label,
        pos=None,
        max_revision=-1,
        require_more_revision=0,
    ):

        return [
            self.abduce_chess(
                _pred_prob,
                _pred_pseudo_label,
                _pos,
                max_revision,
                require_more_revision,
            )
            for _pred_prob, _pred_pseudo_label, _pos in zip(
                pred_prob, pred_pseudo_label, pos
            )
        ]

    def abduce_sudoku(
        self,
        pred_prob,
        pred_pseudo_label,
        inputs,
        max_revision=-1,
        require_more_revision=0,
    ):
        # mask=(inputs==0)
        # _pred_pseudo_label=pred_pseudo_label[mask]
        # _pred_prob=pred_pseudo_label[mask]

        # symbol_num = np.sum(mask)
        symbol_num = len(flatten(pred_pseudo_label))
        max_revision_num = calculate_revision_num(max_revision, symbol_num)

        if self.use_zoopt:
            solution = self.zoopt_get_solution(
                symbol_num, pred_pseudo_label, pred_prob, max_revision_num
            )
            revision_idx = np.where(solution != 0)[0]
            candidates = self.revise_by_idx(pred_pseudo_label, revision_idx)
        else:
            candidates = self.kb.abduce_candidates_sudoku(
                pred_pseudo_label, inputs, max_revision_num, require_more_revision
            )

        candidate = self._get_one_candidate_sudoku(
            pred_pseudo_label, pred_prob, candidates, inputs
        )
        return candidate

    def val(self, pred_prob: List) -> List:
        """
        Perform revision by abduction on the given data.

        Parameters
        ----------
        pred_prob : list
            List of probabilities for predicted results.
        pred_pseudo_label : list
            List of predicted pseudo labels.
        y : any
            Ground truth for the predicted results.
        max_revision : int or float, optional
            Maximum number of revisions to use. If float, represents the fraction of total revisions to use.
            If -1, any revisions are allowed. Defaults to -1.
        require_more_revision : int, optional
            Number of additional revisions to require. Defaults to 0.

        Returns
        -------
        list
            The abduced revisions.
        """

        candidates = self.kb.val(pred_prob)
        # print(candidates)
        return candidates[0][0]

    def batch_val(self, pred_prob: List) -> List:
        return [self.val(_pred_prob) for _pred_prob in pred_prob]

    def val_chess(self, pred_prob: List, pos=None) -> List:
        candidates = self.kb.val_chess(pred_prob, pos)
        # print(candidates)
        return candidates[0][0]

    def batch_val_chess(self, pred_prob: List, pos=None) -> List:
        return [
            self.val_chess(_pred_prob, _pos) for _pred_prob, _pos in zip(pred_prob, pos)
        ]

    def __call__(
        self, pred_prob, pred_pseudo_label, max_revision=-1, require_more_revision=0
    ):
        return self.batch_abduce(
            pred_prob, pred_pseudo_label, max_revision, require_more_revision
        )

    def set_remapping(self):
        self.remapping = dict(zip(self.mapping.values(), self.mapping.keys()))


class BridgeDataset_ulb_chess(Dataset):
    def __init__(self, X: List[Any], Z: List[Any], P: List[Any]):
        """Initialize a basic dataset.

        Parameters
        ----------
        X : List[Any]
            A list of objects representing the input data.
        Z : List[Any]
            A list of objects representing the symbol.
        Y : List[Any]
            A list of objects representing the label.
        """
        # if (not isinstance(X, list)) or (not isinstance(Y, list)):
        #     raise ValueError("X and Y should be of type list.")
        if len(X) != len(Z):
            raise ValueError("Length of X and Y must be equal.")

        self.X = X
        self.Z = Z
        self.P = P
        # self.Y = Y

        # if self.Z is None:
        #     self.Z = [None] * len(self.X)

    def __len__(self):
        """Return the length of the dataset.

        Returns
        -------
        int
            The length of the dataset.
        """
        return len(self.X)

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        """Get an item from the dataset.

        Parameters
        ----------
        index : int
            The index of the item to retrieve.

        Returns
        -------
        Tuple[Any, Any]
            A tuple containing the input and output data at the specified index.
        """
        if index >= len(self):
            raise ValueError("index range error")

        X = self.X[index]
        Z = self.Z[index]
        P = self.P[index]
        # Y = self.Y[index]

        return (X, Z, P)


class PhaseUnsupBridge(BaseBridge):
    def __init__(
        self,
        model: ABLModel,
        abducer: ReasonerBase,
        metric_list: List[BaseMetric],
        use_prob=False,
        use_weight=False,
        val=False,
        require_more_revision=0,
        sudoku=False,
    ) -> None:
        super().__init__(model, abducer)
        self.metric_list = metric_list
        self.use_prob = use_prob
        self.use_weight = use_weight
        self.val = val
        self.sudoku = sudoku
        self.require_more_revision = require_more_revision

    def predict(self, X, prior=None) -> Tuple[List[List[Any]], ndarray]:
        # print('brige_predict:',prior)
        pred_res = self.model.predict(X, prior=prior)
        pred_idx, pred_prob = pred_res["label"], pred_res["prob"]
        return pred_idx, pred_prob

    def abduce_pseudo_label(
        self,
        pred_prob: ndarray,
        pred_pseudo_label: List[List[Any]],
        inputs=None,
        require_more_revision=10,
    ) -> List[List[Any]]:
        if self.sudoku:
            return self.abducer.batch_abduce_sudoku(
                pred_prob,
                pred_pseudo_label,
                inputs,
                require_more_revision=require_more_revision,
            )
        else:
            return self.abducer.batch_abduce(
                pred_prob,
                pred_pseudo_label,
                require_more_revision=require_more_revision,
            )

    def val_pseudo_label(
        self,
        pred_prob: ndarray,
        inputs=None
        # pred_pseudo_label: List[List[Any]]
    ) -> List[List[Any]]:
        if self.sudoku:
            return self.abducer.batch_val_sudoku(pred_prob, inputs=inputs)
        else:
            return self.abducer.batch_val(pred_prob)

    def abduce_pseudo_label_chess(
        self,
        pred_prob: ndarray,
        pred_pseudo_label: List[List[Any]],
        pos=None,
        # inputs=None,
        require_more_revision=10,
    ) -> List[List[Any]]:
        return self.abducer.batch_abduce_chess(
            pred_prob,
            pred_pseudo_label,
            pos=pos,
            require_more_revision=require_more_revision,
        )

    def val_pseudo_label_chess(
        self,
        pred_prob: ndarray,
        # inputs=None
        pos=None
        # pred_pseudo_label: List[List[Any]]
    ) -> List[List[Any]]:
        return self.abducer.batch_val_chess(pred_prob, pos=pos)

    def idx_to_pseudo_label(
        self, idx: List[List[Any]], mapping: Dict = None
    ) -> List[List[Any]]:
        if mapping is None:
            mapping = self.abducer.mapping
        return [[mapping[_idx] for _idx in sub_list] for sub_list in idx]

    def pseudo_label_to_idx(
        self, pseudo_label: List[List[Any]], mapping: Dict = None
    ) -> List[List[Any]]:
        if mapping is None:
            mapping = self.abducer.remapping

        def recursive_map(func, nested_list):
            if isinstance(nested_list, (list, tuple)):
                return [recursive_map(func, x) for x in nested_list]
            else:
                return func(nested_list)

        return recursive_map(lambda x: mapping[x], pseudo_label)
    
    def train(
        self,
        train_data: List[
            Tuple[List[List[Any]], Optional[List[List[Any]]], List[List[Any]]]
        ],
        max_iter: int = 5000,
        batch_size: Union[int, float] = -1,
        eval_interval: int = 10,
        test_data: Tuple[
            List[List[Any]], Optional[List[List[Any]]], List[List[Any]]
        ] = None,
        prior=None,
    ):
        phase_idx = 0
        train_pool = train_data[phase_idx]
        self.abducer = self.abducer_list[phase_idx]
        dataset = BridgeDataset_ulb_chess(*train_pool)
        data_loader = DataLoader(
            dataset,
            batch_size=batch_size,
            collate_fn=lambda data_list: [list(data) for data in zip(*data_list)],
        )
        iter_cnt = 0
        while iter_cnt < max_iter:
            next_flag = False
            for seg_idx, (X, Z, P) in enumerate(data_loader):
                if self.val:
                    pred_idx, pred_prob = self.predict(X, prior=prior)
                else:
                    pred_idx, pred_prob = self.predict(X)
                pred_pseudo_label = self.idx_to_pseudo_label(pred_idx)
                if self.val:
                    abduced_pseudo_label = self.val_pseudo_label_chess(pred_prob, pos=P)
                else:
                    abduced_pseudo_label = self.abduce_pseudo_label_chess(
                        pred_prob,
                        pred_pseudo_label,
                        pos=P,
                        require_more_revision=self.require_more_revision,
                    )
                abduced_label = self.pseudo_label_to_idx(abduced_pseudo_label)
                self.model.train(X, abduced_label)
                if (iter_cnt + 1) % eval_interval == 0:
                    print_log(
                        f"Evaluation start: Epoch(val) [{iter_cnt + 1}], Phase {phase_idx}",
                        logger="current",
                    )
                    if test_data:
                        self.test(test_data)
                iter_cnt += 1

                if phase_idx + 1 < len(train_data):
                    char_acc = self.test(test_data, verbose=0)
                    if all(
                        char_acc[i] > 0.3 for i in self.abducer.kb.pseudo_label_list
                    ):
                        next_flag = True
                        break

            if next_flag:
                phase_idx += 1
                train_pool = train_data[phase_idx]
                self.abducer = self.abducer_list[phase_idx]
                dataset = BridgeDataset_ulb_chess(*train_pool)
                data_loader = DataLoader(
                    dataset,
                    batch_size=batch_size,
                    collate_fn=lambda data_list: [
                        list(data) for data in zip(*data_list)
                    ],
                )

    def _valid(self, data_loader, tag=""):
        # print('_valid')
        for (X, Z) in data_loader:
            # print('X:',X)
            # print('Z:',Z)
            pred_idx, pred_prob = self.predict(X)
            pred_pseudo_label = self.idx_to_pseudo_label(pred_idx)
            mask = X == 9
            # print('X:',X)
            data_samples = dict(
                pred_idx=pred_idx,
                pred_prob=pred_prob,
                pred_pseudo_label=pred_pseudo_label,
                gt_pseudo_label=Z,
                logic_forward=self.abducer.kb.logic_forward,
                mask=mask if self.sudoku else None,
                X=X
                # P=P
            )
            for metric in self.metric_list:
                metric.process(data_samples)

        res = dict()
        for metric in self.metric_list:
            res.update(metric.evaluate())
        # wandb.log({f"{k}/{tag}": v for k, v in res.items()})
        msg = "Evaluation ended, "
        try:
            for k, v in res.items():
                msg += k + f": {v:.3f} "
            print_log(msg, logger="current")
        except:
            pass

    def _valid_chess(self, data_loader, tag="", verbose=1):
        for (X, Z, P) in data_loader:
            pred_idx, pred_prob = self.predict(X)
            pred_pseudo_label = self.idx_to_pseudo_label(pred_idx)
            mask = X == 9
            # print('X:',X)
            data_samples = dict(
                pred_idx=pred_idx,
                pred_prob=pred_prob,
                pred_pseudo_label=pred_pseudo_label,
                gt_pseudo_label=Z,
                logic_forward=self.abducer.kb.logic_forward,
                mask=mask if self.sudoku else None,
                X=X,
                P=P,
            )
            for metric in self.metric_list:
                metric.process(data_samples)

        res = dict()
        for metric in self.metric_list:
            res.update(metric.evaluate())
        pred_acc = [
            round(res["confusion_matrix"][str(i)], 4)
            for i in range(len(res["confusion_matrix"]))
        ]
        # wandb.log({f"{k}/{tag}": v for k, v in res.items()})
        if verbose:
            msg = f"({tag}): "
            msg += ", ".join(
                [f"{k}: {v:.4f}" for k, v in res.items() if "confusion" not in k]
            )
            print_log(msg, logger="current")
            print_log(f"pred_acc: {pred_acc}", logger="current")
        return pred_acc
                
                
    def test(self, test_data, batch_size=1000, verbose=1):
        pred_acc = self.valid(test_data, batch_size, tag="test", verbose=verbose)
        return pred_acc
    
    def valid(self, valid_data, batch_size=1000, tag="valid", verbose=1):
        dataset = BridgeDataset_ulb_chess(*valid_data)
        data_loader = DataLoader(
            dataset,
            batch_size=batch_size,
            collate_fn=lambda data_list: [list(data) for data in zip(*data_list)],
        )
        pred_acc = self._valid_chess(data_loader, tag, verbose=verbose)
        return pred_acc

class ValChessMetric(BaseMetric):
    def __init__(self, prefix: Optional[str] = None, num=1, max_times=100000) -> None:
        super().__init__(prefix)
        self.num = 1
        self.max_times = max_times
        self.y_pred = []
        self.y_gt = []
        self.y_val = []
        self.pred_results = []

    def process(self, data_samples: Sequence[dict]) -> None:
        pred_pseudo_label = data_samples["pred_pseudo_label"]
        pred_prob = data_samples["pred_prob"]
        gt_pseudo_label = data_samples["gt_pseudo_label"]
        self.logic_forward = data_samples["logic_forward"]
        pos = data_samples["P"]

        if not len(pred_pseudo_label) == len(gt_pseudo_label):
            raise ValueError(
                "lengthes of pred_pseudo_label and gt_pseudo_label should be equal"
            )
        self.y_gt.extend(gt_pseudo_label)
        self.y_pred.extend(pred_pseudo_label)
        val_res = [self.val(_prob, _pos) for _prob, _pos in zip(pred_prob, pos)]
        val_pseudo_label, val_prob = [res[0][0] for res in val_res], [
            res[1][0] for res in val_res
        ]
        # print('val_pseudo_label:',val_pseudo_label)
        # print('val_prob:',val_prob)
        self.y_val.extend(val_pseudo_label)
        for val_z, z in zip(val_pseudo_label, gt_pseudo_label):
            correct_num = 0
            for pred_symbol, symbol in zip(val_z, z):
                if pred_symbol == symbol:
                    correct_num += 1
            self.results.append(correct_num / len(z))
        # self.y_val.extend(val_pseudo_label)
        for pred_z, z in zip(pred_pseudo_label, gt_pseudo_label):
            correct_num = 0
            for pred_symbol, symbol in zip(pred_z, z):
                if pred_symbol == symbol:
                    correct_num += 1
            self.pred_results.append(correct_num / len(z))

        flat_pred = list(chain(*pred_pseudo_label))
        flat_gt = list(chain(*gt_pseudo_label))
        labels = list(range(6))
        self.cm = confusion_matrix(flat_gt, flat_pred, labels=labels)

    def sort_mask(self, mask_probability):
        selected_dict = {}
        max_hard_label = [np.argmax(probability) for probability in mask_probability]
        max_score = 1
        for i in range(len(max_hard_label)):
            max_score *= mask_probability[i][max_hard_label[i]]
        sorted_V_tuple = []
        sorted_probs = []
        sorted_indices = []
        max_heap_cls = Heap()
        origin_root_state = []
        mask = [0 for _ in range(len(max_hard_label))]
        for i in range(len(max_hard_label)):
            sorted_prob_with_indices = sorted(
                enumerate(mask_probability[i]), key=lambda x: x[1], reverse=True
            )
            sorted_prob = [x[1] for x in sorted_prob_with_indices]
            sorted_indice = [x[0] for x in sorted_prob_with_indices]
            sorted_probs.append(sorted_prob)
            sorted_indices.append(sorted_indice)
            if mask[i] + 1 < len(mask_probability[i]):
                origin_root_state.append(
                    (i, 0, sorted_probs[i][1] / sorted_probs[i][0])
                )
        origin_root_state = sorted(origin_root_state, key=lambda x: x[2], reverse=True)
        origin_state = [copy.deepcopy(origin_root_state)]
        state = [copy.deepcopy(origin_root_state)]
        suc_prob = (
            sorted_probs[state[0][0][0]][1]
            / sorted_probs[state[0][0][0]][0]
            * max_score
        )
        suc_mask = copy.deepcopy(mask)
        suc_mask[state[0][0][0]] = suc_mask[state[0][0][0]] + 1
        max_heap = Heap()
        max_heap.push((-suc_prob, 1))
        labels = []
        masks = []
        probs = []
        label = [sorted_indices[_][mask[_]] for _ in range(len(mask))]
        masks.append(mask)
        yield label, max_score
        labels.append(label)
        probs.append(max_score)
        selected_dict[tuple(mask)] = max_score
        selected_dict[tuple(suc_mask)] = suc_prob
        cur_budget = 1
        while not max_heap.is_empty():
            heaptop = max_heap.pop()
            # print('heaptop:',heaptop)
            suc_prob = -heaptop[0]
            father_idx = heaptop[1]
            origin_state_f = origin_state[father_idx - 1]
            state_f = state[father_idx - 1]
            mask_f = masks[father_idx - 1]
            # print('mask_f:',mask_f)
            prob_f = probs[father_idx - 1]
            suc_mask = copy.deepcopy(mask_f)
            # print('state_f:',state_f)
            i, j, prob = state_f.pop(0)
            # print(i,j,prob)
            suc_mask[i] = j + 1
            # print('suc_mask:',suc_mask)
            origin_suc_state = []
            for _ in range(len(max_hard_label)):
                if suc_mask[_] + 1 < len(mask_probability[_]):
                    denominator = sorted_probs[_][suc_mask[_]]
                    if denominator > 0:
                        ratio = sorted_probs[_][suc_mask[_] + 1] / denominator
                    else:
                        ratio = float('inf')
                    origin_suc_state.append(
                        (
                            _,
                            suc_mask[_],
                            ratio,
                        )
                    )
            origin_suc_state = sorted(
                origin_suc_state, key=lambda x: x[2], reverse=True
            )
            origin_state.append(origin_suc_state)
            suc_state = copy.deepcopy(origin_suc_state)
            suc_label = [sorted_indices[_][suc_mask[_]] for _ in range(len(suc_mask))]
            yield suc_label, suc_prob
            labels.append(suc_label)
            masks.append(suc_mask)
            probs.append(suc_prob)
            # conflict=0
            while len(suc_state) != 0:
                suc_suc_mask = copy.deepcopy(suc_mask)
                _i, _j, _prob = suc_state[0]
                suc_suc_mask[_i] = _j + 1
                if np.isinf(_prob) or np.isinf(suc_prob):
                    suc_suc_prob = float('inf')
                else:
                    suc_suc_prob = suc_prob * _prob
                if tuple(suc_suc_mask) in selected_dict:
                    suc_state.pop(0)
                    continue
                else:
                    max_heap.push((-suc_suc_prob, cur_budget + 1))
                    selected_dict[tuple(suc_suc_mask)] = suc_suc_prob
                    # suc_state.pop(0)
                    break
            state.append(suc_state)
            while len(state_f) != 0:
                suc_f_mask = copy.deepcopy(mask_f)
                _i, _j, _prob = state_f[0]
                suc_f_mask[_i] = _j + 1
                if np.isinf(_prob) or np.isinf(suc_prob):
                    suc_f_prob = float('inf')
                else:
                    suc_f_prob = prob_f * _prob
                if tuple(suc_f_mask) in selected_dict:
                    state_f.pop(0)
                    continue
                else:
                    max_heap.push((-suc_f_prob, father_idx))
                    selected_dict[tuple(suc_f_mask)] = suc_f_prob
                    break
            state[father_idx - 1] = state_f
            cur_budget += 1
        # return labels

    def val(self, probability, pos):
        # print('probability:',probability)
        generator = self.sort_mask(probability)
        candidates = []
        candidates_prob = []
        n = 0
        times = 0
        if self.num == 0 or self.max_times == 0:
            return candidates, candidates_prob
        for label, prob in generator:
            if self.logic_forward(label, pos):
                candidates.append(label)
                candidates_prob.append(prob)
                n += 1
                if n == self.num:
                    break
            times += 1
            if times == self.max_times:
                break
        return candidates, candidates_prob

    def compute_metrics(self, results: list) -> dict:
        metrics = dict()
        metrics["character_accuracy"] = sum(results) / len(results)
        metrics["pred_character_accuracy"] = sum(self.pred_results) / len(
            self.pred_results
        )
        metrics["confusion_matrix"] = {
            f"{i}": self.cm[i][i] / np.sum(self.cm[i]) if np.sum(self.cm[i]) > 0 else 0
            for i in range(6)
        }
        return metrics
