import sys
import math
from simple_counter.State import State

sys.path.append("../")
import numpy as np

from tf_agents.environments import py_environment

from tf_agents.specs import array_spec
from tf_agents.trajectories import time_step as ts

from config import REWARDS
from utils import log

class ComplexCounterEnv(py_environment.PyEnvironment):
    def __init__(self, players, index, multi_sm, conf):
        self.idx = index
        self.players = players
        self.obs_shape = players
        self.conf = conf
        # action: 0 - players, output the number of 1s
        self._action_spec = array_spec.BoundedArraySpec(shape=(), dtype=np.int32, minimum=0, maximum=players, name="action")
        # observation: 0: NoVote, 1: Vote
        self._observation_spec = array_spec.BoundedArraySpec(shape=(self.obs_shape,), dtype=np.int32,
                                                             minimum=0, maximum=1, name="observation")
        self._reward_spec = array_spec.ArraySpec(shape=(), dtype=np.float32, name="reward")
        self.initial_states = [0] * self.players
        self.global_state_machine = multi_sm
        self.state_machine = self.global_state_machine[index]

        self.states = None
        self._done = False

    def action_spec(self):
        return self._action_spec

    def observation_spec(self):
        return self._observation_spec

    def reward_spec(self):
        return self._reward_spec

    def my_reset(self):
        self.initial_states = [0] * self.players
        self.state_machine.reset()
        self.states = None
        self._done = False

    def transit(self, action):
        self.state_machine.transit(action)

    def _reset(self):
        return ts.restart(np.array(self.states, dtype=np.int32),
                          reward_spec=self._reward_spec)

    def _step(self, action):
        this_func = "step"
        log(this_func, f"{self.states}")
        reward = self.reward()
        if self._done:
            return ts.termination(np.array(self.states, dtype=np.int32), np.array(reward, dtype=np.float32),
                                     outer_dims=())
        else:
            return ts.transition(np.array(self.states, dtype=np.int32), np.array(reward, dtype=np.float32),
                                    discount=1.0, outer_dims=())

    def reward(self):
        this_func = "ComplexCounterEnv_reward"
        reward = 0
        log(this_func, f"Initial states: {self.initial_states}")

        transitions = self.state_machine.get_transitions()
        action = transitions[1]
        cnt_vote = self.initial_states.count(int(State.Vote.value))
        log(this_func, f"node {self.idx} transitions: {transitions}")
        if action == cnt_vote:
            log(this_func, f"{cnt_vote} nodes vote, and output {action}, GOOD")
            return REWARDS["Good"][0]
        else:
            log(this_func, f"{cnt_vote} nodes vote, and output {action}, Bad")
            return REWARDS["Bad"]

        return reward
