import matplotlib.pyplot as plt
import os.path


def plot_board():
    # Plot the board
    fig, ax = plt.subplots()
    # ax.set_facecolor("#FFD700")
    ax.set_facecolor("#FAE794")
    for i in range(3, 16, 6):
        for j in range(3, 16, 6):
            ax.add_artist(plt.Circle((i, j), radius=0.15, color='black'))

    ax.set_xticks(range(19))
    ax.set_yticks(range(19))

    # # Hide the ticks but keep the labels
    # ax.tick_params(axis='both', which='both', length=0)

    ax.set_xticklabels([chr(i + 97) if i < 8 else chr(i + 98) for i in range(19)])  # Skip i column
    ax.set_yticklabels([str(i + 1) for i in range(0, 19, 1)])  # Reverse order to match top-down numbering
    ax.grid(color='grey', linestyle='-', linewidth=0.5, clip_on=False, zorder=0)
    plt.xlim(-1, 19)
    plt.ylim(-1, 19)
    plt.savefig(os.path.join(save_dir, f"empty_board.png"))
    plt.close()


def visualize(move_str, foreground, filename):
    board = [[0 for x in range(19)] for y in range(19)]
    moves = move_str.split()
    moves.pop(0)
    while moves:
        color = moves.pop(0)
        col_row = moves.pop(0)
        col = ord(col_row[0]) - 97 if col_row[0] < 'i' else ord(
            col_row[0]) - 98  # Convert column letter to index, skip i column
        row = int(col_row[1:]) - 1  # Convert row number to index
        if color == 'b':
            board[row][col] = 1  # Add a black stone at the given position
        elif color == 'w':
            board[row][col] = -1  # Add a white stone at the given position

    board_label = [[-1 for _ in range(19)] for _ in range(19)]
    for i, players in enumerate(foreground):
        for player in players:
            color = player[0]
            col_row = player[2:]
            col = ord(col_row[0]) - 97 if col_row[0] < 'i' else ord(
                col_row[0]) - 98  # Convert column letter to index, skip i column
            row = int(col_row[1:]) - 1  # Convert row number to index
            board_label[row][col] = i  # Add a black stone at the given position

    # Plot the board
    alpha = 0.6
    fig, ax = plt.subplots(figsize=(8, 8))
    # ax.set_facecolor("#FFD700")
    # ax.set_facecolor("#f7d560")
    # ax.set_facecolor("#FFF2CC")
    for i in range(3, 16, 6):
        for j in range(3, 16, 6):
            ax.add_artist(plt.Circle((i, j), radius=0.15, color='black'))
    for i in range(19):
        for j in range(19):
            if board[i][j] == 1:
                if board_label[i][j] != -1:
                    ax.add_artist(plt.Circle((j, i), radius=0.4, color='#203864', linewidth=2, clip_on=False))
                    ax.text(j, i, board_label[i][j] + 1, ha="center", va="center", color="white", fontsize=15)
                else:
                    ax.add_artist(plt.Circle((j, i), radius=0.4, color='#203864', linewidth=2, alpha=0.4, clip_on=False))
            elif board[i][j] == -1:
                if board_label[i][j] != -1:
                    ax.add_artist(plt.Circle((j, i), radius=0.4, color="#F4B183", linewidth=2, zorder=2, clip_on=False))
                    ax.text(j, i, board_label[i][j] + 1, ha="center", va="center", color="black", fontsize=15)
                else:
                    ax.add_artist(plt.Circle((j, i), radius=0.4, color='#F4B183', linewidth=2, zorder=2, alpha=0.4, clip_on=False))

    # Add grid lines and labels
    ax.set_xticks(range(19))
    # ax.tick_params(labelsize=15, pad=10)
    ax.set_yticks(range(19))
    # ax.tick_params(labelsize=15, pad=10)
    # ax.set_xticklabels([chr(i + 97) if i < 8 else chr(i + 98) for i in range(19)])  # Skip i column
    # ax.set_yticklabels([str(i + 1) for i in range(0, 19, 1)])  # Reverse order to match top-down numbering
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.grid(color='grey', linestyle='-', linewidth=0.5, clip_on=False, zorder=0)


    for i in ['top', 'bottom', 'left', 'right']:
        ax.spines[i].set_color('grey')  # 顶部边框颜色
        ax.spines[i].set_linewidth(0.5)
    plt.xlim(0, 19 - 1)
    plt.ylim(0, 19 - 1)
    plt.savefig(os.path.join(save_dir, f"{filename[:-4]}.svg"), bbox_inches='tight', transparent=True)
    plt.close()


if __name__ == '__main__':
    save_dir = "board_figure/board_figure_label"
    os.makedirs(save_dir, exist_ok=True)

    load_dir = "sgf_label"
    # plot_board()

    count = 0
    for idx, filename in enumerate(os.listdir(load_dir)):
        print(filename)
        count += 1

        position = open(os.path.join(load_dir, filename), "r").readlines()[0][:-1]
        all_positions = position.split(";")
        print("all_positions: ", all_positions, len(all_positions))
        set_position_N = "set_position " + " ".join(all_positions)
        print(set_position_N)

        # 根据标注的文件找出所有的player
        players = open(os.path.join(load_dir, filename), "r").readlines()[1][:-1].split(";")
        n_attributes = len(players)
        print("all_players: ", players, n_attributes)

        # 前景信息
        foreground = []
        for player in players:
            foreground.extend(player[1:-1].split(","))
        print("foreground: ", foreground)

        # 检查标注的前景信息是否是players中的一个
        if not set(foreground).issubset(set(all_positions)):
            print("illegal file: ", filename)
            print(set(foreground) - set(all_positions))
            break

        # 前景信息
        foreground = []
        for player in players:
            foreground.append(player[1:-1].split(","))
        print("foreground: ", foreground)

        visualize(set_position_N, foreground, filename)

    print(count)

