import random
import numpy as np

from config import CRASH_RATIO, RECEIVE_CRASH_RATIO
from primary_backup.State import State
from utils import construct_bits_string, evenly_distributed_crash_with_survival, log


def reset_envs(envs, players, input=None):
    for env in envs:
        env.my_reset()

    if input is None:
        s = random.randint(0, 2**players - 1)
    else:
        s = input.get_init()

    states = list(construct_bits_string(s, players))
    initial_states = [0] * players
    for i in range(players):
        if int(states[i]) == 0:
            initial_states[i] = int(State.LocalZero.value)
        if int(states[i]) == 1:
            initial_states[i] = int(State.LocalOne.value)

    for i, env in enumerate(envs):
        env.initial_states = initial_states[:]
        env.state_machine.set_initial(initial_states[i])

    return insert_failure(envs, initial_states, 1, envs[0].encode_history, envs[0].encode_id, input.get_crash_info(0))


def step_envs(envs, actions, players, input=None):
    this_func = "PB: step_envs"
    current_round = envs[0].current_round
    num_round = envs[0].num_round

    # replace actions to lost if the node is crashed
    for i in range(players):
        is_crashed = envs[i].is_crash
        if actions[i] == State.Lost.value:
            continue

        # Non-final round: crashed nodes cannot act
        if current_round != num_round:
            if is_crashed:
                actions[i] = State.Lost.value
        # At last round, we consider the situation that a node crash after it has made a decision
        else:
            # If the node is crashed at previous round, then we don't need to consider its action
            if envs[i].is_crash and envs[i].get_prev_action() == State.Lost.value:
                actions[i] = State.Lost.value

    for i in range(players):
        envs[i].transit(actions[i])

    if current_round == num_round:
        for env in envs:
            # env.states = actions[:]
            env._done = True
    else:
        return insert_failure(envs, actions[:], current_round + 1, envs[0].encode_history, envs[0].encode_id, input)


def insert_failure(envs, input_states, round, history, encode_id, input=None):
    this_func = f"PB: insert_failure, round{round}"
    sample_env = envs[0]
    crash_nodes = sample_env.crash_nodes
    players = sample_env.players

    # Randomly select crashed nodes
    new_crashed_nodes, alive, r_mat = input.get_info()
    if round == envs[0].num_round: # round is 1-based
        receivers = sorted(alive + new_crashed_nodes)
    else:
        receivers = sorted(alive)

    log(this_func, 2, f"new crashed nodes: {new_crashed_nodes}")
    for i in new_crashed_nodes:
        envs[i].crash()

    for i, env in enumerate(envs):
        env.states = input_states[:]
        
    # Update states for nodes that are not crashed or newly crashed.
    for r_idx, r_node in enumerate(receivers):
        assert r_node not in crash_nodes, f"Receiver {r_node} is already crashed!"
        next_states = input_states[:]

        # Apply receiver's message-loss pattern
        for c_idx, c_node in enumerate(new_crashed_nodes):
            if not r_mat[r_idx][c_idx]:  # Message lost
                next_states[c_node] = State.Lost.value
                log(this_func, 2, f"receiver {r_node} ← sender {c_node}: LOST")
            else:
                log(this_func, 2, f"receiver {r_node} ← sender {c_node}: RECEIVED")

        envs[r_node].states = next_states[:]

    final_messages = [None] * players
    for i, env in enumerate(envs):
        final_messages[i] = env.states[:]

    # Insert history states of each node and update node state
    for i, env in enumerate(envs):
        if history:
            if encode_id:
                env.states = [i] + [round - 1] + insert_history(env, env.states)
            else:
                env.states = [round - 1] + insert_history(env, env.states)
        else:
            if encode_id:
                env.states = [i] + [round - 1] + env.states
            else:
                env.states = [round - 1] + env.states

    # Update crash_nodes
    crash_nodes.update(set(new_crashed_nodes))
    for env in envs:
        env.crash_nodes = crash_nodes

    # Have to update env.prev_messages after insert_history, because insert_history will use env.prev_messages
    for i, env in enumerate(envs):
        env.prev_messages = final_messages[i][:]
    # Return receive status for each alive node to MCTS, so that it can differentiate the CrashNode


def insert_history(env, input_state):
    output_state = []
    for i in range(len(input_state)):
        output_state.append(input_state[i])
        if env.prev_messages is None:
            output_state.append(State.Dummy.value)
        else:
            output_state.append(env.prev_messages[i])

    return output_state
