import tkinter as tk
from functools import partial
from SM import State, meanings


def read_from_transitions_file(path):
    transitions = []
    with open(path, "r") as f:
        lines = f.readlines()
        for line in lines:
            edge = line.strip().split(":")
            inputs = edge[0].split(",")
            comb = list(map(lambda x: int(x), inputs))
            transitions.append([comb, int(edge[1])])

    return transitions


def get_transition(transitions, input):
    for entry in transitions:
        if entry[0] == input:
            return entry[1]


if __name__ == "__main__":
    transitions = read_from_transitions_file("full_transitions.txt")
    window = tk.Tk()
    window.title("Multiple Boxes Example")

    n0_init = tk.IntVar()
    n0 = tk.Label(window, bg="yellow", width=20, text="Node 0")
    n0.pack()

    initial_states = [None] * 3
    all_states = [[] for _ in range(3)]

    def print_selection(node):
        if node == 0:
            n0.config(text="Node0: " + meanings[n0_init.get()])
            initial_states[node] = n0_init.get()
        elif node == 1:
            n1.config(text="Node1: " + meanings[n1_init.get()])
            initial_states[node] = n1_init.get()
        elif node == 2:
            n2.config(text="Node2: " + meanings[n2_init.get()])
            initial_states[node] = n2_init.get()
        all_states[0] = initial_states.copy()
        all_states[1] = initial_states.copy()
        all_states[2] = initial_states.copy()

    r1 = tk.Radiobutton(
        window,
        text="Local Commit",
        variable=n0_init,
        value=int(State.LocalCommit.value),
        command=partial(print_selection, 0),
    )
    r1.pack()
    r2 = tk.Radiobutton(
        window,
        text="Local Abort",
        variable=n0_init,
        value=int(State.LocalAbort.value),
        command=partial(print_selection, 0),
    )
    r2.pack()

    n1_init = tk.IntVar()
    n1 = tk.Label(window, bg="yellow", width=20, text="Node 1")
    n1.pack()

    r3 = tk.Radiobutton(
        window,
        text="Local Commit",
        variable=n1_init,
        value=int(State.LocalCommit.value),
        command=partial(print_selection, 1),
    )
    r3.pack()
    r4 = tk.Radiobutton(
        window,
        text="Local Abort",
        variable=n1_init,
        value=int(State.LocalAbort.value),
        command=partial(print_selection, 1),
    )
    r4.pack()

    n2_init = tk.IntVar()
    n2 = tk.Label(window, bg="yellow", width=20, text="Node 2")
    n2.pack()

    r5 = tk.Radiobutton(
        window,
        text="Local Commit",
        variable=n2_init,
        value=int(State.LocalCommit.value),
        command=partial(print_selection, 2),
    )
    r5.pack()
    r6 = tk.Radiobutton(
        window,
        text="Local Abort",
        variable=n2_init,
        value=int(State.LocalAbort.value),
        command=partial(print_selection, 2),
    )
    r6.pack()

    crash_nodes_var = tk.StringVar()
    show_crash_label = tk.Label(
        window,
        bg="green",
        fg="yellow",
        font=("Arial", 12),
        width=25,
        textvariable=crash_nodes_var,
    )
    show_crash_label.pack()
    show_crash_nodes_button = tk.Button(
        window,
        text="show crash nodes",
        width=15,
        height=2,
        command=lambda: crash_nodes_var.set(f"Crash Nodes: {crash_nodes}"),
    )
    show_crash_nodes_button.pack()

    node_list_var = tk.StringVar()
    node_list_var.set((0, 1, 2))

    var1 = tk.StringVar()
    l = tk.Label(
        window, bg="green", fg="yellow", font=("Arial", 12), width=25, textvariable=var1
    )
    l.pack()

    crash_nodes = set()
    alive_nodes = set()
    old_crashed_nodes = set()

    def crash_selection():
        value = lb.get(lb.curselection())
        crash = int(value)
        if crash in crash_nodes:
            var1.set(f"node {value} already crashed")
        else:
            crash_nodes.add(crash)
            var1.set(f"node {value} crashed")

        if len(crash_nodes) == 3:
            var1.set("ALL node crash!")
            window.after(50, window.destroy)
        alive_nodes.update({0, 1, 2} - crash_nodes)
        old_crashed_nodes.update(crash_nodes)

    select = tk.Button(
        window,
        text="phase 1 crash selection",
        width=20,
        height=2,
        command=crash_selection,
    )
    select.pack()

    lb = tk.Listbox(window, listvariable=node_list_var, height=3)
    lb.pack()

    check_node_index = 0  # point to the 1st alive nodes
    crash_node_index = 0  # point to the 1st crash nodes

    def select_receive():
        global check_node_index
        global crash_node_index
        if check_node_index >= len(alive_nodes):
            rec.config(text="all alive nodes have been checked")
        else:
            crash_n = list(crash_nodes)[crash_node_index]
            checking_node = list(alive_nodes)[check_node_index]
            if receive_var.get() == 1:
                rec.config(
                    text="Node"
                    + str(checking_node)
                    + " receive from Node"
                    + str(crash_n)
                )
            else:
                all_states[checking_node][crash_n] = int(State.Lost_R1.value)
                rec.config(
                    text="Node"
                    + str(checking_node)
                    + " didn't receive from Node"
                    + str(crash_n)
                )
            check_node_index = check_node_index + 1
            crash_node_index = (crash_node_index + 1) % len(crash_nodes)

    receive_var = tk.IntVar()
    rec = tk.Label(window, bg="yellow", width=25, text="Phase 1 Receive Selection")
    rec.pack()
    receive_button = tk.Radiobutton(
        window, text="Receive", variable=receive_var, value=1, command=select_receive
    )
    receive_button.pack()
    not_receive_button = tk.Radiobutton(
        window,
        text="Not Receive",
        variable=receive_var,
        value=0,
        command=select_receive,
    )
    not_receive_button.pack()

    # show the state of the nodes
    def get_next_state():
        next_states = []
        for n in [0, 1, 2]:
            if n in crash_nodes:
                node_labels[n].config(text=f"Node {n}: crashed")
                next_states.append(0)
                continue
            next = get_transition(transitions, all_states[n])
            next_states.append(next)
            node_labels[n].config(
                text=f"Node 0: {list(map(lambda x: meanings[x], all_states[n]))} "
                f"--> {meanings[next]}"
            )
        all_states[0] = next_states.copy()
        all_states[1] = next_states.copy()
        all_states[2] = next_states.copy()

    get_next_state_bt = tk.Button(
        window, text="get next state", width=15, height=2, command=get_next_state
    )
    get_next_state_bt.pack()
    node_labels = []
    node0_label = tk.Label(window, bg="yellow", width=55, text="Node 0")
    node0_label.pack()
    node_labels.append(node0_label)
    node1_label = tk.Label(window, bg="yellow", width=55, text=str(all_states[1]))
    node1_label.pack()
    node_labels.append(node1_label)
    node2_label = tk.Label(window, bg="yellow", width=55, text=str(all_states[2]))
    node2_label.pack()
    node_labels.append(node2_label)

    var2 = tk.StringVar()
    l2 = tk.Label(
        window, bg="green", fg="yellow", font=("Arial", 12), width=25, textvariable=var2
    )
    l2.pack()
    new_crash_nodes = set()

    def phase_2_crash_selection():
        value = lb2.get(lb2.curselection())
        crash = int(value)
        if crash in crash_nodes:
            var2.set(f"node {value} already crashed")
        else:
            new_crash_nodes.add(crash)
            crash_nodes.add(crash)
            var2.set(f"node {value} crashed")

        if len(crash_nodes) == 3:
            var2.set("ALL node crash!")
            window.after(50, window.destroy)
        alive_nodes.clear()
        alive_nodes.update({0, 1, 2} - crash_nodes)

    phase2 = tk.Button(
        window,
        text="phase 2 crash selection",
        width=20,
        height=2,
        command=phase_2_crash_selection,
    )
    phase2.pack()
    lb2 = tk.Listbox(window, listvariable=node_list_var, height=3)
    lb2.pack()

    # Determine wehther the alive nodes receive messages from the crashed nodes
    check_node_index_2 = 0
    crash_node_index_2 = 0

    def select_receive_2():
        global check_node_index_2
        global crash_node_index_2
        if len(new_crash_nodes) == 0:
            rec_2.config(text="no new crash nodes")
            return
        if check_node_index_2 >= len(alive_nodes):
            rec_2.config(text="all alive nodes have been checked")
        else:
            crash_n = list(new_crash_nodes)[crash_node_index_2]
            checking_node = list(alive_nodes)[check_node_index_2]
            if receive_var.get() == 1:
                rec_2.config(
                    text="Node"
                    + str(checking_node)
                    + " receive from Node"
                    + str(crash_n)
                )
            else:
                all_states[checking_node][crash_n] = int(State.Lost_R2.value)
                rec_2.config(
                    text="Node"
                    + str(checking_node)
                    + " didn't receive from Node"
                    + str(crash_n)
                )
            check_node_index_2 = check_node_index_2 + 1
            crash_node_index_2 = (crash_node_index_2 + 1) % len(new_crash_nodes)

    receive_var_2 = tk.IntVar()
    rec_2 = tk.Label(window, bg="yellow", width=25, text="Phase 2 Receive Selection")
    rec_2.pack()
    receive_button_2 = tk.Radiobutton(
        window,
        text="Receive",
        variable=receive_var_2,
        value=1,
        command=select_receive_2,
    )
    receive_button_2.pack()
    not_receive_button_2 = tk.Radiobutton(
        window,
        text="Not Receive",
        variable=receive_var_2,
        value=0,
        command=select_receive_2,
    )
    not_receive_button_2.pack()

    def get_next_state_2():
        for n in alive_nodes:
            for old_crash_ndoes in old_crashed_nodes:
                all_states[n][old_crash_ndoes] = int(State.Lost_R2.value)

        next_states = []
        for n in [0, 1, 2]:
            if n in crash_nodes:
                node_labels_2[n].config(text=f"Node {n}: crashed")
                next_states.append(0)
                continue
            next = get_transition(transitions, all_states[n])
            next_states.append(next)
            node_labels_2[n].config(
                text=f"Node 0: {list(map(lambda x: meanings[x], all_states[n]))} "
                f"--> {meanings[next]}"
            )

    get_next_state_bt_2 = tk.Button(
        window, text="get next state", width=15, height=2, command=get_next_state_2
    )
    get_next_state_bt_2.pack()
    node_labels_2 = []
    node0_label_2 = tk.Label(window, bg="yellow", width=55, text="Node 0")
    node0_label_2.pack()
    node_labels_2.append(node0_label_2)
    node1_label_2 = tk.Label(window, bg="yellow", width=55, text=str(all_states[1]))
    node1_label_2.pack()
    node_labels_2.append(node1_label_2)
    node2_label_2 = tk.Label(window, bg="yellow", width=55, text=str(all_states[2]))
    node2_label_2.pack()
    node_labels_2.append(node2_label_2)

    window.mainloop()
