import numpy as np

class traversability(object):
    """A sensor that reads the height around a robot."""

    def __init__(self,height_map,terrain_scale):

        # self.shape = (1, 20, 20)
        self.shape = (1, 150, 50)
        self.height_local_map = np.zeros(shape=self.shape)
        self.score = np.zeros(shape=(self.shape[1], self.shape[2]))
        self.map_score = 0

        # self._robot = robot
        self.height_map=height_map
        self.terrain_scale=terrain_scale

    # def on_reset(self, env):
    #     # self.map_score = self.heightmap2score(env.height_map, env.height_map.shape[0], env.height_map.shape[1])
    #     del env
    #     self.is_safe = True
    #     self.height_local_map = np.zeros(shape=self.shape)
    #     self.score = np.zeros(shape=(self.shape[1], self.shape[2]))
    #     self.foot_score = [0] * self.num_channels
    #     self.foot_surrounding_height = [0] * self.num_channels

    def height_map_converter(self, height_map, scale, x, y):
        if x != x:
            self.is_safe = False
        else:
            map_x = int((x * scale) + height_map.shape[0] / 2)
            map_y = int((y * scale) + height_map.shape[1] / 2)
            if map_x < 0 or map_x >= height_map.shape[0] or map_y < 0 or map_y >= height_map.shape[1]:
                map_x = map_y = 0
                self.is_safe = False

            return map_x, map_y

    def score_assessment(self):

        # self.scan_dots_location = []
        # scanner_info = []
        # self_position = []

        for i in range(self.shape[1]):
            # body_center_x = (-self.shape[1] // 2 + i) * self.terrain_scale * 4
            body_center_x = i * self.terrain_scale * 4
            for j in range(self.shape[2]):
                body_center_y = (-self.shape[2] // 2 + j) * self.terrain_scale * 4
                map_height = self.height_map[self.height_map_converter(self.height_map,
                                                                              1 / self.terrain_scale,
                                                                              body_center_x,
                                                                              body_center_y)]
                self.height_local_map[0, i, j] = map_height
                # if i%2==0 and j%2==0:
                #     self.scan_dots_location.append([body_center_x, body_center_y, map_height])

        # score-----------------------
        confidence_interval = self.get_confidence_bound()
        self.score = self.heightmap2score(self.height_local_map[0], self.shape[1], self.shape[2],
                                          confidence_interval)
        return self.score

    def get_confidence_bound(self):
        mu = np.squeeze(self.height_local_map[0])
        sigma = 0.01
        # lower, upper = mu - 2 * sigma, mu
        # X_lower = stats.truncnorm((lower - mu) / sigma, (upper - mu) / sigma, loc=mu, scale=sigma)
        X = np.random.normal(loc=mu, scale=sigma, size=np.shape(mu))
        X_lower = np.where(X > mu, -X, X)
        # X_lower = np.clip(X, a_min=mu - 2 * sigma, a_max=mu)

        X = np.random.normal(loc=mu, scale=sigma, size=np.shape(mu))
        X_upper = np.where(X < mu, -X, X)
        # X_upper = np.clip(X, a_min=mu, a_max=mu+2*sigma)
        # print(X_lower<mu)
        # input()
        return X_upper - X_lower

    def heightmap2score(self, height, x_size, y_size, confidence_interval):
        # score wieght
        # w_sd=60.
        # w_sl=30.
        # w_max=10.
        # w_min=10.

        # w_sd = 40
        # w_sl = 10
        # w_max = 20
        # w_min = 20
        # w_ci = 10
        w_sd = 50
        w_sl = 20
        w_max = 10
        w_min = 10
        w_ci = 10

        score = np.zeros((x_size, y_size))
        for i in range(x_size):
            for j in range(y_size):
                if i >= 1 and j >= 1 and i < x_size - 1 and j < y_size - 1:
                    slope, sum = 0., 0.
                    hmax = height[i][j]
                    hmin = height[i][j]
                    tmpmap = np.zeros(9)
                    k = 0
                    for m in [i - 1, i, i + 1]:
                        for n in [j - 1, j, j + 1]:
                            # // slope
                            if not (m == i and n == j):
                                slope = abs(
                                    (height[i][j] - height[m][n]) / np.sqrt(
                                        pow((i - m), 2) + pow((j - n), 2))) + slope
                            # //max
                            if height[m][n] > hmax:
                                hmax = height[m][n]
                            # //min
                            if height[m][n] < hmin:
                                hmin = height[m][n]
                            #   sum
                            sum = sum + height[m][n]
                            tmpmap[k] = height[m][n]  # 暂存周围区域的高度值
                            k += 1

                    # //标准差
                    mean = sum / 9.
                    sd_sum = 0.
                    for l in range(9):
                        sd_sum = sd_sum + pow((tmpmap[l] - mean), 2)
                    sd = np.sqrt(sd_sum / 9.)

                    # //平均斜率
                    slope = slope / 8.
                    # //traversability map
                    score[i][j] = w_sd * sd + w_sl * slope + w_max * (hmax - height[i][j]) + w_min * (
                            height[i][j] - hmin)

                elif i == 0 or j == 0 or i == x_size - 1 or j == y_size - 1:
                    score[i][j] = 0.

                # if i==1 and j==10:
                #     print(hmax)
                #     print(hmin)
                #     print(sd)
                #     print(slope)
                #     print(height[i][j])
                #     print(score[i][j])

        score += w_ci * confidence_interval
        score = np.tanh(np.mean(score))
        # score = np.clip(score, a_min=0, a_max=1)
        return score

    # def sigmoid(self,x):
    #     return 1 / (1 + np.exp(-x))

    def render_map(self):
        import matplotlib.pylab as plt
        from mpl_toolkits.mplot3d import Axes3D

        fig = plt.figure()
        ax = Axes3D(fig)
        X = np.arange(0, self.shape[1], 1)
        Y = np.arange(0, self.shape[2], 1)
        X, Y = np.meshgrid(X, Y)
        map = np.squeeze(self.score)
        ax.plot_surface(X, Y, map, rstride=1, cstride=1)
        plt.contourf(X, Y, map, 20)
        plt.show()

