import gymnasium as gym
import numpy as np
import time
import airsim
from typing import SupportsFloat, Any

import optax
from gymnasium.core import ActType, ObsType
from typing import List, Optional
from itertools import product


g1 = np.asarray([11868.188, 24324.245, -500])
g2 = np.asarray([26665.359, 17007.539, -500])
g3 = np.asarray([30569.709, 1199.599, -500])
g4 = np.asarray([27045.309, -14896.644, -500])
g5 = np.asarray([-32063.947, -4752.712, -500])
g6 = np.asarray([-28793.303, -26310.730, -500])
g7 = np.asarray([-15940.754, -25149.439, -500])
g8 = np.asarray([1443.634, 12289.035, -500])
g9 = np.asarray([15910.926, -1908.472, -500])

spot_list = [
    (15, 140), (143, 0), (140, -50),
    (140, -80), (140, -115), (110, -140),
    (65, -145), (45, -140), (15, -140),
    (-30, -140), (-100, -140), (-140, -140),
    (-140, -105), (-140, -70), (-140, -30),
    (-150, 15), (-140, 50), (-145, 90),
    (-140, 120), (-100, 155), (-10, 140),
    (140, 60), (140, 20), (105, -115),
    (45, -115), (-65, -115), (-80, -115),
    (-105, -115), (-115, -55), (-110, 60),
    (-110, 77), (-115, 100), (-70, 110),
    (80, 105), (110, 120), (115, 70),
    (113, 10), (110, -10), (60, -15),
    (-10, -20), (10, -48), (-65, -10),
    (-79, -8), (0, 0)]  # (-20, -46),

spot = np.array(spot_list)
goal = np.concatenate([spot, -10 * np.ones(shape=(spot.shape[0], 1))], axis=1) * 100


class Wind(object):
    def __init__(self, max_speed, dynamics, smooth, seed: int = 42):
        self.max_speed = max_speed
        self.smooth = smooth
        self.dynamics = dynamics
        self.direction = np.array([np.random.random() * np.pi, np.random.random() * 2 * np.pi])
        self.tendency_speed = 0
        self.tendency_direction = np.zeros(3, )
        self.np_rng = np.random.default_rng(seed)
        self.temp_speed = self.np_rng.uniform(0, 1, (1,)) * max_speed

    def reset(self):
        self.direction = np.array([np.random.random() * np.pi, np.random.random() * 2 * np.pi])
        self.tendency_speed = 0
        self.tendency_direction = np.ones(2, ) * self.np_rng.uniform(0, 1, (2,)) * np.pi
        self.temp_speed = self.np_rng.uniform(0, 1, (1,)) * self.max_speed

    def change_wind(self):

        change_speed = self.np_rng.uniform(0, 1, size=(1,))[0] * self.max_speed * self.dynamics * (
                1 - self.smooth) + self.tendency_speed * self.smooth
        self.tendency_speed = change_speed
        self.temp_speed = np.clip(change_speed + self.temp_speed, a_max=self.max_speed, a_min=0)
        change_direction = self.np_rng.uniform(0, 1, (2,)) * self.dynamics * np.pi * (
                1 - self.smooth) + self.tendency_direction * self.smooth
        self.tendency_direction = change_direction
        self.direction[0] = self.angle_add(np.pi, self.direction[0], change_direction[0], mode='theta')
        self.direction[1] = self.angle_add(2 * np.pi, self.direction[1], change_direction[1], mode='phi')
        return

    @staticmethod
    def angle_add(a_max, angle1, angle2, mode='theta'):
        angle = angle1 + angle2
        if angle < 0:
            if mode == 'theta':
                angle = 2 * a_max - angle
            else:
                angle = angle + a_max
        elif angle > a_max:
            if mode == 'theta':
                angle = 2 * a_max - angle
            else:
                angle = angle - a_max
        return angle

    def get_wind(self):
        w = self.wind.copy()
        return [w[0].item(), w[1].item(), w[2].item()]

    def step(self):
        self.change_wind()
        return self.wind.copy()

    @property
    def wind(self):
        x = self.temp_speed * np.sin(self.direction[0]) * np.cos(self.direction[1])
        y = self.temp_speed * np.sin(self.direction[0]) * np.sin(self.direction[1])
        z = self.temp_speed * np.cos(self.direction[1])
        return np.array([x, y, z])


class AirSimEnv(gym.Env):
    def __init__(self,
                 ip="127.0.0.1",
                 drone_id='Drone1',
                 verbose: bool = True,
                 scale=2,
                 num_grids: int = 12,
                 terminal_if_success: bool = False,
                 hard: bool = False,
                 seed: int = 10003):
        np_rng = np.random.default_rng(seed)
        self.drone_id = drone_id
        self.verbose = verbose
        self.num_grids = num_grids

        self.client = airsim.MultirotorClient(ip=ip)
        self.client.confirmConnection()
        self.client.enableApiControl(True)
        self.client.simEnableWeather(True)
        if self.client.simIsPause():
            self.client.simPause(False)

        self.client.armDisarm(True, vehicle_name=self.drone_id)
        self.client.takeoffAsync(vehicle_name=self.drone_id).join()

        self.terminal_if_success = terminal_if_success
        self.hard = hard
        self.substep = 0
        self.success_step = [0]

        self.goal_list = goal

        self.current_goal = g1

        self.image_number = 63385
        self.obstacles_info = None
        self.scale = scale

        self.temp_time = time.time()

        # self.speed = speed
        # self.time_per_step = np.random.lognormal(mean=-0.774503, sigma=0.726192)/self.speed
        self.timesteps = 0
        self.sensor_shape = (1, 20)
        self.np_rng = np_rng
        self.observation_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(28 + num_grids ** 2,),
                                                dtype=np.float32)
        if self.hard:
            self.action_space = gym.spaces.Box(low=-1, high=1, shape=(4,), dtype=np.float32)
        else:
            self.action_space = gym.spaces.Box(low=-1, high=1, shape=(3,), dtype=np.float32)
        self.height_constrain = 20
        self.collision_reward = -100
        self.alive_reward = 0.1
        self.collision_count = 0
        self.wind = Wind(0.95, 0.2, smooth=0.99)
        self.pause_client = False

    def process_point_cloud(self,
                            point_cloud: List,
                            num_grids: int,
                            scaling: float = 100.):
        pt = np.asarray(point_cloud, dtype=np.float64)
        pt = pt.reshape(-1, 3)
        r = np.linalg.norm(pt, axis=-1, keepdims=False).clip(1e-5)
        z = pt[..., -1]

        theta = np.arccos(z / r)
        phi = np.arctan2(pt[..., 1], pt[..., 0])
        # log polar
        # pt = np.concatenate([np.log(r[..., None]), theta[..., None], phi[..., None]], axis=-1)
        int_theta = np.round(theta * (num_grids - 1) / np.pi).astype(np.int32)

        grid = np.log(scaling * np.ones(shape=(num_grids, num_grids)))

        int_phi = np.round((phi + np.pi) / 2 * (num_grids - 1)).astype(np.int32)
        log_r = np.log(r)
        for i in range(num_grids):
            for j in range(num_grids):
                log_r_at = log_r[np.where((int_theta == i) & (int_phi == j))]
                if len(log_r_at) > 0:
                    grid[i, j] = np.min(log_r[np.where((int_theta == i) & (int_phi == j))])

        return grid.flatten()

    @property
    def position(self) -> np.ndarray:
        pos = self.client.simGetGroundTruthKinematics(self.drone_id)
        xyz = np.asarray([pos.position.x_val, pos.position.y_val, pos.position.z_val], dtype=np.float64)
        return xyz

    @property
    def velocity(self) -> np.ndarray:
        pos = self.client.simGetGroundTruthKinematics(self.drone_id)
        xyz_xyz = np.asarray([pos.linear_velocity.x_val, pos.linear_velocity.y_val, pos.linear_velocity.z_val,
                              pos.angular_velocity.x_val, pos.angular_velocity.y_val, pos.angular_velocity.z_val],
                             dtype=np.float64)
        return xyz_xyz

    @property
    def force(self) -> np.ndarray:
        kinematics = self.client.simGetGroundTruthKinematics(self.drone_id)
        xyz_force = kinematics.linear_acceleration
        torque = kinematics.angular_acceleration
        force = np.asarray(
            [xyz_force.x_val, xyz_force.y_val, xyz_force.z_val, torque.x_val, torque.y_val, torque.z_val],
            dtype=np.float64)
        return force

    @property
    def orientation(self) -> np.ndarray:
        pos = self.client.simGetGroundTruthKinematics(self.drone_id)
        wxyz = np.asarray([pos.orientation.w_val, pos.orientation.x_val, pos.orientation.y_val, pos.orientation.z_val],
                          dtype=np.float64)
        return wxyz

    @staticmethod
    def vector_process(vector, pos_Info):
        vector = vector - np.array([pos_Info.position.x_val, pos_Info.position.y_val, pos_Info.position.z_val])
        distance = np.linalg.norm(vector, 2)
        vector = vector / distance
        theta = np.arccos(vector[2])
        if vector[0] > 0:
            phi = np.arctan(vector[1] / vector[0]) + np.pi / 2
        elif vector[0] < 0:
            phi = np.arctan(vector[1] / vector[0]) + 3 * np.pi / 2
        else:
            if vector[1] < 0:
                phi = np.pi
            else:
                phi = 0
        location_theta = (theta / np.pi - 1e-7) * 20
        location_phi = (phi / (2 * np.pi) - 1e-7) * 19

        try:
            return distance, int(location_phi), int(location_theta)
        except:
            return distance, 0, 0

    @staticmethod
    def angle_add(a_max, angle1, angle2, mode='theta'):
        angle = angle1 + angle2
        if angle < 0:
            if mode == 'theta':
                angle = 2 * a_max - angle
            else:
                angle = angle + a_max
        elif angle > a_max:
            if mode == 'theta':
                angle = 2 * a_max - angle
            else:
                angle = angle - a_max
        return angle

    @property
    def goal_pos(self):
        return self.current_goal / 100

    @property
    def context(self) -> np.ndarray:
        positions = np.concatenate([self.goal_pos, self.position], axis=-1)
        positions = positions / 100
        orientation = self.orientation
        velocity = self.velocity
        linear_direction = (self.goal_pos - self.position) / 100
        wind = self.wind.get_wind()
        forces = self.force  # force and angular force
        return np.concatenate([linear_direction, velocity, orientation, positions, wind, forces], axis=-1).copy()

    @property
    def lidar_data(self):
        lidar_data = self.client.getLidarData('LidarSensor', self.drone_id)
        point_cloud = lidar_data.point_cloud
        point_cloud = np.asarray(point_cloud).reshape(-1, 3)
        lidar_data = self.process_point_cloud(point_cloud, self.num_grids, 100)
        return lidar_data.flatten().copy()

    @property
    def current_observation(self) -> np.ndarray:
        return np.concatenate([self.context, self.lidar_data], axis=-1).astype(np.float32).copy()

    def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
        if seed is not None:
            self.np_rng = np.random.default_rng(seed)
            self.wind.np_rng = np.random.default_rng(seed)

        if options is not None:
            for k, v in options:
                setattr(self, k, v)
        if self.client.simIsPause():
            self.client.simPause(False)

        self.client.reset()
        self.wind.reset()
        self.client.armDisarm(True, self.drone_id)
        self.client.enableApiControl(True)
        self.client.simEnableWeather(True)
        # self.client.simSetObjectPose()
        pos = self.goal_list[self.np_rng.integers(0, len(self.goal_list))] / 100
        # pos = self.position
        self.client.simSetVehiclePose(
            pose=airsim.Pose(airsim.Vector3r(pos[0], pos[1], -5)),
            ignore_collision=True,
            vehicle_name=self.drone_id)
        self.current_goal = self.goal_list[self.np_rng.integers(0, max(len(self.goal_list) - 1, 1))]
        self.set_current_goal()
        if self.verbose:
            print("RESET!")
        info = {}
        self.collision_count = 0
        self.client.simPause(True)

        observation = self.current_observation
        return observation, info

    def set_current_goal(self):
        dists = np.linalg.norm(self.goal_list / 100 - np.copy(self.position), axis=-1)
        masks = np.where(dists > 100, 1, 0)
        probs = masks / masks.sum()
        index = self.np_rng.choice(np.arange(len(self.goal_list)), p=probs)
        self.current_goal = self.goal_list[index]
        return self.current_goal.copy()

    @property
    def collision(self):
        collision_info = self.client.simGetCollisionInfo(self.drone_id)

        return collision_info.has_collided, collision_info

    def step(
            self, action: ActType
    ) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
        if self.client.simIsPause():
            self.client.simPause(False)

        self.wind.step()
        w_x, w_y, w_z = self.wind.get_wind()
        # action = action * 10
        success = False
        info = {}
        done = False
        if self.hard:
            action = (action + 1) / 2
            a_1, a_2, a_3, a_4 = float(action[0]), float(action[1]), float(action[2]), float(action[3])
            self.client.simSetWind(airsim.Vector3r(w_x, w_y, w_z))
            self.client.moveByMotorPWMsAsync(a_1, a_2, a_3, a_4, 0.1, self.drone_id).join()

        else:
            action = action * 10
            a_1, a_2, a_3 = float(action[0]), float(action[1]), float(action[2])
            [v_x, v_y, v_z, a_vx, a_vy, a_vz] = list(self.velocity)
            self.client.simSetWind(airsim.Vector3r(w_x, w_y, w_z))
            self.client.moveByVelocityAsync(v_x + a_1, v_y + a_2, v_z + a_3, 0.1,
                                            vehicle_name=self.drone_id).join()

        collision, collision_info = self.collision
        self.client.simPause(True)
        if collision:
            done = True
            reward = self.collision_reward
            return self.current_observation, reward, done, False, {"collision": True,
                                                                   'is_success': False}
        else:
            current_vel = self.velocity.copy()[:3]
            goal_pos = self.goal_pos.copy()
            current_pos = self.position.copy()

            direction = (goal_pos - current_pos)

            xy_distance = np.abs(goal_pos[:2] - current_pos[:2]).sum()

            inner = np.dot(direction, current_vel)

            norm1 = (direction ** 2).sum()
            norm2 = (current_vel ** 2).sum()
            if norm1 * norm2 > 0:
                inner = inner / np.sqrt(norm1 * norm2)
            else:
                inner = 0

            reward = inner
            if xy_distance < 12:
                reward = 100
                if self.verbose:
                    print("Success.")
                info['is_success'] = True
                if self.terminal_if_success:
                    return self.current_observation, reward, True, False, info

                self.set_current_goal()

            else:
                done = False

            if np.abs(self.position[-1]) > self.height_constrain:
                done = False
                if np.abs(self.position[-1]) > self.height_constrain:
                    reward = -np.abs((np.abs(self.position[-1]) - self.height_constrain))
                if np.abs(self.position[-1]) > 100:
                    done = True

            if self.verbose:
                color_wind = [0.25, 0.25, 0.5, 0.8]
                pos = self.position
                end = airsim.Vector3r(float(pos[0] + w_x), float(pos[1] + w_y), float(pos[2] + w_z))
                start = airsim.Vector3r(float(pos[0]), float(pos[1]), float(pos[2]))
                self.client.simPlotArrows([start], [end], duration=0.25,
                                          color_rgba=color_wind, arrow_size=5., thickness=2, )
                color_v = [218 / 255, 112 / 255, 214 / 255, 1]
                [v_x, v_y, v_z, a_vx, a_vy, a_vz] = list(self.velocity)
                end_v = airsim.Vector3r(float(pos[0] + v_x), float(pos[1] + v_y), float(pos[2] + v_z))
                self.client.simPlotArrows([start], [end_v], duration=0.25,
                                          color_rgba=color_v, arrow_size=5., thickness=2)

                dr = (goal_pos - pos)
                dr = dr / (np.linalg.norm(dr, ) + 1e-4)
                dr = dr + pos
                end_g = airsim.Vector3r(float(dr[0]), float(dr[1]), float(dr[2]))
                self.client.simPlotArrows([start], [end_g], duration=0.25,
                                          color_rgba=[0, 0, 0, 1], arrow_size=5., thickness=2, )

            reward += self.alive_reward
            return self.current_observation, reward, done, False, info

