import gym
from gym import spaces
from gym.utils import seeding
import numpy as np
import time


class Water():
    def __init__(self, init_x, init_y, init_max_velocity, max_y, min_y, max_x):
        self.init_x, self.init_y = init_x, init_y
        self.x, self.y = init_x, init_y
        self.max_y, self.min_y, self.max_x = max_y, min_y, max_x
        self.velocity = init_max_velocity * (self.x / self.max_x - (self.x / self.max_x) ** 2)

    def update_velocity(self, max_velocity):
        self.velocity = max_velocity * (self.x / self.max_x - (self.x / self.max_x) ** 2)

    def move(self, delta_t=1.):
        self.y += self.velocity * delta_t
        if self.y > self.max_y:
            self.y = self.min_y


class ShipAcrossRiver(gym.Env):
    def __init__(self, state_normalize=True):
        self.state_normalize = state_normalize
        # 设置状态空间，动作空间
        self.min_x, self.max_x = 0, 50
        self.min_y, self.max_y = 0, 100
        self.min_angle, self.max_angle = -np.pi / 3, np.pi / 3
        self.min_velocity, self.max_velocity = 2, 5
        self.min_angle_velocity, self.max_angle_velocity = -1, 1
        observation_space_low = np.array(
            [self.min_x, self.min_y, self.min_angle, self.min_velocity, self.min_angle_velocity])
        observation_space_high = np.array(
            [self.max_x, self.max_y, self.max_angle, self.max_velocity, self.max_angle_velocity])
        self.observation_space = spaces.Box(observation_space_low, observation_space_high)

        self.min_acc_velocity, self.max_acc_velocity = -2, 2
        self.min_angle_velocity, self.max_angle_velocity = -1, 1
        action_space_low = np.array([self.min_acc_velocity, self.min_angle_velocity])
        action_space_high = np.array([self.max_acc_velocity, self.max_angle_velocity])
        self.action_space = spaces.Box(action_space_low, action_space_high)


        # 设置小船及水流初始状态
        self.reset()
        self._init_water()

        # 绘制屏幕
        self.viewer = None

        # 设置 goal 位置
        self.optimal_y, self.optimal_range = 30, 5
        self.sub_optimal_y, self.sub_optimal_range = 70, 10

    def _init_water(self):
        self.waters = [Water(init_x, init_y, self.max_water_speed, self.max_y, self.min_y, self.max_x) for init_x in
                       range(2, self.max_x, 5) for init_y in range(0, self.max_y, 20)]

    def seed(self, seed=None):
        # 产生一个随机化时需要的种子，同时返回一个np_random对象，支持后续的随机化生成操作
        self.np_random, seed = seeding.np_random(seed)
        return [seed]

    def _set_max_water_speed(self):
        '''
        水流速度是二次函数，中间高，两边低
        :return:
        '''
        self.max_water_speed = 2 + 0.5 * np.random.standard_normal()

    def reset(self):
        # 设置水流速度
        self._set_max_water_speed()

        # 设置小船初始状态
        self.x = 0
        self.y = 50
        self.angle = 0
        self.velocity = 2
        self.angle_velocity = 0

        return self._get_state()

    def _get_state(self):
        if self.state_normalize:
            return np.array([(self.x - self.min_x) / (self.max_x - self.min_x),
                             (self.y - self.min_y) / (self.max_y - self.min_y),
                             (self.angle - self.min_angle) / (self.max_angle - self.min_angle),
                             (self.velocity - self.min_velocity) / (self.max_velocity - self.min_velocity),
                             (self.angle_velocity - self.min_angle_velocity) / (
                                     self.max_angle_velocity - self.min_angle_velocity)])
        else:
            return [self.x, self.y, self.angle, self.velocity, self.angle_velocity]

    def step(self, action):
        # 求和代替积分，描绘小船变化
        time_interval = 1.
        n_split = 100
        delta_t = time_interval / n_split
        acc_velocity = action[0]
        acc_angle_velocity = action[1]

        for n in range(n_split):
            # 小船线速度变化
            new_velocity = np.clip(self.velocity + delta_t * acc_velocity, self.min_velocity, self.max_velocity)
            # 小船角速度变化
            new_angle_velocity = np.clip(self.angle_velocity + delta_t * acc_angle_velocity, self.min_angle_velocity,
                                         self.max_angle_velocity)
            # 小船角度变化
            new_angle = np.clip(self.angle + self.angle_velocity * delta_t + 1 / 2 * acc_angle_velocity * delta_t ** 2,
                                self.min_angle, self.max_angle)
            # 当前位置水流速度
            self.water_speed = self.max_water_speed * (self.x / self.max_x - (self.x / self.max_x) ** 2)
            for water in self.waters:
                water.update_velocity(self.max_water_speed)
                water.move(delta_t=delta_t)
            # 小船 x 坐标变化
            cos_theta = np.cos((self.angle + new_angle) / 2)
            vx = self.velocity * cos_theta
            new_vx = new_velocity * cos_theta
            self.x = self.x + (vx + new_vx) / 2 * delta_t
            # 小船 y 坐标变化
            sin_theta = np.sin((self.angle + new_angle) / 2)
            vy = self.velocity * sin_theta
            new_vy = new_velocity * sin_theta
            self.y = self.y + (vy + new_vy) / 2 * delta_t + self.water_speed * delta_t
            # 更新小船速度，角度，角速度
            self.velocity = new_velocity
            self.angle = new_angle
            self.angle_velocity = new_angle_velocity

        # 计算reward
        reward, done = self._cal_reward()
        info = {}

        # 每一步都改变水流状态
        self._set_max_water_speed()

        return self._get_state(), reward, done, info

    def _cal_reward(self):
        # step cost
        if self.x > self.min_x and self.x < self.max_x and self.y > self.min_y and self.y < self.max_y:
            return -0.1, False
        # global optimal is at position 30
        elif self.x >= self.max_x and abs(self.y - self.optimal_y) <= self.optimal_range:
            return 4 * (self.optimal_range - abs(self.y - self.optimal_y)), True
        # sub optimal is at position 70
        elif self.x >= self.max_x and abs(self.y - self.sub_optimal_y) <= self.sub_optimal_range:
            return self.sub_optimal_range - abs(self.y - self.sub_optimal_y), True
        # out of bounds
        elif self.y <= self.min_y or self.y >= self.max_y or self.x <= self.min_x:
            return -1, True
        else:  # 正常通过右侧
            assert self.y < self.max_y and self.y > self.min_y and self.x >= self.max_x
            return 0, True

    def _convert_xy2location(self, x, y, scale, margin):
        location_y = (x - self.min_x) * scale + margin
        location_x = (self.max_y - y) * scale + margin
        return [location_x, location_y]

    def render(self, mode='human'):
        screen_width = 800
        screen_height = 400
        margin = 10
        scale = screen_height / (self.max_x - self.min_x)
        ship_width = 20
        ship_height = 40

        if self.viewer is None:
            from gym.envs.classic_control import rendering
            self.viewer = rendering.Viewer(screen_width + margin * 2, screen_height + margin * 2)

            # 绘制河流
            self.lower_boundary = rendering.make_polyline(
                [self._convert_xy2location(self.min_x, self.min_y, scale=scale, margin=margin),
                 self._convert_xy2location(self.min_x, self.max_y, scale=scale, margin=margin)])
            self.lower_boundary.set_linewidth(4)
            self.lower_boundary.set_color(0xD3 / 255, 0xD3 / 255, 0xD3 / 255)
            self.upper_boundary = rendering.make_polyline(
                [self._convert_xy2location(self.max_x, self.min_y, scale=scale, margin=margin),
                 self._convert_xy2location(self.max_x, self.max_y, scale=scale, margin=margin)])
            self.upper_boundary.set_linewidth(8)
            self.upper_boundary.set_color(0xD3 / 255, 0xD3 / 255, 0xD3 / 255)
            self.viewer.add_geom(self.lower_boundary)
            self.viewer.add_geom(self.upper_boundary)

            # 绘制河水
            l, r, t, b = -20, 10, 2, 0
            self.water_trans = []
            for water_data in self.waters:
                water = rendering.FilledPolygon([(l, b), (l + 2, t), (r - 2, t), (r, b)])
                water.set_color(0.63529412, 0.81176471, 0.99607843)
                tran = rendering.Transform()
                self.water_trans.append(tran)
                water.add_attr(tran)
                self.viewer.add_geom(water)

            # 绘制两个goal
            # optimal goal
            self.optimal_goal = rendering.make_polyline(
                [self._convert_xy2location(self.max_x, self.optimal_y - self.optimal_range, scale=scale, margin=margin),
                 self._convert_xy2location(self.max_x, self.optimal_y + self.optimal_range, scale=scale,
                                           margin=margin)])
            self.optimal_goal.set_linewidth(8)
            self.optimal_goal.set_color(*[0.35, 0.85, 0.35])
            self.viewer.add_geom(self.optimal_goal)
            # sub optimal goal
            self.sub_optimal_goal = rendering.make_polyline(
                [self._convert_xy2location(self.max_x, self.sub_optimal_y - self.sub_optimal_range, scale=scale,
                                           margin=margin),
                 self._convert_xy2location(self.max_x, self.sub_optimal_y + self.sub_optimal_range, scale=scale,
                                           margin=margin)])
            self.sub_optimal_goal.set_linewidth(8)
            # self.sub_optimal_goal.set_color(*[0.85, 0.35, 0.35])
            self.sub_optimal_goal.set_color(*[0xfd / 255, 0xdc / 255, 0x5c / 255])
            self.viewer.add_geom(self.sub_optimal_goal)
            # 绘制船
            l, r, t, b = -ship_width / 2, ship_width / 2, ship_height, 0
            ship = rendering.FilledPolygon([(l, b), (l + 2, t), (0, t + r / 1.6), (r - 2, t), (r, b)])
            ship.set_color(*[0x49 / 255, 0x75 / 255, 0x9c / 255])
            self.ship_trans = rendering.Transform()
            ship.add_attr(self.ship_trans)
            self.viewer.add_geom(ship)

        self.ship_trans.set_translation(*self._convert_xy2location(self.x, self.y, scale=scale, margin=margin))
        self.ship_trans.set_rotation(self.angle)
        for water, water_tran in zip(self.waters, self.water_trans):
            water_tran.set_translation(*self._convert_xy2location(water.x, water.y, scale=scale, margin=margin))
        time.sleep(0.2)

        return self.viewer.render(return_rgb_array=mode == 'rgb_array')

    def close(self):
        pass


def test_env():
    env = ShipAcrossRiver()
    agent = lambda ob: env.action_space.sample()
    ob = env.reset()
    for _ in range(1000):
        a = agent(ob)
        env.render()
        assert env.action_space.contains(a)
        (ob, reward, done, _info) = env.step(a)
        print(ob)

        if done:
            ob = env.reset()
    env.close()


if __name__ == "__main__":
    test_env()
