import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

#==================================================
# 1) Mask function (only used for data area)
#==================================================
def mask_function(r, c, mask_id):
    """True(=1=>black) / False(=0=>white)"""
    if mask_id == 0:
        return ((r + c) % 2 == 0)
    elif mask_id == 1:
        return (r % 2 == 0)
    elif mask_id == 2:
        return (c % 3 == 0)
    elif mask_id == 3:
        return ((r + c) % 3 == 0)
    elif mask_id == 4:
        return (((r // 2) + (c // 3)) % 2 == 0)
    elif mask_id == 5:
        return (((r * c) % 2) + ((r * c) % 3) == 0)
    elif mask_id == 6:
        return ((((r * c) % 2) + ((r * c) % 3)) % 2 == 0)
    elif mask_id == 7:
        return ((((r + c) % 2) + ((r * c) % 3)) % 2 == 0)
    else:
        raise ValueError("mask_id must be in [0..7].")


#==================================================
#==================================================
top_left_format_info = {
    (8,0),(8,1),(8,2),(8,3),(8,4),(8,5),(8,7),(8,8),
    (0,8),(1,8),(2,8),(3,8),(4,8),(5,8),(7,8)
}
top_right_format_info = {
    (8,24),(8,23),(8,22),(8,21),(8,20),(8,19),(8,18),(8,17),
    (25,8),(24,8),(23,8),(22,8),(21,8),(20,8),(19,8),(18,8),(17,8)
}

def is_format_info_module(r, c):
    """ True if format info area (2 places × 15 modules = 30 total) """
    return (r, c) in top_left_format_info or (r, c) in top_right_format_info


#==================================================
# 3) Functional modules (Finder/Timing/Alignment) 
#==================================================
def get_function_module_color(r, c, size=25):
    """
    Returns 0 (white) or 1 (black) if the module is a functional module, or None if it is a data area.
    - Finder pattern (7×7 + 1 module white separator)
    - Timing pattern (row=6, col=6)
    - Alignment pattern (center=16..20,16..20)
      * For version 2, (18,18) is a 5×5 alignment pattern
    """
    if (0 <= r <= 7 and 0 <= c <= 7):
        return _get_finder_color(r, c)
    if (0 <= r <= 7 and 17 <= c <= 24):
        if c == 17:
            return 0
        else:
            local_c = c - 18  # 0..6
            return _get_finder_color(r, local_c)
    if (17 <= r <= 24 and 0 <= c <= 7):
        if r == 17:
            return 0
        else:
            local_r = r - 18  # 0..6
            return _get_finder_color(local_r, c)

    if r == 6 or c == 6:
        return _get_timing_color(r, c)

    if (16 <= r <= 20) and (16 <= c <= 20):
        return _get_alignment_color(r - 16, c - 16)

    return None


def _get_finder_color(local_r, local_c):
    """
    Assumes 7x7 main body + 1 module white separator (total 8x8).
    Returns black/white for the main body (0..6,0..6), and 7th row/column is the separator (white).
    """
    if local_r == 7 or local_c == 7:
        return 0
    if local_r in [0,6] or local_c in [0,6]:
        return 1
    if local_r in [1,5] or local_c in [1,5]:
        return 0
    return 1


def _get_timing_color(r, c):
    """
    Timing pattern:
      row=6 => black/white alternates by column parity
      col=6 => black/white alternates by row parity
    """
    if r == 6:
        return 1 if (c % 2 == 0) else 0
    if c == 6:
        return 1 if (r % 2 == 0) else 0
    return None


def _get_alignment_color(local_r, local_c):
    """
    Alignment pattern (5x5)
      - Outer frame (0,4) => black
      - Inner (1,3) => white
      - Center (2,2) => black
    """
    if local_r in [0,4] or local_c in [0,4]:
        return 1
    if local_r in [1,3] or local_c in [1,3]:
        return 0
    return 1


#==================================================
#==================================================
def generate_qr_pattern_v2(mask_id=0):
    """
    Returns a 25x25 map:
    - Functional modules (original black→2=gray, original white→0=white)
    - Format information (3=light blue)
    - Data area (0=white or 1=black depending on mask)
    """
    size = 25
    mat = np.zeros((size, size), dtype=np.uint8)

    for r in range(size):
        for c in range(size):
            if is_format_info_module(r, c):
                mat[r, c] = 3
                continue

            func_color = get_function_module_color(r, c, size=size)
            if func_color is not None:
                mat[r, c] = 2 if (func_color == 1) else 0
                continue

            is_black = mask_function(r, c, mask_id)
            mat[r, c] = 1 if is_black else 0

    return mat


if __name__ == "__main__":
    cmap = ListedColormap(["white", "black", "lightgray", "aliceblue"])

    for mask_id in range(8):
        pattern = generate_qr_pattern_v2(mask_id=mask_id)

        plt.figure(figsize=(4,4))
        plt.imshow(pattern, cmap=cmap, vmin=0, vmax=3, interpolation='nearest')
        # plt.title(f"Version2 - Mask {mask_id}")
        plt.axis('off')

        plt.savefig(f"qr_v2_mask_{mask_id}.svg", dpi=150, bbox_inches='tight')

    # plt.show()
