# Screen
import sys, cv2
import numpy as np
import imageio as imio
import os, copy
from Environment.environment import Environment, Reward, Done
from Environment.Environments.Sokoban.sokoban_objects import *
from Environment.Environments.Sokoban.sokoban_specs import *
from Record.file_management import numpy_factored
from gym import spaces

class Sokoban(Environment):
    def __init__(self, frameskip = 1, variant="", fixed_limits=False):
        super(Sokoban, self).__init__()
        # breakout specialized parameters are stored in the variant
        self.variant = variant
        self.self_reset = True
        self.fixed_limits = fixed_limits

        # environment properties
        self.num_actions = 4 # this must be defined, -1 for continuous. Only needed for primitive actions
        self.name = "Sokoban" # required for an environment 
        self.discrete_actions = True
        self.frameskip = 1 # no frameskip

        self.num_rows, self.num_columns, self.num_blocks, self.num_obstacles, self.num_targets, self.step_limit, self.preset = sokoban_variants[self.variant]

        # spaces
        self.action_shape = (1,)
        self.action_space = spaces.Discrete(self.num_actions) # gym.spaces
        self.observation_space = spaces.Box(low=0, high=255, shape=(RNG, RNG), dtype=np.uint8) # raw space, gym.spaces
        self.seed_counter = -1

        # state components
        self.frame = None # the image generated by the environment
        self.reward = Reward()
        self.done = Done()
        self.action = np.zeros(self.action_shape)
        self.extracted_state = None
        self.steps = 0

        # running values
        self.itr = 0
        self.total_score = 0

        # factorized state properties
        ranges_fixed, dynamics_fixed, position_masks, instanced = generate_specs_fixed(self.num_obstacles, self.num_blocks, self.num_targets)
        ranges, dynamics, position_masks, instanced = generate_specs(self.num_rows, self.num_columns, self.num_obstacles, self.num_blocks, self.num_targets)
        self.object_names = ["Action", "Pusher", "Obstacle", "Block", "Target", "Done", "Reward"]
        self.object_sizes = {"Action": 1, "Pusher": 2, "Obstacle": 2, "Block": 2, "Target": 3, "Reward": 1, "Done": 1}
        self.object_name_dict = dict() # initialized in reset
        self.object_range = ranges if not self.fixed_limits else ranges_fixed # the minimum and maximum values for a given feature of an object
        self.object_dynamics = dynamics if not self.fixed_limits else dynamics_fixed
        self.object_range_true = ranges
        self.object_dynamics_true = dynamics
        self.object_instanced = instanced
        self.position_masks = position_masks
        self.all_names = sum([([name + str(i) for i in range(instanced[name])] if instanced[name] > 1 else [name]) for name in self.object_names], start = [])
        self.instance_length = len(self.all_names)

        # proximity components
        self.length, self.width = self.num_rows, self.num_columns

        # reset counters
        self.reset()

    def generate_fill(self, class_type, idx, offset=0, max_adjacent=4):
        obj = None
        for i in range(1000):
            pos = np.array([np.random.randint(0+offset, self.num_rows - offset), np.random.randint(0+offset, self.num_columns -offset)])
            if self.occupancy_matrix[pos[0]][pos[1]] is None:
                if max_adjacent < 4: # check adjacent cells for objects
                    total_adjacent = (int(pos[1] + 1 == self.num_columns or self.occupancy_matrix[pos[0]][pos[1]+1] is not None) + 
                        int(pos[1] - 1 == -1 or self.occupancy_matrix[pos[0]][pos[1]-1] is not None) +
                        int(pos[0] + 1 == self.num_rows or self.occupancy_matrix[pos[0] + 1][pos[1]] is not None) +
                        int(pos[0] - 1 == -1 or self.occupancy_matrix[pos[0]-1][pos[1]] is not None))
                    if total_adjacent > max_adjacent:
                        continue
                obj = class_type(pos, idx, self.bound)
                self.occupancy_matrix[pos[0]][pos[1]] = obj
                # print(class_type, pos)
                break
        return obj

    def reset_occupancy(self): # resets the occupancy matrix assuming that the obstacles, blocks, and pusher and non-overlapping
        self.occupancy_matrix = [[None for i in range(self.num_columns)] for j in range(self.num_rows)]
        self.occupancy_matrix[self.pusher.pos[0]][self.pusher.pos[1]] = self.pusher
        for obs in self.obstacles:
            self.occupancy_matrix[obs.pos[0]][obs.pos[1]] = obs
        for blk in self.blocks:
            self.occupancy_matrix[blk.pos[0]][blk.pos[1]] = blk
        for tar in self.targets:
            if self.occupancy_matrix[tar.pos[0]][tar.pos[1]] != None:
                self.occupancy_matrix[tar.pos[0]][tar.pos[1]] = (self.occupancy_matrix[tar.pos[0]][tar.pos[1]], tar)

    def reset(self):
        self.reward, self.done = Reward(), Done()
        if len(self.preset) > 0: # we are selecting resets from a preset group
            self.load_sokoban(self.preset)
        else: # load a random stage based on the inputs
            self.action = Action()
            self.bound = Bound((self.num_rows, self.num_columns))
            self.occupancy_matrix = [[None for i in range(RNG)] for j in range(RNG)]
            self.pusher = Pusher(np.array([np.random.randint(self.num_rows), np.random.randint(self.num_columns)]), self.bound)
            self.occupancy_matrix[self.pusher.pos[0]][self.pusher.pos[1]] = self.pusher
            self.obstacles = [self.generate_fill(Obstacle, i) for i in range(self.num_obstacles)] if self.num_obstacles > 1 else [self.generate_fill(Obstacle, -1)]
            if len(self.obstacles) == 1: self.obstacles[0].name = "Obstacle"
            self.blocks = [self.generate_fill(Block, i, offset = 1, max_adjacent=1) for i in range(self.num_blocks)] if self.num_blocks > 1 else [self.generate_fill(Block, -1, offset = 1, max_adjacent=1)]
            if len(self.blocks) == 1: self.blocks[0].name = "Block"
            self.targets = [self.generate_fill(Target, i, max_adjacent=3) for i in range(self.num_targets)] if self.num_targets > 1 else [self.generate_fill(Target, -1, max_adjacent=3)]
            if len(self.targets) == 1: self.targets[0].name = "Target"
            self.objects = [self.action] + [self.pusher] + self.obstacles + self.blocks + self.targets + [self.reward, self.done]
            if np.any([o is None for o in self.objects]): self.reset() # we could not generate a functional occupancy
        self.object_name_dict = {**{"Action": self.action, "Bound": self.bound, "Pusher": self.pusher, "Reward": self.reward, "Done": self.done}, 
                                **{o.name: o for o in self.obstacles}, **{b.name: b for b in self.blocks}, **{t.name: t for t in self.targets}}
        self.steps = 0
        return self.get_state()

    def clear_interactions(self):
        for obj in self.objects:
            obj.interaction_trace = list()

    def get_state(self, render=False):
        rdset = set(["Reward", "Done"])
        extracted_state = {**{obj.name: obj.get_state() for obj in self.objects if obj.name not in rdset}, **{"Done": [self.done.attribute], "Reward": [self.reward.attribute]}}
        if render: self.frame = self.render()
        return {"raw_state": self.frame, "factored_state": extracted_state}

    def render(self, simple=False):
        self.frame = np.zeros((self.num_rows, self.num_columns))
        self.frame[tuple(self.pusher.pos.astype(int))] = .5
        for target in self.targets:
            self.frame[tuple(target.pos.astype(int))] = .4            
        if type(self.occupancy_matrix[int(self.pusher.pos[0])][int(self.pusher.pos[1])]) == tuple:
            self.frame[tuple(self.pusher.pos.astype(int))] = 0.6
        for obstacle in self.obstacles:
            self.frame[tuple(obstacle.pos.astype(int))] = .2
        for block in self.blocks:
            self.frame[tuple(block.pos.astype(int))] = .8
            if type(self.occupancy_matrix[int(block.pos[0])][int(block.pos[1])]) == tuple:
                self.frame[tuple(block.pos.astype(int))] = 1.0
        return self.frame

    def render_occupancy(self, simple=False):
        self.frame = np.zeros((self.num_rows, self.num_columns))
        for i in range(self.num_rows):
            for j in range(self.num_columns):
                obj = self.occupancy_matrix[i][j]
                # print(i,j,type(obj))
                if type(obj) == tuple and self.pusher.pos[0] == i and self.pusher.pos[1] == j:
                    self.frame[i][j] = 0.6
                if type(obj) == tuple and type(obj[0]) == Block:
                    self.frame[i][j] = 1.0
                if type(obj) == Pusher:
                    self.frame[i][j] = 0.5
                if type(obj) == Block:
                    self.frame[i][j] = 0.8
                if type(obj) == Target:
                    self.frame[i][j] = 0.4
                if type(obj) == Obstacle:
                    self.frame[i][j] = 0.2
        return self.frame


    def update_occupancy_matrix(self, old_pusher, old_block, new_block):
        old_pusher = old_pusher.astype(int)
        if new_block is not None: # move the block if necessary
            old_block =  old_block.astype(int)
            # print("old_block", old_block, new_block.pos, self.occupancy_matrix[old_block[0]][old_block[1]], self.occupancy_matrix[int(new_block.pos[0])][int(new_block.pos[1])])
            if type(self.occupancy_matrix[old_block[0]][old_block[1]]) == tuple:
                self.occupancy_matrix[old_block[0]][old_block[1]] = self.occupancy_matrix[old_block[0]][old_block[1]][1] # keep the target
            else:
                self.occupancy_matrix[old_block[0]][old_block[1]] = None
            obj_at_target = self.occupancy_matrix[int(new_block.pos[0])][int(new_block.pos[1])]
            if obj_at_target is not None: # object at location should be a target
                self.occupancy_matrix[int(new_block.pos[0])][int(new_block.pos[1])] = (new_block, obj_at_target)
            else:
                self.occupancy_matrix[int(new_block.pos[0])][int(new_block.pos[1])] = new_block
        # move the pusher if necessary
        # print("old push", old_pusher, self.occupancy_matrix[old_pusher[0]][old_pusher[1]], type(self.occupancy_matrix[old_pusher[0]][old_pusher[1]]), type(self.occupancy_matrix[old_pusher[0]][old_pusher[1]]) == tuple)
        if type(self.occupancy_matrix[old_pusher[0]][old_pusher[1]]) == tuple:
            self.occupancy_matrix[old_pusher[0]][old_pusher[1]] = self.occupancy_matrix[old_pusher[0]][old_pusher[1]][1] # keep the target
        else:
            self.occupancy_matrix[old_pusher[0]][old_pusher[1]] = None
        obj_at = self.occupancy_matrix[int(self.pusher.pos[0])][int(self.pusher.pos[1])]
        if type(obj_at) == Target:
            self.occupancy_matrix[int(self.pusher.pos[0])][int(self.pusher.pos[1])] = (self.pusher, obj_at)
        else: 
            self.occupancy_matrix[int(self.pusher.pos[0])][int(self.pusher.pos[1])] = self.pusher
        # print("update new pusher", self.pusher.pos, obj_at, self.occupancy_matrix[int(self.pusher.pos[0])][int(self.pusher.pos[1])])

    def check_targets(self):
        total = 0
        for target in self.targets:
            if target.attribute == 1:
                total += 1
        return total == len(self.targets)


    def step(self, action, render = False):
        self.reward.attribute, self.done.attribute = 0.0, False
        self.clear_interactions()
        for i in range(self.frameskip):
            self.action.step(action)
            self.pusher.step(self.action, self.occupancy_matrix)
            new_block, old_block = None, None
            for block in self.blocks:
                block.step(self.pusher, self.occupancy_matrix)
                old, moved = block.update()
                if moved is not None:
                    old_block = old
                    new_block = moved
            for target in self.targets:
                target.step(self.occupancy_matrix)
                target.update()
            old_push = self.pusher.update()
            self.update_occupancy_matrix(old_push, old_block, new_block)
            self.steps += 1
        self.itr += 1
        self.done.attribute = self.check_targets()
        self.reward.attribute = int(self.done.attribute) # get one reward if done
        self.done.attribute = self.done.attribute or self.steps == self.step_limit
        trunc = self.steps == self.step_limit
        state = self.get_state(render=render)

        if self.done.attribute: self.reset()
        # print("end", self.occupancy_matrix[old_push[0]][old_push[1]])
        return state, self.reward.attribute, self.done.attribute, {"TimeLimit.truncated": trunc}

    def set_from_factored_state(self, factored_state, seed_counter=-1, render=False):
        self.pusher.pos = np.array(factored_state["Pusher"]).astype(int)

        for obstacle in self.obstacles:
            obstacle.pos = np.array(factored_state[obstacle.name]).astype(int)
        for block in self.blocks:
            block.pos = np.array(factored_state[block.name]).astype(int)
        for target in self.targets:
            target.pos = np.array(factored_state[target.name][:2]).astype(int)
            target.attribute = factored_state[target.name][2]
        self.reward.attribute = factored_state["Reward"].squeeze()
        self.done.attribute = factored_state["Done"].squeeze()
        self.reset_occupancy()
        # self.render()
        # frame2 = cv2.resize(self.frame, (self.frame.shape[0] * 30, self.frame.shape[1] * 30), interpolation = cv2.INTER_NEAREST)
        # cv2.imshow('frame2',frame2)
        # key = cv2.waitKey(5000)

    def demonstrate(self):
        action = 0
        frame = self.render()
        # print(frame)
        # frame = cv2.resize(frame, (frame.shape[0] * 30, frame.shape[1] * 30), interpolation = cv2.INTER_NEAREST)
        # cv2.imshow('frame',frame)
        frame2 = self.render_occupancy()
        # print(frame2)
        frame2 = cv2.resize(frame2, (frame2.shape[0] * 30, frame2.shape[1] * 30), interpolation = cv2.INTER_NEAREST)
        cv2.imshow('frame2',frame2)
        key = cv2.waitKey(5000)
        if key == ord('q'):
            action = -1
        elif key == ord('a'):
            action = 2
        elif key == ord('w'):
            action = 0
        elif key == ord('s'):
            action = 1
        elif key == ord('d'):
            action = 3
        return action


