import sys
sys.path.append("../")
import numpy as np

from tf_agents.environments import py_environment

from tf_agents.specs import array_spec
from tf_agents.trajectories import time_step as ts

from config import REWARDS
from distributed_locking.State import State
from utils import log


class DistributedLockingEnv(py_environment.PyEnvironment):
    def __init__(self, players, index, num_round, multi_sm, is_training):
        self.idx = index
        self.players = players
        self.obs_shape = players + 1
        self.action_space = 2
        # action: 0: enter, 1: no enter
        self._action_spec = array_spec.BoundedArraySpec(shape=(), dtype=np.int32, minimum=0, maximum=1, name="action")
        # observation: 2: need lock, 3: no need lock
        self._observation_spec = array_spec.BoundedArraySpec(shape=(self.obs_shape,), dtype=np.int32,
                                                             minimum=2, maximum=3, name="observation")
        self._reward_spec = array_spec.ArraySpec(shape=(), dtype=np.float32, name="reward")
        self.initial_states = [0] * self.players
        self.global_state_machine = multi_sm
        self.state_machine = self.global_state_machine[index]
        self.crash_nodes = set()
        self.is_crash = False
        self.current_round = 1
        # internal states that set by DistributedLocking.py, exposed to tf_agents api
        self.states = None

        self._done = False

    def action_spec(self):
        return self._action_spec

    def observation_spec(self):
        return self._observation_spec

    def reward_spec(self):
        return self._reward_spec

    def my_reset(self):
        self.initial_states = [0] * self.players
        self.current_round = 1
        self.state_machine.reset()
        self.crash_nodes = set()
        self.is_crash = False
        self.states = None
        self._done = False

    def transit(self, action):
        self.state_machine.transit(action)

    def _reset(self):
        return ts.restart(np.array(self.states, dtype=np.int32),
                          reward_spec=self._reward_spec)

    def _step(self, action):
        this_func = "step"
        log(this_func, 2, f"{self.states}")
        reward = self.reward()
        if self._done:
            return ts.termination(np.array(self.states, dtype=np.int32), np.array(reward, dtype=np.float32),
                                     outer_dims=())
        else:
            return ts.transition(np.array(self.states, dtype=np.int32), np.array(reward, dtype=np.float32),
                                    discount=1.0, outer_dims=())

    def reward(self):
        this_func = "DistributedLockingEnv_reward"
        reward = 0
        log(this_func, 2, f"Initial states: {self.initial_states}")

        transitions = self.state_machine.get_transitions()
        need_lock = self.state_machine[0] == int(State.Need.value)
        enter_cs = self.state_machine[1] == int(State.Enter.value)
        global_need_lock = list(map(lambda x: x[0], self.global_state_machine))
        global_enter_cs = list(map(lambda x: x[1], self.global_state_machine))
        need_lock_list = self.find_all_index(global_need_lock, int(State.Need.value))
        enter_cs_list = self.find_all_index(global_enter_cs, int(State.Enter.value))
        need_and_enter_list = [ele for ele in need_lock_list if ele in enter_cs_list]
        min_need = need_lock_list[0] if len(need_lock_list) >= 1 else -1  # idx of 1st node need lock
        log(this_func, 2, f"Node {min_need} is the minimum node that needs lock")
        log(this_func, 2, f"node {self.idx} transitions: {transitions}")

        # r1: Doesn't need lock, but enter CS, negative reward
        if not need_lock and enter_cs:
            log(this_func, 2, f"node {self.idx} doesn't need lock, but enter CS, negative reward")
            return REWARDS["Bad"]

        # r2: Multiple nodes enter CS, negative reward
        ## Try to penalize only when multiple nodes need CS and entered CS??
        if self.idx in need_and_enter_list and len(need_and_enter_list) > 1:
            ## if entered CS is a subset of the nodes that need CS, then penalize
            log(this_func, 2, f"Multiple nodes ({enter_cs_list}) enter CS, negative reward")
            return REWARDS["Bad"]

        #TODO: it should be any one want a lock or I want a lock?
        # if min_need != -1 and need_lock:
        #     if self.idx == min_need:
        #         if enter_cs:
        #             log(this_func, 2, "I'm the min need, and I got it, positive reward")
        #             return REWARDS["Good"][0]
        #         else:
        #             log(this_func, 2, "I'm the min need, but I didn't get it, negative reward")
        #             return REWARDS["Bad"]
        #     else:
        #         if enter_cs:
        #             log(this_func, 2, "I'm not the min need, but I got it, negative reward")
        #             return REWARDS["Bad"]
        #         else:
        #             log(this_func, 2, "I'm not the min need, and I didn't get it, positive reward")
        #             return REWARDS["Good"][0]
        if need_lock:
            # r3: If some node wants to enter cs but no one entered cs, negative rewards
            # if some node doesn't need lock, but enter, we should still penalize those nodes
            # who need lock but no one enter
            if len(need_and_enter_list) == 0:
                log(this_func, 2, f"node {need_lock_list} wants the lock "
                               f"but no one entered cs")
                return REWARDS["Bad"]
        #
        # r4: If some node wants to enter cs and one node entered cs, positive rewards
        # if need_lock and enter_cs:
        #     log(this_func, 2, f"node {self.idx} wants the lock and "
        #                    f"node {self.idx} entered the cs")
        #     return REWARDS["Good"][0] + need_and_enter_list[0] # Give higher reward to the last node that needs lock and entered

        # if len(need_and_enter_list) == 0:
        #     return REWARDS["Good"][0]
        
        return REWARDS["Good"][0] # Reward those who needed lock and didn't enter

    # Helper functino to return all indices of target value in an array
    def find_all_index(self, vector, target):
        return [index for index, value in enumerate(vector) if value == target]