import chess
import numpy as np


def boards_to_ndarray(boards):
    arr64 = np.array(boards, dtype=np.uint64)
    arr8 = arr64.view(dtype=np.uint8)
    # a bit array increment from LHS to RHS
    bits = np.unpackbits(arr8, bitorder="little")
    floats = bits.astype(bool)
    boardstack = floats.reshape([len(boards), 8, 8])
    # We do np.flip() onto `boardstack` because the 1st line of the boardimage is the 8th line of the ndarray.
    boardimage = np.flip(np.transpose(boardstack, [1, 2, 0]), axis=0)
    return boardimage


def square_to_coord(s):
    col = s % 8
    row = s // 8
    return (col, row)


def diff(c1, c2):
    x1, y1 = c1
    x2, y2 = c2
    return (x2 - x1, y2 - y1)


def sign(v):
    return -1 if v < 0 else (1 if v > 0 else 0)


def mirror_move(move):
    return chess.Move(
        chess.square_mirror(move.from_square),
        chess.square_mirror(move.to_square),
        promotion=move.promotion,
    )


def result_to_int(result_str):
    if result_str == "1-0":
        return 1
    elif result_str == "0-1":
        return -1
    elif result_str == "1/2-1/2":
        return 0
    else:
        assert False, "bad result"


def get_queen_dir(diff):
    dx, dy = diff
    assert dx == 0 or dy == 0 or abs(dx) == abs(dy)
    magnitude = max(abs(dx), abs(dy)) - 1

    assert magnitude < 8 and magnitude >= 0
    counter = 0
    for x in range(-1, 1 + 1):
        for y in range(-1, 1 + 1):
            if x == 0 and y == 0:
                continue
            if x == sign(dx) and y == sign(dy):
                return magnitude, counter
            counter += 1
    assert False, "bad queen move inputted"


def get_queen_plane(diff):
    NUM_COUNTERS = 8
    mag, counter = get_queen_dir(diff)
    return mag * NUM_COUNTERS + counter


def get_knight_dir(diff):
    dx, dy = diff
    counter = 0
    for x in range(-2, 2 + 1):
        for y in range(-2, 2 + 1):
            if abs(x) + abs(y) == 3:
                if dx == x and dy == y:
                    return counter
                counter += 1
    assert False, "bad knight move inputted"


def is_knight_move(diff):
    dx, dy = diff
    return abs(dx) + abs(dy) == 3 and 1 <= abs(dx) <= 2


def get_pawn_promotion_move(diff):
    dx, dy = diff
    assert dy == 1
    assert -1 <= dx <= 1
    return dx + 1


def get_pawn_promotion_num(promotion):
    assert (
        promotion == chess.KNIGHT
        or promotion == chess.BISHOP
        or promotion == chess.ROOK
    )
    return 0 if promotion == chess.KNIGHT else (1 if promotion == chess.BISHOP else 2)


def move_to_coord(move):
    return square_to_coord(move.from_square)


def get_move_plane(move):
    source = move.from_square
    dest = move.to_square
    difference = diff(square_to_coord(source), square_to_coord(dest))

    QUEEN_MOVES = 56
    KNIGHT_MOVES = 8
    QUEEN_OFFSET = 0
    KNIGHT_OFFSET = QUEEN_MOVES
    UNDER_OFFSET = KNIGHT_OFFSET + KNIGHT_MOVES

    if is_knight_move(difference):
        return KNIGHT_OFFSET + get_knight_dir(difference)
    else:
        if move.promotion is not None and move.promotion != chess.QUEEN:
            return (
                UNDER_OFFSET
                + 3 * get_pawn_promotion_move(difference)
                + get_pawn_promotion_num(move.promotion)
            )
        else:
            return QUEEN_OFFSET + get_queen_plane(difference)


moves_to_actions = {}
actions_to_moves = {}


def action_to_move(board: chess.Board, action, player: int):
    base_move = chess.Move.from_uci(actions_to_moves[action])

    base_coord = square_to_coord(base_move.from_square)
    mirr_move = mirror_move(base_move) if player else base_move
    if mirr_move.promotion == chess.QUEEN:
        mirr_move.promotion = None
    if (
        mirr_move.promotion is None
        and str(board.piece_at(mirr_move.from_square)).lower() == "p"
        and base_coord[1] == 6
    ):
        mirr_move.promotion = chess.QUEEN
    return mirr_move


def make_move_mapping(uci_move):
    TOTAL = 73
    move = chess.Move.from_uci(uci_move)
    source = move.from_square

    coord = square_to_coord(source)
    panel = get_move_plane(move)
    cur_action = (coord[0] * 8 + coord[1]) * TOTAL + panel

    moves_to_actions[uci_move] = cur_action
    actions_to_moves[cur_action] = uci_move


def legal_moves(orig_board: chess.Board):
    """Returns legal moves.

    action space is a 8x8x73 dimensional array
    Each of the 8×8
    positions identifies the square from which to “pick up” a piece. The first 56 planes encode
    possible ‘queen moves’ for any piece: a number of squares [1..7] in which the piece will be
    moved, along one of eight relative compass directions {N, NE, E, SE, S, SW, W, NW}. The
    next 8 planes encode possible knight moves for that piece. The final 9 planes encode possible
    underpromotions for pawn moves or captures in two possible diagonals, to knight, bishop or
    rook respectively. Other pawn moves or captures from the seventh rank are promoted to a
    queen
    """
    if orig_board.turn == chess.BLACK:  # white is 1, black is 0
        board = orig_board.mirror()
    else:
        board = orig_board

    legal_moves = []
    for move in board.legal_moves:
        uci_move = move.uci()
        if uci_move in moves_to_actions:
            legal_moves.append(moves_to_actions[move.uci()])
        else:
            make_move_mapping(uci_move)
            legal_moves.append(moves_to_actions[move.uci()])

    return legal_moves


def get_observation(orig_board: chess.Board, player: int):
    """Returns observation array.

    Observation is an 8x8x(P + L) dimensional array.
    P is going to be your pieces positions + your opponents pieces positions
    L is going to be some metadata such as repetition count,,
    """
    board = orig_board
    if player:
        board = board.mirror()
    else:
        board = board

    all_squares = chess.SquareSet(chess.BB_ALL)
    HISTORY_LEN = 1
    PLANES_PER_BOARD = 13
    AUX_SIZE = 7
    RESULT_SIZE = AUX_SIZE + HISTORY_LEN * PLANES_PER_BOARD
    result = [chess.SquareSet(chess.BB_EMPTY) for _ in range(RESULT_SIZE)]
    AUX_OFF = 0
    BASE = AUX_SIZE

    """        // "Legacy" input planes with:
    // - Plane 104 (0-based) filled with 1 if white can castle queenside.
    // - Plane 105 filled with ones if white can castle kingside.
    // - Plane 106 filled with ones if black can castle queenside.
    // - Plane 107 filled with ones if white can castle kingside.
    if (board.castlings().we_can_000()) result[kAuxPlaneBase + 0].SetAll();
    if (board.castlings().we_can_00()) result[kAuxPlaneBase + 1].SetAll();
    if (board.castlings().they_can_000()) {
      result[kAuxPlaneBase + 2].SetAll();
    }
    if (board.castlings().they_can_00()) result[kAuxPlaneBase + 3].SetAll();
    """
    if board.castling_rights & chess.BB_H1:
        result[AUX_OFF + 0] = all_squares
    if board.castling_rights & chess.BB_A1:
        result[AUX_OFF + 1] = all_squares
    if board.castling_rights & chess.BB_H8:
        result[AUX_OFF + 2] = all_squares
    if board.castling_rights & chess.BB_A8:
        result[AUX_OFF + 3] = all_squares
    """
        if (we_are_black) result[kAuxPlaneBase + 4].SetAll();
        result[kAuxPlaneBase + 5].Fill(history.Last().GetNoCaptureNoPawnPly());
        // Plane kAuxPlaneBase + 6 used to be movecount plane, now it's all zeros.
        // Plane kAuxPlaneBase + 7 is all ones to help NN find board edges.
        result[kAuxPlaneBase + 7].SetAll();
      }
      """
    if player:
        result[AUX_OFF + 4] = all_squares
    result[AUX_OFF + 5].add(board.halfmove_clock // 2)
    result[AUX_OFF + 6] = all_squares
    """
      bool flip = false;
      int history_idx = history.GetLength() - 1;
      for (int i = 0; i < std::min(history_planes, kMoveHistory);
           ++i, --history_idx) {
        const Position& position =
            history.GetPositionAt(history_idx < 0 ? 0 : history_idx);
        const ChessBoard& board =
            flip ? position.GetThemBoard() : position.GetBoard();
        if (history_idx < 0 && fill_empty_history == FillEmptyHistory::NO) break;
        // Board may be flipped so compare with position.GetBoard().
        if (history_idx < 0 && fill_empty_history == FillEmptyHistory::FEN_ONLY &&
            position.GetBoard() == ChessBoard::kStartposBoard) {
          break;
        }

        const int base = i * kPlanesPerBoard;
        result[base + 0].mask = (board.ours() & board.pawns()).as_int();
        result[base + 1].mask = (board.our_knights()).as_int();
        result[base + 2].mask = (board.ours() & board.bishops()).as_int();
        result[base + 3].mask = (board.ours() & board.rooks()).as_int();
        result[base + 4].mask = (board.ours() & board.queens()).as_int();
        result[base + 5].mask = (board.our_king()).as_int();

        result[base + 6].mask = (board.theirs() & board.pawns()).as_int();
        result[base + 7].mask = (board.their_knights()).as_int();
        result[base + 8].mask = (board.theirs() & board.bishops()).as_int();
        result[base + 9].mask = (board.theirs() & board.rooks()).as_int();
        result[base + 10].mask = (board.theirs() & board.queens()).as_int();
        result[base + 11].mask = (board.their_king()).as_int();

        """
    base = BASE
    # In the module `chess`, the color is represented by 1 for white and 0 for black.
    OURS = 1
    THEIRS = 0
    result[base + 0] = board.pieces(chess.PAWN, OURS)
    result[base + 1] = board.pieces(chess.KNIGHT, OURS)
    result[base + 2] = board.pieces(chess.BISHOP, OURS)
    result[base + 3] = board.pieces(chess.ROOK, OURS)
    result[base + 4] = board.pieces(chess.QUEEN, OURS)
    result[base + 5] = board.pieces(chess.KING, OURS)

    result[base + 6] = board.pieces(chess.PAWN, THEIRS)
    result[base + 7] = board.pieces(chess.KNIGHT, THEIRS)
    result[base + 8] = board.pieces(chess.BISHOP, THEIRS)
    result[base + 9] = board.pieces(chess.ROOK, THEIRS)
    result[base + 10] = board.pieces(chess.QUEEN, THEIRS)
    result[base + 11] = board.pieces(chess.KING, THEIRS)

    """
    const int repetitions = position.GetRepetitions();
    if (repetitions >= 1) result[base + 12].SetAll();
    """
    has_repeated = board.is_repetition(2)
    if has_repeated >= 1:
        result[base + 12] = all_squares
    """
        // If en passant flag is set, undo last pawn move by removing the pawn from
        // the new square and putting into pre-move square.
        if (history_idx < 0 && !board.en_passant().empty()) {
          const auto idx = GetLowestBit(board.en_passant().as_int());
          if (idx < 8) {  // "Us" board
            result[base + 0].mask +=
                ((0x0000000000000100ULL - 0x0000000001000000ULL) << idx);
          } else {
            result[base + 6].mask +=
                ((0x0001000000000000ULL - 0x0000000100000000ULL) << (idx - 56));
          }
        }
        if (history_idx > 0) flip = !flip;
      }
    """

    """
    The LeelaChessZero-style en passant flag.
    In FEN, the en passant flag is represented by the square that can be a possible target of an en passant, e.g. the `e3` in `4k3/8/8/8/4Pp2/8/8/4K3 b - e3 99 50`.
    However, for a neural network, it is not easy to train the network to recognize sparse and unstructured data.
    Therefore, we adhere to LeelaChessZero's convention, which adjusts the row number to the 1st for white pawns if the en passant flag is set, and vice versa for black pawns.
    E.g. A white pawn(e2) just made an initial two-square advance, `e2e4`.
         A black pawn(f4) next to that white pawn(e4) can play en passant capture on it.
         To show this chance, we denote the white pawn at `e1` instead of `e4` once that white pawn play two-square advance.
         The en passant flag is set only for one turn, and it is reset after the next turn.
         Note that the en passant flag has nothing to do with the opponent's pawn.
         i.e. an en passant flag always set after an initial two-square advance.

       The board             The observation of the 7th channel(white pawn)
    8  · · · · ♚ · · ·    8  · · · · · · · ·
    7  · · · · · · · ·    7  · · · · · · · ·
    6  · · · · · · · ·    6  · · · · · · · ·
    5  · · · · · · · ·    5  · · · · · · · ·
    4  · · · · ♙ ♟ · ·    4  · · · · · · · ·
    3  · · · · · · · ·    3  · · · · · · · ·
    2  · · · · · · · ·    2  · · · · · · · ·
    1  · · · · ♔ · · ·    1  · · · · 1 · · ·
       a b c d e f g h       a b c d e f g h
    FEN: 4k3/8/8/8/4Pp2/8/8/4K3 b - e3 99 50

    More details:
    https://github.com/Farama-Foundation/PettingZoo/blob/master/pettingzoo/classic/chess/chess.py#L42
    https://github.com/LeelaChessZero/lc0/blob/master/src/chess/board.cc#L1114
    """

    # square where the en passant happened, ranging from 0 to 63 (int)
    square = board.ep_square
    if square:
        # Less than 32 is a white square, otherwise it's a black square
        ours = square < 32
        row = square % 8
        dest_col_add = 0 if ours else 8 * 7
        dest_square = dest_col_add + row
        if ours:
            # Set the `square + 8` position in channel `base` to False
            result[base + 0].remove(square + 8)
            # Set the `dest_square` position in channel `base` to True
            result[base + 0].add(dest_square)
        else:
            # Set the `square + 8` position in channel `base` to False
            result[base + 6].remove(square - 8)
            # Set the `dest_square` position in channel `base` to True
            result[base + 6].add(dest_square)

    return boards_to_ndarray(result)
