{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "3aefb278",
   "metadata": {},
   "source": [
    "# Eval Generation\n",
    "---\n",
    "Notebook takes in our eval dataset and generates our various desired eval datasets we can load in.\n",
    "\n",
    "These will be saved as parquets in the evals folder."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "2ba9b81f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import ast\n",
    "import json\n",
    "import random\n",
    "import pandas as pd\n",
    "\n",
    "from utils.board import get_piece_name_at_location, convert_board"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "df1d7f8a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# =========================================\n",
    "# Config\n",
    "# =========================================\n",
    "PARAMS = {\n",
    "    'output_dir': \"processed_data\",\n",
    "    'max_samples': 5,\n",
    "    'task': 'eval',      # {'rejsampling', 'eval', 'verl', 'verl_eval'}\n",
    "    'board_notation': 'uniform_visual',     # {fen, spaced_fen, visual, uniform_visual}\n",
    "    'use_fixed_ids': False,\n",
    "    'final_samples_fixed_ids': 10,\n",
    "    'legalmoves_minmoves': 3,       # If you set = 2, way too many samples are for pawns (which are very easy anyway)\n",
    "    'worstmove_movethresh': 0.2,\n",
    "    'worstmove_providedmoves': 5,\n",
    "    'bestmove_movethresh': 0.2,\n",
    "    'bestmove_providedmoves': 5,\n",
    "    'predictmove_minpossiblemoves': 3\n",
    "}\n",
    "FILE_MAPPING = {\n",
    "    \"eval\": './chess_data/deepmind62k_evals_1k.csv',\n",
    "    \"verl_eval\": './chess_data/deepmind62k_evals_1k.csv',\n",
    "    \"verl\": './chess_data/deepmind62k_train_50k.csv',\n",
    "    \"rejsampling\": './chess_data/deepmind62k_train_50k.csv',\n",
    "}\n",
    "SAVED_FILES = []\n",
    "\n",
    "# =========================================\n",
    "# Import / process our data\n",
    "# =========================================\n",
    "input_filename = FILE_MAPPING[PARAMS['task']]\n",
    "df = pd.read_csv(input_filename)\n",
    "df['Move'] = df['Move'].apply(ast.literal_eval)\n",
    "df['Win Probability'] = df['Win Probability'].apply(ast.literal_eval)\n",
    "\n",
    "fixed_ids = None\n",
    "if PARAMS['use_fixed_ids']:\n",
    "    fixed_ids = random.sample(\n",
    "        df[\"FEN ID\"].unique().tolist(),\n",
    "        PARAMS['max_samples']\n",
    "    )\n",
    "    pool = df[df[\"FEN ID\"].isin(fixed_ids)]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "21aca1b7",
   "metadata": {},
   "source": [
    "# Generate Data   \n",
    "---"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "626a4b22",
   "metadata": {},
   "source": [
    "### Legal Moves   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "4af81aed",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saved JSONL to ./processed_data/legalmoves_eval_5.jsonl\n"
     ]
    }
   ],
   "source": [
    "# Start by shuffling if not using fixed ids\n",
    "if not PARAMS['use_fixed_ids']:\n",
    "    pool = df.sample(frac=1, ignore_index=True)\n",
    "\n",
    "# Now generate our samples\n",
    "outputs = []\n",
    "for index, row in pool.iterrows():\n",
    "    if not PARAMS['use_fixed_ids'] and len(outputs) >= PARAMS['max_samples']:\n",
    "        break\n",
    "    \n",
    "    board    = row[\"FEN\"]\n",
    "    moveset  = row[\"Move\"]\n",
    "    winprobs = row[\"Win Probability\"]\n",
    "    fen_id   = row[\"FEN ID\"]\n",
    "\n",
    "    # Get the counts of the various pieces in the moveset\n",
    "    piece_counts = {}\n",
    "    for move in moveset:\n",
    "        piece_pos = move[:2]\n",
    "        piece_counts[piece_pos] = piece_counts.get(piece_pos, 0) + 1\n",
    "    valid_pieces = [k for k, v in piece_counts.items() if v >= PARAMS['legalmoves_minmoves']]\n",
    "    if not valid_pieces:\n",
    "        continue  # Skip if no valid pieces w/ enough moves\n",
    "\n",
    "    # Sample a piece from the valid pieces\n",
    "    piece = random.choice(valid_pieces)\n",
    "    piece_name = get_piece_name_at_location(board, piece)\n",
    "    if piece_name is None:\n",
    "        print(f\"Piece not found at {piece} in FEN: {board}\")\n",
    "        continue\n",
    "    \n",
    "    user = f\"\"\"Below is a board in a game you're currently playing.\n",
    "\n",
    "{convert_board(board, PARAMS['board_notation'])}\n",
    "\n",
    "You must provide a list of all legal moves for the {piece_name} at {piece}.\n",
    "\n",
    "You may want to think out loud to help finalize your answer. However, you must provide your answer within answer tags (e.g., <answer> list_of_moves </answer>).\n",
    "\n",
    "The moves must be provided as a list, in UCI notation, and within answer tags in order to be accepted.\"\"\"\n",
    "    \n",
    "    chat_history = [\n",
    "        ['system', \"chess_task_sysprompt.txt\"], \n",
    "        ['user', user], \n",
    "        ['assistant', \"\"], \n",
    "    ]\n",
    "    \n",
    "    legal_moves = [move for move in moveset if move.startswith(piece)]\n",
    "    info = {\n",
    "        'board_id': fen_id,\n",
    "        'board': board,\n",
    "        'answer': legal_moves,\n",
    "        'task_data': (piece_name, piece)\n",
    "    }\n",
    "    \n",
    "    outputs.append({\n",
    "        'chat': chat_history,\n",
    "        'info': info\n",
    "    })\n",
    "\n",
    "# Export as jsonl\n",
    "jsonl_path = f\"./{PARAMS['output_dir']}/legalmoves_{PARAMS['task']}_{len(outputs)}.jsonl\"\n",
    "with open(jsonl_path, 'w', encoding='utf-8') as f:\n",
    "    for obj in outputs:\n",
    "        f.write(json.dumps(obj) + '\\n')\n",
    "print(f\"Saved JSONL to {jsonl_path}\")\n",
    "SAVED_FILES.append(jsonl_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3d6b3e29",
   "metadata": {},
   "source": [
    "### Worst Move"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d24a7100",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saved JSONL to ./processed_data/worstmove_eval_5.jsonl\n"
     ]
    }
   ],
   "source": [
    "# Start by shuffling if not using fixed ids\n",
    "if not PARAMS['use_fixed_ids']:\n",
    "    pool = df.sample(frac=1, ignore_index=True)\n",
    "\n",
    "# Iterate through and get samples\n",
    "outputs = []\n",
    "for index, row in pool.iterrows():\n",
    "    if not PARAMS['use_fixed_ids'] and len(outputs) >= PARAMS['max_samples']:\n",
    "        break\n",
    "\n",
    "    board = row['FEN']\n",
    "    moveset = row['Move']\n",
    "    win_probs = row['Win Probability']\n",
    "    id = row['FEN ID']\n",
    "\n",
    "    move_prob_pairs = list(zip(moveset, win_probs))\n",
    "    worst_move, worst_move_win_prob = min(move_prob_pairs, key=lambda x: x[1])\n",
    "\n",
    "    filtered_moves = [\n",
    "        move for move, prob in move_prob_pairs\n",
    "        if prob > PARAMS['worstmove_movethresh'] + worst_move_win_prob\n",
    "    ]\n",
    "\n",
    "    if len(filtered_moves) < PARAMS['worstmove_providedmoves'] - 1:\n",
    "        continue\n",
    "\n",
    "    sampled_moves = random.sample(filtered_moves, PARAMS['worstmove_providedmoves'] - 1)\n",
    "    sampled_moves.append(worst_move)\n",
    "    random.shuffle(sampled_moves)\n",
    "\n",
    "    user_prompt = f\"\"\"Below is a board in a game you're currently playing.\n",
    "\n",
    "{convert_board(board, PARAMS['board_notation'])}\n",
    "    \n",
    "You must choose the worst move from the following moves: {sampled_moves}. \n",
    "\n",
    "You may want to think out loud to help finalize your answer. However, you must provide your answer within answer tags (e.g., <answer> worst_move </answer>).\n",
    "\n",
    "The move must be provided in UCI notation and within answer tags in order to be accepted.\"\"\"\n",
    "\n",
    "    chat_history = [\n",
    "        ['system', \"chess_task_sysprompt.txt\"],\n",
    "        ['user', user_prompt],\n",
    "        ['assistant', \"\"], \n",
    "    ]\n",
    "    info = {\n",
    "        'board_id': id,\n",
    "        'board': board,\n",
    "        'answer': {'answer': worst_move, 'candidates': sampled_moves}\n",
    "    }\n",
    "\n",
    "    outputs.append({\n",
    "        'chat': chat_history,\n",
    "        'info': info\n",
    "    })\n",
    "\n",
    "# Export as jsonl\n",
    "jsonl_path = f\"./{PARAMS['output_dir']}/worstmove_{PARAMS['task']}_{len(outputs)}.jsonl\"\n",
    "with open(jsonl_path, 'w', encoding='utf-8') as f:\n",
    "    for obj in outputs:\n",
    "        f.write(json.dumps(obj) + '\\n')\n",
    "print(f\"Saved JSONL to {jsonl_path}\")\n",
    "SAVED_FILES.append(jsonl_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2fedf059",
   "metadata": {},
   "source": [
    "### Best Move   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "e10201a3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saved JSONL to ./processed_data/bestmove_eval_5.jsonl\n"
     ]
    }
   ],
   "source": [
    "# Start by shuffling if not using fixed ids\n",
    "if not PARAMS['use_fixed_ids']:\n",
    "    pool = df.sample(frac=1, ignore_index=True)\n",
    "\n",
    "# Iterate through and get samples\n",
    "outputs = []\n",
    "for index, row in pool.iterrows():\n",
    "    if not PARAMS['use_fixed_ids'] and len(outputs) >= PARAMS['max_samples']:\n",
    "        break\n",
    "\n",
    "    board = row['FEN']\n",
    "    moveset = row['Move']\n",
    "    win_probs = row['Win Probability']\n",
    "    id = row['FEN ID']\n",
    "\n",
    "    move_prob_pairs = list(zip(moveset, win_probs))\n",
    "\n",
    "    # Find the best move (highest win probability)\n",
    "    best_move, best_move_win_prob = max(move_prob_pairs, key=lambda x: x[1])\n",
    "\n",
    "    # Filter out moves with win probability < best_move_win_prob - best_move_thresh\n",
    "    filtered_moves = [\n",
    "        move for move, prob in move_prob_pairs\n",
    "        if prob < best_move_win_prob - PARAMS['bestmove_movethresh']\n",
    "    ]\n",
    "\n",
    "    if len(filtered_moves) < PARAMS['bestmove_providedmoves'] - 1:\n",
    "        continue\n",
    "\n",
    "    sampled_moves = random.sample(filtered_moves, PARAMS['bestmove_providedmoves'] - 1)\n",
    "    sampled_moves.append(best_move)\n",
    "    random.shuffle(sampled_moves)\n",
    "\n",
    "    user_prompt = f\"\"\"Below is a board in a game you're currently playing.\n",
    "\n",
    "{convert_board(board, PARAMS['board_notation'])}\n",
    "    \n",
    "You must choose the best move from the following moves: {sampled_moves}. \n",
    "\n",
    "You may want to think out loud to help finalize your answer. However, you must provide your answer within answer tags (e.g., <answer> best_move </answer>).\n",
    "\n",
    "The move must be provided in UCI notation and within answer tags in order to be accepted.\"\"\"\n",
    "\n",
    "    chat_history = [\n",
    "        ['system', \"chess_task_sysprompt.txt\"],\n",
    "        ['user', user_prompt],\n",
    "        ['assistant', \"\"], \n",
    "    ]\n",
    "    info = {\n",
    "        'board_id': id,\n",
    "        'board': board,\n",
    "        'answer': {'answer': best_move, 'candidates': sampled_moves}\n",
    "    }\n",
    "\n",
    "    outputs.append({\n",
    "        'chat': chat_history,\n",
    "        'info': info\n",
    "    })\n",
    "\n",
    "# Export as jsonl\n",
    "jsonl_path = f\"./{PARAMS['output_dir']}/bestmove_{PARAMS['task']}_{len(outputs)}.jsonl\"\n",
    "with open(jsonl_path, 'w', encoding='utf-8') as f:\n",
    "    for obj in outputs:\n",
    "        f.write(json.dumps(obj) + '\\n')\n",
    "print(f\"Saved JSONL to {jsonl_path}\")\n",
    "SAVED_FILES.append(jsonl_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1eeaeb39",
   "metadata": {},
   "source": [
    "### Predict Move   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "833c8b79",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saved JSONL to ./processed_data/predictmove_eval_5.jsonl\n"
     ]
    }
   ],
   "source": [
    "# Start by shuffling if not using fixed ids\n",
    "if not PARAMS['use_fixed_ids']:\n",
    "    pool = df.sample(frac=1, ignore_index=True)\n",
    "\n",
    "# Iterate through and get samples\n",
    "outputs = []\n",
    "for index, row in pool.iterrows():\n",
    "    if not PARAMS['use_fixed_ids'] and len(outputs) >= PARAMS['max_samples']:\n",
    "        break\n",
    "\n",
    "    board = row['FEN']\n",
    "    moveset = row['Move']\n",
    "    win_probs = row['Win Probability']\n",
    "    id = row['FEN ID']\n",
    "\n",
    "    if len(moveset) < PARAMS['predictmove_minpossiblemoves']:\n",
    "        continue\n",
    "\n",
    "    move_prob_dict = dict(zip(moveset, win_probs))\n",
    "\n",
    "    user_prompt = f\"\"\"Below is a chess board from your current game.\n",
    "\n",
    "{convert_board(board, PARAMS['board_notation'])}\n",
    "\n",
    "You must select the best move from this position and return it within answer tags. Your answer must be formatted as <answer> my_move </answer>, where my_move is a legal move in UCI notation.\n",
    "\n",
    "Think step by step if necessary, but do not omit the answer tags or UCI format. Only answers in the correct format will be accepted.\"\"\"\n",
    "\n",
    "    chat_history = [\n",
    "        ['system', \"chess_task_sysprompt.txt\"],\n",
    "        ['user', user_prompt],\n",
    "        ['assistant', \"\"], \n",
    "    ]\n",
    "\n",
    "    info = {\n",
    "        'board_id': id,\n",
    "        'board': board,\n",
    "        'answer': move_prob_dict\n",
    "    }\n",
    "\n",
    "    outputs.append({\n",
    "        'chat': chat_history,\n",
    "        'info': info\n",
    "    })\n",
    "\n",
    "# Export as jsonl\n",
    "jsonl_path = f\"./{PARAMS['output_dir']}/predictmove_{PARAMS['task']}_{len(outputs)}.jsonl\"\n",
    "with open(jsonl_path, 'w', encoding='utf-8') as f:\n",
    "    for obj in outputs:\n",
    "        f.write(json.dumps(obj) + '\\n')\n",
    "print(f\"Saved JSONL to {jsonl_path}\")\n",
    "SAVED_FILES.append(jsonl_path)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "lang-chess",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
