from config import REWARDS
from atomic_commit.State import State
from utils import log


class AtomicCommitEnv:
    def __init__(self, players, index, num_round, multi_sm, is_history, encode_id, is_training=True):
        self.idx = index
        self.is_crash = False
        self.players = players
        self.obs_shape = players * 2
        self.initial_states = [0] * self.players
        self.current_round = 1
        self.num_round = num_round
        self.global_state_machine = multi_sm
        self.state_machine = self.global_state_machine[index]
        self.crash_nodes = set()

        self.states = None
        self.prev_messages = None
        self.is_training = is_training
        self.encode_history = is_history
        self.encode_id = encode_id
        self._done = False

        self._backup = {}  # backup the state variables when store() is called

    def my_reset(self):
        self.is_crash = False
        self.initial_states = [0] * self.players
        self.current_round = 1
        self.state_machine.reset()
        self.crash_nodes = set()
        self.states = None
        self.prev_messages = None
        self._done = False

    def crash(self):
        self.is_crash = True

    def transit(self, action):
        self.state_machine.transit(action)

    def is_decided(self):
        """Check if this node already made a decision"""
        if self.state_machine.final_decision is not None:
            return True
        return False

    def get_decision(self):
        """Get the decision of this node"""
        return self.state_machine.final_decision

    def get_prev_action(self):
        """Get the previous action of this node"""
        return self.state_machine.get_prev_state()

    def store(self):
        """Store the current values of fields I want to save"""
        self._backup = {
            "is_crash": self.is_crash,
            "current_round": self.current_round,
            "crash_nodes": self.crash_nodes.copy(),
            "states": self.states[:],
            "prev_messages": self.prev_messages[:],
            "_done": self._done,
        }

    def restore(self):
        """Restore the saved values"""
        if not self._backup:
            raise ValueError("No backup to restore")
        for name, value in self._backup.items():
            setattr(self, name, value)
        self._backup = {}

    def reward(self):
        """
        Return: reward of one node, is global fault
        """
        this_func = "AtomicCommitSMEnv_reward"
        if not self._done:
            log(this_func, 2, f"Protocol is not done. Current round is {self.current_round}, no rewards")
            return 0, None
        # Skip reward evaluation if the node is crashed and:
        # - it is not in the last round (crashed earlier), or
        # - it is in the last round but found a 'Lost' state (also implies earlier crash)
        if self.is_crash:
            if self.current_round < self.num_round:
                return 0, None
            if State.Lost.value in self.state_machine.get_transitions():
                return 0, None

        log(this_func, 2, f"Current round is {self.current_round}")
        log(this_func, 2, f"Initial states: {self.initial_states}")
        all_commits = all(s == State.LocalCommit.value for s in self.initial_states)
        any_abort = any(s == State.LocalAbort.value for s in self.initial_states)
        transitions = self.state_machine.get_transitions()
        my_decision = self.state_machine.final_decision
        all_decisions = [sm.final_decision for sm in self.global_state_machine]
        log(this_func, 2, f"Transitions: {transitions}")
        log(this_func, 2, f"My Decision: {my_decision}")

        # R1: Cannot change action from abort->commit or commit -> abort
        if State.Abort.value in transitions and State.Commit.value in transitions:
            log(this_func, 2, "change action, negative rewards")
            return REWARDS["Bad"], None

        # R2: If all initial states are LocalCommit and no crash, should COMMIT
        if all_commits:
            if State.Abort.value in transitions and len(self.crash_nodes) == 0:
                log(this_func, 2, "all localcommits and no crash: Abort, negative rewards")
                return REWARDS["Bad"], True

        # R3: If there is localabort in initial states, should not commit
        if any_abort:
            if State.Commit.value in transitions:
                log(this_func, 2, "There is localabort but Commit: negative rewards")
                return REWARDS["Bad"], True

        # R4: should have decision if there is no crash
        if my_decision is None:
            log(this_func, 2, "no decision at the end and no crash, negative rewards")
            return REWARDS["Bad"], True

        # R5: Decision conflicts
        if my_decision is not None:
            for i, decision in enumerate(all_decisions):
                if decision is None:
                    continue
                if my_decision != decision:
                    log(
                        this_func,
                        2,
                        f"My Decision: {my_decision}, Node {i} decision: {decision}.Decision conflicts, negative rewards",
                    )
                    return REWARDS["Bad"], True

        # R6: Only make decision at one round
        if transitions.count(State.Abort.value) > 1 or transitions.count(State.Commit.value) > 1:
            log(this_func, 2, "Make decision at more than one round, negative rewards")
            return REWARDS["Bad"], False

        return REWARDS["Good"][0], None
