import sys
import math

sys.path.append("../")
import numpy as np

from tf_agents.environments import py_environment

from tf_agents.specs import array_spec
from tf_agents.trajectories import time_step as ts
from math_func.State import State
from config import REWARDS
from utils import log

'''
  Be used to model sum/min_idx/max_idx
'''
class MathEnv(py_environment.PyEnvironment):
    def __init__(self, players, index, multi_sm, conf):
        self.idx = index
        self.players = players
        self.obs_shape = players
        self.conf = conf
        self.fn = conf.math_fn
        # action: 0 - players, which represent indices.
        # Pick one node or no node, if no pick, output is <players>
        self._action_spec = array_spec.BoundedArraySpec(shape=(), dtype=np.int32, minimum=0, maximum=players, name="action")
        # observation: 0: NoVote, 1: Vote
        self._observation_spec = array_spec.BoundedArraySpec(shape=(self.obs_shape,), dtype=np.int32,
                                                             minimum=0, maximum=1, name="observation")
        self._reward_spec = array_spec.ArraySpec(shape=(), dtype=np.float32, name="reward")
        self.initial_states = [0] * self.players
        self.global_state_machine = multi_sm
        self.state_machine = self.global_state_machine[index]

        self.states = None
        self._done = False

    def action_spec(self):
        return self._action_spec

    def observation_spec(self):
        return self._observation_spec

    def reward_spec(self):
        return self._reward_spec

    def my_reset(self):
        self.initial_states = [0] * self.players
        self.state_machine.reset()
        self.states = None
        self._done = False

    def transit(self, action):
        self.state_machine.transit(action)

    def _reset(self):
        return ts.restart(np.array(self.states, dtype=np.int32),
                          reward_spec=self._reward_spec)

    def _step(self, action):
        this_func = "step"
        log(this_func, f"{self.states}")
        reward = self.reward()
        if self._done:
            return ts.termination(np.array(self.states, dtype=np.int32), np.array(reward, dtype=np.float32),
                                     outer_dims=())
        else:
            return ts.transition(np.array(self.states, dtype=np.int32), np.array(reward, dtype=np.float32),
                                    discount=1.0, outer_dims=())


    def reward(self):
        this_func = "MathEnv_reward"
        log(this_func, f"Initial states: {self.initial_states}")
        assert(self.fn != None)
        log(this_func, self.fn)
        reward = 0
        global_value = list(map(lambda x: x[0], self.global_state_machine))
        assert(self.initial_states == global_value)
        indices = self.find_all_index(self.initial_states, int(State.One.value))
        transitions = self.state_machine.get_transitions()
        log(this_func, f"node {self.idx} transitions: {transitions}")
        log(this_func, f"indices of 1s: {indices}")
        action = transitions[1]
        if self.fn == "min":
            reward = self.reward_min()
        elif self.fn == "max":
            reward = self.reward_max()
        elif self.fn == "sum":
            # return sum(input)
            reward = self.reward_sum(action)
        elif self.fn == "min_idx":
            # return min(
            reward = self.reward_min_idx(action, indices)
        elif self.fn == "max_idx":
            reward = self.reward_max_idx(action, indices)

        return reward

    def reward_min(self):
        this_func = "reward_min"
        log(this_func, "no implementation of reward_min")
        return 0

    def reward_max(self):
        this_func = "reward_max"
        log(this_func, "no implementation of reward_max")
        return 0

    def reward_sum(self, action):
        this_func = "reward_sum"
        sum = self.initial_states.count(int(State.One.value))
        if action == sum:
            log(this_func, f"sum == {sum}, and output {action}, GOOD")
            return REWARDS["Good"][0]
        else:
            log(this_func, f"sum == {sum}, and output {action}, Bad")
            return REWARDS["Bad"]

    def reward_min_idx(self, action, indices):
        this_func = "reward_min_idx"
        if len(indices) == 0:
            correct = self.players
        else:
            correct = min(indices)
        if action == correct:
            log(this_func, f"indices = {indices}, output {action}, GOOD")
            return REWARDS["Good"][0]
        else:
            log(this_func, f"indices = {indices}, output {action}, Bad")
            return REWARDS["Bad"]

    def reward_max_idx(self, action, indices):
        this_func = "reward_max_idx"
        if len(indices) == 0:
            correct = self.players
        else:
            correct = max(indices)
        if action == correct:
            log(this_func, f"indices = {indices}, output {action}, GOOD")
            return REWARDS["Good"][0]
        else:
            log(this_func, f"indices = {indices}, output {action}, Bad")
            return REWARDS["Bad"]

    def find_all_index(self, vector, target):
        return [index for index, value in enumerate(vector) if value == target]