import re
import ast
import itertools
import copy
from typing import List
import numpy as np


class CrashInfo:
    """
    Information for one round
    """

    def __init__(self, crash, alive, receive, round):
        self.crash: list = crash  # New crashed nodes at this round
        self.alive: list = alive  # Alive nodes after new node crashed at this round
        self.receive: np.ndarray = np.asarray(
            receive, dtype=bool
        ).copy()  # Receive or not receive of the alive + new crashed nodes from new crashed nodes
        self.round: int = round  # 0-based round

    def __str__(self):
        mat_str = np.array2string(self.receive.astype(int), separator="")
        mat_str = mat_str.replace("\n", "")
        return f"round {self.round}: crash={self.crash}, alive={self.alive}, receive={mat_str} "

    def get_info(self):
        return self.crash, self.alive, self.receive


class Info:
    """
    Information for one possible input combination
    """

    def __init__(self, initial: int, num_round):
        self.initial = initial
        self.num_round = num_round
        self.crash_infos: List[CrashInfo] = []

    def __str__(self):
        output = f"initial: {self.initial}, "
        for crash_info in self.crash_infos:
            output += f"{crash_info}"
        return output

    def add_crash_info(self, crash_info: CrashInfo):
        self.crash_infos.append(crash_info)

    # This function will be exposed to MCTS.py
    def get_crash_info(self, round) -> CrashInfo:
        if round == self.num_round:
            return None
        return self.crash_infos[round]

    def get_init(self):
        assert len(self.crash_infos) == self.num_round
        return self.initial


def parse_info_from_string(data: str) -> Info:
    """
    Reverse function to generate Info and CrashInfo objects from the given string representation.
    """
    # Extract the initial value
    initial_match = re.search(r"initial: (\d+)", data)
    if not initial_match:
        raise ValueError("Invalid data format: 'initial' not found.")
    initial = int(initial_match.group(1))

    # Extract all rounds
    round_pattern = r"round: (\d+), crash: (\[.*?\]), alive: (\[.*?\]), receive: (\([^)]*\)(?:,\s*\([^)]*\))*\,?\))"
    rounds = re.findall(round_pattern, data)

    # Create Info object
    num_round = len(rounds)
    info = Info(initial, num_round)

    # Create CrashInfo objects for each round
    for round_data in rounds:
        round_num = int(round_data[0])
        crash = ast.literal_eval(round_data[1])
        alive = ast.literal_eval(round_data[2])
        receive = ast.literal_eval(round_data[3])

        crash_info = CrashInfo(crash, alive, receive, round_num)
        info.add_crash_info(crash_info)

    return info


def recursive_get_combintaions(players, remaining_nodes, round, num_round, cur_result, all_results):
    """
    This function is used to generate all possible combinations of crashed nodes and alive nodes
    The output is stored in all_results, formatted as:
    [[[crash_r1], [alive_r1], [crash_r2], [alive_r2], ...], [...], [...], ...]
    """
    # pick r elements from remaining_nodes
    if cur_result is None:
        cur_result = []

    crashed_so_far = players - len(remaining_nodes)
    # We can allow at most num_round - 1 crashes in total
    max_crash_this_round = min(len(remaining_nodes), num_round - 1 - crashed_so_far)

    for r in range(max_crash_this_round + 1):
        for selection in itertools.combinations(remaining_nodes, r):  # make the selection
            new_remaining_nodes = [node for node in remaining_nodes if node not in selection]
            if round == num_round:
                all_results.append(cur_result + [list(selection)] + [new_remaining_nodes])
            else:
                recursive_get_combintaions(
                    players,
                    new_remaining_nodes,
                    round + 1,
                    num_round,
                    cur_result + [list(selection)] + [new_remaining_nodes],
                    all_results,
                )


def _get_receiver_row_patterns(senders: List[int], receiver: int) -> List[np.ndarray]:
    """
    Return a *list* of NumPy bool arrays, each of length |senders|,
    representing every legal bit-row for this receiver.

    If the receiver itself is one of the just-crashed senders, its
    own bit (diagonal) is forced to 1.

    Eaxmple:
    _get_receiver_row_patterns(senders = [0, 1], receiver = 0):
    [array([ True, False]), array([ True,  True])]
    Meaning:
    Node 0 and 1 newly crash, node 0 receives has its own message, and can or can not receive node 1's message.
    """
    n = len(senders)

    # Receiver not a sender → every 0/1 vector allowed
    if receiver not in senders:
        return [
            np.array(bits, dtype=bool)  # -> shape (n,)
            for bits in itertools.product([0, 1], repeat=n)
        ]

    # Receiver *is* a sender → diagonal bit must be 1
    self_idx = senders.index(receiver)
    patterns = []
    for bits in itertools.product([0, 1], repeat=n - 1):
        vec = list(bits)
        vec.insert(self_idx, 1)  # force diag = 1
        patterns.append(np.array(vec, dtype=bool))
    return patterns


def dfs_create_crash_info(
    crash_comb: List[List[int]],
    initial_state: int,
    round: int,  # 0-based round index
    num_round: int,
    pb_info: Info = None,
    all_combs: List[Info] = None,
):
    """
    This function is used to generate all possible combinations of message lost based on crash info
    """
    if round == num_round:
        all_combs.append(pb_info)
        return
    crash = crash_comb[round * 2]
    alive = crash_comb[round * 2 + 1]
    senders = sorted(crash)  # Only the nodes that just failed can be “unreliable senders” for this round.
    if round == num_round - 1:
        receivers = sorted(crash + alive)  # Still-alive nodes and the just-crashed nodes act as potential receivers.
    else:
        receivers = sorted(alive)

    row_options = [_get_receiver_row_patterns(senders, rcv) for rcv in receivers]

    for row_choice in itertools.product(*row_options):
        receive_np = np.stack(row_choice, axis=0)
        """
        Shape of receive_np: (rows, cols)
        Row i tells you what one particular receiver heard.
        Column j tells you how well the messages sent by the j-th newly-crashed node propagated.
        """
        new_pb_info = Info(initial_state, num_round) if round == 0 else copy.deepcopy(pb_info)
        new_pb_info.add_crash_info(CrashInfo(crash, alive, receive_np, round))

        dfs_create_crash_info(crash_comb, initial_state, round + 1, num_round, new_pb_info, all_combs)


# This function is exposed to other code to generate all Info
def generate_all_combs(initials, players, num_round):
    all_combs = []
    all_initials = initials
    raw_combs = []
    recursive_get_combintaions(players, range(players), 1, num_round, None, raw_combs)
    for init_states in all_initials:
        for comb in raw_combs:
            """
                comb: [[crash_r1], [alive_r1], [crash_r2], [alive_r2], ...]
            """
            assert len(comb) == num_round * 2
            dfs_create_crash_info(comb, init_states, 0, num_round, None, all_combs)

    return all_combs


# Scenario that no crash and only 1 rounds (Dedicated for distribtued locking)
def generate_all_combs_nocrash(players, num_round=1):
    all_combs = []
    all_initials = range(2**players)
    for init_states in all_initials:
        new_info = Info(init_states, num_round)
        new_info.add_crash_info(CrashInfo([], list(range(players)), [], 0))
        all_combs.append(new_info)

    return all_combs


def read_combs(all_combs):
    for comb in all_combs:
        print("-------------------------------------")
        crash_r1, alive_r1, r1_comb = comb.get_input(1)
        crash_r2, alive_r2, r2_comb = comb.get_input(2)
        print(f"initial: {comb.get_init()}")
        print(f"crash_r1: {crash_r1}")
        print(f"alive_r1: {alive_r1}")
        assert len(r1_comb) == len(alive_r1)
        print(f"r1_comb: {r1_comb}")

        print(f"crash_r2: {crash_r2}")
        print(f"alive_r2: {alive_r2}")
        print(f"r2_comb: {r2_comb}")


def hash_combs(initial, crash_r1, receive_r1, crash_r2, receive_r2):
    pass


if __name__ == "__main__":
    players = 2
    all_combs = generate_all_combs(range(2**players), players=players, num_round=2)
    for i, comb in enumerate(all_combs):
        print(i)
        print(comb)
