import numpy as np
import pygame

from learned_planners import LP_DIR
from learned_planners.environments import BoxobanConfig
from learned_planners.interp.utils import parse_level

key_to_action = {pygame.K_w: 0, pygame.K_d: 3, pygame.K_s: 1, pygame.K_a: 2}


def main(text_level: str):
    boxo_cfg = BoxobanConfig(cache_path=LP_DIR / "training/.sokoban_cache/")
    env = boxo_cfg.make()

    reset_opts = parse_level(text_level.strip())

    env.reset(options=reset_opts)

    pygame.init()
    clock = pygame.time.Clock()

    img = env.render()
    scale_factor = 5  # Scale factor to increase the window size
    height, width = img.shape[0] * scale_factor, img.shape[1] * scale_factor  # type: ignore
    screen = pygame.display.set_mode((width, height))
    pygame.display.set_caption("Sokoban")

    done = False
    while not done:
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                done = True
            elif event.type == pygame.KEYDOWN:
                if event.key == pygame.K_q:
                    print("Quitting the game.")
                    done = True
                elif event.key == pygame.K_r:
                    print("Resetting the game.")
                    env.reset()
                elif event.key == pygame.K_x:
                    print("Reloading the game.")
                    env.reset(options=reset_opts)
                elif event.key in key_to_action:
                    action = key_to_action[event.key]
                    _, _, done, _, _ = env.step(action)
                else:
                    print("Invalid key. Use w, a, s, d for movement or q to quit.")

        # Render the environment image
        img = env.render()  # expects a numpy array
        # Convert and scale the image
        surface = pygame.surfarray.make_surface(np.transpose(img, (1, 0, 2)))  # type: ignore
        surface = pygame.transform.scale(surface, (width, height))
        screen.blit(surface, (0, 0))
        pygame.display.flip()
        clock.tick(10)  # Adjust the framerate as needed

    env.close()
    pygame.quit()


if __name__ == "__main__":
    sample_level = """
##########
##########
##########
##########
#### ##  #
# .$  $. #
#  #   $ #
# $###  ##
#@..     #
##########
"""
    main(sample_level)
