import os
import importlib
import types
from vpython import *
import numpy as np
np.set_printoptions(precision=3, suppress=True, edgeitems=30, linewidth=100000)   
import gym
from .fast_jet import FastJet
from .global_variables import *


def is_notebook():
    try:
        shell = get_ipython().__class__.__name__
        if shell == 'ZMQInteractiveShell':
            return True   # Jupyter notebook or qtconsole
        elif shell == 'TerminalInteractiveShell':
            return False  # Terminal running IPython
        else:
            return False  # Other type (?)
    except NameError:
        return False      # Probably standard Python interpreter
        
JUPYTER = is_notebook()


class FastJetEnv(gym.Env):

    metadata = {"render.modes": ["human", "rgb_array"]}

    def __init__(self, task="outside_parallel_skew", continuous=True, skip_frames=1, render_mode=False, camera_angle="outside", show_bbox=False):
        self.skip_frames, self.render_mode, self.camera_angle, self.show_bbox = skip_frames, render_mode, camera_angle, show_bbox
        self.num_phases, self.scene, self.jets, self.rendered = 2, None, [], False
        self.seed()
        self.add_jet() # Add ego jet
        if self.render_mode is not False: self.t = 0; self._render_first()
        self.setup_task(task)
        self.continuous = continuous
        # demanded_pitch, demanded_roll, demanded_yaw, demanded_thrust
        if self.continuous: self.action_space = gym.spaces.Box(np.float32(-1), np.float32(1), shape=(4,)) 
        # no-op, pitch up, pitch down, roll left, roll right, yaw left, yaw right, thrust low, thrust high
        else: self.action_space = gym.spaces.Discrete(9)
    
    def seed(self, seed=None):
        self.np_random, seed = gym.utils.seeding.np_random(seed)
        for jet in self.jets: jet.np_random = self.np_random
        return [seed]
    
    def setup_task(self, task): 
        """
        A task is defined by:
            (1) An observation function and observation space limits
            (2) A reset function
            (3) [Optional] A dynamics function, called after calling step() on self.jets[0]
            (4) [Optional] A reward function
            (5) [Optional] A set of phases and phase change function (NOTE: final phase change = termination)
            (6) [Optional] A set of additional entities to be rendered
            (7) [Optional] A specification for a second "reference jet" to be added to the environment 
        """
        if type(task) == str:
            task = importlib.import_module(f".tasks.{task}", package=__package__).task
        assert type(task) == dict
        self.task = task
        self.obs = types.MethodType(self.task["obs"], self)
        self.observation_space = gym.spaces.Box(np.float32(self.task["obs_lims"][:,0]), np.float32(self.task["obs_lims"][:,1]))
        self._reset = types.MethodType(self.task["reset"], self)
        if "dynamics" in self.task: 
            self.dynamics = types.MethodType(self.task["dynamics"], self)
        if "reward" in self.task: 
            self.reward = types.MethodType(self.task["reward"], self)
        if "phases" in self.task: 
            self.num_phases, self.phase_step = self.task["phases"][0], types.MethodType(self.task["phases"][1], self)
        if "render" in self.task and self.render_mode is not False: 
            self.task["render"](self)
        if "reference_jet" in self.task: 
            self.add_jet(**self.task["reference_jet"]) 

    def add_jet(self, **kwargs):
        self.jets.append(FastJet(np_random=self.np_random, **kwargs))

    def remove_jet(self, num):
        jet = self.jets.pop(num)
        jet.unrender()
        del jet

    def reset(self):
        self.t = 0
        self.phase = 0
        self.render_on = False
        self._reset()
        return self.obs()
    
    def step(self, action):
        if action not in self.action_space:
            raise ValueError(f"Action {action} not in {self.action_space}")
        if not self.continuous: action = DISCRETE_ACTION_MAP[action]
        reward, phase_counts = 0., [0 for _ in range(self.num_phases)]
        for _ in range(self.skip_frames):
            self.frame(action)
            reward += self.reward() # NOTE: Reward summed over frames
            phase_counts[self.phase] += 1
            if self.render_on: self.render()
            if self.is_done: break # Early break if terminate at *any point* during frame skipping
        return self.obs(), reward, self.is_done, {"phase": phase_counts}

    def frame(self, action):
        self.jets[0].step(action)
        self.dynamics(action)
        self.phase_step() 
        self.is_done = self.phase == self.num_phases - 1 # Final phase is termination
        self.t += 1

    # Defaults overwritten by setup_task()
    def _reset(self):      pass
    def dynamics(self, _): pass
    def reward(self):      return 0.
    def phase_step(self):  pass 
    
    def _render_first(self, scene=None):
        if scene is not None: self.scene = scene # Pre-created scene object
        else: self.scene = canvas(width=WINDOW_W, height=WINDOW_H)
        self.scene.background = vector(132/255, 172/255, 217/255)
        self.scene.ambient = color.gray(0.5)
        self.scene.center = vector(vector(X_MAX / 2, (Y_MIN + Y_MAX) / 2, Z_MAX / 2))
        self.jets[0].render(self.scene)
        ground_pos = vector(0, 0, 0)
        self.ground = box(canvas=self.scene, pos=vector(ground_pos.x, ground_pos.y-0.05, ground_pos.z), size=vector(2000, 0.1, 2000), shininess=0, texture='earth_texture.jpg')
        self.label = label(canvas=self.scene, text="", font="monospace", line=False, pixel_pos=True, pos=vector(20, 190, 0), align="left")
        if self.show_bbox: # Bounding box for initialisation
            curve(canvas=self.scene, pos=[vector(0,Y_MIN,0), vector(0,Y_MAX,0), vector(X_MAX,Y_MAX,0), vector(X_MAX,Y_MIN,0), vector(0,Y_MIN,0)])
            curve(canvas=self.scene, pos=[vector(0,Y_MIN,Z_MAX), vector(0,Y_MAX,Z_MAX), vector(X_MAX,Y_MAX,Z_MAX), vector(X_MAX,Y_MIN,Z_MAX), vector(0,Y_MIN,Z_MAX)])
            curve(canvas=self.scene, pos=[vector(0,Y_MIN,0), vector(0,Y_MIN,Z_MAX)])
            curve(canvas=self.scene, pos=[vector(0,Y_MAX,0), vector(0,Y_MAX,Z_MAX)])
            curve(canvas=self.scene, pos=[vector(X_MAX,Y_MAX,0), vector(X_MAX,Y_MAX,Z_MAX)])
            curve(canvas=self.scene, pos=[vector(X_MAX,Y_MIN,0), vector(X_MAX,Y_MIN,Z_MAX)])
            cam_offset = vector(X_MAX * 0, (Y_MAX - Y_MIN) * 0, Z_MAX * np.sqrt(2))
            cam_pos = self.scene.center + cam_offset
            self.scene.camera.axis = -norm(cam_offset) * 100 
            self.scene.camera.pos = cam_pos
        self.rendered = True

    def render(self, mode='human', scene=None):
        assert mode == self.render_mode, f"Render mode is {self.render_mode}, so cannot use {mode}"
        if not self.rendered: self._render_first(scene)
        self.render_on = True
        # Update jet positions and attitudes
        for jet in self.jets: jet.render(self.scene)
        # Update camera angle
        if self.camera_angle == "bbox": assert self.show_bbox
        else:
            if self.camera_angle == "outside":
                axis_delta = (diff_angle(self.jets[0].axis, self.scene.camera.axis) / np.pi)**2
                cam_offset = vector(-5.0, 0.5, 0.0) + axis_delta * vector(-5.0, 0.5, 0.0)
                cam_pos = self.jets[0].pos + cam_offset
            elif "outside_parallel" in self.camera_angle:
                cam_offset = 2.0 * norm(self.jets[0].pos - self.jets[1].pos)
                if "_skew" in self.camera_angle: cam_offset = cam_offset.rotate(angle=0.5, axis=vector(0,1,0))
                cam_pos = self.jets[0].pos + cam_offset
            elif self.camera_angle == "outside_perpendicular":
                dist = self.jets[0].pos - self.jets[1].pos
                cam_offset = norm(cross(dist, vector(0,1,0))) * max(mag(dist), 50) * 0.8
                cam_pos = ((self.jets[0].pos + self.jets[1].pos) / 2) + cam_offset
            else: raise NotImplementedError()
            self.scene.camera.axis = -norm(cam_offset) * 100 # Magnitude *does* make a difference to view distance
            self.scene.camera.pos = cam_pos
        # Update head-up display
        self.label.text = ( 
            'timestep: ' + str(int(self.t)) + \
            '\nairspeed: ' + str(int(mag(self.jets[0].vel)))).ljust(17) + \
            '\n v speed: ' + str(int(self.jets[0].vel.y)) + \
            '\n heading: ' + str(int(degrees(self.jets[0].hdg) + 180)) + \
            '\n   pitch: ' + str(int(degrees(self.jets[0].pitch))) + \
            '\n    roll: ' + str(int(degrees(self.jets[0].roll))) + \
            '\n     alt: ' + str(int(self.jets[0].pos.y)) + \
            '\n       g: ' + str(int(10.0 * self.jets[0].g) / 10.0) + \
            '\n  thrust: ' + str(int(10.0 * self.jets[0].thrust) / 10.0)
        rate(HZ)
        if mode == "rgb_array":
            raise NotImplementedError("Super slow!")
            # http://www.eg.bucknell.edu/~mligare/captureVPython.html
            os.popen("import -window 0x3e00001 __frame__.png")
            im = None
            while im is None: im = cv2.imread("__frame__.png")
            return cv2.resize(im, dsize=(100, 100))
    
    def close(self):
        if self.scene is not None:
            self.scene.delete()
