from BaseEnv import ENV
import numpy as np


class Baird(ENV):

    NUM_STATES = 7
    NUM_ACTIONS = 2
    NUM_FEATURES = 8 + 7

    SEVENTH_STATE = 6

    SOLID_ACTION = 0
    DASH_ACTION = 1
    WINDOW_SIZE = float("inf")

    def __init__(self, normalize_feature, isNoisy,noise):

        super().__init__("Baird")

        self.reward = 0
        self.current_state = None
        self.features = np.zeros((self.NUM_ACTIONS, self.NUM_STATES, self.NUM_FEATURES))
        self.init_feature(normalize_feature)
        self.isNoisy = isNoisy
        self.noise = noise

    def init_feature(self, normalize_feature):

        for i in range(self.NUM_STATES):  # For dash action, it is not denoted
            # solid action
            if i < self.SEVENTH_STATE:
                self.features[self.SOLID_ACTION, i, i] = 2
                self.features[self.SOLID_ACTION, i, self.SEVENTH_STATE + 1] = 1
                #self.features[self.DASH_ACTION, i, self.SEVENTH_STATE + 1] = 1

            if i == self.SEVENTH_STATE:
                self.features[self.SOLID_ACTION, i, i] = 1
                self.features[self.SOLID_ACTION, i, i + 1] = 2

            self.features[self.DASH_ACTION, i, self.NUM_STATES + i + 1] = 1

        if normalize_feature:
            self.features = 1.0 / np.sqrt(5) * self.features  # Features used in Coupled Q Learning

    def reset(self):
        self.current_state = np.random.randint(
            0, self.NUM_STATES
        )  # Start with uniform distribution not Initialize at except the 7th state ?
        return self.current_state

    def step(self, action):

        reward = np.random.normal(0, self.noise, 1) if self.isNoisy else 0
        done = False

        if action == 0:  # solid action
            next_state = 6
        else:  # dash action
            next_state = np.random.randint(
                0, self.NUM_STATES - 1
            )  # self.current_state #np.random.randint(0,6)

        if (
            self.current_state == self.SEVENTH_STATE
            and action == self.SOLID_ACTION
            and np.random.randn() < 1 / 100
        ):  # if at state 6
            done = True

        self.current_state = next_state

        return [next_state, reward, done]

    def isNotEnd(self):
        return True