{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "4e61326c",
   "metadata": {},
   "source": [
    "# Workbook to generate Factual Board Answering (FBA) samples\n",
    "---"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "6a35b7a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from pathlib import Path\n",
    "\n",
    "from fba.fba_generator import fba_generator\n",
    "from utils.sampling_manager import SamplingManager"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "470d547a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load base data. Must use the behaviorcloning dataset\n",
    "# DATA_FILE = \"chess_data/deepmind_behaviorcloning_1k.csv\"        # Smaller dataset for testing\n",
    "DATA_FILE = \"chess_data/deepmind_behaviorcloning_10k.csv\"        # Medium dataset for testing\n",
    "# DATA_FILE = \"chess_data/deepmind_behaviorcloning_100k.csv\"    # Larger dataset\n",
    "df = pd.read_csv(DATA_FILE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "29721687",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ================================================\n",
    "# Sampling criteria.\n",
    "# These determine the ending distribution of the dataset for each task.\n",
    "# ================================================\n",
    "BASE_SAMPLING_CRITERIA = {\n",
    "    \"movecount\": {\n",
    "        (0, 9): 0.15,\n",
    "        (10, 19): 0.3,\n",
    "        (20, 29): 0.25,\n",
    "        (30, 39): 0.20,\n",
    "        (40, None): 0.10,\n",
    "    },\n",
    "    \"player\": {\"w\": 0.5, \"b\": 0.5},\n",
    "}\n",
    "\n",
    "TASK_SAMPLING_CRITERIA = {\n",
    "    \"is_check\": {\n",
    "        \"is_check\": {\"n\": 0.50, \"w\": 0.25, \"b\": 0.25},\n",
    "        \"is_check_gen\": {\"tp\": 0.40, \"fp\": 0.10, \"tn\": 0.50},\n",
    "    },\n",
    "    \"large_mat_adv\": {\n",
    "        \"large_mat_adv_gen\": {\"tp\": 0.40, \"fp\": 0.10, \"tn\": 0.50},\n",
    "    },\n",
    "    \"mat_bal\": {\n",
    "        \"mat_bal\": {\"y\": 0.50, \"n\": 0.50},\n",
    "    },\n",
    "    \"is_legal\": {\n",
    "        \"is_legal_gen\": {\"tp\": 0.50, \"fp\": 0.10, \"tn\": 0.40},\n",
    "        \"is_legal_piece\": {\"p\": 0.10, \"b\": 0.20, \"n\": 0.20, \"r\": 0.20, \"q\": 0.20, \"k\": 0.10,},\n",
    "        \"is_legal_in_check\": {\"y\":0.1, \"n\":0.9},\n",
    "    },\n",
    "    \"under_attack\": {\n",
    "        \"under_attack_gen\": {\"tp\": 0.40, \"fp\": 0.20, \"tn\": 0.40},\n",
    "    },\n",
    "    \"mat_adv_value\": {\n",
    "        \"mat_adv_abs\": {\"0-100\": 0.4, \"100-300\": 0.3, \"300+\": 0.3},\n",
    "    },\n",
    "    \"win_prob\": {\n",
    "        \"win_prob\": {\"0-0.2\": 0.2, \"0.2-0.4\": 0.2, \"0.4-0.6\": 0.2, \"0.6-0.8\": 0.2, \"0.8-1\": 0.2},\n",
    "    },\n",
    "    \"mobility\": {\n",
    "        \"mobility_piece\": {\"p\": 0.06, \"b\": 0.2, \"n\": 0.2, \"r\": 0.2, \"q\": 0.25, \"k\": 0.09,},\n",
    "        \"mobility_moves\": {\"0-1\": 0.25, \"2-3\": 0.3, \"4-5\": 0.3, \"6+\": 0.15},\n",
    "    }, \n",
    "    \"contrastive_ntp\": {\n",
    "        \"contrastive_ntp\": {\"1\": 0.5, \"2\": 0.5, \"None\": 0},\n",
    "        \"contrastive_ntp_piece\": {\"p\": 0.05, \"b\": 0.2, \"n\": 0.2, \"r\": 0.2, \"q\": 0.25, \"k\": 0.1,},\n",
    "    }, \n",
    "    \"cloze_capture\": {\n",
    "        \"cloze_piece\": {\"p\": 0.1, \"b\": 0.2, \"n\": 0.2, \"r\": 0.2, \"q\": 0.2, \"k\": 0.1, \"None\": 0}\n",
    "    },\n",
    "    \"bestmove\": {},\n",
    "    \"multi_sample\": {},\n",
    "    \"bestline\": {},\n",
    "}\n",
    "\n",
    "\n",
    "PRINT_DISTRIBUTION_COLUMNS = {\n",
    "    \"is_check\": [\"movecount_bucket\", \"player_bucket\", \"is_check_bucket\", \"is_check_gen_bucket\"],\n",
    "    \"large_mat_adv\": [\"movecount_bucket\", \"player_bucket\", \"large_mat_adv_gen_bucket\"],\n",
    "    \"mat_bal\": [\"movecount_bucket\", \"player_bucket\", \"mat_bal_bucket\"],\n",
    "    \"is_legal\": [\"movecount_bucket\", \"player_bucket\", \"is_legal_gen_bucket\", \"is_legal_piece_bucket\", \"is_legal_in_check_bucket\"],\n",
    "    \"under_attack\": [\"movecount_bucket\", \"player_bucket\", \"under_attack_gen_bucket\"],\n",
    "    \"mat_adv_value\": [\"movecount_bucket\", \"player_bucket\", \"mat_adv_abs_bucket\"],\n",
    "    \"win_prob\": [\"movecount_bucket\", \"player_bucket\", \"win_prob_bucket\"],\n",
    "    \"mobility\": [\"movecount_bucket\", \"player_bucket\", \"mobility_piece_bucket\", \"mobility_moves_bucket\"],\n",
    "    \"contrastive_ntp\": [\"movecount_bucket\", \"player_bucket\", \"contrastive_ntp_piece_bucket\", \"contrastive_ntp_bucket\"],\n",
    "    \"cloze_capture\": [\"movecount_bucket\", \"player_bucket\", \"cloze_piece_bucket\"],\n",
    "    \"bestmove\": [\"movecount_bucket\", \"player_bucket\"],\n",
    "    \"multi_sample\": [\"movecount_bucket\", \"player_bucket\"],\n",
    "    \"bestline\": [\"movecount_bucket\", \"player_bucket\"],\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "757f9913",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ================================================\n",
    "# Generation criteria.\n",
    "# These the arguments we use for generation for each task. \n",
    "# The multi-sample task has 'frequency' which uses a lottery ticket system to determine which FBAs to generate (using their own generation criteria). \n",
    "# ================================================\n",
    "GENERATION_CONFIG = {    # This determines the actual generation parameters for each task. Adjust these if you'd like to alter the generation behavior for a task (piece frequency, etc.)\n",
    "    \"is_check\": {\"tp\": 0.8},\n",
    "    \"large_mat_adv\": {\"tp\": 0.8},\n",
    "    \"mat_bal\": None,\n",
    "    \"is_legal\": {\n",
    "        \"choose_legal\": {\"legal_you\": 0.5, \"legal_opp\": 0.1, \"illegal\": 0.4},\n",
    "        \"piece_freq\": {\"p\": 1, \"n\": 3, \"b\": 3, \"r\": 3, \"q\": 5, \"k\": 3},\n",
    "        \"in_check\": 0.1\n",
    "    },\n",
    "    \"under_attack\": {\n",
    "        \"legal_attack\": {\"attack_you\": 0.4, \"attack_opp\": 0.2, \"safe\": 0.4},\n",
    "        \"piece_freq\": {\"p\": 1, \"n\": 2, \"b\": 2, \"r\": 2, \"q\": 3, \"k\": 1},\n",
    "    },\n",
    "    \"mat_adv_value\": None,\n",
    "    \"win_prob\": None,\n",
    "    \"mobility\": {\n",
    "        \"piece_freq\": {\"p\": 1, \"n\": 3, \"b\": 3, \"r\": 3, \"q\": 5, \"k\": 1},\n",
    "    },\n",
    "    \"contrastive_ntp\": {\n",
    "        \"min_threshold\": 0.25,\n",
    "        \"piece_freq\": {\"p\": 1, \"n\": 5, \"b\": 5, \"r\": 5, \"q\": 8, \"k\": 3},\n",
    "    },\n",
    "    \"cloze_capture\": {\n",
    "        \"piece_freq\": {\"p\": 1, \"n\": 3, \"b\": 3, \"r\": 3, \"q\": 5, \"k\": 1},\n",
    "    },\n",
    "    \"bestmove\": None,\n",
    "    \"multi_sample\": None,\n",
    "    \"bestline\": {\n",
    "        \"plies\": (4, 6),\n",
    "        \"search_depth\": 10,\n",
    "    }\n",
    "}\n",
    "\n",
    "# New technique -- can update 'args' to equal something different if I want to adjust args\n",
    "GENERATION_CONFIG['multi_sample'] = {\n",
    "    \"generation_samples\": (4, 6),   # Dictates (min, max) number of FBAs to generate for each board\n",
    "    \"tasks\": {\n",
    "        \"is_check\": {\n",
    "            \"frequency\": 2,         # Lottery ticket frequency (higher = more likely to be chosen)\n",
    "            \"max_samples\": 1,       # Max number of samples to take from this task\n",
    "            \"args\": GENERATION_CONFIG['is_check']     # Defaults to using generation args previously defined for this task\n",
    "        },\n",
    "        \"is_legal\": {\n",
    "            \"frequency\": 5,\n",
    "            \"max_samples\": 3,\n",
    "            \"args\": GENERATION_CONFIG['is_legal']\n",
    "        },\n",
    "        \"under_attack\": {\n",
    "            \"frequency\": 5,\n",
    "            \"max_samples\": 2,\n",
    "            \"args\": GENERATION_CONFIG['under_attack']\n",
    "        },\n",
    "        \"mat_adv_value\": {\n",
    "            \"frequency\": 5,\n",
    "            \"max_samples\": 1,\n",
    "            \"args\": GENERATION_CONFIG['mat_adv_value']\n",
    "        },\n",
    "        \"mobility\": {\n",
    "            \"frequency\": 5,\n",
    "            \"max_samples\": 2,\n",
    "            \"args\": GENERATION_CONFIG['mobility']\n",
    "        },\n",
    "        \"cloze_capture\": {\n",
    "            \"frequency\": 5,\n",
    "            \"max_samples\": 2,\n",
    "            \"args\": GENERATION_CONFIG['cloze_capture']\n",
    "        }\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "03f27f80",
   "metadata": {},
   "outputs": [],
   "source": [
    "# =====================================\n",
    "# Helper functions\n",
    "# =====================================\n",
    "def generate(task, count, base_df):\n",
    "    # Start by filtering based on base_criteria (more efficient with our function)\n",
    "    base_df = SamplingManager(base_df, BASE_SAMPLING_CRITERIA).get_samples(len(base_df))\n",
    "    base_df.drop(columns=['movecount_bucket', 'player_bucket'])\n",
    "    cfg = GENERATION_CONFIG.get(task)\n",
    "    fba_df = fba_generator(task, base_df, cfg) if cfg else fba_generator(task, base_df)\n",
    "\n",
    "    # Now do final sampling based on combined criteria\n",
    "    sm   = SamplingManager(fba_df, BASE_SAMPLING_CRITERIA)\n",
    "    crit = {**BASE_SAMPLING_CRITERIA, **TASK_SAMPLING_CRITERIA.get(task, {})}\n",
    "    out  = sm.get_samples(count, criteria=crit)\n",
    "    return out\n",
    "\n",
    "def print_distributions(df, cols):\n",
    "    for col in cols:\n",
    "        print(f\"\\n{col}:\")\n",
    "        print(df[col].value_counts(normalize=True).sort_index())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b50461f3",
   "metadata": {},
   "source": [
    "# Generate our FBA Samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "bdde7df4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define our desired number of samples for each task.\n",
    "# You can do single task generations (i.e., just asks '1' problem) if you choose any the tasks above 'predict_bestmove'.\n",
    "# multi_sample will generate multiple FBA queries per board.\n",
    "NUM_SAMPLES_PER_TASK = {\n",
    "    # \"is_check\": 50,\n",
    "    # \"large_mat_adv\": 50,\n",
    "    # \"mat_bal\": 50,\n",
    "    # \"is_legal\": 50,\n",
    "    # \"under_attack\": 50,\n",
    "    # \"mat_adv_value\": 50,\n",
    "    # \"win_prob\": 50,\n",
    "    # \"mobility\": 50,\n",
    "    # \"contrastive_ntp\": 50,\n",
    "    # \"cloze_capture\": 50,\n",
    "    \"bestmove\": 1_000,\n",
    "    \"multi_sample\": 1_000,\n",
    "    \"bestline\": 1_000,\n",
    "}\n",
    "\n",
    "TASK_TO_LABEL_MAP = {\n",
    "    \"is_check\": \"is-check\",\n",
    "    \"large_mat_adv\": \"large-mat-adv\",\n",
    "    \"mat_bal\": \"mat-bal\",\n",
    "    \"is_legal\": \"is-legal\",\n",
    "    \"under_attack\": \"under-attack\",\n",
    "    \"mat_adv_value\": \"mat-adv-value\",\n",
    "    \"win_prob\": \"win-prob\",\n",
    "    \"mobility\": \"mobility\",\n",
    "    \"contrastive_ntp\": \"contrastive-ntp\",\n",
    "    \"cloze_capture\": \"cloze-capture\",\n",
    "    \"bestmove\": \"bestmove\",\n",
    "    \"multi_sample\": \"multi-fba\",\n",
    "    \"bestline\": \"bestline\",\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "1f8d7692",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "=== bestmove ===\n",
      "[906/9060] 34384.52 samples/s\n",
      "[1812/9060] 33873.73 samples/s\n",
      "[2718/9060] 34082.29 samples/s\n",
      "[3624/9060] 33814.20 samples/s\n",
      "[4530/9060] 34140.35 samples/s\n",
      "[5436/9060] 29492.92 samples/s\n",
      "[6342/9060] 29289.26 samples/s\n",
      "[7248/9060] 29571.46 samples/s\n",
      "[8154/9060] 29695.46 samples/s\n",
      "[9060/9060] 29613.82 samples/s\n",
      "Total Number of generation errors: 0\n",
      "\n",
      "movecount_bucket:\n",
      "movecount_bucket\n",
      "0-9      0.15\n",
      "10-19    0.30\n",
      "20-29    0.25\n",
      "30-39    0.20\n",
      "40+      0.10\n",
      "Name: proportion, dtype: float64\n",
      "\n",
      "player_bucket:\n",
      "player_bucket\n",
      "b    0.5\n",
      "w    0.5\n",
      "Name: proportion, dtype: float64\n",
      "\n",
      "=== multi_sample ===\n",
      "[906/9060] 1885.44 samples/s\n",
      "[1812/9060] 1867.59 samples/s\n",
      "[2718/9060] 1902.43 samples/s\n",
      "[3624/9060] 1896.71 samples/s\n",
      "[4530/9060] 1899.33 samples/s\n",
      "[5436/9060] 1889.67 samples/s\n",
      "[6342/9060] 1887.79 samples/s\n",
      "[7248/9060] 1893.49 samples/s\n",
      "[8154/9060] 1884.13 samples/s\n",
      "[9060/9060] 1886.67 samples/s\n",
      "Total Number of generation errors: 0\n",
      "\n",
      "movecount_bucket:\n",
      "movecount_bucket\n",
      "0-9      0.15\n",
      "10-19    0.30\n",
      "20-29    0.25\n",
      "30-39    0.20\n",
      "40+      0.10\n",
      "Name: proportion, dtype: float64\n",
      "\n",
      "player_bucket:\n",
      "player_bucket\n",
      "b    0.5\n",
      "w    0.5\n",
      "Name: proportion, dtype: float64\n",
      "\n",
      "=== bestline ===\n",
      "[906/9060] 97.70 samples/s\n",
      "[1812/9060] 99.52 samples/s\n",
      "[2718/9060] 97.87 samples/s\n",
      "[3624/9060] 99.05 samples/s\n",
      "[4530/9060] 99.31 samples/s\n",
      "[5436/9060] 98.56 samples/s\n",
      "[6342/9060] 98.57 samples/s\n",
      "[7248/9060] 98.74 samples/s\n",
      "[8154/9060] 98.78 samples/s\n",
      "[9060/9060] 98.69 samples/s\n",
      "Total Number of generation errors: 714\n",
      "\n",
      "movecount_bucket:\n",
      "movecount_bucket\n",
      "0-9      0.15\n",
      "10-19    0.30\n",
      "20-29    0.25\n",
      "30-39    0.20\n",
      "40+      0.10\n",
      "Name: proportion, dtype: float64\n",
      "\n",
      "player_bucket:\n",
      "player_bucket\n",
      "b    0.5\n",
      "w    0.5\n",
      "Name: proportion, dtype: float64\n"
     ]
    }
   ],
   "source": [
    "for task, count in NUM_SAMPLES_PER_TASK.items():\n",
    "    print(f\"\\n=== {task} ===\")\n",
    "    samples = generate(task, count, df)\n",
    "    print_distributions(samples, PRINT_DISTRIBUTION_COLUMNS[task])\n",
    "    outpath = Path(f\"processed_data/fba_{TASK_TO_LABEL_MAP[task]}_{count}.jsonl\")\n",
    "    samples[f\"{task}_chat\"].to_json(outpath, orient=\"records\", lines=True)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
