from vpython import *
import numpy as np
from .global_variables import *


class FastJet:
    def __init__(self, np_random, arrow_length=0, show_sphere=False):
        self.np_random = np_random, self.arrow_length, self.show_sphere = np_random, arrow_length, show_sphere
        self.left_trail, self.right_trail, self.visual_model = None, None, None
        self.reset()
        
    def reset(self, pos=vector(0., 0., 0.), vel=vector(0., 0., 0.), axis=None, up=None, hdg=0., pitch=0., roll=0., r_vel=vector(0., 0., 0.)):
        if axis is None: axis, up = dir_to_vec(hdg, pitch, roll) # If specified in hdg, pitch, roll format.
        self._set(pos, axis, up)
        # May be different to what was specified (e.g. pitch > 90 will make pitch < 90 + roll)
        self.hdg, self.pitch, self.roll = vec_to_dir(self.axis, self.up)
        self.vel = vel
        self.r_vel = r_vel
        # These variables are always initialised in the same way.
        self.acc = vector(0., 0., 0.)
        self.g = 0.
        self.alive = True
        self.ground_contact = False
        self.thrust = 0.0
        self.brake = 0.0
        
        if self.left_trail is not None: self.left_trail.clear(); self.right_trail.clear()

    def _set(self, pos=None, axis=None, up=None):
        """Directly specify pos, axis, up independent of dynamics. Use with care!"""
        if pos is not None:  self.pos = pos
        if axis is not None: self.axis = norm(axis)
        if up is not None:   self.up = norm(up)
        self.check_orthogonal()

    def check_orthogonal(self):
        dp = dot(self.axis, self.up)
        assert np.isclose(dp, 0.0), f"Axis {self.axis} and up {self.up} non-orthogonal (dot product = {dp})"
        
    def step(self, action):
        assert len(action) == 4
        
        # demands
        pitch_demand = -action[0] * PITCH_POWER / HZ
        roll_demand = action[1] * ROLL_POWER / HZ
        yaw_demand = -action[2] * YAW_POWER / HZ
        thrust_demand = action[3] * THRUST_POWER
        brake_demand = 0.0
        
        right_axis = cross(self.axis, self.up)
        prev_vel = self.vel
        
        # ground contact
        if self.ground_contact:
            if np.abs(self.axis.y) < 0.02:
                self.r_vel.x = -self.axis.y
                pitch_demand *= 0.0 # clamps to ground - need to change for takeoff
            else:
                self.r_vel.x -= 0.1 * self.axis.y # pitch
                pitch_demand *= 0.025
                
            self.r_vel.y += 0.15 * right_axis.y # roll
            roll_demand *= 0.025
            yaw_demand *= 0.5
            
            if action[3] < -0.8:
                brake_demand = 0.2
                thrust_demand = -THRUST_POWER
                self.brake = brake_demand
        
        # rotation velocity
        self.r_vel.x += pitch_demand
        self.r_vel.y += roll_demand
        self.r_vel.z += yaw_demand
        self.r_vel *= 0.95 # stability damping
        
        # pitch
        self.axis = self.axis.rotate(angle=radians(self.r_vel.x), axis=right_axis)
        self.up = self.up.rotate(angle=radians(self.r_vel.x), axis=right_axis)        
        # roll
        self.up = self.up.rotate(angle=radians(self.r_vel.y), axis=self.axis)
        # yaw
        self.axis = self.axis.rotate(angle=radians(self.r_vel.z), axis=self.up)

        self.check_orthogonal()

        # hdg, pitch, roll representation
        self.hdg, self.pitch, self.roll = vec_to_dir(self.axis, self.up)
        
        # thrust
        self.thrust = THRUST_POWER + thrust_demand
        self.vel += self.thrust * self.axis
        
        # brake
        self.vel -= brake_demand * norm(self.vel)
        
        # gravity
        self.vel.y -= G
        
        # lift
        vel_factor = min(mag(self.vel) / 2500.0, 1.0)
        self.vel += G * self.up * vel_factor
    
        # drag
        speed = mag(self.vel)
        drag = DRAG * speed
        
        right_axis = cross(self.axis, self.up)
        vel_norm = norm(self.vel)
        self.vel -= drag.x * self.axis * dot(vel_norm, self.axis)
        self.vel -= drag.z * right_axis * dot(vel_norm, right_axis)
        self.vel -= drag.y * self.up   * dot(vel_norm, self.up)

        # pos
        self.pos += self.vel / HZ * RENDER_SCALE
        
        # ground
        self.ground_contact = self.pos.y <= 0.5
        if self.ground_contact:
            self.pos.y = 0.5
            
            up_diff = np.abs(diff_angle(self.up, vector(0.0, 1.0, 0.0)))
            h_vel_diff = np.abs(diff_angle(vector(self.vel.x, 0.0, self.vel.z), vector(self.axis.x, 0.0, self.axis.z)))
            
            # ground handling
            h_vel = (self.vel.x**2 + self.vel.z**2)**0.5 * vector(self.axis.x, 0.0, self.axis.z)
            self.vel = 0.1 * vector(h_vel.x, self.vel.y, h_vel.z) + 0.9 * self.vel
            
            if self.vel.y >= -20.0:
                self.vel.y = 0.0
            elif self.vel.y <= -60.0 or up_diff > radians(15.0) or h_vel_diff > radians(15):
                self.alive = False
            else:
                self.vel.y = -self.vel.y * 0.75
            
        self.acc = self.vel - prev_vel
        self.g = mag(self.acc + vector(0.0, G, 0.0)) / G
        
    def _render_first(self, scene):
        self.visual_model = compound(
            [
                cylinder(canvas=scene, width=0.1, height=0.1, length=0.8, pos=vector(-0.4, 0.0, 0.0), shininess=0),
                cylinder(canvas=scene, width=0.16, height=0.06, length=0.7, pos=vector(-0.34, 0.0, 0.0), shininess=0),
                pyramid(canvas=scene, width=0.3, height=0.04, length=0.24, pos=vector(-0.25, 0, 0), axis=vector(0, 0, +0.24), shininess=0),
                pyramid(canvas=scene, width=0.3, height=0.04, length=0.24, pos=vector(-0.25, 0, 0), axis=vector(0, 0, -0.24), shininess=0),
                pyramid(canvas=scene, width=0.5, height=0.04, length=0.4, pos=vector(0, 0, 0), axis=vector(0, 0, +0.4), shininess=0),
                pyramid(canvas=scene, width=0.5, height=0.04, length=0.4, pos=vector(0, 0, 0), axis=vector(0, 0, -0.4), shininess=0),
                pyramid(canvas=scene, width=0.24, height=0.02, length=0.24, pos=vector(-0.26, 0, 0), axis=vector(0, 0.2, -0.2), up=vector(0, 0.2, +0.2), shininess=0),
                pyramid(canvas=scene, width=0.24, height=0.02, length=0.24, pos=vector(-0.26, 0, 0), axis=vector(0, 0.2, +0.2), up=vector(0, 0.2, -0.2), shininess=0),
                cone(canvas=scene, width=0.1, height=0.1, length=0.2, pos=vector(+0.4, 0.0, 0.0), color=color.gray(0.4), shininess=0),
                ellipsoid(canvas=scene, width=0.035, height=0.035, length=0.3, pos=vector(0.0, -0.02, 0.3), color=color.gray(0.5), shininess=0),
                ellipsoid(canvas=scene, width=0.035, height=0.035, length=0.3, pos=vector(0.0, -0.02, 0.2), color=color.gray(0.5), shininess=0),
                ellipsoid(canvas=scene, width=0.035, height=0.035, length=0.3, pos=vector(0.0, -0.02, -0.3), color=color.gray(0.5), shininess=0),
                ellipsoid(canvas=scene, width=0.035, height=0.035, length=0.3, pos=vector(0.0, -0.02, -0.2), color=color.gray(0.5), shininess=0),
                ellipsoid(canvas=scene, width=0.05, height=0.07, length=0.2, pos=vector(0.4, 0.022, 0.0), color=color.black)
            ],
            texture='granite_texture.jpg',
            canvas=scene
        )
        self.thrust_sphere = sphere(canvas=scene, pos=vector(-0.4, 0.0, 0.0), radius=0.025, opacity=0.4, shininess=0, emissive=True, color=color.orange)

       # NOTE: Trails cause a problem when have multiple VPython canvases open, which happens when using FastJetInterface.
        # def left_wing():
        #     return self.visual_model.pos - (0.1 * self.visual_model.axis) - (self.up * 0.06) + (0.39 * cross(self.visual_model.axis, self.visual_model.up))
        # def right_wing():
        #     return self.visual_model.pos - (0.1 * self.visual_model.axis) - (self.up * 0.06) - (0.39 * cross(self.visual_model.axis, self.visual_model.up))
        # self.left_trail = attach_trail(left_wing, radius=0.005, color=color.white, retain=200)
        # self.right_trail = attach_trail(right_wing, radius=0.005, color=color.white, retain=200)
        
        if self.arrow_length > 0:
            self.arrow_fwd = arrow(canvas=scene, color=color.red)    # fwd   = red
            self.arrow_up = arrow(canvas=scene, color=color.green)   # up    = green
            self.arrow_right = arrow(canvas=scene, color=color.blue) # right = blue
            if self.show_sphere:
                self.sphere = sphere(canvas=scene, radius=self.arrow_length, opacity=0.25, shininess=0, emissive=True, color=color.white)
        
    def render(self, scene):
        if self.visual_model is None:
            self._render_first(scene)
            
        self.visual_model.pos = self.pos + self.up * 0.06
        self.visual_model.axis = self.axis
        self.visual_model.up = self.up
        
        self.thrust_sphere.pos = self.pos - self.axis * 0.49
        self.thrust_sphere.opacity = 0.75 * (self.thrust / (2.0 * THRUST_POWER))
        
        if self.arrow_length > 0:
            self.arrow_fwd.pos = self.pos
            self.arrow_fwd.axis = self.axis
            self.arrow_up.pos = self.pos
            self.arrow_up.axis = self.up
            self.arrow_right.pos = self.pos
            self.arrow_right.axis = cross(self.axis, self.up)
            self.arrow_fwd.length = self.arrow_up.length = self.arrow_right.length = self.arrow_length
            if self.show_sphere: self.sphere.pos = self.pos

    def unrender(self):
        if self.visual_model is not None:            
            self.visual_model.visible = False;  del self.visual_model
            self.thrust_sphere.visible = False; del self.thrust_sphere
            if self.arrow_length > 0:
                self.arrow_fwd.visible = False;     del self.arrow_fwd
                self.arrow_up.visible = False;      del self.arrow_up
                self.arrow_right.visible = False;   del self.arrow_right
                if self.show_sphere: 
                    self.sphere.visible = False;    del self.sphere

# radians
def dir_to_vec(hdg, pitch, roll):
    fwd = vector(
        cos(hdg)*cos(pitch),
        sin(pitch),
        cos(pitch)*sin(hdg)
    )
    up = vector(
        np.sin(-roll)*np.sin(hdg)-np.cos(-roll)*np.sin(pitch)*np.cos(hdg),
        np.cos(-roll)*np.cos(pitch),
        -np.sin(-roll)*np.cos(hdg)-np.cos(-roll)*np.sin(pitch)*np.sin(hdg)
    )
    return fwd, up
    
# radians
def vec_to_dir(fwd, up):
    hdg = np.arctan2(fwd.z, fwd.x)
    pitch = np.arcsin(fwd.y)
    roll_0 = vector(-fwd.z, 0.0, fwd.x)
    up_0 = cross(roll_0, fwd)
    roll = np.arctan2(dot(roll_0, up), dot(up_0 , up))
    # (-180 to +180, -90 to +90, -180 to +180)
    return hdg, pitch, roll