{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Contamination Detection via Context (CoDeC)\n",
    "\n",
    "This notebook implements a simple in-context data contamination detection method CoDeC for large language models. The method compares the confidence of a model in predicting target text with and without similar context, identifying potentially contaminated data when the model shows higher confidence without context.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Method Description\n",
    "\n",
    "CoDeC works by comparing how well a language model predicts a target text in two scenarios:\n",
    "\n",
    "1. **Without context**: The model sees only the target text\n",
    "2. **With context**: The model sees similar examples before the target text\n",
    "\n",
    "### Key Intuition\n",
    "\n",
    "If a text was seen during training (contaminated), the model should not benefit from the context. However, if the text is novel, providing similar examples as context should help the model, making it more confident.\n",
    "\n",
    "Therefore:\n",
    "- **Higher confidence without context** → Likely contaminated\n",
    "- **Higher confidence with context** → Likely not contaminated\n",
    "\n",
    "### Algorithm Steps\n",
    "\n",
    "1. For each sample in the dataset:\n",
    "   - Get log probabilities for the sample alone\n",
    "   - Get log probabilities for the sample with context examples\n",
    "   - Compare average confidence scores\n",
    "   - Classify as contaminated if confidence is higher without context\n",
    "\n",
    "2. Calculate overall contamination score as the proportion of samples classified as contaminated\n",
    "\n",
    "### Parameters\n",
    "\n",
    "- **num_context_examples**: How many examples to use as context\n",
    "- **model_name**: HuggingFace model identifier\n",
    "\n",
    "### Data Loading Options\n",
    "\n",
    "- **Pickle files (.pkl)**: Should contain a list of strings\n",
    "- **Text files (.txt)**: Automatically split into chunks\n",
    "- **Benchmark datasets**: Download a benchmark dataset from the HuggingFace hub\n",
    "\n",
    "This method is simple, interpretable, and requires only the model's log probabilities."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Install required packages\n",
    "%pip install torch transformers numpy -q"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "from datasets import load_dataset\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "from typing import List, Dict, Any\n",
    "import random\n",
    "import pickle\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Core Implementation\n",
    "\n",
    "The contamination detection pipeline consists of three main components:\n",
    "1. **Model Handler**: Loads and manages the language model\n",
    "2. **Contamination Detector**: Implements the core detection algorithm\n",
    "3. **Pipeline**: Orchestrates the detection process for a dataset\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ModelHandler:\n",
    "    \"\"\"Handles model loading and inference for contamination detection.\"\"\"\n",
    "    \n",
    "    def __init__(self, model_name: str, device: str = \"auto\"):\n",
    "        self.model_name = model_name\n",
    "        if device == \"auto\":\n",
    "            self.device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "        else:\n",
    "            self.device = device\n",
    "        \n",
    "        print(f\"Loading model {model_name} on {self.device}...\")\n",
    "        self.tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "        self.model = AutoModelForCausalLM.from_pretrained(\n",
    "            model_name,\n",
    "            torch_dtype=torch.float16 if self.device == \"cuda\" else torch.float32\n",
    "        ).to(self.device)\n",
    "        \n",
    "        # Add padding token if not present\n",
    "        if self.tokenizer.pad_token is None:\n",
    "            self.tokenizer.pad_token = self.tokenizer.eos_token\n",
    "        \n",
    "        print(f\"Model loaded successfully!\")\n",
    "    \n",
    "    def get_logprobs(self, text: str) -> np.ndarray:\n",
    "        \"\"\"Get log probabilities for each token in the text.\"\"\"\n",
    "        with torch.no_grad():\n",
    "            # Tokenize input\n",
    "            inputs = self.tokenizer(text, return_tensors=\"pt\").to(self.device)\n",
    "            \n",
    "            # Get model outputs\n",
    "            outputs = self.model(**inputs)\n",
    "            logits = outputs.logits\n",
    "            \n",
    "            # Convert logits to log probabilities\n",
    "            log_probs = torch.log_softmax(logits, dim=-1)\n",
    "            \n",
    "            # Get the log probability of each actual token\n",
    "            input_ids = inputs[\"input_ids\"][0]\n",
    "            token_log_probs = []\n",
    "            \n",
    "            for i in range(len(input_ids) - 1):\n",
    "                next_token_id = input_ids[i + 1]\n",
    "                token_log_prob = log_probs[0, i, next_token_id].item()\n",
    "                token_log_probs.append(token_log_prob)\n",
    "            \n",
    "            return np.array(token_log_probs)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CoDeC:\n",
    "    \"\"\"Detects data contamination using in-context learning differences.\"\"\"\n",
    "    \n",
    "    def __init__(self, token_range: tuple = (10, -1)):\n",
    "        \"\"\"\n",
    "        Initialize the contamination detector.\n",
    "        \n",
    "        Args:\n",
    "            token_range: Range of tokens to consider for contamination detection.\n",
    "        \"\"\"\n",
    "        self.token_range = token_range\n",
    "    \n",
    "    def detect_contamination(self, target_text: str, context_examples: List[str], \n",
    "                           model_handler: ModelHandler) -> float:\n",
    "        \"\"\"\n",
    "        Detect contamination for a single target text.\n",
    "        \n",
    "        Args:\n",
    "            target_text: The text to test for contamination\n",
    "            context_examples: List of similar examples to use as context\n",
    "            model_handler: Model handler for inference\n",
    "        \n",
    "        Returns:\n",
    "            Contamination score (1 = contaminated, 0 = not contaminated)\n",
    "        \"\"\"\n",
    "        # Get log probabilities without context\n",
    "        logprobs_no_context = model_handler.get_logprobs(target_text)\n",
    "        \n",
    "        # Create context by joining examples\n",
    "        if context_examples:\n",
    "            context = \"\\n\\n\".join(context_examples)\n",
    "            text_with_context = context + \"\\n\\n\" + target_text\n",
    "        else:\n",
    "            text_with_context = target_text\n",
    "        \n",
    "        # Get log probabilities with context\n",
    "        logprobs_with_context = model_handler.get_logprobs(text_with_context)\n",
    "        \n",
    "        # Extract target portion from context output\n",
    "        # The target text appears at the end, so we take the last N tokens\n",
    "        target_tokens_no_context = len(logprobs_no_context)\n",
    "        logprobs_target_from_context = logprobs_with_context[-target_tokens_no_context:]\n",
    "        \n",
    "        # Calculate average confidence for the specified token range\n",
    "        start_idx, end_idx = self.token_range\n",
    "        end_idx = min(end_idx, len(logprobs_no_context))\n",
    "        \n",
    "        if start_idx >= len(logprobs_no_context):\n",
    "            # Target text is too short\n",
    "            return 0.0\n",
    "        \n",
    "        # Average confidence without context\n",
    "        confidence_no_context = np.mean(logprobs_no_context[start_idx:end_idx])\n",
    "        \n",
    "        # Average confidence with context\n",
    "        confidence_with_context = np.mean(logprobs_target_from_context[start_idx:end_idx])\n",
    "        \n",
    "        # Calculate difference (higher confidence without context suggests contamination)\n",
    "        confidence_diff = confidence_no_context - confidence_with_context\n",
    "        \n",
    "        # Return binary classification (1 if contaminated, 0 if not)\n",
    "        return 1.0 if confidence_diff > 0 else 0.0\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def contamination_detection_pipeline(\n",
    "    model: str | ModelHandler,\n",
    "    dataset: List[str],\n",
    "    num_context_examples: int = 1,\n",
    "    max_dataset_size: int = 1000\n",
    ") -> Dict[str, Any]:\n",
    "    \"\"\"\n",
    "    Run contamination detection on a dataset.\n",
    "    \n",
    "    Args:\n",
    "        model: HuggingFace model name/checkpoint, or ModelHandler instance\n",
    "        dataset: List of text samples to test\n",
    "        num_context_examples: Number of context examples to use\n",
    "        max_dataset_size: Maximum number of samples to process\n",
    "    \n",
    "    Returns:\n",
    "        Dictionary with contamination scores and overall score\n",
    "    \"\"\"\n",
    "    # Initialize components\n",
    "    if isinstance(model, str):\n",
    "        model_handler = ModelHandler(model)\n",
    "    else:\n",
    "        model_handler = model\n",
    "    detector = CoDeC()\n",
    "    \n",
    "    # Process each sample\n",
    "    sample_scores = []\n",
    "\n",
    "    if max_dataset_size is not None:\n",
    "        if len(dataset) > max_dataset_size:\n",
    "            dataset = random.sample(dataset, max_dataset_size)\n",
    "    \n",
    "    print(f\"Processing {len(dataset)} samples...\")\n",
    "    \n",
    "    for i, target_text in enumerate(dataset):\n",
    "        # Select context examples (excluding current sample)\n",
    "        available_examples = dataset[:i] + dataset[i+1:]\n",
    "        \n",
    "        if len(available_examples) >= num_context_examples:\n",
    "            # Randomly sample context examples\n",
    "            context_examples = random.sample(available_examples, num_context_examples)\n",
    "        else:\n",
    "            # Use all available examples if not enough\n",
    "            context_examples = available_examples\n",
    "        \n",
    "        # Detect contamination for this sample\n",
    "        score = detector.detect_contamination(target_text, context_examples, model_handler)\n",
    "        sample_scores.append(score)\n",
    "        \n",
    "        # print(f\"Sample {i+1}/{len(dataset)}: Score = {score}\")\n",
    "    \n",
    "    # Calculate overall contamination score\n",
    "    overall_score = np.mean(sample_scores)\n",
    "    \n",
    "    return {\n",
    "        \"model_name\": model_handler.model_name,\n",
    "        \"overall_contamination_score\": overall_score,\n",
    "        \"sample_scores\": sample_scores,\n",
    "        \"num_samples\": len(dataset),\n",
    "        \"num_contaminated\": sum(sample_scores)\n",
    "    }\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data Loading Utilities"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_dataset_from_pkl(file_path: str) -> List[str]:\n",
    "    \"\"\"\n",
    "    Load a dataset from a pickle file containing a list of strings.\n",
    "    \n",
    "    Args:\n",
    "        file_path: Path to the .pkl file\n",
    "        \n",
    "    Returns:\n",
    "        List of text samples\n",
    "    \"\"\"\n",
    "    try:\n",
    "        with open(file_path, 'rb') as f:\n",
    "            dataset = pickle.load(f)\n",
    "        \n",
    "        # Ensure dataset is a list of strings\n",
    "        if not isinstance(dataset, list):\n",
    "            raise ValueError(f\"Expected list in pickle file, got {type(dataset)}\")\n",
    "        \n",
    "        # Convert all items to strings if they aren't already\n",
    "        dataset = [str(item) for item in dataset]\n",
    "        \n",
    "        print(f\"Loaded {len(dataset)} samples from {file_path}\")\n",
    "        return dataset\n",
    "        \n",
    "    except Exception as e:\n",
    "        print(f\"Error loading pickle file {file_path}: {e}\")\n",
    "        raise\n",
    "\n",
    "\n",
    "def load_dataset_from_txt(file_path: str, chunk_size: int = 500, min_chunk_size: int = 100) -> List[str]:\n",
    "    \"\"\"\n",
    "    Load a dataset from a text file by splitting it into chunks.\n",
    "    \n",
    "    Args:\n",
    "        file_path: Path to the .txt file\n",
    "        chunk_size: Approximate number of characters per chunk\n",
    "        min_chunk_size: Minimum chunk size (chunks smaller than this are discarded)\n",
    "        \n",
    "    Returns:\n",
    "        List of text chunks\n",
    "    \"\"\"\n",
    "    try:\n",
    "        with open(file_path, 'r', encoding='utf-8') as f:\n",
    "            text = f.read()\n",
    "        \n",
    "            chunks = []\n",
    "\n",
    "            while len(text) > chunk_size:\n",
    "                sample = text[:chunk_size]\n",
    "\n",
    "                # Remove the boundary words that may have been cut off\n",
    "                sample = ' '.join(sample.split(' ')[1:-1])\n",
    "\n",
    "                # Only add chunks that are larger than the minimum chunk size\n",
    "                if len(sample) > min_chunk_size:\n",
    "                    chunks.append(sample.strip())\n",
    "\n",
    "                text = text[chunk_size:]\n",
    "        \n",
    "        print(f\"Split text file {file_path} into {len(chunks)} chunks\")\n",
    "        print(f\"Average chunk size: {np.mean([len(chunk) for chunk in chunks]):.1f} characters\")\n",
    "        \n",
    "        return chunks\n",
    "        \n",
    "    except Exception as e:\n",
    "        print(f\"Error loading text file {file_path}: {e}\")\n",
    "        raise\n",
    "\n",
    "\n",
    "def load_local_dataset(file_path: str, **kwargs) -> List[str]:\n",
    "    \"\"\"\n",
    "    Load a dataset from a file, automatically detecting the format.\n",
    "    \n",
    "    Args:\n",
    "        file_path: Path to the dataset file (.pkl or .txt)\n",
    "        **kwargs: Additional arguments passed to format-specific loaders\n",
    "        \n",
    "    Returns:\n",
    "        List of text samples\n",
    "    \"\"\"\n",
    "    if file_path.endswith('.pkl'):\n",
    "        return load_dataset_from_pkl(file_path)\n",
    "    elif file_path.endswith('.txt'):\n",
    "        return load_dataset_from_txt(file_path, **kwargs)\n",
    "    else:\n",
    "        raise ValueError(f\"Unsupported file format. Expected .pkl or .txt, got: {file_path}\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Benchmarks\n",
    "\n",
    "It is straightforward to evaluate the contamination detection pipeline on any benchmark. Here, we show how to evaluate it on the GPQA dataset:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_gpqa_dataset():\n",
    "    df = load_dataset(\"Idavidrein/gpqa\", f\"gpqa_diamond\")[\"train\"].to_pandas()\n",
    "    return df[\"Question\"].tolist()\n",
    "\n",
    "def load_gsm8k_dataset():\n",
    "    df = load_dataset(\"openai/gsm8k\", \"main\")[\"test\"].to_pandas()\n",
    "    return df[\"question\"].tolist()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Pile training data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_pile_wikipedia_dataset():\n",
    "    dataset = load_dataset(\"iamgroot42/mimir\", \"wikipedia_(en)\", split=\"ngram_13_0.8\")\n",
    "    return dataset[\"member\"]\n",
    "\n",
    "def load_pile_github_dataset():\n",
    "    dataset = load_dataset(\"iamgroot42/mimir\", \"github\", split=\"ngram_13_0.8\")\n",
    "    return dataset[\"member\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Playground"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading model EleutherAI/pythia-410m on cuda...\n",
      "Model loaded successfully!\n",
      "\n",
      " Seen datasets\n",
      "Processing 1000 samples...\n",
      "Contamination score: 0.949\n",
      "Processing 1000 samples...\n",
      "Contamination score: 0.897\n",
      "\n",
      " Unseen datasets\n",
      "Processing 1000 samples...\n",
      "Contamination score: 0.063\n",
      "Processing 198 samples...\n",
      "Contamination score: 0.42424242424242425\n"
     ]
    }
   ],
   "source": [
    "ID_datasets = [load_pile_wikipedia_dataset(), load_pile_github_dataset()]\n",
    "OOD_datasets = [load_gsm8k_dataset(), load_gpqa_dataset()]\n",
    "\n",
    "data = [('Seen datasets', ID_datasets), ('Unseen datasets', OOD_datasets)]\n",
    "\n",
    "model_handler = ModelHandler(\"EleutherAI/pythia-410m\")\n",
    "\n",
    "for name, datasets in data:\n",
    "    print(\"\\n\", name)\n",
    "    for dataset in datasets:\n",
    "        result = contamination_detection_pipeline(model_handler, dataset, num_context_examples=1, max_dataset_size=1000)\n",
    "        print(f\"Contamination score: {result['overall_contamination_score']}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py_icc",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
