import numpy as np
import gym
from gym.spaces import Discrete, Box
import scipy.linalg
import math

from envs.Lock_batch import LockBatch



'''
fast sampling. credit: https://stackoverflow.com/questions/34187130/fast-random-weighted-selection-across-all-rows-of-a-stochastic-matrix/34190035
'''
def sample(prob_matrix, items, n):

    cdf = np.cumsum(prob_matrix, axis=1)
    # random numbers are expensive, so we'll get all of them at once
    ridx = np.random.random(size=n)
    # the one loop we can't avoid, made as simple as possible
    idx = np.zeros(n, dtype=int)
    for i, r in enumerate(ridx):
        idx[i] = np.searchsorted(cdf[i], r)
    # fancy indexing all at once is faster than indexing in a loop
    return items[idx]


class DiabolicalLockMaze(gym.Env):
    """A (stochastic) combination lock environment.
    
    Can configure the length, dimension, and switching probability via env_config"""

    def __init__(self,env_config={}):
        self.initialized=False

    def init(self,horizon=100, action_dim=10, p_switch=0.5, p_anti_r=0.5, anti_r=0.1,noise=0.1, num_envs=10, temperature=1, 
                variable_latent=False, dense=False, seed=123, optimal_reward=5, sub_optimal_reward=2):
        self.initialized=True
        self.horizon=horizon
        # 3*2 because there are two locks
        self.state_dim = 3*2
        self.action_dim = action_dim
        self.action_space = Discrete(self.action_dim)

        self.reward_range = (0.0,1.0)

        #double the observation dim, for two locks
        self.observation_dim = 2 ** int(math.ceil(np.log2(self.horizon+3)) + 1)

        self.observation_space = Box(low=0.0, high=1.0, shape=(self.observation_dim,),dtype=np.float)

        self.p_switch = p_switch
        self.p_anti_r = p_anti_r
        self.anti_r = anti_r
        self.noise = noise
        self.rotation = scipy.linalg.hadamard(self.observation_space.shape[0])

        # number of parallel environments, note that some could go to chain0 and some goes to chain1, so it could cause
        # potential bugs, be careful
        self.num_envs = num_envs
        self.tau = temperature

        self.variable_latent = variable_latent
        self.dense = dense

        self.optimal_reward = optimal_reward
        self.sub_optimal_reward = sub_optimal_reward

        if dense:
            self.step_reward = 0.1

        self.locks = [LockBatch(), LockBatch()]

        for i in range(len(self.locks)):
            print("init lock {}".format(i))
            self.locks[i].init(horizon=horizon-1, 
                action_dim=action_dim, 
                p_switch=p_switch, 
                p_anti_r=p_anti_r, 
                anti_r=anti_r,
                noise=noise,
                num_envs=num_envs,
                temperature=temperature,
                variable_latent=variable_latent,
                dense=dense)
            
            self.locks[i].seed(seed+i)
            self.locks[i].action_space.seed(seed+i)
        
        self.locks[0].optimal_reward = self.optimal_reward
        self.locks[1].optimal_reward = self.sub_optimal_reward
        

    def step(self,action):
        if self.h == self.horizon:
            raise Exception("[LOCK] Exceeded horizon")


        # the first timestep decides which lock you go to
        if self.h == 0:
            self.lock_index = (action < self.action_dim//2).astype(int)
            self.locks[0].reset()
            self.locks[1].reset()
            # obses = [obs0,obs1]
            # obs = np.zeros((self.num_envs, self.observation_dim))
            # #set the corresponding place to be the correct observation
            # for i in range(self.num_envs):
            #     obs[i][self.lock_index[i]] = obses[self.lock_index[i]][i]
            # obs = obs.reshape(self.num_envs, -1)
            
            self.h += 1
            obs = self.make_obs()

            return obs, np.zeros((self.num_envs,1)), False, {}
        
        # the remaining timestep just proceed in that chain
        else:

            _, r0, done, _ = self.locks[0].step(action)
            _, r1, _, _ = self.locks[1].step(action)
            rs = [r0, r1]
            #set the corresponding place to be the correct observation]
            r = np.zeros_like(r0)
            for i in range(self.num_envs):
                r[i] = rs[self.lock_index[i]][i]
            
            self.h += 1

            obs = self.make_obs()
            
            return obs, r, done, {}

    def get_state(self):
        state0 = self.locks[0].get_state()
        state1 = self.locks[1].get_state()
        # for i in range(self.num_envs):
        #     if self.lock_index[i] == 1:
        #         state0[i] = state1[i]
        # self.state = state0
        #set the state from 0 to 6
        self.state = self.lock_index * self.locks[0].state_dim + \
            np.where(self.lock_index, state1, state0)

        return self.state

    def get_counts(self):
        self.get_state()
        counts = np.zeros(self.state_dim, dtype=np.int)
        for i in range(self.num_envs):
            counts[self.state[i]] += 1

        return counts

    def make_obs(self, s =  [], h = None):
        """
        covers both make_obs and generate obs
        """
        # obs0 = self.locks[0].make_obs(s)
        # obs1 = self.locks[1].make_obs(s)
        # obses = [obs0,obs1]
        # obs = np.zeros((self.num_envs, 2, self.observation_dim//2))
        # #set the corresponding place to be the correct observation
        # for i in range(self.num_envs):
        #     obs[i][self.lock_index[i]] = obses[self.lock_index[i]][i]
        # obs = obs.reshape(self.num_envs, -1)
        if h == None:
            h = self.h
        if len(s) == 0:
            s = self.get_state()
        gaussian = np.zeros((self.num_envs, self.observation_dim))
        gaussian[:,:(self.horizon+self.state_dim)] = np.random.normal(0,self.noise,[self.num_envs,self.horizon+self.state_dim])
        gaussian[np.arange(self.num_envs), s] += 1
        gaussian[:,self.state_dim+self.h] += 1
        gaussian[:, self.state_dim + self.h] += self.lock_index

        obs = np.matmul(self.rotation, gaussian.T).T

        return obs

    def sample_latent(self, obs):
        
        latent_exp = np.exp(self.latents / self.tau)

        softmax = latent_exp / latent_exp.sum(axis=-1, keepdims=True)
        self.state = sample(softmax, self.all_latents, self.num_envs)

    # def generate_obs(self, s, h):

    #     gaussian = np.zeros((self.num_envs, self.observation_space.shape[0]-1))
    #     gaussian[:,:(self.horizon+self.state_dim)] = np.random.normal(0,self.noise,[self.num_envs,self.horizon+self.state_dim])
    #     gaussian[:, s] += 1
    #     gaussian[:,self.state_dim+h] += 1

    #     x = np.matmul(self.rotation, gaussian.T).T

    #     x = np.append(x, self.lock_index.reshape(-1,1), axis = 1)

    #     return x

    def trim_observation(self,o,h):
        return (o)

    def reset(self, bad=False):
        if not self.initialized:
            raise Exception("Environment not initialized")
        self.h=0
        obs = np.zeros((self.num_envs, self.observation_space.shape[0]))

        return (obs)

    def render(self,mode='human'):
        if self.state == 0:
            print("A%d" % (self.h))
        if self.state == 1:
            print("B%d" % (self.h))
        if self.state == 2:
            print("C%d" % (self.h))
        

