import gym
import numpy as np
from gym.spaces import Box
from IPython import embed
import os
import matplotlib.pyplot as plt

class ProcGenMazeEnv(gym.Env):
    def __init__(self, level_generator, config, max_steps, pixel_obs = False, **kwargs):
        super().__init__()

        try:
            from griddly import GymWrapperFactory, gd
            from gym.envs.registration import register as gym_register
            wrapper = GymWrapperFactory()
            wrapper.build_gym_from_yaml('_MazeEnv', f"{os.getcwd()}/envs/maze.yaml")
            
            gym_register(
                id='GDY-MazeEnv-v0',
                entry_point='envs.maze_proc_gen:ProcGenMazeEnv'
            )
        except:
            pass
        
        if pixel_obs == True:
            player_obs_type = gd.ObserverType.SPRITE_2D
        else:
            player_obs_type = gd.ObserverType.VECTOR
        
        env = gym.make("GDY-_MazeEnv-v0",
                       level=0,
                       player_observer_type=player_obs_type,
                       global_observer_type=gd.ObserverType.SPRITE_2D,
                       max_steps=max_steps,
                       **kwargs)
        
        self.pixel_obs = pixel_obs
        self.level_generator = level_generator
        self.config = config
        self.ep_reward = 0
        self.env = env

        self.observation_space = Box(0, 3, shape=(1, config["width"], config["height"]))
        self.action_space = env.action_space

    def step(self, action):
        obs, r, d, info = self.env.step(action)
        return obs, r, d, info
    
    def reset(self, random_gen=True, level_string=None):
        if random_gen == True:
            obs = self.env.reset(level_string=self.level_generator.generate())
        elif level_string is not None:
            obs = self.env.reset(level_string=level_string)
        else:
            obs = self.env.reset()
        
        self.ep_reward = 0
        return obs

    def render(self, mode="rgb_array"):
        return self.env.render(mode, observer="global")
    
class ProcGenPOMazeEnv(gym.Wrapper):
    def __init__(self, level_generator, config, max_steps, pixel_obs = False, **kwargs):
        try:
            from griddly import GymWrapperFactory, gd
            from gym.envs.registration import register as gym_register
            wrapper = GymWrapperFactory()
            wrapper.build_gym_from_yaml('_POMazeEnv', f"{os.getcwd()}/envs/maze_po.yaml")
            gym_register(
                id='GDY-POMazeEnv-v0',
                entry_point='envs.maze_proc_gen:ProcGenPOMazeEnv'
            )
        except:
            pass
        
        if pixel_obs == True:
            player_obs_type = gd.ObserverType.SPRITE_2D
        else:
            player_obs_type = gd.ObserverType.VECTOR
        
        env = gym.make("GDY-_POMazeEnv-v0",
                       level=0,
                       player_observer_type=player_obs_type,
                       global_observer_type=gd.ObserverType.SPRITE_2D,
                       max_steps=max_steps,
                       **kwargs)
        
        super(ProcGenPOMazeEnv, self).__init__(env)

        self.level_generator = level_generator
        self.config = config
        self.observation_space = Box(0, 3, shape=(1, 5, 5))
        self.ep_reward = 0

    def step(self, action):
        obs, r, d, info = super().step(action)
        return obs, r, d, info
    
    def reset(self, random_gen=True, level_string=None):
        if random_gen == True:
            obs = super().reset(level_string=self.level_generator.generate())
        elif level_string is not None:
            obs = super().reset(level_string=level_string)
        else:
            obs = super().reset()
        
        self.ep_reward = 0
        return obs

    def render(self, mode="rgb_array"):
        return super().render(mode, observer="global")
    
if __name__ == "__main__":
    from IPython import embed
    from countbased.envs.maze_level_generator import LabyrinthLevelGenerator
    
    config = {
        'width': 16,
        'height': 16,
        'wall_density': 0.8,
        'num_goals': 0
    }
    
    level_generator = LabyrinthLevelGenerator(config)
    env = ProcGenPOMazeEnv(level_generator, config)
    obs = env.reset(level_string=level_generator.generate())