import cv2
import numpy as np
import opensimplex

def edge_detector(image, use_depth = False):
    #print(image)
    if use_depth:
        temp = np.copy(image)
    image = image*255
    img_blur = cv2.GaussianBlur(image,(3,3), 0,0) 
    img_blur = np.uint8(img_blur)
    edges = cv2.Canny(image=img_blur, threshold1=20, threshold2=40)
    edges = np.float32(edges)/255
    if use_depth:
        edges *= temp
    return edges

def simplex_noise(image, num=10, scale=24):
    count = np.random.randint(0, num + 1)
    for i in range(count):
        #center = np.random.randint((0, 0), (84, 42), size = 2)
        center = [np.random.randint(0, 84), np.random.randint(0, 84)]
        size = np.random.randint(2, 5, size = 2)

        for x in range(size[0]):
            for y in range(size[1]):
                if x+center[0] >= 84 or y+center[1] >= 84:
                    continue
                image[x+center[0], y+center[1]] =  max(0, opensimplex.noise2(x=(x+center[0])/scale, y=(y+center[1])/scale))
    return image

def close_gaps(image):
    image = np.uint8(image*255)
    kernel = np.ones((4,4),np.uint8)
    closing = cv2.morphologyEx(image, cv2.MORPH_CLOSE, kernel)
    closing = np.float32(closing)
    return closing / 255

class ImageSmoother:
    def __init__(self, flat_size = 7056, img_size = (84,84), threshold = 0.01):
        self.most_recent_nonzero = np.ones(flat_size)*0.8832035064697266
        self.flat_size = flat_size
        self.img_size = img_size
        self.t = threshold
        print("IMAGE SMOOTHER ACTIVATE")

    def UpdateNonzero(self, obs):
        self.most_recent_nonzero[obs>self.t] = obs[obs>self.t]

    def step(self, action):
        obs, reward, term, trunc, info = self.env.step(action)
        obs = self.Smooth(obs)
        return obs, reward, term, trunc, info
    
    def reset(self, seed=None, options=None):
        obs, info = self.env.reset(seed=seed, options=options)
        obs = self.Smooth(obs)
        return obs, info

    def Smooth(self, image):
        frame = np.copy(image)
        self.UpdateNonzero(frame)
        frame[frame<=self.t] = self.most_recent_nonzero[frame<=self.t]
        #frame = cv2.medianBlur(np.reshape(frame, (84, 84)),5).flatten()
        return frame
        