import copy
import linecache
import select
import time
import os.path
import torch
import subprocess
import string
import copy
from tqdm import tqdm
import math
import numpy as np
import argparse
from interaction_utils import generate_all_masks, generate_subset_masks, generate_reverse_subset_masks, \
    generate_set_with_intersection_masks


def make_dir(path):
    if not os.path.exists(path):
        os.makedirs(path)


def get_reward(values, selected_dim):
    if selected_dim == "max":
        values = values[:, torch.argmax(values[-1])]  # select the predicted dimension, by default
    elif selected_dim == "0":
        values = values[:, 0]
    elif selected_dim == "gt-log-odds":
        eps = 1e-7
        values = math.log(values / (1 - values + eps) + eps)
    elif selected_dim == None:
        values = values
    else:
        raise Exception(f"Unknown [selected_dim] {selected_dim}.")

    return values


def save_to_file(file_name, contents):
    fh = open(file_name, 'w')
    fh.write(contents)
    fh.close()


def check_same_sign(lst):
    if len(lst) < 2:
        return True

    first_sign = 0  # 记录第一个数字的符号

    for num in lst:
        if num != 0:
            sign = num / abs(num)  # 判断符号
            if first_sign == 0:
                first_sign = sign
            elif sign != first_sign:
                return False

    return True


parser = argparse.ArgumentParser(description="sparsify and-or harsanyi")
parser.add_argument('--device', default=0, type=int,
                    help="set the device.")
parser.add_argument('--load_dir', default="sgf_label",
                    type=str, help="the path for the labeled chess file.")
parser.add_argument('--save_dir', default="eval_andor",
                    type=str, help="the path for saving the Iand and Ior file.")
parser.add_argument('--chessNum_one_player', default=1, type=int)
# parser.add_argument('--playersNum', default=10, type=int)
parser.add_argument('--reward_way', default="gt-log-odds-minus-mean", type=str,
                    help="the way for calculating the rewards."
                    "choose from: gt-log-odds, gt-log-odds-minus-mean, gt-log-odds-minus-mean-minus-empty,"
                         "gt-log-odds-minus-mean-minus-empty-optim, none")
args = parser.parse_args()


os.makedirs(args.save_dir, exist_ok=True)
rewards_mean = np.load(os.path.join(args.save_dir, "rewards_mean.npy"))

count = 0
for idx, filename in enumerate(os.listdir(args.load_dir)):

    # 如果存在这个文件夹的话，就直接下一条数据
    print(filename[-8:-4])
    save_folder = os.path.join(args.save_dir, filename[:-4], args.reward_way, "before_sparsify")

    # 找出所有的棋子的位置
    position = open(os.path.join(args.load_dir, filename), "r").readlines()[0][:-1]
    all_positions = position.split(";")
    print("all_positions: ", all_positions, len(all_positions))

    # 根据标注的文件找出所有的player
    players = open(os.path.join(args.load_dir, filename), "r").readlines()[1][:-1].split(";")
    n_attributes = len(players)
    print("all_players: ", players, n_attributes)

    # 黑色棋子表示成1，白色棋子表示成-1
    players_color = [-args.chessNum_one_player if player[1] == "w" else args.chessNum_one_player for player in players]
    print("players color: ", players_color)
    if sum(players_color) != 0:
        continue

    # 前景信息
    foreground = []
    for player in players:
        foreground.extend(player[1:-1].split(","))
    print("foreground: ", foreground, len(foreground))

    # 背景信息
    background = list(set(all_positions).difference(set(foreground)))
    print("background: ", background, len(background))

    # 得到v(N)以及v(empty)的set_position的命令
    set_position_N = "set_position " + " ".join(all_positions)
    commands_N = [set_position_N, "showboard", "kata-raw-nn 0"]
    set_position_empty = "set_position " + " ".join(background)
    commands_empty = [set_position_empty, "showboard", "kata-raw-nn 0"]

    # all_masks表示针对players中所有player的不同遮挡状态，数量为n_masks: 2^n
    all_masks = torch.BoolTensor(generate_all_masks(n_attributes))
    n_masks, _ = all_masks.shape

    # mat是n*n(1024*1024)大小的矩阵，其中不同行代表不同的遮挡状态S，不同列表示S的不同子集L下的(-1)^(|S|-|L|)的大小
    Iand_mat = []
    Ior_mat = []
    # outputs是n(1024*1)大小的矩阵，其中不同行代表不同的遮挡状态下的output
    outputs = torch.zeros([n_masks, 3])
    rewards = torch.zeros(n_masks)
    a = torch.zeros(n_masks)
    rewards_mean_ids = torch.zeros(n_masks)

    try:
        # 启动 KataGo 引擎子进程
        cmd = "./katago gtp -model b18c384nbt-uec.bin.gz -config configs/gtp_chinese.cfg"
        p = subprocess.Popen(cmd, shell=True, stdin=subprocess.PIPE, stdout=subprocess.PIPE, bufsize=1, text=True)

        commands = []
        num_white_extra = [0 for _ in range(n_masks)]
        for i in tqdm(range(n_masks), ncols=100, desc="Generating mask"):
            mask_S = all_masks[i]

            # Iand_mat
            Iand_row = torch.zeros(n_masks)
            # ===============================================================================================
            # Note: I(S) = \sum_{L\subseteq S} (-1)^{s-l} v(L)
            mask_Ls, L_indices = generate_subset_masks(mask_S, all_masks)
            L_indices = (L_indices == True).nonzero(as_tuple=False)
            assert mask_Ls.shape[0] == L_indices.shape[0]
            Iand_row[L_indices] = torch.pow(-1., mask_S.sum() - mask_Ls.sum(dim=1)).unsqueeze(1)
            # ===============================================================================================
            Iand_mat.append(Iand_row.clone())

            Ior_row = torch.zeros(n_masks)
            # ===============================================================================================
            # Note: I(S) = -\sum_{L\subseteq S} (-1)^{s+(n-l)-n} v(N\L) if S is not empty
            if mask_S.sum() == 0:
                Ior_row[i] = 1.
            else:
                mask_NLs, NL_indices = generate_reverse_subset_masks(mask_S, all_masks)
                NL_indices = (NL_indices == True).nonzero(as_tuple=False)
                assert mask_NLs.shape[0] == NL_indices.shape[0]
                Ior_row[NL_indices] = - torch.pow(-1., mask_S.sum() + mask_NLs.sum(dim=1) + n_attributes).unsqueeze(1)
            # ================================================================================================
            Ior_mat.append(Ior_row.clone())

            # 先把被删除的棋子找出来，此时白棋比黑棋多多少个棋子
            remove_color = [players_color[i] for i, mask in enumerate(mask_S) if not mask]
            num_white_extra[i] = sum(remove_color)

            # output
            S_inputs = copy.deepcopy(background)
            for id, mask_i in enumerate(mask_S.tolist()):
                if mask_i:
                    S_inputs.extend(players[id][1:-1].split(","))
            set_position_S = "set_position " + " ".join(S_inputs)
            commands.append(set_position_S)
            commands.append("showboard")
            commands.append("kata-raw-nn 0")

        Iand_mat = torch.stack(Iand_mat).float()
        Ior_mat = torch.stack(Ior_mat).float()

        flag = True
        for cmd in commands_empty:
            p.stdin.write((cmd + "\n"))
            while True:
                data = p.stdout.readline()
                if data == "? Illegal stone placements - overlapping stones or stones with no liberties?\n":
                    flag = False
                    break
                if data.startswith('whiteWin'):
                    whiteWin = float(data[len('whiteWin') + 1:-len('\\n')])
                    if args.reward_way == "gt-log-odds":
                        v_empty = get_reward(whiteWin, "gt-log-odds")
                    elif args.reward_way == "gt-log-odds-minus-mean":
                        v_empty = get_reward(whiteWin, "gt-log-odds") - rewards_mean[0]
                    elif args.reward_way == "gt-log-odds-minus-mean-minus-empty":
                        v_empty_ = get_reward(whiteWin, "gt-log-odds") - rewards_mean[0]
                        v_empty = get_reward(whiteWin, "gt-log-odds") - rewards_mean[0] - v_empty_
                    else:
                        raise NotImplementedError(f"Invalid reward_way: {args.reward_way}")
                if not data.strip():
                    break
        if not flag:
            continue

        flag = True
        result_N = ""
        for cmd in commands_N:
            p.stdin.write((cmd + "\n"))
            while True:
                data = p.stdout.readline()
                if data == "? Illegal stone placements - overlapping stones or stones with no liberties?\n":
                    flag = False
                    break
                if data.startswith('whiteWin'):
                    whiteWin = float(data[len('whiteWin') + 1:-len('\\n')])
                    if args.reward_way == "gt-log-odds":
                        v_N = get_reward(whiteWin, "gt-log-odds")
                    elif args.reward_way == "gt-log-odds-minus-mean":
                        v_N = get_reward(whiteWin, "gt-log-odds") - rewards_mean[0]
                    elif args.reward_way == "gt-log-odds-minus-mean-minus-empty":
                        v_N = get_reward(whiteWin, "gt-log-odds") - rewards_mean[0] - v_empty_
                    else:
                        raise NotImplementedError(f"Invalid reward_way: {args.reward_way}")
                if not data.strip():
                    break
                result_N += data
            print("got: {}".format(result_N))
        if not flag:
            continue

        flag = True
        count = 0
        result_all = ""
        for idx, cmd in enumerate(commands):
            p.stdin.write((cmd + "\n"))
            result = ""
            while True:
                data = p.stdout.readline()
                if not data.strip():
                    break
                if data == "? Illegal stone placements - overlapping stones or stones with no liberties?\n":
                    flag = False

                if data.startswith('whiteWin'):
                    whiteWin = float(data[len('whiteWin') + 1:-len('\\n')])
                    outputs[(idx + 1) // 3 - 1, 0] = whiteWin

                    if args.reward_way == "gt-log-odds":
                        rewards[count] = get_reward(whiteWin, "gt-log-odds")
                    elif args.reward_way == "gt-log-odds-minus-mean":
                        if num_white_extra[count] < 0:
                            rewards_mean_id = -num_white_extra[count] + len(rewards_mean) // 2
                        else:
                            rewards_mean_id = num_white_extra[count]
                        rewards[count] = get_reward(whiteWin, "gt-log-odds") - rewards_mean[rewards_mean_id]
                    elif args.reward_way == "gt-log-odds-minus-mean-minus-empty":
                        if num_white_extra[count] < 0:
                            rewards_mean_id = -num_white_extra[count] + len(rewards_mean) // 2
                        else:
                            rewards_mean_id = num_white_extra[count]
                        a[count] = rewards_mean[rewards_mean_id]
                        rewards_mean_ids[count] = rewards_mean_id
                        rewards[count] = get_reward(whiteWin, "gt-log-odds") - rewards_mean[rewards_mean_id] - v_empty_
                    else:
                        raise NotImplementedError(f"Invalid reward_way: {args.reward_way}")
                    count += 1

                elif data.startswith('whiteLoss'):
                    whiteLoss = float(data[len('whiteLoss') + 1:-len('\\n')])
                    outputs[(idx + 1) // 3 - 1, 1] = whiteLoss
                elif data.startswith('noResult'):
                    noResult = float(data[len('noResult') + 1:-len('\\n')])
                    outputs[(idx + 1) // 3 - 1, 2] = noResult

                result += data
                result_all += data
        # print("got: {}".format(result_all))
        # print("count: ", count)
        # exit()
        if not flag:
            continue

        os.makedirs(save_folder, exist_ok=True)
        save_to_file(os.path.join(save_folder, "result.txt"), result_N)

        Iand = torch.matmul(Iand_mat, rewards)
        Ior = torch.matmul(Ior_mat, rewards)

        masks = all_masks.cpu().numpy()
        rewards = rewards.cpu().numpy()
        a = a.cpu().numpy()
        rewards_mean_ids = rewards_mean_ids.cpu().numpy()

        np.save(os.path.join(save_folder, "rewards.npy"), rewards)
        np.save(os.path.join(save_folder, "masks.npy"), masks)
        np.save(os.path.join(save_folder, "outputs.npy"), outputs)
        np.save(os.path.join(save_folder, "a.npy"), a)
        np.save(os.path.join(save_folder, "rewards_mean_ids.npy"), rewards_mean_ids)

        Iand = Iand.cpu().numpy()
        Ior = Ior.cpu().numpy()

        np.save(os.path.join(save_folder, "Iand.npy"), Iand)
        np.save(os.path.join(save_folder, "Ior.npy"), Ior)
        np.save(os.path.join(save_folder, "reward2Iand.npy"), Iand_mat)
        np.save(os.path.join(save_folder, "reward2Ior.npy"), Ior_mat)

        np.save(os.path.join(save_folder, "v_N.npy"), rewards[-1])
        np.save(os.path.join(save_folder, "v_empty.npy"), rewards[0])
        # np.save(os.path.join(save_folder, "v_empty_.npy"), v_empty_)

        with open(os.path.join(save_folder, "info.txt"), 'a') as f:
            f.write("\n[Before Sparsifying]\n")
            f.write(f"\tSum of I^and and I^or: {np.sum(Iand) / 2 + np.sum(Ior) / 2}\n")
            f.write(f"\n[v_N]: \t{rewards[-1]}\n")
            f.write(f"\n[v_empty]: \t{rewards[0]}\n")
            f.write(f"\n[v_N - v_empty]: \t{rewards[-1] - rewards[0]}\n")

        torch.cuda.empty_cache()
    finally:
        p.stdin.close()
        p.wait()





