from org.hipparchus.geometry.euclidean.threed import Vector3D
from org.hipparchus.geometry import Vector
from org.orekit.propagation.numerical import NumericalPropagator
from org.hipparchus.ode.nonstiff import DormandPrince853Integrator
from org.orekit.forces.gravity import NewtonianAttraction
from org.orekit.orbits import KeplerianOrbit, OrbitType, PositionAngleType
from org.orekit.propagation import SpacecraftState
from org.orekit.time import AbsoluteDate

from constants import EARTH_RADIUS, MU, ATTITUDE, INERTIAL_FRAME
from orekit import JArray_double
import numpy as np
from math import radians
from bodies import Satellite
import sys
import pygame
from play3d.models import Grid, CircleChain, Plot
from interface.interface_models import Earth, Trail
from interface.pygame_utils import handle_camera_with_keys
from play3d.three_d import Device, Camera
import play3d.three_d as three_d
from play3d.matrix import Matrix

class Interface:

    def __init__(self, bodies, params):

        self.bodies = bodies
        self.params = params

        scale = params["zoom"] if "zoom" in params else 1.0
        self.distance_scale = scale / EARTH_RADIUS

        Device.viewport(1024, 768)
        pygame.init()
        pygame.display.set_caption("OrbitZoo")
        pygame.display.set_icon(pygame.image.load(".\interface\logo_orbitzoo.png"))
        self.screen = pygame.display.set_mode(Device.get_resolution())
        self.font = pygame.font.SysFont("Consolas", 12)

        x, y, z = 0, 1, 2
        line_adapter = lambda p1, p2, color: pygame.draw.line(self.screen, color, (p1[x], p1[y]), (p2[x], p2[y]), 1)
        put_pixel = lambda x, y, color: pygame.draw.circle(self.screen, color, (x, y), 1)
        Device.set_renderer(put_pixel, line_renderer=line_adapter)

        if self.params["earth"]["show"]:
            self.earth = Earth(resolution = self.params["earth"]["resolution"], position=(0,0,0), color=self.params["earth"]["color"], scale=scale)
        if self.params["equator_grid"]["show"]:
            self.grid = Grid(color=self.params["equator_grid"]["color"], dimensions=(6*int(scale), 6*int(scale)), dot_precision=self.params["equator_grid"]["resolution"])
        if "orbits" in params:
            self.orbits = []
            # propagate each orbit a full period to get the trail
            for orbit in self.params["orbits"]:
                current_date = AbsoluteDate()
                keplerian_orbit = KeplerianOrbit(orbit["a"] + EARTH_RADIUS, orbit["e"], radians(orbit["i"]), radians(orbit["pa"]), radians(orbit["raan"]), radians(0.0), PositionAngleType.TRUE, INERTIAL_FRAME, current_date, MU)
                state = SpacecraftState(keplerian_orbit, 1.0)
                tolerances = NumericalPropagator.tolerances(60.0, state.getOrbit(), state.getOrbit().getType())
                integrator = DormandPrince853Integrator(1e-3, 500.0, JArray_double.cast_(tolerances[0]), JArray_double.cast_(tolerances[1]))
                integrator.setInitialStepSize(10.0)
                propagator = NumericalPropagator(integrator)
                propagator.setMu(MU)
                propagator.setOrbitType(OrbitType.KEPLERIAN)
                propagator.setAttitudeProvider(ATTITUDE)
                propagator.addForceModel(NewtonianAttraction(MU))
                propagator.setInitialState(state)
                trail = []
                step_seconds = 30.0
                for _ in range(int(state.getKeplerianPeriod() / step_seconds) + 2):
                    current_date = current_date.shiftedBy(step_seconds)
                    state = propagator.propagate(current_date)
                    pos = state.getPVCoordinates().getPosition()
                    pos = [pos.getX(), pos.getZ(), pos.getY()]
                    if len(trail) == 0:
                        trail.append(pos)
                    else:
                        relative_pos = np.array(pos) - np.array(trail[0])
                        trail.append(list(relative_pos))
                initial_pos = list(np.array(trail[0]) * self.distance_scale)
                self.orbits.append(Trail(trail = trail[1:], position=initial_pos, color=orbit["color"], scale=self.distance_scale))

        self.spheres = []
        for body in self.bodies:
            pos = body.current_state.getPVCoordinates().getPosition()
            pos = pos.scalarMultiply(float(self.distance_scale))
            if isinstance(body, Satellite):
                self.spheres.append(CircleChain(position=(pos.getX(), pos.getZ(), pos.getY()), color=params["satellites"]["color_body"] if body.color is None else body.color, scale=scale/300))
            else:
                self.spheres.append(CircleChain(position=(pos.getX(), pos.getZ(), pos.getY()), color=params["drifters"]["color_body"] if body.color is None else body.color, scale=scale/300))

        camera = Camera.get_instance()
        camera.move(y=1, z=5+2*scale)

        self.reset()

    def reset(self):
        if len(self.bodies) > 0:
            self.trails = {body.name: [] for body in self.bodies}
            self.initial_timestamp = self.bodies[0].current_date.getDate().toString().rpartition('.')[0].replace('T', ' ')

    def save_screenshot(self, filename = 'screenshot.jpg'):
        pygame.image.save(self.screen,filename)

    def frame(self):

        if pygame.event.get(pygame.QUIT):
            sys.exit(0)

        self.screen.fill((20, 20, 20))
        # self.screen.fill((255, 255, 255))

        # draw grid
        if self.params["equator_grid"]["show"]:
            self.grid.draw()

        # draw Earth
        if self.params["earth"]["show"]:
            self.earth.draw()

        # draw orbits
        if "orbits" in self.params:
            for orbit in self.orbits:
                orbit.draw()

        # draw bodies
        for i in range(len(self.bodies)):

            body = self.bodies[i]
            body_type = "satellites" if body.has_thrust else "drifters"

            pos = body.current_state.getPVCoordinates().getPosition()
            vel = body.current_state.getPVCoordinates().getVelocity()

            scaled_pos = pos.scalarMultiply(float(self.distance_scale))
            x, y, z = scaled_pos.getX(), scaled_pos.getZ(), scaled_pos.getY()

            r = Vector3D.cast_(Vector.cast_(pos).normalize())
            v = Vector3D.cast_(Vector.cast_(vel))
            w = Vector3D.cast_(Vector.cast_(r.crossProduct(v)).normalize())
            s = Vector3D.cast_(Vector.cast_(w.crossProduct(r)).normalize())

            # draw body
            if self.params[body_type]["show"]:
                self.spheres[i].set_position(x, y, z)
                self.spheres[i].draw()

            # draw body velocity
            if self.params[body_type]["show_velocity"]:
                vel_vector = [s.getX(), s.getZ(), s.getY()]
                velocity_arrow = Plot(func=lambda t: self.arrow_func(t, vel_vector, scale=0.3), allrange=[0, 1], position=(x, y, z), color=self.params["satellites"]["color_velocity"], interpolate=50)
                velocity_arrow.draw()

            # draw body thrust
            if body_type == "satellites" and self.params[body_type]["show_thrust"] and body.thrust:
                thrust = Vector3D.cast_(Vector.cast_(r.scalarMultiply(float(body.thrust[0])).add(s.scalarMultiply(float(body.thrust[1]))).add(w.scalarMultiply(float(body.thrust[2])))).normalize())
                thrust_vector = [thrust.getX(), thrust.getZ(), thrust.getY()]
                thrust_arrow = Plot(func=lambda t: self.arrow_func(t, thrust_vector, scale=0.2), allrange=[0, 1], position=(x, y, z), color=self.params["satellites"]["color_thrust"], interpolate=50)
                thrust_arrow.draw()

            # draw trail
            if self.params[body_type]["show_trail"]:
                if len(self.trails[body.name]) > self.params[body_type]["trail_last_steps"]:
                    self.trails[body.name].pop(1)
                pos = [pos.getX(), pos.getZ(), pos.getY()]
                if len(self.trails[body.name]) == 0:
                    self.trails[body.name].append(pos)
                else:
                    relative_pos = np.array(pos) - np.array(self.trails[body.name][0])
                    self.trails[body.name].append(list(relative_pos))
                initial_pos = list(np.array(self.trails[body.name][0]) * self.distance_scale)
                trail = Trail(trail = self.trails[body.name][1:], position=initial_pos, color=self.params[body_type]["color_trail"] if body.color is None else body.color, scale=self.distance_scale)
                trail.draw()

            # draw body label
            if self.params[body_type]["show_label"]:
                VP = three_d.Camera.View_Projection_matrix()
                # data = Matrix([[x, y, z, 1]])
                data = Matrix([[x*1/40, y*1/40, z*1/40, 1]])
                points = data @ self.spheres[i].matrix @ VP
                center = self.spheres[i]._perspective_divide(points=points)[0][:2]
                label = self.font.render(body.name, True, self.params[body_type]["color_label"])
                label_rect = label.get_rect(center=(center[0], center[1] - 20))
                self.screen.blit(label, label_rect)

        # draw timestamp
        if self.params["timestamp"]["show"] and len(self.bodies) > 0:
            # initial timestamp
            text_timestamp = f'Start:   {self.initial_timestamp}'
            label_timestamp = self.font.render(text_timestamp, True, (255, 255, 255))
            # label_timestamp = self.font.render(text_timestamp, True, (0, 0, 0))
            label_rect_timestamp = label_timestamp.get_rect(center=(110, 20))
            self.screen.blit(label_timestamp, label_rect_timestamp)

            # current timestamp
            current_timestamp = self.bodies[0].current_date.getDate().toString().rpartition('.')[0].replace('T', ' ')
            text_timestamp = f'Current: {current_timestamp}'
            label_timestamp = self.font.render(text_timestamp, True, (255, 255, 255))
            # label_timestamp = self.font.render(text_timestamp, True, (0, 0, 0))
            label_rect_timestamp = label_timestamp.get_rect(center=(110, 35))
            self.screen.blit(label_timestamp, label_rect_timestamp)

        handle_camera_with_keys()
        pygame.display.update()
        pygame.time.delay(self.params["delay_ms"])

    def arrow_func(self, t, vector, scale=1):
        return [t * component * scale for component in vector]
    