import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

import os.path
import torch
import math
import numpy as np
import subprocess
import tqdm
import copy
from interaction_utils import *


def make_dir(path):
    if not os.path.exists(path):
        os.makedirs(path)


def get_reward(values, selected_dim):
    if selected_dim == "max":
        values = values[:, torch.argmax(values[-1])]  # select the predicted dimension, by default
    elif selected_dim == "0":
        values = values[:, 0]
    elif selected_dim == "gt-log-odds":
        eps = 1e-7
        values = math.log(values / (1 - values + eps) + eps)
    elif selected_dim == None:
        values = values
    else:
        raise Exception(f"Unknown [selected_dim] {selected_dim}.")

    return values


def save_to_file(file_name, contents):
    fh = open(file_name, 'w')
    fh.write(contents)
    fh.close()


def check_same_sign(lst):
    if len(lst) < 2:
        return True

    first_sign = 0  # 记录第一个数字的符号

    for num in lst:
        if num != 0:
            sign = num / abs(num)  # 判断符号
            if first_sign == 0:
                first_sign = sign
            elif sign != first_sign:
                return False

    return True


def plot_reward_mean(reward_mean, save_path, save_name, reward_std, standard=None):
    os.makedirs(save_path, exist_ok=True)
    length = len(reward_mean)

    # 从 0, 1, 2, 3, 4, 5, -1, -2, -3, -4, -5 --》 -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5
    reward_mean = np.concatenate((reward_mean[length // 2 + 1:][::-1], reward_mean[:length // 2 + 1]))
    reward_std = np.concatenate((reward_std[length // 2 + 1:][::-1], reward_std[:length // 2 + 1]))
    print("mean: ", reward_mean)
    print("std: ", reward_std)

    reward_upper = reward_mean + reward_std
    reward_lower = reward_mean - reward_std

    plt.figure(figsize=(5, 4))
    X = np.linspace(-(length // 2), length // 2, length).astype(int)
    plt.plot(X, reward_mean, '-o', markersize=3)
    plt.fill_between(X, reward_lower, reward_upper, color='lightblue', alpha=0.6)

    plt.yticks(fontproperties='Times New Roman', size=20)
    plt.xticks(fontproperties='Times New Roman', size=20)
    # plt.legend(prop={'family': 'Times New Roman', 'size': 28}, loc="center right", bbox_to_anchor=(1, 0.5))
    ax = plt.gca()  # 获得坐标轴的句柄
    ax.spines['bottom'].set_linewidth(2)  ###设置底部坐标轴的粗细
    ax.spines['left'].set_linewidth(2)  ####设置左边坐标轴的粗细
    ax.spines['right'].set_linewidth(2)  ###设置右边坐标轴的粗细
    ax.spines['top'].set_linewidth(2)  ####设置上部坐标轴的粗细
    plt.tight_layout()
    plt.savefig(os.path.join(save_path, f"{save_name}.svg"), bbox_inches='tight', transparent=True)
    plt.close("all")


# 在标注之后的sgf文件中找到属于背景的棋子，以及我们想要计算的player棋子
chessNum_one_player = 1
playersNum = 10
load_dir = "sgf_label"
filenames = os.listdir(load_dir)

save_dir = "eval_andor"
os.makedirs(save_dir, exist_ok=True)

# 0, 1, 2, 3, 4, 5, -1, -2, -3, -4, -5
rewards_mean = [[] for _ in range(playersNum + 1)]

for idx, filename in enumerate(filenames):

    # 如果存在这个文件夹的话，就直接下一条数据
    print(idx, filename[-8:-4])

    # 找出所有的棋子的位置
    position = open(os.path.join(load_dir, filename), "r").readlines()[0][:-1]
    all_positions = position.split(";")
    print("all_positions: ", all_positions, len(all_positions))

    # 根据标注的文件找出所有的player
    players = open(os.path.join(load_dir, filename), "r").readlines()[1][:-1].split(";")
    n_attributes = len(players)
    print("all_players: ", players, n_attributes)

    # 黑色棋子表示成1，白色棋子表示成-1
    players_color = [-chessNum_one_player if player[1] == "w" else chessNum_one_player for player in players]
    print("players color: ", players_color)
    if sum(players_color) != 0:
        continue

    # 前景信息
    foreground = []
    for player in players:
        foreground.extend(player[1:-1].split(","))
    print("foreground: ", foreground)

    # 背景信息
    background = list(set(all_positions).difference(set(foreground)))
    print("background: ", background)

    # 检查标注的前景信息是否是players中的一个
    if not (set(foreground).issubset(set(all_positions)) and set(background).issubset(set(all_positions))):
        print("illegal file: ", filename)
        print(set(foreground) - set(all_positions))
        print(set(background) - set(all_positions))
        break

    # 得到v(N)以及v(empty)的set_position的命令
    set_position_N = "set_position " + " ".join(all_positions)
    set_position_empty = "set_position " + " ".join(background)
    commands_N_empty = [set_position_N, "showboard", "kata-raw-nn 0", set_position_empty, "showboard", "kata-raw-nn 0"]

    # all_masks表示针对players中所有player的不同遮挡状态，数量为n_masks: 2^n
    all_masks = torch.BoolTensor(generate_all_masks(n_attributes))
    n_masks, _ = all_masks.shape

    try:
        # 启动 KataGo 引擎子进程
        cmd = "./katago gtp -model b18c384nbt-uec.bin.gz -config configs/gtp_chinese.cfg"
        p = subprocess.Popen(cmd, shell=True, stdin=subprocess.PIPE, stdout=subprocess.PIPE, bufsize=1, text=True)

        commands = []
        masks_same_sign = [[0] * 2 for _ in range(n_masks)]
        for i in tqdm(range(n_masks), ncols=100, desc="Generating mask"):
            mask_S = all_masks[i]

            # 先把被删除的棋子找出来，(1)判断删除的棋子颜色是否是一样的; (2)此时白棋比黑棋多多少个棋子
            # masks_same_sign[i][0]: True or False; masks_same_sign[i][1]: 0, 1, 2, 3, 4, 5, -1, -2, -3, -4, -5
            remove_color = [players_color[i] for i, mask in enumerate(mask_S) if not mask]
            masks_same_sign[i][0] = check_same_sign(remove_color)
            masks_same_sign[i][1] = sum(remove_color)

            # output
            S_inputs = copy.deepcopy(background)
            for id, mask_i in enumerate(mask_S.tolist()):
                if mask_i:
                    S_inputs.extend(players[id][1:-1].split(","))
            set_position_S = "set_position " + " ".join(S_inputs)
            commands.append(set_position_S)
            commands.append("showboard")
            commands.append("kata-raw-nn 0")

        flag = True
        for cmd in commands_N_empty:
            p.stdin.write((cmd + "\n"))
            while True:
                data = p.stdout.readline()
                if data == "? Illegal stone placements - overlapping stones or stones with no liberties?\n":
                    flag = False
                    break
                if not data.strip():
                    break
        if not flag:
            continue

        flag = True
        count = 0
        for idx, cmd in enumerate(commands):
            p.stdin.write((cmd + "\n"))
            while True:
                data = p.stdout.readline()
                if not data.strip():
                    break
                if data == "? Illegal stone placements - overlapping stones or stones with no liberties?\n":
                    flag = False
                if data.startswith('whiteWin'):
                    whiteWin = float(data[len('whiteWin') + 1:-len('\\n')])

                    # 如果这个mask中删除的棋子的颜色是一样的
                    if masks_same_sign[count][0]:
                        if masks_same_sign[count][1] < 0:
                            rewards_mean_id = -masks_same_sign[count][1] + len(rewards_mean) // 2
                        else:
                            rewards_mean_id = masks_same_sign[count][1]
                        rewards_mean[rewards_mean_id].append(get_reward(whiteWin, "gt-log-odds"))
                    count += 1
        if not flag:
            continue

        torch.cuda.empty_cache()
    finally:
        p.stdin.close()
        p.wait()

os.makedirs(save_dir, exist_ok=True)
print("check: ", [len(rewards) for rewards in rewards_mean])
rewards_std = [np.std(np.array(rewards)) for rewards in rewards_mean]
rewards_mean = [np.mean(np.array(rewards)) for rewards in rewards_mean]
print("mean: ", rewards_mean)
print("std: ", rewards_std)
np.save(os.path.join(save_dir, "rewards_mean.npy"), np.array(rewards_mean))
np.save(os.path.join(save_dir, "rewards_std.npy"), np.array(rewards_std))

rewards_mean = np.load(os.path.join(save_dir, "rewards_mean.npy"))
rewards_std = np.load(os.path.join(save_dir, "rewards_std.npy"))

plot_reward_mean(rewards_mean, save_dir, save_name="rewards_mean_std", reward_std=rewards_std, standard=None)

