import sys
import math
from simple_counter.State import State

sys.path.append("../")
import numpy as np

from tf_agents.environments import py_environment

from tf_agents.specs import array_spec
from tf_agents.trajectories import time_step as ts

from config import REWARDS
from utils import log

class CounterEnv(py_environment.PyEnvironment):
    def __init__(self, players, index, multi_sm, conf):
        self.idx = index
        self.players = players
        self.obs_shape = players
        self.conf = conf
        # action: 0: Selected, 1: Not Selected
        # Corresponding state value: 2: Selected, 3: Not Selected
        self._action_spec = array_spec.BoundedArraySpec(shape=(), dtype=np.int32, minimum=0, maximum=1, name="action")
        # observation: 0: NoVote, 1: Vote
        self._observation_spec = array_spec.BoundedArraySpec(shape=(self.obs_shape,), dtype=np.int32,
                                                             minimum=0, maximum=1, name="observation")
        self._reward_spec = array_spec.ArraySpec(shape=(), dtype=np.float32, name="reward")
        self.initial_states = [0] * self.players
        self.global_state_machine = multi_sm
        self.state_machine = self.global_state_machine[index]

        self.states = None
        self._done = False

    def action_spec(self):
        return self._action_spec

    def observation_spec(self):
        return self._observation_spec

    def reward_spec(self):
        return self._reward_spec

    def my_reset(self):
        self.initial_states = [0] * self.players
        self.state_machine.reset()
        self.states = None
        self._done = False

    def transit(self, action):
        self.state_machine.transit(action)

    def _reset(self):
        return ts.restart(np.array(self.states, dtype=np.int32),
                          reward_spec=self._reward_spec)

    def _step(self, action):
        this_func = "step"
        log(this_func, f"{self.states}")
        reward = self.reward()
        if self._done:
            return ts.termination(np.array(self.states, dtype=np.int32), np.array(reward, dtype=np.float32),
                                     outer_dims=())
        else:
            return ts.transition(np.array(self.states, dtype=np.int32), np.array(reward, dtype=np.float32),
                                    discount=1.0, outer_dims=())

    def reward(self):
        this_func = "CounterEnv_reward"
        reward = 0
        log(this_func, f"Initial states: {self.initial_states}")

        transitions = self.state_machine.get_transitions()
        action = transitions[1]
        cnt_vote = self.initial_states.count(int(State.Vote.value))
        log(this_func, f"node {self.idx} transitions: {transitions}")
        majority = math.ceil(self.players / 2)
        if cnt_vote >= majority:
            if action == int(State.Selected.value):
                log(this_func, f"majority reaches, and selected, GOOD")
                reward = REWARDS["Good"][0]
            else:
                log(this_func, f"majority reaches, but not selected, BAD")
                reward = REWARDS["Bad"]

        if cnt_vote < majority:
            if action == int(State.Selected.value):
                log(this_func, f"majority not reached, but selected, BAD")
                reward = REWARDS["Bad"]
            else:
                log(this_func, f"majority not reached, and not selected, GOOD")
                reward = REWARDS["Good"][0]

        return reward
