"""CHIP-8 emulator constants and static data."""

import jax.numpy as jnp

PROGRAM_START = 0x200
FONT_START = 0x50
SCREEN_WIDTH = 64
SCREEN_HEIGHT = 32
STACK_SIZE = 16
ADDRESS_MASK = 0xFFF

FONT_DATA = jnp.array([
    0xF0, 0x90, 0x90, 0x90, 0xF0,  # 0
    0x20, 0x60, 0x20, 0x20, 0x70,  # 1
    0xF0, 0x10, 0xF0, 0x80, 0xF0,  # 2
    0xF0, 0x10, 0xF0, 0x10, 0xF0,  # 3
    0x90, 0x90, 0xF0, 0x10, 0x10,  # 4
    0xF0, 0x80, 0xF0, 0x10, 0xF0,  # 5
    0xF0, 0x80, 0xF0, 0x90, 0xF0,  # 6
    0xF0, 0x10, 0x20, 0x40, 0x40,  # 7
    0xF0, 0x90, 0xF0, 0x90, 0xF0,  # 8
    0xF0, 0x90, 0xF0, 0x10, 0xF0,  # 9
    0xF0, 0x90, 0xF0, 0x90, 0x90,  # A
    0xE0, 0x90, 0xE0, 0x90, 0xE0,  # B
    0xF0, 0x80, 0x80, 0x80, 0xF0,  # C
    0xE0, 0x90, 0x90, 0x90, 0xE0,  # D
    0xF0, 0x80, 0xF0, 0x80, 0xF0,  # E
    0xF0, 0x80, 0xF0, 0x80, 0x80   # F
], dtype=jnp.uint8)

class InstructionType:
    """CHIP-8 instruction types based on first nibble."""
    SYS = 0x0
    JP = 0x1
    CALL = 0x2
    SE_BYTE = 0x3
    SNE_BYTE = 0x4
    SE_REG = 0x5
    LD_BYTE = 0x6
    ADD_BYTE = 0x7
    ALU = 0x8
    SNE_REG = 0x9
    LD_I = 0xA
    JP_V0 = 0xB
    RND = 0xC
    DRW = 0xD
    KEY = 0xE
    MISC = 0xF