import numpy as np
import gym
from gym.spaces import MultiBinary, Discrete, Box
import scipy.linalg
import math

class Lock(gym.Env):
    """A (stochastic) combination lock environment.
    
    Can configure the length, dimension, and switching probability via env_config"""

    def __init__(self,env_config={}):
        self.initialized=False

    def init(self,horizon=100, action_dim=10, p_switch=0.5, p_anti_r=0.5, anti_r=0.1,noise=0.1):
        self.initialized=True
        self.max_reward=1
        self.horizon=horizon
        self.state_dim = 3
        self.action_dim = action_dim
        self.action_space = Discrete(self.action_dim)

        self.reward_range = (0.0,1.0)

        self.observation_dim = 2 ** int(math.ceil(np.log2(self.horizon+4)))

        self.observation_space = Box(low=0.0, high=1.0, shape=(self.observation_dim,),dtype=np.float)

        self.p_switch = p_switch
        self.p_anti_r = p_anti_r
        self.anti_r = anti_r
        self.noise = noise
        self.rotation = scipy.linalg.hadamard(self.observation_space.shape[0])

        self.opt_a = np.random.randint(low=0, high=self.action_space.n, size=self.horizon)
        self.opt_b = np.random.randint(low=0, high=self.action_space.n, size=self.horizon)

        print("[LOCK] Initializing Combination Lock Environment")
        print("[LOCK] A sequence: ", end="")
        print([z for z in self.opt_a], end=", ")
        print("[LOCK] B sequence: ", end="")
        print([z for z in self.opt_b], end=", ")

    def step(self,action):
        if self.h == self.horizon:
            raise Exception("[LOCK] Exceeded horizon")

        r = 0
        #rtmp = np.random.binomial(1,0.5)
        next_state = None
        ber = np.random.binomial(1, self.p_switch)
        ber_r = np.random.binomial(1, self.p_anti_r)
        ## First check for end of episode
        if self.h == self.horizon-1:
            ## Done with episode, need to compute reward
            if self.state == 0 and action == self.opt_a[self.h]:
                r = 1
                next_state = 0
            elif self.state == 1 and action == self.opt_b[self.h]:
                r = 1
                next_state = 1
            else:
                if ber_r:
                    r = self.anti_r
                else:
                    r = 0
                next_state = 2
            self.h +=1
            self.state = next_state
            obs = self.make_obs(self.state)
            return obs, r, True, {}

        
        ## Decode current state
        r = 0
        if self.state == 0:
            ## In state A
            if action == self.opt_a[self.h]:
                if ber:
                    next_state = 1
                else:
                    next_state = 0
            else:
                if ber_r:
                    r = self.anti_r
                else:
                    r = 0
                next_state = 2
        elif self.state == 1:
            ## In state B
            if action == self.opt_b[self.h]:
                if ber:
                    next_state = 0
                else:
                    next_state = 1
            else:
                if ber_r:
                    r = self.anti_r
                else:
                    r = 0
                next_state = 2
        else:
            ## In state C
            next_state = 2
        self.h +=1
        self.state = next_state
        obs = self.make_obs(self.state)
        return obs, 0, False, {}

    def get_state(self):
        return self.state

    def make_obs(self, s):

        gaussian = np.zeros(self.observation_space.shape)
        gaussian[:(self.horizon+self.state_dim)] = np.random.normal(0,self.noise,[self.horizon+self.state_dim])
        gaussian[s] += 1
        gaussian[self.state_dim+self.h] += 1

        x = (self.rotation*np.matrix(gaussian).T).T
        return np.reshape(np.array(x), x.shape[1])

    def trim_observation(self,o,h):
        return (o)

    def reset(self):
        if not self.initialized:
            raise Exception("Environment not initialized")
        self.h=0
        ber = np.random.binomial(1, self.p_switch)

        if ber:
            self.state=0
        else:
            self.state=1
        obs = self.make_obs(self.state)
        return (obs)

    def render(self,mode='human'):
        if self.state == 0:
            print("A%d" % (self.h))
        if self.state == 1:
            print("B%d" % (self.h))
        if self.state == 2:
            print("C%d" % (self.h))
        

    def close(self):
        pass

    #def seed(self, seed=None):
    #    gym.spaces.prng.seed(seed)

if __name__=='__main__':
    env = Lock()
    env.init()
    print("!!!!!!!!")
    print(env.action_space.shape)
    # for t in range(20):
    #     o = env.reset()
    #     done = False
    #     h = 0
    #     while not done:
    #         env.render()
    #         #print(env.trim_observation(o,h))
    #         (o,r,done,blah) = env.step(env.action_space.sample())
    #         print(o)
    #         h += 1
    #     print("End of episode: r=%d" % (r))
