# Screen
import sys, cv2
import numpy as np
import imageio as imio
import os, copy
from Environment.environment import Environment, Reward, Done
from Environment.Environments.Asteroids.asteroid_objects import *
from Environment.Environments.Asteroids.asteroid_specs import *
from Record.file_management import numpy_factored
from State.angle_calculator import sincos_to_angle
from gym import spaces



def rand_sample(variance, zero=False):
    # @param zero centers the variance around 0
    if zero: np.array([(np.random.rand() - 0.5) * 2 * variance[i] for i in range(2)])
    return np.array([np.random.rand() * variance[i] for i in range(2)])

class Asteroids(Environment):
    def __init__(self, frameskip = 1, variant="", fixed_limits=False):
        super(Asteroids, self).__init__()
        # Asteroids specialized parameters are stored in the variant
        self.variant = variant
        self.fixed_limits = fixed_limits
        self.self_reset = True
        self.transpose = False

        # environment properties
        self.num_actions = 6 # noop, forward, backward, right, left, fire
        self.name = "Asteroids" # required for an environment 
        self.discrete_actions = True
        self.frameskip = 1 # no frameskip

        # spaces
        self.action_shape = (1,)
        self.action_space = spaces.Discrete(self.num_actions) # gym.spaces
        self.observation_space = spaces.Box(low=0, high=255, shape=(84, 84), dtype=np.uint8) # raw space, gym.spaces
        self.seed_counter = -1

        # state components
        self.frame = None # the image generated by the environment
        self.reward = Reward()
        self.done = Done()
        self.action = np.zeros(self.action_shape)
        self.extracted_state = None

        # running values
        self.itr = 0
        self.total_score = 0
        self.reset_counter= 0

        # asign variant values
        self.num_asteroids, self.asteroid_size, self.asteroid_speed, self.asteroid_size_variance, \
                    self.asteroid_variance, self.ship_variance, self.ship_speed, self.movement_type, \
                    self.laser_speed, self.hit_reward, self.shot_penalty, self.crash_penalty, \
                    self.completion_reward, self.max_steps = asteroid_variants[variant]

        # factorized state properties
        self.object_names = ["Action", "Ship", "Laser", "Asteroid", 'Done', "Reward"]
        self.object_sizes = {"Action": 1, "Ship": 5, "Laser": 5, "Asteroid": 5, 'Done': 1, "Reward": 1}
        self.object_name_dict = dict() # initialized in reset

        # spec ranges
        self.fixed_limits = fixed_limits
        ranges_fixed, dynamics_fixed, position_masks, instanced = generate_specs_fixed(self.asteroid_size, self.asteroid_size_variance, self.num_asteroids)
        ranges, dynamics, position_masks, instanced = generate_specs(self.asteroid_speed, self.ship_speed[0], self.laser_speed, self.asteroid_size, self.asteroid_size_variance, self.num_asteroids)
        self.position_masks = position_masks
        self.object_range_true = ranges
        self.object_dynamics_true = dynamics
        self.object_range = ranges if not self.fixed_limits else ranges_fixed # the minimum and maximum values for a given feature of an object
        self.object_dynamics = dynamics if not self.fixed_limits else dynamics_fixed
        self.object_instanced = instanced
        # reset counters
        self.hit_counter = 0
        self.shot_counter = 0
        self.lives = 0 # lives is not implemented in the current version
        self.itr = 0
        self.all_names = sum([[name + str(i) for i in range(instanced[name])] for name in self.object_names], start = [])
        self.instance_length = len(self.all_names)
        self.reset()

    def get_asteroid_position(self, variance):
        while True:
            pos = rand_sample(variance)
            if np.linalg.norm(pos - self.ship.pos, ord=2) > self.asteroid_size + 4:
                break
        return pos

    def reset(self):
        self.action_obj = Action()
        self.ship = Ship(np.array([np.random.rand() * self.ship_variance * 84 for i in range(2)]), np.array(0), self.ship_speed, self.movement_type)
        self.asteroids = [Asteroid(self.get_asteroid_position([self.asteroid_variance * 84, self.asteroid_variance * 84]), rand_sample([self.asteroid_speed, self.asteroid_speed]), 1, self.asteroid_size + np.round(np.random.rand() * self.asteroid_size_variance), i) for i in range(self.num_asteroids)]
        if len(self.asteroids) == 1: self.asteroids[0].name = "Asteroid" # if only one asteroid, use "Asteroid" instead of "Asteroid0"
        self.laser = Laser(np.zeros(2), self.laser_speed, 0)
        self.reward = Reward()
        self.done = Done()
        self.objects = [self.action_obj, self.ship, self.laser] + self.asteroids + [self.reward, self.done]
        self.object_name_dict = {**{"Action": self.action_obj, "Ship":self.ship, "Laser": self.laser, "Done": self.done, "Reward": self.reward}, **{self.asteroids[i].name: self.asteroids[i] for i in range(len(self.asteroids))}}
        self.hit_counter = 0
        self.shot_counter = 0
        self.reset_counter = 0
        return self.get_state(render=True)

    def render(self):
        self.frame = np.zeros((84,84), np.uint8)

        # render a triangle
        cv2.line(self.frame, np.round(self.ship.tip).astype(int), np.round(self.ship.right).astype(int), 128,1)
        cv2.line(self.frame, np.round(self.ship.tip).astype(int), np.round(self.ship.left).astype(int), 128,1)
        cv2.line(self.frame, np.round(self.ship.left).astype(int), np.round(self.ship.right).astype(int), 128,1)

        # render asteroids
        for asteroid in self.asteroids:
            if asteroid.exist == 1: cv2.circle(self.frame, asteroid.pos.astype(int), asteroid.size.astype(int), 64, 1)

        # render laser
        if self.laser.exist: cv2.line(self.frame, np.round(self.laser.bottom).astype(int), np.round(self.laser.top).astype(int), 255, 2)
        return self.frame

    def clear_interactions(self):
        for obj in self.objects:
            obj.interaction_trace = list()

    def count_asteroids(self):
        return np.sum([a.exist for a in self.asteroids])

    def get_state(self, render=False):
        rdset = set(["Reward", "Done"])
        extracted_state = numpy_factored({**{obj.name: obj.get_state() for obj in self.objects if obj.name not in rdset}, **{"Done": [self.done.attribute], "Reward": [self.reward.attribute]}})
        if render: self.frame = self.render()
        return {"raw_state": self.frame, "factored_state": extracted_state}


    def step(self, action, render=False):
        self.reward.attribute, self.done.attribute = 0.0, False
        self.clear_interactions()
        self.reset_counter += 1
        for i in range(self.frameskip):
            self.action_obj.step(action)
            self.ship.step(self.action_obj)
            shot = self.laser.step(self.ship, self.action_obj)
            self.shot_counter += float(shot)
            self.reward.attribute += float(shot) * self.shot_penalty
            if shot: self.reward.interact(self.laser)
            for asteroid in self.asteroids:
                asteroid.step(self.laser)
                hit = asteroid.update()
                self.hit_counter += int(hit)
                if hit: 
                    self.reward.interact(asteroid)
                    self.reward.interact(self.laser)
                self.reward.attribute += float(hit) * self.hit_reward
            self.ship.update(self.laser)
            for asteroid in self.asteroids:
                crash = self.ship.intersect(asteroid)
                if crash: 
                    self.reward.interact(asteroid)
                    self.reward.interact(self.ship)
                self.reward.attribute += crash * self.crash_penalty
            self.laser.update()
            # print(self.ship.pos)
        asteroid_count = self.count_asteroids()
        if asteroid_count == 0:
            self.reward.attribute += self.completion_reward
            self.done.attribute = True
        trunc = False
        if self.reset_counter == self.max_steps:
            self.done.attribute = True
            trunc = True
        # print(asteroid_count, self.reward)
        info = {"lives": self.lives, "TimeLimit.truncated": trunc, "total_score": self.num_asteroids - asteroid_count}
        full_state = self.get_state(render)
        self.itr += 1
        if self.done.attribute: self.reset()

        return full_state, self.reward.attribute, self.done.attribute, info

    def set_from_factored_state(self, factored_state, seed_counter=-1, render=False):
        self.ship.pos = factored_state["Ship"][:2]
        self.ship.angle = sincos_to_angle(factored_state["Ship"][2], factored_state["Ship"][3])
        self.ship.update_tips()

        for asteroid in self.asteroids:
            asteroid.pos = factored_state[asteroid.name][:2]
            asteroid.vel = factored_state[asteroid.name][2:4]
            asteroid.exist = factored_state[asteroid.name][4]

        self.laser.pos = factored_state["Laser"][:2]
        self.laser.vel = factored_state["Laser"][2:4]
        self.laser.exist = factored_state["Laser"][4]
        self.laser.update_bottom_top()
        self.reward.attribute = factored_state["Reward"]
        self.done.attribute = factored_state["Done"]
        self.reset_counter = 0

    def demonstrate(self):
        action = 0
        frame = self.render()
        frame = cv2.resize(frame, (frame.shape[0] * 4, frame.shape[1] * 4))
        cv2.imshow('frame',frame)
        key = cv2.waitKey(100)
        action = 6
        if key == ord('q'):
            action = -1
        elif key == ord('z'):
            action = 0
        elif key == ord('d'):
            action = 3
        elif key == ord('w'):
            action = 2
        elif key == ord('s'):
            action = 1
        elif key == ord('a'):
            action = 4
        elif key == ord(' '):
            action = 5
        return action
