# Screen
import sys, cv2
import numpy as np
import imageio as imio
import os, copy
from Environment.environment import Environment
from Environment.Environments.TaxiCar.taxicar_objects import *
from Environment.Environments.TaxiCar.taxicar_specs import *
from Record.file_management import numpy_factored
from gym import spaces

DIRS = np.array([[1,0], [-1,0], [0,1], [0,-1]])

class TaxiCar(Environment):
    def __init__(self, frameskip = 1, variant="", fixed_limits=False):
        super(TaxiCar, self).__init__()
        # breakout specialized parameters are stored in the variant
        self.variant = variant
        self.self_reset = True
        self.fixed_limits = fixed_limits

        # environment properties
        self.num_actions = 4 # this must be defined, -1 for continuous. Only needed for primitive actions
        self.name = "Taxicar" # required for an environment 
        self.discrete_actions = True
        self.frameskip = 1 # no frameskip

        self.num_rows, self.num_columns, self.num_pedestrians, self.num_vehicles, self.num_targets, self.step_limit, self.preset = taxicar_variants[self.variant]
        self.num_passengers = self.num_targets
        # spaces
        self.action_shape = (1,)
        self.action_space = spaces.Discrete(self.num_actions) # gym.spaces
        self.observation_space = spaces.Box(low=0, high=255, shape=(RNG, RNG), dtype=np.uint8) # raw space, gym.spaces
        self.seed_counter = -1

        # state components
        self.frame = None # the image generated by the environment
        self.reward = Reward()
        self.done = Done()
        self.action = np.zeros(self.action_shape)
        self.extracted_state = None
        self.steps = 0

        # running values
        self.itr = 0
        self.total_score = 0

        # factorized state properties
        ranges_fixed, dynamics_fixed, position_masks, instanced = generate_specs_fixed(self.num_pedestrians, self.num_vehicles, self.num_targets)
        ranges, dynamics, position_masks, instanced = generate_specs(self.num_rows, self.num_columns, self.num_pedestrians, self.num_vehicles, self.num_targets)
        self.object_names = ["Action", "Taxi", "Passenger", "Pedestrian", "Vehicle", "Target", "Done", "Reward"]
        self.object_sizes = {"Action": 1, "Taxi": 5, "Passenger": 6, "Pedestrian": 7, 'Vehicle': 6, "Target": 3, "Reward": 1, "Done": 1}
        self.object_name_dict = dict() # initialized in reset
        self.object_range = ranges if not self.fixed_limits else ranges_fixed # the minimum and maximum values for a given feature of an object
        self.object_dynamics = dynamics if not self.fixed_limits else dynamics_fixed
        self.object_range_true = ranges
        self.object_dynamics_true = dynamics
        self.object_instanced = instanced
        self.position_masks = position_masks
        self.all_names = sum([([name + str(i) for i in instanced[name]] if instanced[name] > 1 else [name]) for name in self.object_names], start = [])
        self.instance_length = len(self.all_names)

        # reset counters
        self.reset()

    def generate_fill(self, class_type, idx, offset=0, max_adjacent=4, has_vel = False):
        obj = None
        for i in range(1000):
            pos = np.array([np.random.randint(0+offset, self.num_rows - offset), np.random.randint(0+offset, self.num_columns -offset)])
            if len(self.occupancy_matrix[pos[0]][pos[1]]) == 0:
                if max_adjacent < 4: # check adjacent cells for objects
                    total_adjacent = (int(pos[1] + 1 == self.num_columns or self.occupancy_matrix[pos[0]][pos[1]+1] is not None) + 
                        int(pos[1] - 1 == -1 or self.occupancy_matrix[pos[0]][pos[1]-1] is not None) +
                        int(pos[0] + 1 == self.num_rows or self.occupancy_matrix[pos[0] + 1][pos[1]] is not None) +
                        int(pos[0] - 1 == -1 or self.occupancy_matrix[pos[0]-1][pos[1]] is not None))
                    if total_adjacent > max_adjacent:
                        continue
                if has_vel: obj = class_type(pos, idx, DIRS[np.random.randint(4), self.bound])
                else: obj = class_type(pos, idx, self.bound)
                self.occupancy_matrix[pos[0]][pos[1]] = obj
                # print(class_type, pos)
                break
        return obj

    def reset_occupancy(self): # resets the occupancy matrix assuming that the obstacles, blocks, and pusher and non-overlapping
        self.occupancy_matrix = np.array([[list() for i in range(self.num_columns)] for j in range(self.num_rows)])
        self.occupancy_matrix[self.taxi.pos].append(self.taxi)
        for pas in self.passengers:
            self.occupancy_matrix[pas.pos].append(pas)
        for ped in self.pedestrians:
            self.occupancy_matrix[ped.pos].append(ped)
        for veh in self.vehicles:
            self.occupancy_matrix[veh.pos].append(veh)
        for tar in self.targets:
            self.occupancy_matrix[tar.pos].append(tar)

    def reset(self):
        self.reward, self.done = Reward(), Done()
        self.action = Action()
        self.bound = Bound((self.num_rows, self.num_columns))
        self.taxi = Taxi(np.array([np.random.randint(self.num_rows), np.random.randint(self.num_columns)]))
        self.targets = [self.generate_fill(Target, i) for i in range(self.num_targets)] if self.num_targets > 1 else [self.generate_fill(Target, -1)]
        self.passengers = [self.generate_fill(Passenger, i, self.targets[np.random.randint(self.num_targets)].pos) for i in range(self.num_passengers)] if self.num_passengers > 1 else [self.generate_fill(Passenger, -1, self.targets[np.random.randint(self.num_targets)].pos)]
        self.pedestrians = [self.generate_fill(Pedestrians, i, has_vel=True) for i in range(self.num_pedestrians)] if self.num_pedestrians > 1 else [self.generate_fill(Pedestrians, -1, has_vel=True)]
        self.vehicles = [self.generate_fill(Vehicle, i, has_vel=True) for i in range(self.num_vehicles)] if self.num_vehicles > 1 else [self.generate_fill(Vehicle, -1, has_vel=True)]
        self.objects = [self.action] + [self.taxi] + self.passengers + self.pedestrians + self.vehicles + self.targets + [self.reward, self.done]
        self.reset_occupancy()
        self.object_name_dict = {{"Action": self.action,"Taxi": self.taxi, "Reward": self.reward, "Done": self.done}, 
                                **{p.name: p for p in self.passengers}, {p.name: p for p in self.pedestrians}, {v.name: v for v in self.vehicles}, {t.name: t for t in self.targets}}
        self.steps = 0
        return self.get_state()

    def clear_interactions(self):
        for obj in self.objects:
            obj.interaction_trace = list()

    def get_state(self, render=False):
        rdset = set(["Reward", "Done"])
        extracted_state = {**{obj.name: obj.get_state() for obj in self.objects if obj.name not in rdset}, **{"Done": [self.done.attribute], "Reward": [self.reward.attribute]}}
        if render: self.frame = self.render()
        return {"raw_state": self.frame, "factored_state": extracted_state}

    def render_occupancy(self, simple=False):
        self.frame = np.zeros((self.num_rows, self.num_columns, 3))
        for i in range(self.num_rows):
            for j in range(self.num_columns):
                obj = self.occupancy_matrix[i,j]
                # print(i,j,type(obj))
                if len(obj) > 0:
                    for obj in self.objects:
                        if type(obj) == Passenger:
                            self.frame[i,j,0] = 0.5
                        elif type(obj) == Taxi:
                            self.frame[i,j,0] = 1.0
                        if type(obj) == Vehicle:
                            self.frame[i,j,1] = 0.5
                        if type(obj) == Pedestrian:
                            self.frame[i,j,1] = 1.0
                        if type(obj) == Target:
                            self.frame[i,j,2] = 1.0

        return self.frame

    def step(self, action, render = False):
        self.reward.attribute, self.done.attribute = 0.0, False
        self.clear_interactions()
        for i in range(self.frameskip):
            self.action.step(action)
            self.taxi.prestep(self.action)
            for target in self.targets:
                target.step()
            for vehicle in self.vehicles:
                vehicle.step()
            self.taxi.step(self.vehicles, self.passengers, self.reward)
            for ped in self.pedestrians:
                ped.step(self.vehicles, self.taxi, self.reward)
            for passenger in self.passengers:
                passenger.step(self.taxi, self.targets, self.reward)
            old_push = self.pusher.update()
            for o in self.objects:
                o.update()
            self.reset_occupancy()
            self.steps += 1
        self.itr += 1
        self.done.attribute = self.check_targets()
        self.done.attribute = self.done.attribute or self.steps == self.step_limit
        trunc = self.steps == self.step_limit
        state = self.get_state(render=render)

        if self.done.attribute: self.reset()
        # print("end", self.occupancy_matrix[old_push[0]][old_push[1]])
        return state, self.reward.attribute, self.done.attribute, {"TimeLimit.truncated": trunc}

    # def set_from_factored_state(self, factored_state, seed_counter=-1, render=False):
    #     self.pusher.pos = np.array(factored_state["Pusher"]).astype(int)

    #     for obstacle in self.obstacles:
    #         obstacle.pos = np.array(factored_state[obstacle.name]).astype(int)
    #     for block in self.blocks:
    #         block.pos = np.array(factored_state[block.name]).astype(int)
    #     for target in self.targets:
    #         target.pos = np.array(factored_state[target.name][:2]).astype(int)
    #         target.attribute = factored_state[target.name][2]
    #     self.reward.attribute = factored_state["Reward"]
    #     self.done.attribute = factored_state["Done"]
    #     self.reset_occupancy()
    #     # self.render()
    #     # frame2 = cv2.resize(self.frame, (self.frame.shape[0] * 30, self.frame.shape[1] * 30), interpolation = cv2.INTER_NEAREST)
    #     # cv2.imshow('frame2',frame2)
    #     # key = cv2.waitKey(5000)

    def demonstrate(self):
        action = 0
        frame = self.render()
        # print(frame)
        # frame = cv2.resize(frame, (frame.shape[0] * 30, frame.shape[1] * 30), interpolation = cv2.INTER_NEAREST)
        # cv2.imshow('frame',frame)
        frame2 = self.render_occupancy()
        # print(frame2)
        frame2 = cv2.resize(frame2, (frame2.shape[0] * 30, frame2.shape[1] * 30), interpolation = cv2.INTER_NEAREST)
        cv2.imshow('frame2',frame2)
        key = cv2.waitKey(5000)
        if key == ord('q'):
            action = -1
        elif key == ord('a'):
            action = 2
        elif key == ord('w'):
            action = 0
        elif key == ord('s'):
            action = 1
        elif key == ord('d'):
            action = 3
        return action


