from collections import defaultdict
import itertools
import random
import numpy as np
from MultiSM import MultiStateMachine
from atomic_commit.AtomicCommitEnv import AtomicCommitEnv
from atomic_commit.State import State as ACState
from config import CUTOF_INDEX
from generate_combs import CrashInfo, Info
from import_functions import get_action_space, import_reset_envs, import_step_envs
from primary_backup.PrimaryBackupEnv import PrimaryBackupEnv
from primary_backup.State import State as PBState
from distributed_locking.DistributedLockingEnv import DistributedLockingEnv
from distributed_locking.State import State as DLState
from utils import hash_observation, log


class EnvManager:
    def __init__(self, players, protocol, num_rounds, is_history, encode_id):
        self.players = players
        self.round = 0  # 0-based round number
        self.newly_crash: set = {}  # Store newly crashed nodes in the current round
        self.num_rounds = num_rounds
        self.done = False
        self.action_space = get_action_space(protocol)
        self.multi_sm = MultiStateMachine(players, protocol)
        self.reset_envs = import_reset_envs(protocol)
        self.step_envs = import_step_envs(protocol)
        self.receive_comb = {i: [] for i in range(num_rounds)}
        if protocol == "primary_backup":
            self.envs = [
                PrimaryBackupEnv(players, i, num_rounds, self.multi_sm, is_history, encode_id, is_training=True)
                for i in range(players)
            ]
            self.all_actions = [action.value for action in list(PBState)[: self.action_space]]
            self.lost_action = PBState.Lost.value
            self.actions_per_round = {
                i: self.all_actions[CUTOF_INDEX:] if i < num_rounds - 1 else self.all_actions for i in range(num_rounds)
            }  # Set action range for different rounds

        elif protocol == "distributed_locking":
            self.envs = [DistributedLockingEnv(players, i, num_rounds, self.multi_sm, is_training=True) for i in range(players)]
            self.all_actions = [DLState.Enter.value, DLState.NoEnter.value]
            self.actions_per_round = {0: self.all_actions}

        elif protocol == "atomic_commit":
            self.envs = [
                AtomicCommitEnv(players, i, num_rounds, self.multi_sm, is_history, encode_id, is_training=True)
                for i in range(players)
            ]
            self.all_actions = [action.value for action in list(ACState)[: self.action_space]]
            self.lost_action = ACState.Lost.value
            self.actions_per_round = {
                i: self.all_actions[CUTOF_INDEX:] if i < num_rounds - 1 else self.all_actions for i in range(num_rounds)
            }  # Set action range for different rounds

        self._backup = {}

    def filter_action(self, actions):
        """
        Filter out actions by defined rules
        """
        # Rule 1: If one node make decision at 1st round, keep that state at the following round
        conditions = []
        for i in range(self.players):
            if not self.is_crash(i) and self.envs[i].is_decided():
                conditions.append([i, self.envs[i].get_decision()])  # get limitation pair as [index, fix_value]

        filtered = [row for row in actions if all(row[i] == val for i, val in conditions)]
        return filtered

    def get_actions(self, round):
        """
        Get all possible actions based on currenct state
        """
        func = "get_actions"
        actions = self.actions_per_round[round]
        crash_idx = self.get_crash_nodes()
        if round < self.num_rounds - 1:
            crash_prev = crash_idx
        else:
            crash_prev = crash_idx - self.newly_crash
        alive_idx = [n for n in range(self.players) if n not in crash_prev]
        log(func, 2, f"crash_idx: {crash_idx}")

        # Get the unique observation states of all alive nodes
        obs_states = {hash_observation(self.get_states(n)) for n in alive_idx}
        num_unique = len(obs_states)

        combos_per_state = list(itertools.product(actions, repeat=num_unique))
        random.shuffle(combos_per_state)
        all_combs = []
        for comb in combos_per_state:
            # Create a mapping of observation state to action
            state_to_action_map = {state: action for state, action in zip(obs_states, comb)}
            full_comb = [None] * self.players
            for i in range(self.players):
                # Crashed prior to this round  -> Lost
                if i in crash_prev:
                    full_comb[i] = self.lost_action
                    continue
                # New crash but not last round -> Lost
                if i in self.newly_crash and round < self.num_rounds - 1:
                    full_comb[i] = self.lost_action
                    continue
                # Alive or newly-crashing      -> normal mapping
                obs_hash = hash_observation(self.get_states(i))
                full_comb[i] = state_to_action_map[obs_hash]
            all_combs.append(full_comb)
        return all_combs

    def step(self, actions, crash_info: CrashInfo):
        self.step_envs(self.envs, actions, self.players, crash_info)
        self.newly_crash = set(crash_info.crash) if crash_info else None
        # Need to explicitly increase the round number because this is increased in _step() in PrimaryBackupEnv
        # But we will not call _step in this implementation
        for env in self.envs:
            env.current_round += 1
        self.round += 1
        self.receive_comb[self.round] = crash_info.receive if crash_info else None

    def step_back(self, steps):
        self.multi_sm.step_back(steps)

    def init(self, input_setting: Info):
        self.round = 0
        self.reset_envs(self.envs, self.players, input_setting)
        self.newly_crash = set(input_setting.get_crash_info(0).crash)
        self.receive_comb[self.round] = input_setting.get_crash_info(0).receive

    def get_rewards(self):
        raw_rewards = [env.reward() for env in self.envs]  # Each returns (reward, is_global_fault)
        rewards = [r for r, _ in raw_rewards]
        global_flags = [g for _, g in raw_rewards]
        # Crash nodes may get 0 reward, but if all other nodes get positive rewards, we set crash nodes' rewards to 1
        if self.is_done():
            if all(r >= 0 for r in rewards):
                rewards = [1] * self.players
            if any(r < 0 for r in rewards):
                rewards = [-1] * self.players

        return rewards

    def is_done(self):
        return self.envs[0]._done

    def store(self):
        self._backup = {"round": self.round, "newly_crash": self.newly_crash, "receive_comb": self.receive_comb.copy()}
        for env in self.envs:
            env.store()

    def restore(self):
        for name, value in self._backup.items():
            setattr(self, name, value)
        self._backup = {}
        for env in self.envs:
            env.restore()

    def get_alive_nodes(self):
        return [i for i in range(self.players) if i not in self.envs[0].crash_nodes]

    def get_crash_nodes(self):
        return self.envs[0].crash_nodes

    def is_crash(self, id):
        assert id < self.players
        return self.envs[id].is_crash

    def get_env(self, id):
        assert id < self.players
        return self.envs[id]

    def get_states(self, id):
        assert id < self.players
        return self.envs[id].states

    def get_received_msg(self, id):
        assert id < self.players
        return self.envs[id].prev_messages

    def get_zero_based_round(self):
        return self.round

    def get_cur_receive_comb(self):
        return self.receive_comb[self.round]

    def construct_crash_key(self, crash_info: CrashInfo):
        mat_str = np.array2string(crash_info.receive.astype(int), separator="")
        mat_str = mat_str.replace("\n", "")
        return str(set(crash_info.crash) | self.get_crash_nodes()) + " " + mat_str

    def get_current_states(self):
        current_state = [env.states for env in self.envs]
        all_transitions = [env.state_machine.get_transitions() for env in self.envs]
        log("EnvManager: get_current_states", 1, f"Current state: {current_state}")
        log("EnvManager: get_current_states", 1, f"Transitions: {all_transitions}")
        return current_state, all_transitions

    def print_current_states(self):
        current_state, all_transitions = self.get_current_states()
        print(f"Initial States: {self.envs[0].initial_states}")
        for i in range(self.players):
            print(f"Node {i} transitions: {all_transitions[i]}")
