import io
import zstandard as zstd
from tqdm import tqdm
import chess
import chess.pgn
import chess.engine
import numpy as np
from contextlib import ExitStack
from pathlib import Path
import gzip
from nets.policy_index import policy_index
import tensorflow as tf
from typing import overload, Iterator
import sys
import h5py
import ast
from functools import partial

PLANE_KEY = dict(zip("PNBRQKpnbrqke", range(13)))
STORAGE_DIR = "/storage1/fs1/XXXX-1/Active/chess"

COLUMN_NAMES = {
    "pieces": tf.string,
    "W_OOO": tf.int32,
    "W_OO": tf.int32,
    "B_OOO": tf.int32,
    "B_OO": tf.int32,
    "starttime": tf.int32,
    "startinc": tf.int32,
    "curtime": tf.float32,
    "curtimeopp": tf.float32,
    "rating": tf.int32,
    "ratingopp": tf.int32,
    "user": tf.string,
    "useropp": tf.string,
    "enginebefore": tf.string,
    "moveplayed": tf.int32,
    "result": tf.float32,
    "engineafter": tf.string,
    "movesleft": tf.int32,
    "timeleft": tf.float32,
}


def line_to_board(data: str | dict) -> chess.Board:
    if isinstance(data, str):
        data = dict(zip(COLUMN_NAMES.keys(), data.split(",")))
    board = chess.Board()
    board.clear_board()
    for square, symbol in zip(chess.SQUARES, data["pieces"]):
        if symbol == " ":
            continue
        elif symbol == "e":
            board.ep_square = square
        else:
            board.set_piece_at(square, chess.Piece.from_symbol(symbol))
    castle = []
    for i, k in zip(["Q", "K", "q", "k"], ["W_OOO", "W_OO", "B_OOO", "B_OO"]):
        if data[k] == "1":
            castle.append(i)
    if castle:
        board.set_castling_fen("".join(castle))
    return board


def line_to_arr_eager(
    data: str | dict,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """
    Convert a line of data to NumPy arrays for eager execution.

    Args:
        data: A string line from CSV or a dictionary of column values

    Returns:
        Tuple of (X, Y, Z, Q, M) NumPy arrays
    """
    if isinstance(data, str):
        data = dict(zip(COLUMN_NAMES.keys(), data.split(",")))

    # Convert pieces to a 2D grid
    pieces_list = list(data["pieces"])
    pieces_grid = np.array(pieces_list).reshape(8, 8)

    # Initialize arrays
    X = np.zeros((112, 8, 8), dtype=np.float32)

    # Fill in piece planes
    for symbol, i in PLANE_KEY.items():
        piece_mask = (pieces_grid == symbol).astype(np.float32)
        X[i] = piece_mask

    # En passant square (if any)
    ep_mask = (pieces_grid == "e").astype(np.float32)
    X[12] = ep_mask

    # Set castling planes
    for i, k in zip([104, 105, 106, 107], ["W_OOO", "W_OO", "B_OOO", "B_OO"]):
        X[i] = float(data[k])

    # Set the "ones" plane
    X[111] = 1.0

    # Create one-hot encoded outputs
    Y = np.zeros(1858, dtype=np.float32)
    Y[int(data["moveplayed"])] = 1.0
    Z = np.zeros(3, dtype=np.float32)
    Z[2 - int(float(data["result"]) * 2)] = 1.0
    Q = np.zeros_like(Z)
    M = np.array(float(data["movesleft"]), dtype=np.float32)

    return X, Y, Z, Q, M


def line_to_arr(
    data: str | dict, combine_outputs: bool = False
) -> tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]:
    """
    Convert a line of data to TensorFlow tensors for graph mode.

    Args:
        data: A string line from CSV or a dictionary of column values

    Returns:
        Tuple of (X, Y, Z, Q, M) TensorFlow tensors
    """
    if isinstance(data, str):
        data = dict(zip(COLUMN_NAMES.keys(), data.split(",")))

    ## Get X values
    pieces = tf.reshape(
        tf.strings.bytes_split(data["pieces"]).to_tensor(), (-1, 1, 8, 8)
    )
    X = tf.tile(
        tf.reshape(tf.zeros_like(pieces, dtype=tf.float32), (-1, 1, 8, 8)),
        tf.constant([1, 112, 1, 1], dtype=tf.int32),
    )
    mask = tf.ones_like(X, dtype=tf.int32) * tf.reshape(tf.range(112), (1, 112, 1, 1))
    for i, l in enumerate(PLANE_KEY):
        tmp = pieces
        tmp = tf.strings.regex_replace(tmp, l, "1")
        tmp = tf.strings.regex_replace(tmp, "[^1]", "0")
        tmp = tf.strings.to_number(tmp, tf.float32)
        X = tf.add(X, tf.multiply(tf.cast(mask == i, tf.float32), tmp))

    for i, k in zip([104, 105, 106, 107], ["W_OOO", "W_OO", "B_OOO", "B_OO"]):
        X = tf.add(
            X,
            tf.multiply(
                tf.cast(mask == i, tf.float32),
                tf.cast(tf.reshape(data[k], (-1, 1, 1, 1)), tf.float32),
            ),
        )

    X = tf.add(
        X,
        tf.multiply(
            tf.cast(mask == 111, tf.float32),
            tf.cast(tf.reshape(1, (-1, 1, 1, 1)), tf.float32),
        ),
    )

    X = tf.convert_to_tensor(X)
    Y = tf.convert_to_tensor(tf.one_hot(data["moveplayed"][:, 0], 1858))
    Z = tf.convert_to_tensor(
        tf.one_hot(tf.cast(2 - data["result"][:, 0] * 2, tf.int32), 3)
    )
    Q = tf.convert_to_tensor(tf.zeros_like(Z))
    M = tf.convert_to_tensor(tf.cast(data["movesleft"], tf.float32))

    if combine_outputs:
        return X, (Y, Z, M)
    else:
        return X, Y, Z, Q, M


def compress_dataset(*args):
    """
    Efficiently compress chess position data by exploiting structure:
    - Layers 0-12: Store 1 bit per square (piece presence)
    - Layers 13-103: Always zero, don't store
    - Layers 104-107: Store 1 bit per layer (castling rights)
    - Layers 108-110: Always zero, don't store
    - Layer 111: Always 1, don't store
    """
    if len(args) == 2:
        X, (Y, Z, M) = args
        Q = None
    elif len(args) == 5:
        X, Y, Z, Q, M = args
    else:
        raise ValueError(f"Expected 2 or 5 arguments, got {len(args)}")

    batch_size = tf.shape(X)[0]

    # Extract and pack layers 0-12 (13 layers * 64 squares = 832 bits = 104 bytes)
    piece_layers = X[:, 0:13, :, :]  # Shape: (batch_size, 13, 8, 8)
    piece_flat = tf.reshape(piece_layers, [batch_size, -1])  # (batch_size, 832)
    piece_uint8 = tf.cast(piece_flat, tf.uint8)

    # Group into bytes (832 bits = 104 bytes)
    piece_grouped = tf.reshape(piece_uint8, [batch_size, -1, 8])  # (batch_size, 104, 8)
    bit_shifts = tf.constant([7, 6, 5, 4, 3, 2, 1, 0], dtype=tf.uint8)
    piece_shifted = tf.bitwise.left_shift(piece_grouped, bit_shifts)
    piece_packed = tf.reduce_sum(piece_shifted, axis=2)  # (batch_size, 104)

    # Extract castling rights (layers 104-107) - just 4 bits total
    castling = X[:, 104:108, 0, 0]  # Shape: (batch_size, 4) - constant across board
    castling_uint8 = tf.cast(castling, tf.uint8)

    # Pack 4 castling bits into 1 byte
    castling_shifts = tf.constant([3, 2, 1, 0], dtype=tf.uint8)
    castling_shifted = tf.bitwise.left_shift(castling_uint8, castling_shifts)
    castling_packed = tf.reduce_sum(
        castling_shifted, axis=1, keepdims=True
    )  # (batch_size, 1)

    # Combine piece data and castling data
    X_packed = tf.concat([piece_packed, castling_packed], axis=1)  # (batch_size, 105)

    # Convert Y and Z from one-hot to indices
    Y_indices = tf.cast(tf.argmax(Y, axis=1), tf.int32)
    Z_indices = tf.cast(tf.argmax(Z, axis=1), tf.int8)

    if Q is None:
        return X_packed, (Y_indices, Z_indices, M)
    else:
        return X_packed, Y_indices, Z_indices, Q, M


def uncompress_dataset(*args):
    """
    Decompress chess position data back to full 112x8x8 format.
    """
    if len(args) == 2:
        X_packed, (Y_indices, Z_indices, M) = args
        Q = None
    elif len(args) == 5:
        X_packed, Y_indices, Z_indices, Q, M = args
    else:
        raise ValueError(f"Expected 2 or 5 arguments, got {len(args)}")

    batch_size = tf.shape(X_packed)[0]

    # Split packed data
    piece_packed = X_packed[:, :104]  # (batch_size, 104) - piece data
    castling_packed = X_packed[:, 104]  # (batch_size,) - castling data

    # Unpack piece layers (0-12)
    piece_expanded = tf.expand_dims(piece_packed, axis=2)  # (batch_size, 104, 1)
    piece_tiled = tf.tile(piece_expanded, [1, 1, 8])  # (batch_size, 104, 8)

    bit_masks = tf.constant([128, 64, 32, 16, 8, 4, 2, 1], dtype=tf.uint8)
    piece_masked = tf.bitwise.bitwise_and(piece_tiled, bit_masks)
    piece_bits = tf.cast(tf.greater(piece_masked, 0), tf.uint8)

    piece_flat = tf.reshape(piece_bits, [batch_size, 832])  # (batch_size, 832)
    piece_layers = tf.reshape(
        piece_flat, [batch_size, 13, 8, 8]
    )  # (batch_size, 13, 8, 8)

    # Unpack castling rights (104-107)
    castling_expanded = tf.expand_dims(castling_packed, axis=1)  # (batch_size, 1)
    castling_tiled = tf.tile(castling_expanded, [1, 4])  # (batch_size, 4)

    castling_masks = tf.constant([8, 4, 2, 1], dtype=tf.uint8)  # 2^3, 2^2, 2^1, 2^0
    castling_masked = tf.bitwise.bitwise_and(castling_tiled, castling_masks)
    castling_bits = tf.cast(tf.greater(castling_masked, 0), tf.float32)
    castling_layers = tf.reshape(
        castling_bits, [batch_size, 4, 1, 1]
    )  # (batch_size, 4, 1, 1)
    castling_layers = tf.tile(castling_layers, [1, 1, 8, 8])  # (batch_size, 4, 8, 8)

    # Reconstruct full tensor
    X = tf.zeros([batch_size, 112, 8, 8], dtype=tf.float32)

    # Set piece layers (0-12)
    piece_layers_float = tf.cast(piece_layers, tf.float32)
    X = tf.concat([piece_layers_float, X[:, 13:, :, :]], axis=1)

    # Set layers 13-103 to zero (already zero from tf.zeros)

    # Set castling layers (104-107)
    X_before = X[:, :104, :, :]
    X_after = X[:, 108:, :, :]
    X = tf.concat([X_before, castling_layers, X_after], axis=1)

    # Set layers 108-110 to zero (already zero)

    # Set layer 111 to ones
    ones_layer = tf.ones([batch_size, 1, 8, 8], dtype=tf.float32)
    X = tf.concat([X[:, :111, :, :], ones_layer], axis=1)

    # Convert Y and Z from indices back to one-hot
    Y = tf.one_hot(Y_indices, 1858, dtype=tf.float32)
    Z = tf.one_hot(Z_indices, 3, dtype=tf.float32)

    if Q is None:
        return X, (Y, Z, M)
    else:
        return X, Y, Z, Q, M


def csv_to_dataset(
    data_paths: str | list[str],
    rating_ranges: list[tuple[int, int]] | None = None,
    batch_size: int = 1024,
    shuffle: bool = True,
    shuffle_buffer_size: int = 1000000,
    epochs: int = 1,
    seed: int = 0,
    take: int | list = -1,
    combine_outputs: bool = False,
    cache_binary: str | None = None,
    endgame: int | None = None,
) -> tf.data.Dataset:
    """Turns a csv.gz file (or list or glob of csv.gz files) into a tensorflow dataset."""

    def filter_rating(data, rating_range):
        if rating_range is None:
            return True
        if not (rating_range[0] <= data["rating"] < rating_range[1]):
            return False
        elif not (rating_range[0] <= data["ratingopp"] < rating_range[1]):
            return False
        else:
            return True

    def filter_endgame(data):
        if endgame is None:
            return True
        pieces_chars = tf.strings.bytes_split(data["pieces"])
        is_space = tf.equal(pieces_chars, tf.constant(" "))
        is_en_passant = tf.equal(pieces_chars, tf.constant("e"))
        is_empty = tf.logical_or(is_space, is_en_passant)
        num_pieces = tf.reduce_sum(tf.cast(tf.logical_not(is_empty), tf.int32))
        return tf.less(num_pieces, endgame)

    if not isinstance(data_paths, list):
        data_paths = [data_paths]

    if rating_ranges is None:
        rating_ranges = [None] * len(data_paths)
    assert len(rating_ranges) == len(data_paths)

    shuffle_buffer_size = shuffle_buffer_size // len(data_paths)

    def get_compression(data_path):
        if data_path.endswith(".gz"):
            return "GZIP"
        elif data_path.endswith(".zlib"):
            return "ZLIB"

    if combine_outputs:
        line_to_arr_fn = partial(line_to_arr, combine_outputs=True)
    else:
        line_to_arr_fn = line_to_arr

    datasets = [
        tf.data.experimental.make_csv_dataset(
            data_path,
            1,
            compression_type=get_compression(data_path),
            num_epochs=epochs,
            column_names=list(COLUMN_NAMES.keys()),
            column_defaults=list(COLUMN_NAMES.values()),
            shuffle=True,
            shuffle_seed=seed,
            shuffle_buffer_size=shuffle_buffer_size,
            num_parallel_reads=tf.data.AUTOTUNE,
        )
        .filter(partial(filter_rating, rating_range=rating_range))
        .filter(filter_endgame)
        .batch(batch_size)
        .map(line_to_arr_fn, num_parallel_calls=tf.data.AUTOTUNE)
        .take(t)
        for data_path, rating_range, t in zip(data_paths, rating_ranges, take)
    ]
    dataset = tf.data.Dataset.sample_from_datasets(
        datasets, seed=seed, stop_on_empty_dataset=False
    )

    print(data_paths, rating_ranges, take)
    if cache_binary == "":
        dataset = (
            dataset.map(compress_dataset)
            .cache(cache_binary)
            .unbatch()
            .shuffle(
                buffer_size=shuffle_buffer_size * len(data_paths),
                seed=seed,
                reshuffle_each_iteration=True,
            )
            .batch(batch_size)
            .map(uncompress_dataset)
        )
    elif cache_binary is not None:
        dataset = (
            dataset.map(compress_dataset)
            .cache(cache_binary)
            .unbatch()
            .shuffle(
                buffer_size=shuffle_buffer_size * len(data_paths),
                seed=seed,
                reshuffle_each_iteration=True,
            )
            .batch(batch_size)
            .map(uncompress_dataset)
        )

    return dataset


def puzzle_to_dataset(
    puzzle_path: str = "/storage1/fs1/XXXX-1/Active/chess/data/lichess_db_puzzle.csv.zst",
    batch_size: int = 1024,
    shuffle_buffer_size: int = 1000000,
    seed: int = 0,
    skip: int = 0,
    take: int = -1,
    combine_outputs: bool = False,
    cache_binary: str | None = None,
    end_early=False,
    both_sides=True,
    avg_puzzle_len=1,
) -> tf.data.Dataset:
    def load_puzzle_lines():
        """Generator that yields raw puzzle lines from the CSV"""
        with open(puzzle_path, "rb") as f:
            dctx = zstd.ZstdDecompressor()
            decompressed_content = dctx.stream_reader(f)
            text_stream = io.TextIOWrapper(decompressed_content, encoding="utf-8")
            next(text_stream)  # Skip header

            for line in text_stream:
                yield line.strip()

    def process_puzzle_line(line, combine_outputs=False):
        def _process(line_bytes):
            line = line_bytes.numpy().decode("utf-8")
            parts = line.split(",")

            positions = []

            try:
                fen = parts[1]
                moves_str = parts[2]
                board = chess.Board(fen)
                moves = [chess.Move.from_uci(m) for m in moves_str.split()]

                for i, move in enumerate(moves):
                    if (
                        i % 2 == 1 or both_sides
                    ):  # Only process odd-indexed moves (opponent's moves in puzzles)
                        X, _ = board_to_arr(board)

                        if board.turn == chess.BLACK:
                            move_mirrored = chess.Move(
                                from_square=chess.square_mirror(move.from_square),
                                to_square=chess.square_mirror(move.to_square),
                                promotion=move.promotion,
                                drop=move.drop,
                            )
                            move_idx = board_to_uci(board.mirror(), move_mirrored.uci())
                        else:
                            move_idx = board_to_uci(board, move.uci())

                        Y = np.zeros(1858, dtype=np.float32)
                        Y[move_idx] = 1.0

                        Z = np.zeros(3, dtype=np.float32)
                        Q = np.zeros(3, dtype=np.float32)
                        M = np.zeros(1, dtype=np.float32)

                        positions.append((X, Y, Z, Q, M))

                    board.push(move)

            except Exception:
                pass  # Skip invalid puzzles

            # Convert to numpy arrays for returning
            if positions:
                return (
                    np.stack([p[0] for p in positions]),
                    np.stack([p[1] for p in positions]),
                    np.stack([p[2] for p in positions]),
                    np.stack([p[3] for p in positions]),
                    np.stack([p[4] for p in positions]),
                )

            else:
                return (
                    np.zeros((0, 112, 8, 8), dtype=np.float32),
                    np.zeros((0, 1858), dtype=np.float32),
                    np.zeros((0, 3), dtype=np.float32),
                    np.zeros((0, 3), dtype=np.float32),
                    np.zeros((0, 1), dtype=np.float32),
                )

        # Use py_function to wrap the Python processing
        X, Y, Z, Q, M = tf.py_function(
            func=_process,
            inp=[line],
            Tout=[tf.float32, tf.float32, tf.float32, tf.float32, tf.float32],
        )

        # Set shapes for the outputs
        X.set_shape([None, 112, 8, 8])
        Y.set_shape([None, 1858])
        Z.set_shape([None, 3])
        Q.set_shape([None, 3])
        M.set_shape([None, 1])

        # Create a dataset from the tensors
        if combine_outputs:
            return X, (Y, Z, M)
        else:
            return X, Y, Z, Q, M

    process_fn = partial(process_puzzle_line, combine_outputs=combine_outputs)

    # Start with loading puzzle lines
    dataset = (
        tf.data.Dataset.from_generator(
            load_puzzle_lines, output_signature=tf.TensorSpec(shape=(), dtype=tf.string)
        )
        .shuffle(
            buffer_size=shuffle_buffer_size, seed=seed, reshuffle_each_iteration=False
        )
        .skip(skip)
        .take(take // avg_puzzle_len)
        .map(process_fn, num_parallel_calls=tf.data.AUTOTUNE)
    )

    if end_early:
        return dataset

    dataset = (
        dataset.unbatch()
        .take(take)
        .batch(batch_size)
        .map(compress_dataset, num_parallel_calls=tf.data.AUTOTUNE)
        .cache(cache_binary if cache_binary is not None else "")
        .unbatch()
        .shuffle(
            buffer_size=shuffle_buffer_size, seed=seed, reshuffle_each_iteration=True
        )
        .batch(batch_size)
        .map(uncompress_dataset, num_parallel_calls=tf.data.AUTOTUNE)
    )

    return dataset


def hdf5_to_dataset(
    mode: str = "gt_y",
    onefile: bool = False,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    X, Y, Z, Q, M = [], [], [], [], []
    for idx in tqdm(range(100), disable=False, ncols=80):
        i1, i2 = idx * 1000000, (idx + 1) * 1000000
        with h5py.File(
            f"{STORAGE_DIR}/leela/results_new/{1600}/{i1}_{i2}.hdf5", "r"
        ) as f:
            for pos in f["positions"]:  # type: ignore
                X.append(f["positions"][pos]["gt_x"][()])  # type: ignore
                legals = ast.literal_eval(f["positions"][pos].attrs["legal_moves"])  # type: ignore
                if mode == "gt_y":
                    Y.append(f["positions"][pos]["gt_y"][()])  # type: ignore
                elif mode.startswith("sf17_"):
                    uci = f["models"][mode][pos].attrs["value_uci"]  # type: ignore
                    y = np.zeros(1858, dtype=np.float32)
                    y[legals[uci]] = 1
                    Y.append(y)
                Z.append(np.array([1, 0, 0], dtype=np.float32))
                Q.append(np.zeros(3, dtype=np.float32))
                M.append(np.zeros(1, dtype=np.float32))
            # break
        if onefile:
            break

    X = np.array(X)
    Y = np.array(Y)
    Z = np.array(Z)
    Q = np.array(Q)
    M = np.array(M)

    return X, Y, Z, Q, M

    # return (
    #     tf.data.Dataset.from_tensor_slices((X, Y, Z, Q, M))
    #     .shuffle(seed=seed, buffer_size=shuffle_buffer_size)
    #     .batch(batch_size)
    #     .take(take)
    # )


@overload
def board_to_uci(board: chess.Board, uci: str) -> int: ...


@overload
def board_to_uci(board: chess.Board, uci: list[str]) -> list[int]: ...


def board_to_uci(board: chess.Board, uci: str | list[str]) -> int | list[int]:
    def _flip_uci(uci: str) -> str:
        uci_list = list(uci)
        uci_list[1] = str(9 - int(uci[1]))
        uci_list[3] = str(9 - int(uci[3]))
        return "".join(uci_list)

    if board.turn == chess.BLACK:
        board = board.mirror()
        if isinstance(uci, str):
            uci = _flip_uci(uci)
        if isinstance(uci, list):
            uci = [_flip_uci(u) for u in uci]

    if isinstance(uci, str):
        return policy_index.index(uci.rstrip("n"))
    if isinstance(uci, list):
        return [policy_index.index(u.rstrip("n")) for u in uci]


def board_to_arr(board: chess.Board) -> tuple[np.ndarray, list[str]]:
    """
    Converts a chess board object into a 112x8x8 array. Always with White to move.

    Guide to layers:
    0 - White Pawn
    1 - White Knight
    2 - White Bishop
    3 - White Rook
    4 - White Queen
    5 - White King
    6 - Black Pawn
    7 - Black Knight
    8 - Black Bishop
    9 - Black Rook
    10 - Black Queen
    11 - Black King
    12 - en passant square, if there is one
    13 to 103 - Zeros (we are ignoring move history)
    104 - White Queenside Castling
    105 - White Kingside Castling
    106 - Black Queenside Castling
    107 - Black Kingside Castling
    108 - Zeros (We are ignoring move history)
    109 - Zeros (We are ignoring move history)
    110 - Zeros (Same as Leela)
    111 - Ones (Same as Leela)

    """

    arr = np.zeros((112, 64), dtype=np.float32)
    csv = [[], "0", "0", "0", "0"]

    if board.turn == chess.BLACK:
        board = board.mirror()

    for square in chess.SQUARES:
        piece = board.piece_at(square)
        if piece is None:
            csv[0].append(" ")
        else:
            symbol = piece.symbol()
            csv[0].append(piece.symbol())
            arr[PLANE_KEY[symbol], square] = 1
    if ep := board.ep_square:
        arr[12, ep] = 1
        csv[0][ep] = "e"

    if board.has_queenside_castling_rights(chess.WHITE):
        arr[104] = 1
        csv[1] = "1"
    if board.has_kingside_castling_rights(chess.WHITE):
        arr[105] = 1
        csv[2] = "1"
    if board.has_queenside_castling_rights(chess.BLACK):
        arr[106] = 1
        csv[3] = "1"
    if board.has_kingside_castling_rights(chess.BLACK):
        arr[107] = 1
        csv[4] = "1"
    arr[111] = 1

    arr = arr.reshape(112, 8, 8)
    csv[0] = "".join(csv[0])

    return arr, csv


def process_pgn(pgn_lines: list[str], header: dict[str, str]):
    if not pgn_lines or not header:
        return []

    if (
        header["TimeControl"] == "-"
        or header["Termination"] == "Abandoned"
        or "FEN" in header
    ):
        return []

    startTime, startInc = header["TimeControl"].split("+")
    tc = int(startTime) + 40 * int(startInc)
    if not (3 * 60 <= tc < 8 * 60):
        return []

    if header.get("WhiteTitle", "") == "BOT" or header.get("BlackTitle", "") == "BOT":
        return []

    ratings = {
        chess.WHITE: header["WhiteElo"],
        chess.BLACK: header["BlackElo"],
    }

    users = {
        chess.WHITE: header["White"],
        chess.BLACK: header["Black"],
    }
    curTimes = {
        chess.WHITE: startTime,
        chess.BLACK: startTime,
    }

    match header["Result"].split("-")[0]:
        case "1":
            results = 1
        case "0":
            results = 0
        case "1/2":
            results = 0.5
        case _:
            return []
    results = {
        chess.WHITE: str(results),
        chess.BLACK: str(1 - results),
    }

    pgn = chess.pgn.read_game(io.StringIO("".join(pgn_lines)))
    if not pgn:
        return
    if pgn.errors:
        raise ValueError(pgn.errors)

    game_len = len(list(pgn.mainline())) - 1

    isWhite = True
    lines: list[list[str]] = []
    board = chess.Board()
    engineBefore: chess.engine.PovScore | None = None
    for movei, node in enumerate(pgn.mainline()):
        _, line = board_to_arr(board)
        move = board_to_uci(board, node.move.uci())
        rating, ratingOpp = ratings[isWhite], ratings[not isWhite]
        user, userOpp = users[isWhite], users[not isWhite]
        curTime = curTimes[isWhite]
        result = results[isWhite]
        movesLeft = game_len - movei
        timeAfter = node.clock()
        timeAfter = str(timeAfter) if timeAfter is not None else "0"
        engineAfter = node.eval()

        line = line + [
            startTime,
            startInc,
            curTimes[isWhite],
            curTimes[not isWhite],
            rating,
            ratingOpp,
            user,
            userOpp,
            str(engineBefore.pov(isWhite)) if engineBefore else "-",
            str(move),
            result,
            str(engineAfter.pov(isWhite)) if engineAfter else "-",
            str(movesLeft),
            timeAfter,
        ]

        lines.append(line)
        board.push_uci(node.move.uci())

        curTimes[isWhite] = timeAfter
        engineBefore = engineAfter
        isWhite = not isWhite
    return lines


def stream_lines(path: str) -> Iterator[str]:
    with open(path, "rb") as f:
        dctx = zstd.ZstdDecompressor()
        decompressed_content = dctx.stream_reader(f)
        text_stream = io.TextIOWrapper(decompressed_content, encoding="utf-8")

        for line in text_stream:
            yield line


def stream_games(path: str) -> Iterator[tuple[list[str], dict[str, str]]]:
    lines, header = [], {}
    for line in stream_lines(path):
        if line.startswith("[Event "):
            yield (lines, header)
            lines, header = [], {}

        if line.startswith("["):
            k, v = line[1:-2].split(' "')
            header[k] = v[:-1]
        lines.append(line)


def process_lichess_file(month: str, start: int, end: int) -> None:
    # TODO(crystall): make this a "constant"
    data_path = "/storage1/fs1/XXXX-1/Active/chess/data"
    input_file = f"{data_path}/lichess_db_z/lichess_db_standard_rated_{month}.pgn.zst"
    if not Path(input_file).exists():
        return  # TODO(crystall): raise an error
    output_folder = f"{data_path}/csv_data_z3/{month}"
    ratings = range(300, 4100, 100)  # TODO(crystall): make this a constant or arg

    for rating in ratings:
        Path(f"{output_folder}/{rating}").mkdir(parents=True, exist_ok=True)

    # TODO(crystall): make this a "constant" or arg
    file_size = int(1e6)
    # file_size = int(1e5)
    batch_size = int(1e4)

    assert file_size % batch_size == 0
    assert start % file_size == 0, f"{start} is not divisible by {file_size}"
    assert end % file_size == 0, f"{end} is not divisible by {file_size}"

    with ExitStack() as exit_stack:
        fouts = {}
        for i, (pgn, header) in tqdm(
            enumerate(stream_games(input_file)),
            ncols=80,
            mininterval=600,
            total=end,
        ):
            if i < start:
                continue
            if i >= end:
                break

            # Create new output files for every group of file_size data points
            out_paths, fouts, empty = {}, {}, {}
            if i % file_size == 0:
                # Close previous files to write to
                if fouts:
                    for r in ratings:
                        fouts[r].close()
                        if empty[r]:
                            Path(out_paths[r]).unlink()

                # Create new files to write to
                out_paths = {
                    rating: f"{output_folder}/{rating}/{i}_{i + file_size}.csv.gz"
                    for rating in ratings
                }
                fouts = {
                    rating: exit_stack.enter_context(
                        gzip.open(f"{out_paths[rating]}", "wb", compresslevel=9)
                    )
                    for rating in ratings
                }
                empty = {rating: True for rating in ratings}

            lines = process_pgn(pgn, header)
            if not lines:
                continue
            for line in lines:
                r = int(line[9]) // 100 * 100
                fouts[r].write(f"{','.join(str(i) for i in line)}\n".encode())
                empty[r] = False

        # Clean up files, remove empty files which never got written to
        for r, f in fouts.items():
            f.close()
            if empty[r]:
                Path(out_paths[r]).unlink()


if __name__ == "__main__":
    month = sys.argv[1]
    start = int(sys.argv[2]) * 1000000
    end = int(sys.argv[3]) * 1000000
    process_lichess_file(month, start, end)


# months = ["07", "08", "09", "10", "11"]
# fnames = list(range(120))
# for rating in [800, 1600, 2400]:
#     for month in months:
#         for fname in fnames:
#             path = f"{DATA_DIR}/2024-{month}/{rating}/{fname*1000000}_{(fname+1)*1000000}.csv.gz"
#             if not pathlib.Path(path).exists():
#                 continue
#             pathlib.Path(f"{DATA_DIR}/train/2024-{month}/{rating}/").mkdir(exist_ok=True, parents=True)
#             path2 = f"{DATA_DIR}/train/2024-{month}/{rating}/{fname*1000000}_{(fname+1)*1000000}.csv.gz"
#             os.symlink(path, path2)

# for rating in [1200, 2000]:
#     for year in [2023, 2024]:
#         pattern = f"{DATA_DIR}/{year}-*/{rating}/*"
#         for file in tqdm(glob.glob(pattern), desc=f"{year}/{rating}"):
#             link = f"{DATA_DIR}/train/{'/'.join(file.split('/')[-3:])}"
#             Path('/'.join(link.split('/')[:-1])).mkdir(parents=True, exist_ok=True)
#             if not Path(link).exists():
#                 os.symlink(file, link)
# for rating in [1200, 2000]:
#     for year in [2025]:
#         pattern = f"{DATA_DIR}/{year}-*/{rating}/*"
#         for file in tqdm(glob.glob(pattern), desc=f"{year}/{rating}"):
#             link = f"{DATA_DIR}/test/{'/'.join(file.split('/')[-3:])}"
#             Path('/'.join(link.split('/')[:-1])).mkdir(parents=True, exist_ok=True)
#             if not Path(link).exists():
#                 os.symlink(file, link)
