import pybullet
import random
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D


class Complex_terrain:

    def __init__(self,
                 terrain_type=None,
                 grid_size=None,
                 parameters=None,
                 _range=None,
                 _pybullet_client=None,
                 land_size=None):
        """Construct

        Args:
            terrain_type: "hills", "Slippery Hills", Steps", "Stairs"
            grid_size: (int) size of grid
            friction_coefficient: (float, float)(m,d) sampled from the Gaussian distribution of mean m and standard deviation d.
            parameters: [float, float, float] roughness, frequency, amplitude.
                        [float, float] step width, step height
            range_: [float, float] range
        """
        if land_size is None:
            land_size = [1200, 200]
        self.terrain_type = terrain_type
        self.grid_size = grid_size
        self.parameters = parameters
        self.range_ = _range
        self.pybullet_client = _pybullet_client
        self.land_size = land_size
        self.scale = 0
        self.HeightMap = np.zeros(shape=self.land_size)
        self.mid_height = 0
        self.pos_offset = [0, 0, 0]

    @property
    def build(self):
        cur_height = 0
        mesh_l = 0.01
        mesh_w = 0.01
        mesh_h = 0.01
        self.scale = 0.01
        FRICTION_COEFFICIENT = 0.8  # np.random.uniform(0.7, 0.9)
        if self.terrain_type == "Hills" or self.terrain_type == "Slippery Hills":
            _HeightField = [0] * self.land_size[0]
            # wave_reta = random.randint(5, 35)
            wave_reta = -30
            # if random.random() > 0.5:
            #     wave_reta *= -1
            for j in range(int(self.land_size[0])):
                _HeightField[j] = wave_reta * np.sin(((j - (600 - wave_reta * 7.5)) / wave_reta / 10) * np.pi)
            # HeightField = np.array(_HeightField).reshape(1,len(_HeightField)) * \
            #               np.array(_HeightField).reshape(len(_HeightField), 1)
            # HeightField = HeightField.reshape(360000)
            HeightField = _HeightField * self.land_size[1]
            self.mid_height = HeightField[int(self.land_size[0] / 2)]
            self.pos_offset = [0, 0, -HeightField[int(self.land_size[0] / 2)] * self.scale]
        elif self.terrain_type == "Steps":
            FRICTION_COEFFICIENT = np.random.uniform(0.6, 0.9)
            grid_width = 20
            HeightField = []
            # max_height = 6  # random.randint(3, 6)
            max_height = 4

            for j in range(int(self.land_size[1] / grid_width)):
                _h = [random.uniform(0.01, max_height) for _ in range(int(self.land_size[0] / grid_width))]
                height_list = [_h[i // grid_width] for i in range(self.land_size[0])]
                HeightField += height_list * grid_width
        if self.terrain_type == "Stairs UP" or self.terrain_type == "Stairs DOWN":
            times_remain_same_height = 10
            stairs_width_bound = [16, 23]
            if self.terrain_type == "Stairs DOWN":
                stairs_list = [-4, -6, -8]
            else:
                stairs_list = [4, 6, 8, 6, 4, 2, 2, 4, 10]
            half_size = self.land_size[0] // 2 - 35
            HeightField = [0] * half_size
            for i in range(half_size):
                if times_remain_same_height == 0:
                    times_remain_same_height = 20 # np.random.randint(stairs_width_bound[0], stairs_width_bound[1])
                    # cur_height += stairs_list[counter_s]
                    # counter_s += 1
                    cur_height += 4  # stairs_list[np.random.randint(0, len(stairs_list))]  # 3 + 2 * random.random()
                HeightField[i] = cur_height
                times_remain_same_height -= 1
            self.mid_height = HeightField[-1]
            _HeightField = HeightField + [HeightField[-1]] * 70 + (np.array(HeightField) + HeightField[-1]).tolist()
            HeightField = _HeightField * self.land_size[1]
        if self.terrain_type == "Stairs MIX":
            # stairs_list = [4, 6, 2, 8, -8, -4, -3, -2] + [0] * 5
            # stairs_list = [0] * 3 + [4, 6, 8, 6, 8, 6, 4, 0, 0, 0, -4, -2, -2, -3, -2, -2, -3, 0, 0, ] + [0] * 60
            # stairs_list = [0] * 3 + [8] * 20 + [-8] * 20 + [0] * 60
            # 0, 2, 4, 2, 0, 0, 0, -2, -4, -2] + [0]*20
            # stairs_list = [8, 6, -6, 4, -4, 2, -2, 0, 0, -10]
            stairs_list = [8, 6, -6, 4, -4, 2, -2, 0]
            HeightField = 0
            half_size = self.land_size[0] // 2 - 5
            HeightField = [0] * half_size
            times_remain_same_height = 10
            counter_s = 0
            for i in range(half_size):
                if times_remain_same_height == 0:
                    times_remain_same_height = 20  # np.random.randint(16, 23)


                    # cur_height += stairs_list[counter_s]
                    # counter_s += 1

                    cur_height += stairs_list[np.random.randint(0, len(stairs_list))]  # 3 + 2 * random.random()
                HeightField[i] = cur_height
                times_remain_same_height -= 1
            self.mid_height = HeightField[0]
            _HeightField = HeightField[::-1] + [HeightField[0]] * 10 + HeightField
            self.pos_offset = [0, 0, np.mean(_HeightField) * self.scale]
            HeightField = _HeightField * self.land_size[1]
        elif self.terrain_type == "Plane":
            FRICTION_COEFFICIENT = np.random.uniform(0.6, 0.9)
            mesh_l = 0.01
            mesh_w = 0.01
            mesh_h = 0.01
            self.scale = 0.01
            HeightField = [0] * self.land_size[0] * self.land_size[1]
        elif self.terrain_type == "Block":
            _HeightField = np.zeros(int(self.land_size[0] / 2))
            for i in range(1, 5):
                _HeightField[int((self.land_size[0]) * i / (2 * 5)): int((self.land_size[0]) * i / (2 * 5)) + 5] \
                    = np.random.randint(6, 16)
            HeightField = ((-1 * np.array(_HeightField)).tolist() + _HeightField.tolist()) * self.land_size[1]

        HeightField = self.set_normal_distribution_noise(HeightField)

        self.HeightMap = np.array(HeightField).reshape(self.land_size[1], self.land_size[0]).T
        self.HeightMap = self.scale * (self.HeightMap - self.mid_height) # + self.pos_offset[2]
        terrainShape = self.pybullet_client.createCollisionShape(shapeType=self.pybullet_client.GEOM_HEIGHTFIELD,
                                                                 meshScale=[mesh_l, mesh_w, mesh_h],
                                                                 heightfieldTextureScaling=(self.land_size[0] - 1) / 10,
                                                                 heightfieldData=np.array(HeightField)+20,
                                                                 numHeightfieldRows=self.land_size[0],
                                                                 numHeightfieldColumns=self.land_size[1])
        ground_id = self.pybullet_client.createMultiBody(0, terrainShape)

        self.pybullet_client.resetBasePositionAndOrientation(ground_id, self.pos_offset, [0, 0, 0, 1])

        self.pybullet_client.changeDynamics(ground_id, -1, lateralFriction=FRICTION_COEFFICIENT)
        self.pybullet_client.changeVisualShape(ground_id, -1, rgbaColor=[0.8, 0.6, 0.4, 1])

        return ground_id

    def get_height_map(self):
        return self.HeightMap

    def render_map(self):
        fig = plt.figure()
        ax = Axes3D(fig)
        X = np.arange(0, 10, 0.05)
        Y = np.arange(0, 10, 0.02)
        X, Y = np.meshgrid(X, Y)
        ax.plot_surface(X, Y, self.HeightMap, rstride=1, cstride=1)
        plt.contourf(X, Y, self.HeightMap, 20)
        plt.show()

    def get_terrain_scale(self):
        return self.scale

    def set_normal_distribution_noise(self, HeightField):
        # mu = np.array(HeightField)
        # sigma = 0.01
        # X = []
        # for i in range(len(mu)):
        #     lower, upper = mu[i] - 2 * sigma, mu[i] + 2 * sigma  # 截断在[μ-2σ, μ+2σ]
        #     X.append(stats.truncnorm((lower - mu[i]) / sigma, (upper - mu[i]) / sigma, loc=mu[i], scale=sigma))
        mu = np.array(HeightField)
        sigma = 0.01
        X = np.random.normal(loc=mu,scale=sigma,size=np.shape(mu))
        X = np.clip(X,a_min=mu - 2*sigma,a_max=mu + 2*sigma)
        # X = []
        # for i in range(len(mu)):
        #     lower, upper = mu[i] - 2 * sigma, mu[i] + 2 * sigma  # 截断在[μ-2σ, μ+2σ]
        #     X.append(stats.truncnorm((lower - mu[i]) / sigma, (upper - mu[i]) / sigma, loc=mu[i], scale=sigma))
        return X



if __name__ == '__main__':
    from pybullet_utils import bullet_client
    import pybullet_data as pd

    p = pybullet
    p.connect(p.GUI, options="--width=1280 --height=720 "
                             "--mp4fps=30")
    pybullet.configureDebugVisualizer(
        pybullet.COV_ENABLE_GUI)
    p.setAdditionalSearchPath(pd.getDataPath())
    terrain_list = ["Stairs UP", "Stairs MIX", 'Hills', 'Block']
    _terrain = Complex_terrain(terrain_type='Hills', _pybullet_client=p, land_size=[600, 600])
    id = _terrain.build
    # np.save('terrain_data/blocks.npy', _terrain.get_height_map())
    print()
