from enum import Enum

'''
    Digits for Action:
    4: Abort
    5: Commit
    6: Do nothing
    3: Lost

    Digits in state/message:
    0: no information (fill unneeded positions)
    1: abort
    2：commit
    3: Lost
'''


class Act(Enum):
    Abort = 4
    Commit = 5
    DoNothing = 6
    Lost = 3


class Stat(Enum):
    Dummy = 0
    Abort = 1
    Commit = 2
    Lost = 3  # skip 2 to differentiate message loss and DoNothing


def validate(env, actions):
    this_func = "validate"
    rewards = [0.] * env.players
    if env.current_round == 2:
        all_commits = all(s == Stat.Commit.value for s in env.initial_states)
        any_aborts = any(s == Stat.Abort.value for s in env.initial_states)
        for i in range(env.players):
            if i in env.crash_nodes:
                continue

            # if len(env.crash_nodes) > 0:
            #     if env.final_decision[i] == Act.Abort.value:
            #         rewards[i] = max(rewards[i], 50)

            # Cannot change action from commit->abort
            if env.prev_action[i] == Act.Commit.value and actions[i] == Act.Abort.value:
                rewards[i] = -1
                continue
            # Cannot change action from abort->commit
            if env.prev_action[i] == Act.Abort.value and actions[i] == Act.Commit.value:
                rewards[i] = -1
                continue

            # if all initial states are Commit
            if all_commits:
                if env.final_decision[i] == Act.Commit.value and len(env.crash_nodes) == 0:
                    rewards[i] = max(rewards[i], 0.5)
                if env.final_decision[i] != Act.Commit.value and len(env.crash_nodes) == 0:
                    rewards[i] = -1
                    continue

            # if you see Abort in initial, then should not commit
            if any_aborts:
                if env.final_decision[i] == Act.Commit.value:
                    rewards[i] = -1
                    continue
                if env.final_decision[i] == Act.Abort.value:
                    rewards[i] = max(rewards[i], 0.5)

            # Should not do nothing at both rounds
            if env.prev_action[i] == Act.DoNothing.value and actions[i] == Act.DoNothing.value:
                rewards[i] = -0.5
                continue


        # if all(r >= 0 for r in rewards):
        if len(env.crash_nodes) > 0:
            env.final_decision = [a for a in env.final_decision if a != None]
        env.final_decision = [a for a in env.final_decision if a != Act.DoNothing.value]
        if len(set(env.final_decision)) > 1:
            for i in range(env.players):
                if i not in env.crash_nodes:
                    rewards[i] = -1

    return rewards