import torch
import chess
import chess.syzygy
import config


def tb_res_to_wdl(tb_res):
    """Utility function. Transform EGTB output to simple WDL, ignoring cursed wins and losses."""
    if tb_res == 2:
        return 0
    elif tb_res == -2:
        return 2
    else:
        assert tb_res == 0 or tb_res == -1 or tb_res == 1
        return 1


def open_egtb(max_fds=512):
    config.tablebase = chess.syzygy.open_tablebase(config.tb_path, max_fds=max_fds)


def close_egtb():
    if config.tablebase is not None:
        config.tablebase.close()


def tb_probe_wdl_ab(board, alpha=-2, beta=2):
    if config.tablebase is not None:
        return config.tablebase.probe_ab(board, alpha, beta)[0]
    with chess.syzygy.open_tablebase(config.tb_path) as tablebase:
        return tablebase.probe_ab(board, alpha, beta)[0]


def tb_probe_wdl(board):
    assert isinstance(board, chess.Board)
    if config.tablebase is not None:
        return config.tablebase.probe_wdl(board)
    with chess.syzygy.open_tablebase(config.tb_path) as tablebase:
        return tablebase.probe_wdl(board)


def tb_probe_result(board):
    return tb_res_to_wdl(tb_probe_wdl(board))


def tensor_to_board(x):
    if len(x.shape) >= 2:
        return [tensor_to_board(x[i]) for i in range(x.shape[0])]
    piece_idxs = x[:(12*64)].nonzero().squeeze()
    board = chess.Board(None)
    for piece_idx in piece_idxs:
        piece_type = piece_idx // 64
        piece_color = chess.WHITE if piece_type < 6 else chess.BLACK
        piece_type = (piece_type % 6) + 1
        piece = chess.Piece(piece_type, piece_color)
        square = chess.SQUARES[piece_idx % 64]
        board.set_piece_at(square, piece)
    return board


def get_board_tensor(board):
    features_board = torch.zeros(2*6*64)
    for square in board.piece_map():
        piece = board.piece_map()[square]
        idx = (piece.piece_type-1) * 64 + square
        if piece.color == chess.BLACK:
            idx += 6 * 64
        features_board[idx] = 1
    return features_board.view(1, 768)


def tb_probe_tensor(x):
    return tb_probe_result(tensor_to_board(x))


def softmax(x, dim=-1):
    return torch.nn.functional.softmax(x, dim=dim)


def get_xy(square):
    return square % 8, square // 8
