{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "abf9b153",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Open-Orca Dataset Preprocessing Notebook\n",
    "\n",
    "This notebook loads the Open-Orca dataset from HuggingFace, concatenates the \n",
    "system_prompt, question, and response fields into a complete_text field,\n",
    "and saves the processed dataset for instruction classification.\n",
    "\n",
    "WORKFLOW:\n",
    "1. Cell 0: Load the Open-Orca dataset from HuggingFace, process it, and save to disk\n",
    "2. Cell 1: Load the processed dataset, sample it, and export to JSONL format\n",
    "\n",
    "IMPORTANT: Run cells in order (0 -> 1) as each cell depends on the previous one.\n",
    "\"\"\"\n",
    "\n",
    "from datasets import load_dataset, DatasetDict, Dataset\n",
    "from typing import Union, Dict, Any, List\n",
    "\n",
    "# Login using e.g. `huggingface-cli login` to access this dataset\n",
    "# Load the Open-Orca dataset for instruction following tasks\n",
    "ds = load_dataset(\"Open-Orca/OpenOrca\")\n",
    "\n",
    "# Print dataset structure and statistics\n",
    "dataset_columns: List[str] = []\n",
    "\n",
    "if isinstance(ds, DatasetDict):\n",
    "    # Handle multi-split datasets\n",
    "    for split_name, split_ds in ds.items():\n",
    "        print(f\"{split_name}: {len(split_ds)} datapoints\")\n",
    "    \n",
    "    # Get column names from the first split\n",
    "    first_split_key: str = list(ds.keys())[0]\n",
    "    dataset_columns = ds[first_split_key].column_names\n",
    "    print(\"Columns:\", dataset_columns)\n",
    "    \n",
    "    # Calculate total datapoints across all splits\n",
    "    total_datapoints: int = sum(len(split_ds) for split_ds in ds.values())\n",
    "    print(\"Total datapoints:\", total_datapoints)\n",
    "    \n",
    "elif isinstance(ds, Dataset):\n",
    "    # Handle single dataset\n",
    "    dataset_columns = ds.column_names\n",
    "    print(\"Columns:\", dataset_columns)\n",
    "    print(\"Total datapoints:\", len(ds))\n",
    "\n",
    "def concatenate_text(example: Dict[str, Any]) -> Dict[str, str]:\n",
    "    \"\"\"\n",
    "    Concatenate system_prompt, question, and response into complete_text.\n",
    "    \n",
    "    Args:\n",
    "        example (Dict[str, Any]): Single dataset example containing the fields\n",
    "                                 'system_prompt', 'question', and 'response'\n",
    "        \n",
    "    Returns:\n",
    "        Dict[str, str]: Dictionary with 'complete_text' field containing\n",
    "                       the concatenated text from all three input fields\n",
    "        \n",
    "    Note:\n",
    "        Fields are concatenated with newlines and stripped of leading/trailing whitespace.\n",
    "    \"\"\"\n",
    "    # Concatenate the three text fields with newlines as separators\n",
    "    complete_text: str = f\"{example['system_prompt']}\\n{example['question']}\\n{example['response']}\".strip()\n",
    "    return {\n",
    "        'complete_text': complete_text\n",
    "    }\n",
    "\n",
    "# Apply concatenation function to all splits in the dataset\n",
    "if isinstance(ds, (DatasetDict, Dataset)):\n",
    "    # Process the dataset by adding the complete_text column\n",
    "    ds = ds.map(concatenate_text)\n",
    "    \n",
    "    print(\"\\nAfter adding complete_text column:\")\n",
    "    updated_columns: List[str] = []\n",
    "    \n",
    "    if isinstance(ds, DatasetDict):\n",
    "        # Show updated columns for multi-split dataset\n",
    "        updated_columns = ds[list(ds.keys())[0]].column_names\n",
    "        print(\"Columns:\", updated_columns)\n",
    "    else:\n",
    "        # Show updated columns for single dataset\n",
    "        updated_columns = ds.column_names\n",
    "        print(\"Columns:\", updated_columns)\n",
    "    \n",
    "    # Save the processed dataset locally for future use\n",
    "    save_path: str = \"../../data/orca\"\n",
    "    ds.save_to_disk(save_path)\n",
    "    print(f\"\\nDataset saved to: {save_path}\")\n",
    "else:\n",
    "    print(\"Cannot process streaming dataset - requires materialized dataset for processing\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87b30351",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Dataset Sampling and Export\n",
    "\n",
    "This cell loads the processed Open-Orca dataset, samples a specified number of \n",
    "datapoints, and exports them to a JSONL file for training the instruction classifier.\n",
    "\"\"\"\n",
    "\n",
    "from datasets import load_from_disk, DatasetDict, Dataset\n",
    "import json\n",
    "import os\n",
    "from typing import Dict, Any, List, Union\n",
    "\n",
    "# Configuration: Number of samples to extract\n",
    "NUM_SAMPLES: int = 2000\n",
    "\n",
    "# Load the processed dataset from disk\n",
    "load_path: str = \"../../data/orca\"\n",
    "\n",
    "# Check if the dataset directory exists\n",
    "if not os.path.exists(load_path):\n",
    "    print(f\"Error: Dataset directory does not exist at {load_path}\")\n",
    "    print(\"Please run Cell 0 first to process and save the dataset.\")\n",
    "    raise FileNotFoundError(f\"Directory {load_path} not found. Run Cell 0 to create the processed dataset.\")\n",
    "\n",
    "try:\n",
    "    ds: Union[DatasetDict, Dataset] = load_from_disk(load_path)\n",
    "    print(f\"Dataset loaded from: {load_path}\")\n",
    "except Exception as e:\n",
    "    print(f\"Error loading dataset from {load_path}: {e}\")\n",
    "    print(\"Please run Cell 0 first to process and save the dataset.\")\n",
    "    raise\n",
    "\n",
    "# Determine which dataset split to use for sampling\n",
    "sample_ds: Dataset\n",
    "if isinstance(ds, DatasetDict):\n",
    "    # Use the first split (usually 'train')\n",
    "    first_split: str = list(ds.keys())[0]\n",
    "    sample_ds = ds[first_split]\n",
    "    print(f\"Using '{first_split}' split for sampling\")\n",
    "else:\n",
    "    sample_ds = ds\n",
    "\n",
    "# Sample the specified number of datapoints and select only relevant columns\n",
    "# Use shuffle with fixed seed for reproducible sampling\n",
    "sample_dataset: Dataset = sample_ds.shuffle(seed=42).select(range(NUM_SAMPLES))\n",
    "sample_dataset = sample_dataset.select_columns(['id', 'complete_text'])\n",
    "\n",
    "# Display the sampled data for verification\n",
    "print(f\"\\nSampled {NUM_SAMPLES} datapoints:\")\n",
    "for i in range(min(10, NUM_SAMPLES)):  # Show first 10 or fewer if NUM_SAMPLES < 10\n",
    "    id_val: str = sample_dataset['id'][i]\n",
    "    complete_text: str = sample_dataset['complete_text'][i]\n",
    "    print(f\"\\n--- Example {i+1} ---\")\n",
    "    print(f\"ID: {id_val}\")\n",
    "    print(f\"Complete Text: {complete_text[:200]}...\")  # Show first 200 chars for brevity\n",
    "\n",
    "# Export sampled data to JSONL format for training\n",
    "output_path: str = f\"../../data/orca/orca_{NUM_SAMPLES}.jsonl\"\n",
    "with open(output_path, 'w', encoding='utf-8') as f:\n",
    "    for i in range(NUM_SAMPLES):\n",
    "        # Create example dictionary with required fields\n",
    "        example: Dict[str, str] = {\n",
    "            \"id\": sample_dataset['id'][i],\n",
    "            \"complete_text\": sample_dataset['complete_text'][i]\n",
    "        }\n",
    "        # Write as JSONL (one JSON object per line)\n",
    "        json.dump(example, f, ensure_ascii=False)\n",
    "        f.write('\\n')\n",
    "\n",
    "print(f\"\\nSample saved to: {output_path}\")\n",
    "print(f\"Ready for instruction classification training with {NUM_SAMPLES} examples\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e8197cc3",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "classifier",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
