import gym
from gym import spaces
from gym.utils import seeding
import random
import numpy as np

import cv2
cv2.ocl.setUseOpenCL(False)

class CustomNChainEnv(gym.Env):
    metadata = {
        'render.modes': ['human', 'rgb_array'], 'video.frames_per_second': 30
    }

    def __init__(self, n=5, slip=0, small=2, large=10, std=10):
        self.n = n
        self.slip = slip  # probability of 'slipping' an action
        self.small = small  # payout for 'backwards' action
        self.large = large  # payout at end of chain for 'forwards' action
        self.std = std
        self.state = 0  # Start at beginning of the chain
        self.action_space = spaces.Discrete(6) # 0: forward , 1: backward, 2~5: no op
        self.shape = (210,160,3)
        self.observation_space = spaces.Box(low=0, high=255, shape= self.shape, dtype=np.uint8)
        self.seed()

    def seed(self, seed=None):
        self.np_random, seed = seeding.np_random(seed)
        return [seed]

    def step(self, action):
        done = False
        if self.np_random.rand() < self.slip:
            # action = not action  # agent slipped, reverse action taken
            action = np.random.randint(6)
            
        assert action < self.action_space.n
        if action == 0:
            if self.state < self.n -1:
                reward =0
                self.state +=1
            else:
                reward1 = self.large[0] + 1*self.std* np.random.randn()
                reward2 = self.large[1] + 1*self.std* np.random.randn()
                reward = np.random.choice([reward1, reward2], p= [0.5, 0.5])

                done = True
        elif action == 1:  # 'backwards': go back to the beginning, get small reward
            if self.state > 0:
                reward = 0
                self.state -= 1    
            else:
                # reward = self.small + 5 * self.std * np.random.randn() #DLTV implementation

                reward1 = self.small[0] + 1*self.std* np.random.randn()
                reward2 = self.small[1] + 1*self.std* np.random.randn()
                reward = np.random.choice([reward1, reward2], p= [0.5, 0.5])                

                done = True
  
        else:  # no-op
            reward = 0
        return self.get_obs(), reward, done, {}


    def reset(self, start_state=None):
        if start_state == None:
            self.state = 2
        else:
            self.state = start_state
        return self.get_obs()

    def get_obs(self):
        self.observation = np.zeros(self.shape).astype(np.uint8)
        #Semi-goal(GREEN)
        self.observation[0:20, 0:20, 1] =255
        #goal(RED)
        self.observation[0:20, (self.n-1)*20 : (self.n)*20, 2] = 255
        #current state(WHITE)

        self.observation[0:20 , self.state * 20 : (self.state+1) *20, :]= 255
      
        return self.observation

    def render(self, mode='human', close=False):

        cv2.imshow("frame", self.observation)
        cv2.waitKey(5000)
        cv2.destroyAllWindows()

        return 