import matplotlib.pyplot as plt
import os.path
import numpy as np
import csv


def plot_board():
    # Plot the board
    fig, ax = plt.subplots()
    ax.set_facecolor("#FFD700")
    for i in range(3, 16, 6):
        for j in range(3, 16, 6):
            ax.add_artist(plt.Circle((i, j), radius=0.15, color='black'))

    ax.set_xticks(range(19))
    ax.set_yticks(range(19))

    # # Hide the ticks but keep the labels
    # ax.tick_params(axis='both', which='both', length=0)

    ax.set_xticklabels([chr(i + 97) if i < 8 else chr(i + 98) for i in range(19)])  # Skip i column
    ax.set_yticklabels([str(i + 1) for i in range(0, 19, 1)])  # Reverse order to match top-down numbering
    ax.grid(color='black', linestyle='-', linewidth=1, clip_on=False, zorder=0)
    plt.xlim(-1, 19)
    plt.ylim(-1, 19)
    plt.savefig(os.path.join(save_dir, f"empty_board.png"))
    plt.close()


def visualize(move_str, foreground, sample_id, masks_S, phis_S):
    save_folder = os.path.join(save_dir, sample_id)
    os.makedirs(save_folder,exist_ok=True)

    board = [[0 for x in range(19)] for y in range(19)]
    moves = move_str.split()
    moves.pop(0)
    while moves:
        color = moves.pop(0)
        col_row = moves.pop(0)
        col = ord(col_row[0]) - 97 if col_row[0] < 'i' else ord(
            col_row[0]) - 98  # Convert column letter to index, skip i column
        row = int(col_row[1:]) - 1  # Convert row number to index
        if color == 'b':
            board[row][col] = 1  # Add a black stone at the given position
        elif color == 'w':
            board[row][col] = -1  # Add a white stone at the given position

    board_label = [[-1 for _ in range(19)] for _ in range(19)]
    for i, players in enumerate(foreground):
        for player in players:
            color = player[0]
            col_row = player[2:]
            col = ord(col_row[0]) - 97 if col_row[0] < 'i' else ord(
                col_row[0]) - 98  # Convert column letter to index, skip i column
            row = int(col_row[1:]) - 1  # Convert row number to index
            board_label[row][col] = i   # Add a black stone at the given position

    for mask_idx, mask in enumerate(masks_S):
        mask_indexs = np.where(mask)[0] + 1
        mask_index_str = ""
        length = len(mask_indexs)
        for i, mask_index in enumerate(mask_indexs):
            if i == length - 1:
                mask_index_str += str(mask_index)
            else:
                mask_index_str += (str(mask_index) + ", ")

        coordinates = []
        coordinate_x_min, coordinate_x_max = 19, -1
        coordinate_y_min, coordinate_y_max = 19, -1
        for i in range(19):
            for j in range(19):
                # 如果该棋子是选中的10个棋子中的一个，记录其坐标
                if board_label[i][j] != -1:
                    coordinates.append((i, j))
                    coordinate_x_min = min(coordinate_x_min, j)
                    coordinate_x_max = max(coordinate_x_max, j)
                    coordinate_y_min = min(coordinate_y_min, i)
                    coordinate_y_max = max(coordinate_y_max, i)

        if coordinate_x_min - 0 < 19 - coordinate_x_max:
            x_limits = np.arange(0, 11)
        else:
            x_limits = np.arange(8, 19)
        if coordinate_y_min - 0 < 19 - coordinate_y_max:
            y_limits = np.arange(0, 11)
        else:
            y_limits = np.arange(8, 19)

        # x_limits = np.arange(max(coordinate_x_min - 2, 0), min(coordinate_x_max + 2, 19))
        # y_limits = np.arange(max(coordinate_y_min - 2, 0), min(coordinate_y_max + 2, 19))

        # Plot the board
        fig, ax = plt.subplots(figsize=(8, 8))
        # ax.add_patch(plt.Rectangle((-1, -1), 20, 20, color=colors[mask_idx], alpha=0.7, zorder=0))
        # ax.add_patch(plt.Rectangle((x_limits[0] - 0.5, y_limits[0] - 0.5), 12, 12,
        #                            color=colors[mask_idx], alpha=0.7, zorder=0, clip_on=False))
        # ax.add_patch(plt.Rectangle((x_limits[0] - 0.5, y_limits[0] - 0.5), 12, 12, alpha=0.7, zorder=0, clip_on=False))
        for i in range(3, 16, 6):
            for j in range(3, 16, 6):
                if i not in x_limits or j not in y_limits:
                    continue
                ax.add_artist(plt.Circle((i, j), radius=0.15, color='black'))

        # Add grid lines and labels
        ax.set_xticks(x_limits)
        # ax.tick_params(labelsize=FONT + 4, pad=20)
        ax.set_yticks(y_limits)
        # ax.tick_params(labelsize=FONT + 4, pad=20)
        # ax.set_xticklabels([chr(i + 97) if i < 8 else chr(i + 98) for i in x_limits])  # Skip i column
        # ax.set_yticklabels([str(i + 1) for i in y_limits])  # Reverse order to match top-down numbering
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.grid(color='grey', linestyle='-', linewidth=0.5, clip_on=False, zorder=0)
        for i in ['top', 'bottom', 'left', 'right']:
            ax.spines[i].set_color('grey')  # 顶部边框颜色
            ax.spines[i].set_linewidth(0.5)
        plt.xlim(x_limits[0], x_limits[-1])
        plt.ylim(y_limits[0], y_limits[-1])

        # Plot the stones
        for i in range(19):
            for j in range(19):
                if i not in y_limits or j not in x_limits:
                    continue
                # 如果该棋子是黑色棋子
                if board[i][j] == 1:
                    # 如果该棋子是选中的10个棋子中的一个，并且该棋子是选中的common coalition中的棋子，标出数字，圈成黄色
                    if board_label[i][j] != -1 and mask[board_label[i][j]] == True:
                        circle = plt.Circle((j, i), radius=radius, color='#203864', linewidth=2, clip_on=False)
                        circle.set_edgecolor((1, 0, 0, 0.6))  # Set the circle edge color to large red (R=1, G=0, B=0)
                        circle.set_linewidth(10)  # Set the line width of the circle to 3
                        ax.add_artist(circle)
                        ax.text(j, i, board_label[i][j] + 1, ha="center", va="center", color="white", fontsize=FONT)
                    # 如果该棋子是选中的10个棋子中的一个，不是common coalition中的棋子，标出数字
                    if board_label[i][j] != -1 and mask[board_label[i][j]] == False:
                        ax.add_artist(plt.Circle((j, i), radius=radius, color='#203864', linewidth=2, alpha=0.3, clip_on=False))
                        ax.text(j, i, board_label[i][j] + 1, ha="center", va="center", color="white", fontsize=FONT)
                    # 如果该棋子是背景棋子，不是选中的10个棋子中的一个
                    if board_label[i][j] == -1:
                        ax.add_artist(plt.Circle((j, i), radius=radius, color='#203864', linewidth=2, alpha=0.3, clip_on=False))
                # 如果该棋子是白色棋子
                elif board[i][j] == -1:
                    # 如果该棋子是选中的10个棋子中的一个，并且该棋子是选中的common coalition中的棋子，标出数字，圈成黄色
                    if board_label[i][j] != -1 and mask[board_label[i][j]] == True:
                        circle = plt.Circle((j, i), radius=radius, color="#F4B183", linewidth=2, zorder=2, clip_on=False)
                        circle.set_edgecolor((1, 0, 0, 0.6))  # Set the circle edge color to large red (R=1, G=0, B=0)
                        circle.set_linewidth(10)  # Set the line width of the circle to 3
                        ax.add_artist(circle)
                        ax.text(j, i, board_label[i][j] + 1, ha="center", va="center", color="black", fontsize=FONT)
                    # 如果该棋子是选中的10个棋子中的一个，不是common coalition中的棋子，标出数字
                    if board_label[i][j] != -1 and mask[board_label[i][j]] == False:
                        ax.add_artist(plt.Circle((j, i), radius=radius, color="#F4B183", linewidth=2, zorder=2, alpha=0.3, clip_on=False))
                        ax.text(j, i, board_label[i][j] + 1, ha="center", va="center", color="black", fontsize=FONT)
                    # 如果该棋子是背景棋子，不是选中的10个棋子中的一个
                    if board_label[i][j] == -1:
                        ax.add_artist(plt.Circle((j, i), radius=radius, color="#F4B183", linewidth=2, zorder=2, alpha=0.3, clip_on=False))

        # Add text
        text_x, text_y = (x_limits[0] + x_limits[-1]) / 2, y_limits[-1] + 0.5
        phi = "{:.2f}".format(phis_S[mask_idx])
        bbox_props = dict(boxstyle="round, pad=0.3", facecolor="white", edgecolor="white", alpha=0)
        # ax.text(text_x, text_y, r"$\varphi(\{{ {} \}}) = {}$".format(mask_index_str, phi), fontsize=30, ha='center', va='center', bbox=bbox_props)

        # Save the figure
        plt.savefig(os.path.join(save_folder, f"{mask_idx}_{phis_S[mask_idx]}.png"), bbox_inches='tight')
        plt.close()


if __name__ == '__main__':
    save_dir = "board_figure/board_figure_background"
    os.makedirs(save_dir, exist_ok=True)

    load_dir = "sgf_label"
    load_coalition_dir = "analysis_coalitions"

    FONT = 24
    radius = 0.4
    # 起始颜色：红色
    start_color = np.array([239, 67, 83]) / 255.0
    # 结束颜色：蓝色
    end_color = np.array([0, 0, 255]) / 255.0
    # 生成10个渐变色
    colors = [start_color + (end_color - start_color) * i / 10 for i in range(11)]
    # 将颜色列表输出
    colors = ["#%02X%02X%02X" % (int(color[0] * 255), int(color[1] * 255), int(color[2] * 255)) for color in colors]

    count = 0
    filenames = [filename for filename in os.listdir(load_coalition_dir) if not filename.startswith("id")]
    for idx, filename in enumerate(filenames):
        sample_id = filename.split(".")[0]
        sample_id = sample_id.split("_")[1]
        print(sample_id)
        count += 1

        position = open(os.path.join(load_dir, f"{sample_id}.txt"), "r").readlines()[0][:-1]
        all_positions = position.split(";")
        print("all_positions: ", all_positions, len(all_positions))
        set_position_N = "set_position " + " ".join(all_positions)
        print(set_position_N)

        # 根据标注的文件找出所有的player
        players = open(os.path.join(load_dir, f"{sample_id}.txt"), "r").readlines()[1][:-1].split(";")
        n_attributes = len(players)
        print("all_players: ", players, n_attributes)

        # 前景信息
        foreground = []
        for player in players:
            foreground.extend(player[1:-1].split(","))
        print("foreground: ", foreground)

        # 前景信息
        foreground = []
        for player in players:
            foreground.append(player[1:-1].split(","))
        print("foreground: ", foreground)

        phis_S = []
        masks_S = []
        orders_S = []
        with open(os.path.join(load_coalition_dir, filename), newline='') as csvfile:
            reader = csv.reader(csvfile)
            for i, row in enumerate(reader):
                if i == 0:
                    continue
                phis_S.append([float(item) for i, item in enumerate(row) if i == 0])
                orders_S.append([int(float(item)) for i, item in enumerate(row) if i == 1])
                masks_S.append([float(item) > 0.5 for i, item in enumerate(row) if i > 1])
        masks_S = np.array(masks_S, dtype=bool)[:10]
        phis_S = np.array(phis_S).flatten()[:10]
        orders_S = np.array(orders_S).flatten()[:10]

        sorted_indices = np.argsort(phis_S)[::-1]
        sorted_masks_S = masks_S[sorted_indices]
        sorted_phis_S = phis_S[sorted_indices]

        visualize(set_position_N, foreground, sample_id, sorted_masks_S, sorted_phis_S)

    print(count)

