from vpython import *
from numpy import array, float32


HZ = 50.0
RENDER_SCALE = 0.01
WINDOW_W, WINDOW_H = 1600, 800

DISCRETE_ACTION_MAP = array([
    [ 0,  0,  0,  0], # 0: no-op
    [-1,  0,  0,  0], # 1: pitch up
    [+1,  0,  0,  0], # 2: roll down
    [ 0, -1,  0,  0], # 3: roll left
    [ 0, +1,  0,  0], # 4: roll right
    [ 0,  0, -1,  0], # 5: yaw left
    [ 0,  0, +1,  0], # 6: yaw right
    [ 0,  0,  0, -1], # 7: thrust low
    [ 0,  0,  0, +1], # 8: thrust high
])

X_MAX = 200.0 / 2
Y_MIN = 100.0
Y_MAX = Y_MIN + 200.0 / 2
Z_MAX = 200.0 / 2

# PITCH_POWER, ROLL_POWER, YAW_POWER, THRUST_POWER = (2.0, 15.0, 0.5, 15.0) # v0: Version sent 21/10/21
# PITCH_POWER, ROLL_POWER, YAW_POWER, THRUST_POWER = (2.0, 4.0, 0.5, 25.0) # v1: Version sent 28/10/21
PITCH_POWER, ROLL_POWER, YAW_POWER, THRUST_POWER = (4.0, 12.0, 1.0, 25.0) # v2: High power
# PITCH_POWER, ROLL_POWER, YAW_POWER, THRUST_POWER = (2.0, 10.0, 0.5, 25.0) # v3: Low power

START_SPEED_MIN = 2000.0
START_SPEED_MAX = 3000.0
V_LIM = (START_SPEED_MIN + START_SPEED_MAX) / 2

START_RVEL = 0.0

G = 9.81

DRAG = vector(0.01, 0.1, 0.04)

DEFAULT_OBS = lambda env: array([
    env.jets[0].pos.x,
    env.jets[0].pos.y,
    env.jets[0].pos.z,
    env.jets[0].vel.x,
    env.jets[0].vel.y,
    env.jets[0].vel.z,
    env.jets[0].acc.x,
    env.jets[0].acc.y,
    env.jets[0].acc.z,
    env.jets[0].r_vel.x,
    env.jets[0].r_vel.y,
    env.jets[0].r_vel.z,
    env.jets[0].axis.x,
    env.jets[0].axis.y,
    env.jets[0].axis.z,
    env.jets[0].up.x,
    env.jets[0].up.y,
    env.jets[0].up.z,
    env.jets[0].thrust,
    # NOTE: Cleaner to have reference jet dims as separate.
    env.jets[1].pos.x,
    env.jets[1].pos.y,
    env.jets[1].pos.z,
    env.jets[1].vel.x,
    env.jets[1].vel.y,
    env.jets[1].vel.z,
    env.jets[1].acc.x,
    env.jets[1].acc.y,
    env.jets[1].acc.z,
    env.jets[1].r_vel.x,
    env.jets[1].r_vel.y,
    env.jets[1].r_vel.z,
    env.jets[1].axis.x,
    env.jets[1].axis.y,
    env.jets[1].axis.z,
    env.jets[1].up.x,
    env.jets[1].up.y,
    env.jets[1].up.z
], dtype=float32)

DEFAULT_OBS_LIMS = array([ # NOTE: Calibrated using scripts/calibrate_obs_lims.py
    [0.0,            X_MAX],          # env.jets[0].pos.x,
    [Y_MIN,          Y_MAX],          # env.jets[0].pos.y,
    [0.0,            Z_MAX],          # env.jets[0].pos.z,
    [-V_LIM,         V_LIM],          # env.jets[0].vel.x,
    [-V_LIM,         V_LIM],          # env.jets[0].vel.y,
    [-V_LIM,         V_LIM],          # env.jets[0].vel.z,
    [-30.0,          30.0],           # env.jets[0].acc.x,
    [-30.0,          30.0],           # env.jets[0].acc.y,
    [-30.0,          30.0],           # env.jets[0].acc.z,
    [-PITCH_POWER/4, PITCH_POWER/4],  # env.jets[0].r_vel.x,
    [-ROLL_POWER/4,  ROLL_POWER/4],   # env.jets[0].r_vel.y,
    [-YAW_POWER/4,   YAW_POWER/4],    # env.jets[0].r_vel.z,
    [-1.0,           1.0],            # env.jets[0].axis.x,
    [-1.0,           1.0],            # env.jets[0].axis.y,
    [-1.0,           1.0],            # env.jets[0].axis.z,
    [-1.0,           1.0],            # env.jets[0].up.x,
    [-1.0,           1.0],            # env.jets[0].up.y,
    [-1.0,           1.0],            # env.jets[0].up.z,
    [0.0,            2*THRUST_POWER], # env.jets[0].thrust,
    [0.0,            X_MAX],          # env.jets[1].pos.x,
    [Y_MIN,          Y_MAX],          # env.jets[1].pos.y,
    [0.0,            Z_MAX],          # env.jets[1].pos.z,
    [-V_LIM,         V_LIM],          # env.jets[1].vel.x,
    [-V_LIM,         V_LIM],          # env.jets[1].vel.y,
    [-V_LIM,         V_LIM],          # env.jets[1].vel.z,
    [-30.0,          30.0],           # env.jets[1].acc.x,
    [-30.0,          30.0],           # env.jets[1].acc.y,
    [-30.0,          30.0],           # env.jets[1].acc.z,
    [-PITCH_POWER/4, PITCH_POWER/4],  # env.jets[1].r_vel.x,
    [-ROLL_POWER/4,  ROLL_POWER/4],   # env.jets[1].r_vel.y,
    [-YAW_POWER/4,   YAW_POWER/4],    # env.jets[1].r_vel.z,
    [-1.0,           1.0],            # env.jets[1].axis.x,
    [-1.0,           1.0],            # env.jets[1].axis.y,
    [-1.0,           1.0],            # env.jets[1].axis.z,
    [-1.0,           1.0],            # env.jets[1].up.x,
    [-1.0,           1.0],            # env.jets[1].up.y,
    [-1.0,           1.0]             # env.jets[1].up.z
])
