import ast
from generate_combs import CrashInfo, Info, generate_all_combs


def is_subset(sub, full):
    return set(sub).issubset(set(full))


def filter_combs(initials, state, all_combs: list[Info], lost_value):
    """
    Filter initial state based on initial value and one of state received during simulation
    Based on the received state, we can determine all potential pattern that lead to this state.
    We can filter out all combs based on Lost value in the receiving states. The lost value means
    the node crashed at either that round or any previous rounds.
    E.g.,
    We received state [1, 2, 2, 2, 4], then we need to filter out all combs that
    node 3 crashed at either round 0 or round 1.
    """
    # Example filtering logic (to be replaced with actual logic)
    filtered_combs = []
    if isinstance(state, str):
        state = ast.literal_eval(state)

    round = state[0]
    crashed_nodes = [i - 1 for i in range(1, len(state)) if state[i] == lost_value]
    alive_nodes = [i - 1 for i in range(1, len(state)) if state[i] != lost_value]
    for i, comb in enumerate(all_combs):
        if comb.get_init() in initials:
            # Lost received at this round means this node can crash at any previous rounds including this round
            acc_crashed_nodes = []
            for r in range(round + 1):
                info: CrashInfo = comb.get_crash_info(r)
                acc_crashed_nodes.extend(info.crash)

            if round == 0:
                # no previous round
                filtered_combs.append(i)
            else:
                prev_info: CrashInfo = comb.get_crash_info(round - 1)
                if is_subset(
                    alive_nodes, prev_info.alive
                ):  # If received msg from other nodes -> those nodes must be alive in previous round
                    if crashed_nodes:
                        if is_subset(crashed_nodes, acc_crashed_nodes):
                            filtered_combs.append(i)
                    else:
                        filtered_combs.append(i)

    return filtered_combs


def filter_combs_by_initial_and_round(initials, round, players, all_combs: list[Info]):
    """
    Filter all combinations based on initial state and round.
    Guarantee that all nodes are alive in previous round (round - 1).
    """
    filtered_combs = []
    for i, comb in enumerate(all_combs):
        if comb.get_init() in initials:
            if round == 0:
                filtered_combs.append(i)
                continue
            else:
                prev_info: CrashInfo = comb.get_crash_info(round - 1)
                if len(prev_info.alive) == players:  # Means all node are alive in previous round
                    filtered_combs.append(i)
    return filtered_combs


def filter_combs_by_initial_states(initials, all_combs: list[Info]):
    """
    Filter all combinations based on initial state.
    """
    filtered_combs = []
    for i, comb in enumerate(all_combs):
        if comb.get_init() in initials:
            filtered_combs.append(i)
    return filtered_combs


if __name__ == "__main__":
    all_combs = generate_all_combs(players=4, num_round=3)
    f1 = filter_combs([0], [1, 4, 2, 2, 2], all_combs, lost_value=4)
    # f2 = filter_combs([0], [0, 5, 5, 5, 5], all_combs, lost_value=4)
    print(len(f1))
