{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# Usage example:\n",
    "import os\n",
    "from datasets import load_dataset, concatenate_datasets, DatasetDict\n",
    "import numpy as np\n",
    "# Login using e.g. `huggingface-cli login` to access this dataset\n",
    "ds = load_dataset(\"textdetox/multilingual_toxicity_dataset\")\n",
    "\n",
    "# languages selected\n",
    "languages = ['en', 'de', 'es', 'fr', 'ja', 'ar', 'uk', 'hi']\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "train split:\n",
      "Total samples: 24000\n",
      "Samples per language:\n",
      "ar: 3000\n",
      "de: 3000\n",
      "en: 3000\n",
      "es: 3000\n",
      "fr: 3000\n",
      "hi: 3000\n",
      "ja: 3000\n",
      "uk: 3000\n",
      "Toxic ratio: 0.500\n",
      "\n",
      "val split:\n",
      "Total samples: 3200\n",
      "Samples per language:\n",
      "ar: 400\n",
      "de: 400\n",
      "en: 400\n",
      "es: 400\n",
      "fr: 400\n",
      "hi: 400\n",
      "ja: 400\n",
      "uk: 400\n",
      "Toxic ratio: 0.500\n",
      "\n",
      "test split:\n",
      "Total samples: 3200\n",
      "Samples per language:\n",
      "ar: 400\n",
      "de: 400\n",
      "en: 400\n",
      "es: 400\n",
      "fr: 400\n",
      "hi: 400\n",
      "ja: 400\n",
      "uk: 400\n",
      "Toxic ratio: 0.500\n"
     ]
    }
   ],
   "source": [
    "def create_multilingual_splits(\n",
    "    dataset_dict, \n",
    "    languages,\n",
    "    train_size=0.7,\n",
    "    val_size=0.1,\n",
    "    test_size=0.2,\n",
    "    seed=0\n",
    "):\n",
    "    \"\"\"\n",
    "    Create stratified train/val/test splits from multilingual dataset\n",
    "    \n",
    "    Args:\n",
    "        dataset_dict: HuggingFace dataset dictionary containing language-specific datasets\n",
    "        languages: List of language codes to include\n",
    "        train_size: Proportion for training set (default: 0.7)\n",
    "        val_size: Proportion for validation set (default: 0.1)\n",
    "        test_size: Proportion for test set (default: 0.2)\n",
    "        seed: Random seed for reproducibility (default: 0)\n",
    "    \n",
    "    Returns:\n",
    "        Dictionary containing train, val, and test splits\n",
    "    \"\"\"\n",
    "    assert abs(train_size + val_size + test_size - 1.0) < 1e-6, \"Split sizes must sum to 1\"\n",
    "    \n",
    "    # 1. Process each language dataset and add language column\n",
    "    processed_datasets = []\n",
    "    for lang in languages:\n",
    "        if lang not in dataset_dict:\n",
    "            print(f\"Warning: Language {lang} not found in dataset\")\n",
    "            continue\n",
    "            \n",
    "        # Get dataset for this language and add language column\n",
    "        lang_dataset = dataset_dict[lang]\n",
    "        lang_dataset = lang_dataset.add_column('lang', [lang] * len(lang_dataset))\n",
    "        \n",
    "        # Keep only required columns\n",
    "        lang_dataset = lang_dataset.remove_columns(\n",
    "            [col for col in lang_dataset.column_names if col not in ['text', 'toxic', 'lang']]\n",
    "        )\n",
    "        processed_datasets.append(lang_dataset)\n",
    "\n",
    "    # 2. Combine all language datasets\n",
    "    combined_dataset = concatenate_datasets(processed_datasets)\n",
    "\n",
    "    # Create a combined stratification column\n",
    "    def create_stratification_key(example):\n",
    "        return {'strat': f\"{example['lang']}_{example['toxic']}\"}\n",
    "\n",
    "    combined_dataset = combined_dataset.map(create_stratification_key)\n",
    "    combined_dataset = combined_dataset.class_encode_column('strat')\n",
    "\n",
    "    # 3. Create splits while maintaining distribution\n",
    "    n = len(combined_dataset)\n",
    "    train_samples = int(train_size * n)\n",
    "    val_test_samples = int((val_size + test_size) * n)\n",
    "    \n",
    "    splits_dict = combined_dataset.train_test_split(\n",
    "        train_size=train_samples,\n",
    "        test_size=val_test_samples,\n",
    "        seed=seed,\n",
    "        stratify_by_column='strat'\n",
    "    )\n",
    "\n",
    "    # Further split the test portion into val and test\n",
    "    val_test_splits = splits_dict['test'].train_test_split(\n",
    "        train_size=val_size,\n",
    "        test_size=test_size,\n",
    "        seed=seed,\n",
    "        stratify_by_column='strat'\n",
    "    )\n",
    "\n",
    "    # Remove stratification column and create final splits\n",
    "    final_splits = {\n",
    "        'train': splits_dict['train'].remove_columns('strat'),\n",
    "        'val': val_test_splits['train'].remove_columns('strat'),\n",
    "        'test': val_test_splits['test'].remove_columns('strat')\n",
    "    }\n",
    "\n",
    "    # Print statistics\n",
    "    for split_name, split_dataset in final_splits.items():\n",
    "        print(f\"\\n{split_name} split:\")\n",
    "        print(f\"Total samples: {len(split_dataset)}\")\n",
    "        print(\"Samples per language:\")\n",
    "\n",
    "        unique_langs = split_dataset.unique('lang')\n",
    "        counts = {}\n",
    "        for lang in unique_langs:\n",
    "            counts[lang] = len([x for x in split_dataset['lang'] if x == lang])\n",
    "        \n",
    "        for lang, count in sorted(counts.items()):\n",
    "            print(f\"{lang}: {count}\")\n",
    "\n",
    "        toxic_ratio = np.mean(split_dataset['toxic'])\n",
    "        print(f\"Toxic ratio: {toxic_ratio:.3f}\")\n",
    "\n",
    "    return final_splits\n",
    "\n",
    "splits = create_multilingual_splits(\n",
    "    ds,\n",
    "    languages,\n",
    "    train_size=0.6,\n",
    "    val_size=0.2,\n",
    "    test_size=0.2,\n",
    "    seed=0\n",
    ")\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Saving the dataset (1/1 shards): 100%|██████████| 24000/24000 [00:00<00:00, 141810.45 examples/s]\n",
      "Saving the dataset (1/1 shards): 100%|██████████| 3200/3200 [00:00<00:00, 388305.30 examples/s]\n",
      "Saving the dataset (1/1 shards): 100%|██████████| 3200/3200 [00:00<00:00, 372361.57 examples/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saved datasets to data/multilingual\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "def save_splits(splits, output_dir='data/processed'):\n",
    "    \"\"\"\n",
    "    Save dataset splits to disk and optionally push to HuggingFace Hub\n",
    "    \n",
    "    Args:\n",
    "        splits: Dictionary containing train/val/test splits\n",
    "        output_dir: Local directory to save splits\n",
    "    \"\"\"\n",
    "  \n",
    "    \n",
    "    # Create output directory if it doesn't exist\n",
    "    os.makedirs(output_dir, exist_ok=True)\n",
    "    \n",
    "    # Create a DatasetDict for better organization\n",
    "    dataset_dict = DatasetDict(splits)\n",
    "    \n",
    "    # Save locally\n",
    "    dataset_dict.save_to_disk(output_dir)\n",
    "    print(f\"Saved datasets to {output_dir}\")\n",
    "    \n",
    "\n",
    "\n",
    "\n",
    "# Example usage:\n",
    "# Save the splits\n",
    "save_splits(splits, output_dir='data/multilingual')\n",
    "\n",
    "# Load them back (in another script/session)\n",
    "# loaded_splits = load_splits('data/processed')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
