from config import REWARDS
from primary_backup.State import State
from utils import log


class PrimaryBackupEnv:
    def __init__(self, players, index, num_round, multi_sm, is_history, encode_id, is_training=True):
        self.idx = index
        self.is_crash = False
        self.players = players
        self.obs_shape = players * 2
        self.initial_states = [0] * self.players
        self.current_round = 1
        self.num_round = num_round
        self.global_state_machine = multi_sm
        self.state_machine = self.global_state_machine[index]
        self.crash_nodes = set()

        self.states = None
        self.prev_messages = None
        self.is_training = is_training
        self.encode_history = is_history
        self.encode_id = encode_id
        self._done = False

        self._backup = {}  # backup the state variables when store() is called

    def action_state_map(self, action):
        assert State(action).value == action
        return State(action).value

    def action_spec(self):
        return self._action_spec

    def observation_spec(self):
        return self._observation_spec

    def reward_spec(self):
        return self._reward_spec

    def my_reset(self):
        self.is_crash = False
        self.initial_states = [0] * self.players
        self.current_round = 1
        self.state_machine.reset()
        self.crash_nodes = set()
        self.states = None
        self.prev_messages = None
        self._done = False

    def crash(self):
        self.is_crash = True

    def transit(self, action):
        self.state_machine.transit(action)

    def is_decided(self):
        """Check if this node already made a decision"""
        if self.state_machine.final_decision is not None:
            return True
        return False

    def get_decision(self):
        """Get the decision of this node"""
        return self.state_machine.final_decision

    def get_prev_action(self):
        """Get the previous action of this node"""
        return self.state_machine.get_prev_state()

    def store(self):
        """Store the current values of fields I want to save"""
        self._backup = {
            "is_crash": self.is_crash,
            "current_round": self.current_round,
            "crash_nodes": self.crash_nodes.copy(),
            "states": self.states[:],
            "prev_messages": self.prev_messages[:],
            "_done": self._done,
        }

    def restore(self):
        """Restore the saved values"""
        if not self._backup:
            raise ValueError("No backup to restore")
        for name, value in self._backup.items():
            setattr(self, name, value)
        self._backup = {}

    def reward(self):
        """
        Return: reward of one node, is global fault
        """
        this_func = f"PrimaryBackupEnv_reward: Node {self.idx}"
        if not self._done:
            log(this_func, 2, f"Protocol is not done. Current round is {self.current_round}, no rewards")
            return 0, None

        # Skip reward evaluation if the node is crashed and:
        # - it is not in the last round (crashed earlier), or
        # - it is in the last round but found a 'Lost' state (also implies earlier crash)
        if self.is_crash:
            if self.current_round < self.num_round:
                return 0, None
            if State.Lost.value in self.state_machine.get_transitions():
                return 0, None

        log(this_func, 2, f"Current round is {self.current_round}")
        log(this_func, 2, f"Initial states: {self.initial_states}")
        transitions = self.state_machine.get_transitions()
        my_decision = self.state_machine.final_decision
        all_decisions = [sm.final_decision for sm in self.global_state_machine]
        log(this_func, 2, f"Transitions: {transitions}")
        log(this_func, 2, f"My Decision: {my_decision}")

        # R1:Cannot change decision
        if State.Zero.value in transitions[1:] and State.One.value in transitions[1:]:
            log(this_func, 2, "Changed decisions, negative rewards")
            return REWARDS["Bad"], False

        # R2: The decision doesn't exist in the initial states
        if my_decision is not None:
            if self.initial_states.count(State.LocalZero.value) == self.players and my_decision == State.One.value:
                log(this_func, 2, "Decision doesn't exist in the initial states, negative rewards")
                return REWARDS["Bad"], True
            if self.initial_states.count(State.LocalOne.value) == self.players and my_decision == State.Zero.value:
                log(this_func, 2, "Decision doesn't exist in the initial states, negative rewards")
                return REWARDS["Bad"], True

        # R3: There is no decision at the end
        if my_decision is None:
            log(this_func, 2, "No decision at the end, negative rewards")
            return REWARDS["Bad"], True

        # R4: Decision conflicts
        # my_decision and majority_decision should be ONE or ZERO when reach this point
        for i, decision in enumerate(all_decisions):
            if decision is None:
                continue
            if my_decision != decision:
                return REWARDS["Bad"], True

        # R5: Only make decision at one round
        if transitions.count(State.Zero.value) > 1 or transitions.count(State.One.value) > 1:
            log(this_func, 2, "Make decision at more than one round, negative rewards")
            return REWARDS["Bad"], False

        return REWARDS["Good"][0], None
