from enum import Enum

import tensorflow as tf
import numpy as np
import random
import time

from tf_agents.environments import py_environment
from tf_agents.specs import array_spec
from tf_agents.trajectories import time_step as ts

from tf_agents.trajectories import time_step as ts
from utils import construct_bits_string, pad_input_state, map_idx_to_action, log

class Act(Enum):
    Abort = 0
    Commit = 1


class Stat(Enum):
    Abort = 1
    Commit = 2
    Lost = 3

class CoordinatorEnv(py_environment.PyEnvironment):
    def __init__(self, players):
        self.players = players
        self._action_spec = array_spec.BoundedArraySpec(shape=(), dtype=np.int32, minimum=0, maximum=1, name="action")
        self._observation_spec = array_spec.BoundedArraySpec(shape=(self.players, ), dtype=np.int32,
                                                             minimum=0, maximum=3, name="observation")
        self._reward_spec = array_spec.ArraySpec(shape=(), dtype=np.float32, name="reward")

        self._state = [0] * self.players
        self._init_states = [0] * self.players
        self._done = False
        self._has_crash = False
        self._random = True

    def action_spec(self):
        return self._action_spec

    def observation_spec(self):
        return self._observation_spec

    def reward_spec(self):
        return self._reward_spec


    def _reset(self):
        if not self._random:
            self._done = False
            return ts.restart(np.array(self._state, dtype=np.int32), reward_spec=self._reward_spec)

        self._random = True
        self._done = False
        self._has_crash = False
        random.seed(time.perf_counter())
        s = random.randint(0, 2 ** self.players - 1)
        if random.random() < 0.3:
            s = 2 ** self.players - 1
        states = list(construct_bits_string(s, self.players))

        for i in range(self.players):
            if int(states[i]) == 0:
                # abort
                self._init_states[i] = float(Stat.Abort.value)
            if int(states[i]) == 1:
                # commit
                self._init_states[i] = float(Stat.Commit.value)
        self._state = self._init_states[:]
        # insert failure
        if random.random() < 0.5:
            for i in range(self.players):
                if random.random() < 0.3:
                    self._has_crash = True
                    self._state[i] = float(Stat.Lost.value)

        return ts.restart(np.array(self._state, dtype=np.int32), reward_spec=self._reward_spec)

    def my_reset(self, init_states, state, crash):
        self._state = state
        self._init_states = init_states
        self._has_crash = crash
        self._random = False


    def _step(self, action):
        if self._done:
            return self.reset()
        reward = self.validate(action)
        self._state = [0] * self.players
        self._done = True
        return ts.termination(np.array(self._state, dtype=np.int32), np.array(reward, dtype=np.float32),
                              outer_dims=())


    def validate(self, action):
        all_commits = all(s == Stat.Commit.value for s in self._init_states)
        any_aborts = any(s == Stat.Abort.value for s in self._init_states)
        if all_commits and not self._has_crash:
            if action == Act.Abort.value:
                return -1.0
            else:
                return 0.5

        if any_aborts and action == Act.Abort.value:
            return 0.5

        if any_aborts and action == Act.Commit.value:
            return -1.0

        return 0.

