{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2e68efe62b2e4bd686e80b5abc91f75f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/2 shards):   0%|          | 0/71763 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0584884177954b65996d84d3d715e184",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/5000 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import pandas as pd\n",
    "from datasets import load_from_disk, concatenate_datasets, Dataset, DatasetDict\n",
    "\n",
    "# Load\n",
    "ds = concatenate_datasets(\n",
    "    [\n",
    "        load_from_disk(f\"./data/Lichess/chunks/chunk_{chunk_index}\")\n",
    "        for chunk_index in range(21)\n",
    "    ]\n",
    ")\n",
    "df = ds.to_pandas()\n",
    "assert not df.isna().any().any()\n",
    "\n",
    "# Normalize difficulty\n",
    "df[\"NormRating\"] = (df[\"Rating\"] - df[\"Rating\"].min()) / (\n",
    "    df[\"Rating\"].max() - df[\"Rating\"].min()\n",
    ")\n",
    "df[\"NormRatingDeviation\"] = df[\"RatingDeviation\"] / (\n",
    "    df[\"Rating\"].max() - df[\"Rating\"].min()\n",
    ")\n",
    "df[\"RatingQuantile\"] = df[\"NormRating\"].rank(pct=True)\n",
    "\n",
    "\n",
    "# Tags\n",
    "# filter and remove \"oneMone\" and \"short\" tags\n",
    "df = df[\n",
    "    df[\"Themes\"].apply(lambda themes: \"oneMove\" in themes and \"short\" not in themes)\n",
    "]\n",
    "df[\"Themes\"] = df[\"Themes\"].apply(\n",
    "    lambda themes: [t for t in themes if t not in [\"oneMove\"]]\n",
    ")\n",
    "all_tags_dicts = {\n",
    "    \"GoalTags\": {\n",
    "        \"equality\": \"Equality\",\n",
    "        \"advantage\": \"Advantage\",\n",
    "        \"crushing\": \"Crushing\",\n",
    "        \"mate\": \"Checkmate\",\n",
    "    },\n",
    "    \"MotifTags\": {\n",
    "        \"advancedPawn\": \"Advanced Pawn\",\n",
    "        \"attackingF2F7\": \"Attacking f2 or f7\",\n",
    "        \"capturingDefender\": \"Capture the Defender\",\n",
    "        \"discoveredAttack\": \"Discovered Attack\",\n",
    "        \"doubleCheck\": \"Double Check\",\n",
    "        \"exposedKing\": \"Exposed King\",\n",
    "        \"fork\": \"Fork\",\n",
    "        \"hangingPiece\": \"Hanging Piece\",\n",
    "        \"kingsideAttack\": \"Kingside Attack\",\n",
    "        \"pin\": \"Pin\",\n",
    "        \"queensideAttack\": \"Queenside Attack\",\n",
    "        \"sacrifice\": \"Sacrifice\",\n",
    "        \"skewer\": \"Skewer\",\n",
    "        \"trappedPiece\": \"Trapped Piece\",\n",
    "        \"attraction\": \"Attraction\",\n",
    "        \"clearance\": \"Clearance\",\n",
    "        \"defensiveMove\": \"Defensive Move\",\n",
    "        \"deflection\": \"Deflection\",\n",
    "        \"interference\": \"Interference\",\n",
    "        \"intermezzo\": \"Intermezzo\",\n",
    "        \"quietMove\": \"Quiet Move\",\n",
    "        \"xRayAttack\": \"X-ray Attack\",\n",
    "        \"zugzwang\": \"Zugzwang\",\n",
    "    },\n",
    "    \"PhaseTags\": {\n",
    "        \"opening\": \"Opening\",\n",
    "        \"middlegame\": \"Middlegame\",\n",
    "        \"endgame\": \"Endgame\",\n",
    "        \"bishopEndgame\": \"Bishop Endgame\",\n",
    "        \"knightEndgame\": \"Knight Endgame\",\n",
    "        \"pawnEndgame\": \"Pawn Endgame\",\n",
    "        \"queenEndgame\": \"Queen Endgame\",\n",
    "        \"rookEndgame\": \"Rook Endgame\",\n",
    "        \"queenRookEndgame\": \"Queen and Rook Endgame\",\n",
    "    },\n",
    "    \"MateTags\": {\n",
    "        \"mateIn1\": \"Mate in 1\",\n",
    "        \"anastasiaMate\": \"Anastasia's Mate\",\n",
    "        \"arabianMate\": \"Arabian Mate\",\n",
    "        \"backRankMate\": \"Back Rank Mate\",\n",
    "        \"bodenMate\": \"Boden's Mate\",\n",
    "        \"doubleBishopMate\": \"Double Bishop Mate\",\n",
    "        \"dovetailMate\": \"Dovetail Mate\",\n",
    "        \"hookMate\": \"Hook Mate\",\n",
    "        \"smotheredMate\": \"Smothered Mate\",\n",
    "    },\n",
    "    \"SpecialMoveTags\": {\n",
    "        \"castling\": \"Castling\",\n",
    "        \"enPassant\": \"En Passant\",\n",
    "        \"promotion\": \"Promotion\",\n",
    "        \"underPromotion\": \"Underpromotion\",\n",
    "    },\n",
    "    \"OriginTags\": {\n",
    "        \"master\": \"Master Games\",\n",
    "        \"masterVsMaster\": \"Master Vs Master Games\",\n",
    "        \"superGM\": \"Super GM Games\",\n",
    "    },\n",
    "}\n",
    "for tag_col_name, tags_dict in all_tags_dicts.items():\n",
    "    df[tag_col_name] = df[\"Themes\"].apply(\n",
    "        lambda themes: [tags_dict[tag] for tag in themes if tag in tags_dict]\n",
    "    )\n",
    "df = df[df[\"GoalTags\"].apply(lambda x: len(x) == 1)]\n",
    "df[\"GoalTags\"] = df[\"GoalTags\"].apply(lambda x: x[0])\n",
    "df.drop(columns=[\"Themes\"], inplace=True)\n",
    "\n",
    "\n",
    "# Subsample using tag\n",
    "df = pd.concat(\n",
    "    [\n",
    "        df[df[\"GoalTags\"] == \"Checkmate\"].sample(30000, random_state=42),\n",
    "        df[df[\"GoalTags\"] != \"Checkmate\"],\n",
    "    ]\n",
    ")\n",
    "df = df.sample(frac=1, random_state=42)\n",
    "\n",
    "\n",
    "# Rename and reorder\n",
    "rename_dict = {\n",
    "    \"PuzzleId\": \"puzzle_id\",\n",
    "    \"NormRating\": \"rating\",\n",
    "    \"NormRatingDeviation\": \"rating_std\",\n",
    "    \"RatingQuantile\": \"rating_quantile\",\n",
    "    \"GoalTags\": \"tag\",\n",
    "    \"CurFEN\": \"fen\",\n",
    "    \"CurSimpPGN\": \"pgn\",\n",
    "    \"CurAnnoPGN\": \"annotated_pgn\",\n",
    "    \"CurUCIs\": \"uci_seq\",\n",
    "    \"CurSANs\": \"san_seq\",\n",
    "    \"AnswerSAN\": \"answer_san\",\n",
    "    \"AnswerUCI\": \"answer_uci\",\n",
    "    \"NumMoves\": \"init_num_moves\",\n",
    "    \"PuzzlePlayer\": \"player\",\n",
    "    \"Popularity\": \"popularity_score\",\n",
    "    \"NbPlays\": \"puzzle_num_plays\",\n",
    "    \"MotifTags\": \"motif_tags\",\n",
    "    \"PhaseTags\": \"phase_tags\",\n",
    "    \"MateTags\": \"mate_tags\",\n",
    "    \"SpecialMoveTags\": \"special_move_tags\",\n",
    "    \"OriginTags\": \"game_origin_tags\",\n",
    "    \"OpeningTags\": \"opening_tags\",\n",
    "    \"GameHash\": \"game_hash\",\n",
    "    \"GameUrl\": \"game_url\",\n",
    "    \"CompSimpPGN\": \"game_pgn\",\n",
    "    \"CompAnnoPGN\": \"game_annotated_pgn\",\n",
    "    \"Rating\": \"unnorm_rating\",\n",
    "    \"RatingDeviation\": \"unnorm_rating_std\",\n",
    "    \"PrevFEN\": \"previous_fen\",\n",
    "    \"LastMoveUCI\": \"last_move_uci\",\n",
    "}\n",
    "df = df[rename_dict.keys()]\n",
    "df.rename(columns=rename_dict, inplace=True)\n",
    "\n",
    "ds = Dataset.from_pandas(df.reset_index(drop=True))\n",
    "ds = ds.train_test_split(test_size=5000, seed=42)\n",
    "\n",
    "DatasetDict({\"train\": ds[\"train\"], \"eval\": ds[\"test\"]}).save_to_disk(\"./prepub/Lichess\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
