from pathlib import Path
from typing import Tuple

import numpy as np
import pygame
from PIL import Image


class Game:
    def __init__(self, play_env, keymap, action_names, size: Tuple[int, int], fps: int, verbose: bool, record_mode: bool) -> None:
        self.env = play_env
        self.keymap = keymap
        self.height, self.width = size
        self.fps = fps
        self.verbose = verbose
        self.record_mode = record_mode

        self.record_dir = Path('media') / 'recordings'

        print('\nActions (env):\n')
        for key, idx in keymap.items():
            print(f'{pygame.key.name(key)}: {action_names[idx]}')

        print('\nActions (game control):\n')
        print('⏎ : reset env')
        print('. : pause/unpause')
        print('e : step-by-step (when paused)')
        print('m : next mode')
        print('↑ : next axis 1')
        print('↓ : prev axis 1')
        print('→ : next axis 2')
        print('← : prev axis 2')

    def run(self) -> None:
        pygame.init()

        header_height = 150 if self.verbose else 0
        font_size = 18
        screen = pygame.display.set_mode((self.width, self.height + header_height))
        clock = pygame.time.Clock()
        font = pygame.font.SysFont('mono', font_size)
        header_rect = pygame.Rect(0, 0, self.width, header_height)

        def clear_header():
            pygame.draw.rect(screen, pygame.Color('black'), header_rect)
            pygame.draw.rect(screen, pygame.Color('white'), header_rect, 1)

        def draw_text(text, idx_line, idx_column=0):
            pos = (5 + idx_column * int(self.width // 4), 5 + idx_line * font_size)
            assert (0 <= pos[0] <= self.width) and (0 <= pos[1] <= header_height)
            screen.blit(font.render(text, True, pygame.Color('white')), pos)

        def draw_game(obs):
            assert obs.ndim == 4 and obs.size(0) == 1
            img = Image.fromarray(obs[0].add(1).div(2).mul(255).byte().permute(1, 2, 0).cpu().numpy())
            pygame_image = np.array(img.resize((self.width, self.height), resample=Image.NEAREST)).transpose((1, 0, 2))
            surface = pygame.surfarray.make_surface(pygame_image)
            screen.blit(surface, (0, header_height))

        obs, info = self.env.reset()

        do_reset, do_wait = False, False
        should_stop = False
        
        while not should_stop:

            do_one_step = False
            action = 0  # noop
            pygame.event.pump()

            for event in pygame.event.get():
            
                if event.type == pygame.QUIT:
                    should_stop = True
                
                if event.type != pygame.KEYDOWN:
                    continue
                
                if event.key == pygame.K_RETURN:
                    do_reset = True
                
                if event.key == pygame.K_PERIOD:
                    do_wait = not do_wait
                    print('Game paused.' if do_wait else 'Game resumed.')
                
                if event.key == pygame.K_e:
                    do_one_step = True
                
                if event.key == pygame.K_m:
                    do_reset = self.env.next_mode()

                if event.key == pygame.K_UP:
                    do_reset = self.env.next_axis_1()
                
                if event.key == pygame.K_DOWN:
                    do_reset = self.env.prev_axis_1()

                if event.key == pygame.K_RIGHT:
                    do_reset = self.env.next_axis_2()

                if event.key == pygame.K_LEFT:
                    do_reset = self.env.prev_axis_2()

                if event.key in self.keymap.keys():
                    action = self.keymap[event.key]
                
            if action == 0:
                pressed = pygame.key.get_pressed()
                for key, action in self.keymap.items():
                    if pressed[key]:
                        break
                else:
                    action = 0

            if do_reset:
                obs, info = self.env.reset()
                do_reset = False

            if do_wait and not do_one_step:
                continue

            next_obs, _, end, trunc, info = self.env.step(action)
            
            draw_game(obs)
            
            if self.verbose and info is not None:
                clear_header()
                assert isinstance(info, dict) and 'header' in info
                header = info['header']
                for j, col in enumerate(header):
                    for i, row in enumerate(col):
                        draw_text(row, idx_line=i, idx_column=j)
            
            pygame.display.flip()   # update screen
            clock.tick(self.fps)    # ensures game maintains the given frame rate

            if end or trunc:
                print('Dead.' if end else 'Time limit reached.')
                obs, info = self.env.reset()
                do_reset = False

            else:
                obs = next_obs            

        pygame.quit()
