{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## This notebook is to compute the perplexity of sequences with a prompt using an interactive GPU\n",
    "\n",
    "This has been used to play around with the temperature of generated text and its perplexity, but also to compute the perplexity of in-distribution canaries. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "gather": {
     "logged": 1721055934000
    },
    "jupyter": {
     "outputs_hidden": false,
     "source_hidden": false
    },
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from tqdm import tqdm\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
    "\n",
    "# check if we have a GPU\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "gather": {
     "logged": 1720709196196
    },
    "jupyter": {
     "outputs_hidden": false,
     "source_hidden": false
    },
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "outputs": [],
   "source": [
    "# load the model\n",
    "model_name = 'mistralai/Mistral-7B-v0.1'  # You can choose other models like 'gpt2-medium', 'gpt2-large', etc.\n",
    "model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, use_auth_token='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')\n",
    "model.eval()\n",
    "model = model.to('cuda:0')\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token='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')\n",
    "tokenizer.pad_token_id = tokenizer.eos_token_id"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## (1) Get some example data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "jupyter": {
     "outputs_hidden": false,
     "source_hidden": false
    },
    "nteract": {
     "transient": {
      "deleting": false
     }
    }
   },
   "outputs": [],
   "source": [
    "import datasets\n",
    "\n",
    "agnews_dataset = datasets.load_dataset('fancyzhx/ag_news')\n",
    "agnews_dataset = agnews_dataset['train']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_labels = [x['label'] for x in agnews_dataset]\n",
    "all_label_names = agnews_dataset.features['label'].names\n",
    "all_label_names[all_labels[123]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get only the sequences with more than 50 words\n",
    "\n",
    "from tqdm import tqdm\n",
    "\n",
    "canary_length = 50\n",
    "valid_dataset_entries = []\n",
    "\n",
    "valid_indices = []\n",
    "for idx in tqdm(range(len(agnews_dataset))):\n",
    "    sample = agnews_dataset[idx]['text']\n",
    "    sample_split = sample.split()\n",
    "    if len(sample_split) >= canary_length:\n",
    "        entry = {key: agnews_dataset[idx][key] for key in agnews_dataset.column_names}\n",
    "        truncated_canary = \" \".join(sample_split[:canary_length])\n",
    "        entry['text'] = truncated_canary\n",
    "        valid_dataset_entries.append(entry)\n",
    "\n",
    "valid_dataset = datasets.Dataset.from_dict({\n",
    "    col: [entry[col] for entry in valid_dataset_entries] for col in agnews_dataset.column_names\n",
    "})\n",
    "\n",
    "print(f\"Number of valid samples: {len(valid_dataset)}\")\n",
    "valid_dataset[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np \n",
    "\n",
    "N_TO_SELECT = 100\n",
    "\n",
    "random_indices = np.random.choice(range(len(valid_dataset)), N_TO_SELECT, replace = False)\n",
    "in_canaries = valid_dataset.select(random_indices)\n",
    "in_canaries[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## (2) Get the functionality to compute perplexity given the prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# let's get all functions to compute the perplexity with a prompt\n",
    "# Starting with the data -> prompt -> tokenized dataset\n",
    "\n",
    "from functools import partial\n",
    "import numpy as np\n",
    "\n",
    "def convert_classification_record_to_synthesizer_record(\n",
    "        record, text_name: str, templated_prompt: str, label_name: str, label_int2str):\n",
    "    \"\"\"\n",
    "    Converts a classification record to a synthesizer record by replacing the label name in the templated prompt.\n",
    "\n",
    "    Args:\n",
    "        record (Dict[str, Any]): The classification record containing the text and label information.\n",
    "        text_name (str): The name of the text field in the record.\n",
    "        templated_prompt (str): The templated prompt with placeholders for the label name.\n",
    "                                The placeholder should be in the form of '{label_name}'.\n",
    "        label_name (str): The name of the label field in the record.\n",
    "        label_int2str (Callable[[int], str]): A function to convert the label integer to its corresponding string representation.\n",
    "\n",
    "    Returns:\n",
    "        Dict[str, Any]: The synthesizer record with the replaced prompt and the original text.\n",
    "\n",
    "    Raises:\n",
    "        ValueError: If the label name is not found in the templated prompt.\n",
    "    \"\"\"\n",
    "    label_str = label_int2str(record[label_name])\n",
    "    if f\"{{{label_name}}}\" not in templated_prompt:\n",
    "        raise ValueError(f\"Label name '{{{label_name}}}' not found in templated prompt: {templated_prompt}\")\n",
    "    prompt = templated_prompt.replace(f\"{{{label_name}}}\", label_str)\n",
    "    return {\n",
    "        \"prompt\": prompt,\n",
    "        \"completion\": record[text_name],\n",
    "    }\n",
    "\n",
    "def prep_dataset(dataset, prompt=\"A news article about {label}: \"):\n",
    "    int2str_mapping = agnews_dataset.features['label'].int2str\n",
    "    dataset = dataset.map(\n",
    "        partial(\n",
    "            convert_classification_record_to_synthesizer_record,\n",
    "            label_int2str=int2str_mapping,\n",
    "            text_name='text',\n",
    "            templated_prompt=prompt,\n",
    "            label_name='label'\n",
    "        ),\n",
    "        remove_columns=dataset.column_names\n",
    "    )\n",
    "    return dataset\n",
    "\n",
    "def main_preprocess_function(examples, tokenizer=tokenizer, sequence_len=256):\n",
    "    batch_size = len(examples[\"prompt\"])\n",
    "    model_inputs = tokenizer(examples[\"prompt\"])\n",
    "    labels = tokenizer(examples[\"completion\"])\n",
    "        \n",
    "    # Concatenate the prompt and completion parts as one input and set -100 to the labels of the prompt part\n",
    "    # This is because only the completion part will be used to calculate the loss\n",
    "    for i in range(batch_size):\n",
    "        sample_input_ids = model_inputs[\"input_ids\"][i]\n",
    "        # Tokenizer adds <s> to input_ids so just take the rest\n",
    "        label_input_ids = labels[\"input_ids\"][i][1:]\n",
    "        model_inputs[\"input_ids\"][i] = sample_input_ids + label_input_ids + [tokenizer.eos_token_id]\n",
    "        labels[\"input_ids\"][i] = [-100] * len(sample_input_ids) + label_input_ids + [tokenizer.eos_token_id]\n",
    "        model_inputs[\"attention_mask\"][i] = [1] * len(model_inputs[\"input_ids\"][i])\n",
    "\n",
    "    # Pad the samples with sequence_len and trim if longer than sequence_len\n",
    "    # NOTE THAT IF CONTEXT IS LONGER THAN SEQUENCE_LEN, THERE WILL BE NOTHING TO PREDICT, LABEL IS ALL -100\n",
    "    for i in range(batch_size):\n",
    "        sample_input_ids = model_inputs[\"input_ids\"][i]\n",
    "        label_input_ids = labels[\"input_ids\"][i]\n",
    "        model_inputs[\"input_ids\"][i] = [tokenizer.pad_token_id] * (sequence_len - len(sample_input_ids)) \\\n",
    "                                        + sample_input_ids\n",
    "        model_inputs[\"attention_mask\"][i] = [0] * (sequence_len - len(sample_input_ids)) \\\n",
    "                                            + model_inputs[\"attention_mask\"][i]\n",
    "        labels[\"input_ids\"][i] = [-100] * (sequence_len - len(sample_input_ids)) + label_input_ids\n",
    "        model_inputs[\"input_ids\"][i] = torch.tensor(model_inputs[\"input_ids\"][i][:sequence_len])\n",
    "        model_inputs[\"attention_mask\"][i] = torch.tensor(model_inputs[\"attention_mask\"][i][:sequence_len])\n",
    "        labels[\"input_ids\"][i] = torch.tensor(labels[\"input_ids\"][i][:sequence_len])\n",
    "\n",
    "    model_inputs[\"labels\"] = labels[\"input_ids\"]\n",
    "    return model_inputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# function to compute perplexity in batches\n",
    "\n",
    "from torch.nn import CrossEntropyLoss\n",
    "from typing import Dict\n",
    "\n",
    "# add prompt to dataset\n",
    "dataset_prepped = prep_dataset(in_canaries)\n",
    "\n",
    "dataset_tokenized =  dataset_prepped.map(\n",
    "        main_preprocess_function, batched=True, num_proc=2, desc=\"tokenizing dataset\", \n",
    "            remove_columns=dataset_prepped.column_names\n",
    ") \n",
    "\n",
    "def compute_perplexity_batch(\n",
    "    model: AutoModelForCausalLM,\n",
    "    batch: Dict[str, torch.Tensor],\n",
    "    device:torch.device):\n",
    "\n",
    "    with torch.no_grad():\n",
    "        \n",
    "        input_ids = torch.tensor(batch[\"input_ids\"]).to(device)\n",
    "        attention_mask = torch.tensor(batch[\"attention_mask\"]).to(device)\n",
    "        labels =  torch.tensor(batch[\"labels\"])\n",
    "\n",
    "        output = model(input_ids, attention_mask=attention_mask)\n",
    "\n",
    "        logits = output.logits[:, :-1, :] # remove the last token prediction\n",
    "        labels = labels[:, 1:] # Remove the first token in the labels\n",
    "        labels_np = labels.cpu().numpy()        \n",
    "\n",
    "        # pytorch expects n_classes dimension as the 2nd (i.e. index 1) dimension\n",
    "        logits = np.swapaxes(logits.cpu().numpy(), 1, -1)\n",
    "\n",
    "        loss = CrossEntropyLoss(ignore_index=-100, reduction=\"none\")(torch.tensor(logits), torch.tensor(labels)).cpu().numpy()\n",
    "    \n",
    "    completion_mask = attention_mask.cpu().numpy()[:, 1:]\n",
    "    completion_mask[labels_np == -100] = 0\n",
    "    mean_cross_entropy_loss = loss.mean(axis=1, where=completion_mask.astype(bool))\n",
    "    ppls = np.exp(mean_cross_entropy_loss)\n",
    "\n",
    "    return ppls\n",
    "\n",
    "ppl = compute_perplexity_batch(model=model, batch = dataset_tokenized[:5], device = torch.device('cuda:0'))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# now for all of them\n",
    "\n",
    "batch_size = 5\n",
    "n_batches = int(len(in_canaries) / batch_size)\n",
    "perplexities_in_canaries = []\n",
    "\n",
    "for i in tqdm(range(n_batches)):\n",
    "    ppl_batch = compute_perplexity_batch(model=model, \n",
    "                                         batch=dataset_tokenized[i * batch_size: (i+1) * batch_size], \n",
    "                                         device=torch.device('cuda:0'))\n",
    "    perplexities_in_canaries.extend(ppl_batch)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## (3) Let's compute it for in-canaries actually implemented"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from privacy_estimates.experiments.aml import JobList\n",
    "\n",
    "# get the urls for in-canary, n_rep= 12 (main experiments)\n",
    "all_urls = {\n",
    "    'sst2': {\n",
    "        'synthetic_2gram': 'https://ml.azure.com/runs/wheat_wolf_85rw8vptjh?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourcegroups/PPML/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47'},\n",
    "    'agnews': {\n",
    "        'synthetic_2gram': 'https://ml.azure.com/runs/coral_celery_6b76r68h1k?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourcegroups/PPML/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47'}\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# download the canaries\n",
    "\n",
    "DATASET = 'agnews'\n",
    "\n",
    "jobs = JobList.from_urls([all_urls[DATASET]['synthetic_2gram']])\n",
    "\n",
    "test = jobs[0].get_node('add_index_to_dataset_2').get_node('append_column_incrementing').download_input('data', f'{DATASET}_canaries_')\n",
    "in_canaries = datasets.load_from_disk(f'{DATASET}_canaries')\n",
    "in_canaries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt_dict = {\n",
    "    'sst2': \"A sentence with a {label} sentiment: \",\n",
    "    'agnews': \"A news article about {label}: \"\n",
    "}\n",
    "\n",
    "prompt = prompt_dict[DATASET]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_prepped = prep_dataset(in_canaries, prompt=prompt)\n",
    "\n",
    "dataset_tokenized =  dataset_prepped.map(\n",
    "        main_preprocess_function, batched=True, num_proc=2, desc=\"tokenizing dataset\", \n",
    "            remove_columns=dataset_prepped.column_names\n",
    ") "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# now for all of them\n",
    "\n",
    "batch_size = 5\n",
    "n_batches = int(len(in_canaries) / batch_size)\n",
    "perplexities_in_canaries = []\n",
    "\n",
    "for i in tqdm(range(n_batches)):\n",
    "    ppl_batch = compute_perplexity_batch(model=model, \n",
    "                                         batch=dataset_tokenized[i * batch_size: (i+1) * batch_size], \n",
    "                                         device=torch.device('cuda:0'))\n",
    "    perplexities_in_canaries.extend(ppl_batch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "bins = np.histogram(perplexities_in_canaries, bins = 50)[1]\n",
    "\n",
    "plt.figure(figsize = (6, 6))\n",
    "plt.hist(perplexities_in_canaries, bins = bins, alpha = 0.5, density = True, label = f'In-distribution canaries injected')\n",
    "plt.xlabel('Mistral 7B perplexity computed using prompt', fontsize = 15)\n",
    "plt.ylabel('Frequency', fontsize = 15)\n",
    "plt.legend(fontsize = 15)\n",
    "plt.grid()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# add the perplexity to the dataset and save this, to be used in a different notebook \n",
    "# (eg to plot the in-distribution boxplot in the perplexity experiment plots)\n",
    "import pickle \n",
    "\n",
    "with open('ppl_in_canaries_boxplot_{DATASET}', 'rb') as f:\n",
    "    ppl_in_canaries = pickle.dump(perplexities_in_canaries, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## (4) Let's play around with synthetic canaries and their perplexities (varying temperature)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# let's now also get the synthetic ones\n",
    "\n",
    "import nltk\n",
    "\n",
    "def is_duplicate(seq: str, canaries: list, threshold: float = 0.2):\n",
    "    seq_set = set(nltk.word_tokenize(seq.lower()))\n",
    "    for canary in canaries:\n",
    "        canary_set = set(nltk.word_tokenize(canary.lower()))\n",
    "        if nltk.jaccard_distance(seq_set, canary_set) < threshold:\n",
    "            return True\n",
    "    return False\n",
    "\n",
    "def update_temperature(all_ppls, min_temperature, max_temperature, min_ppl, max_ppl, learning_rate=0.1):\n",
    "\n",
    "    # Calculate the mean perplexity of the current batch\n",
    "    mean_ppl = np.mean(all_ppls)\n",
    "\n",
    "    # Compute the difference from the target perplexity range\n",
    "    lower_diff = min_ppl - mean_ppl\n",
    "    upper_diff = mean_ppl - max_ppl\n",
    "\n",
    "    # Adjust temperatures based on the mean perplexity\n",
    "    if mean_ppl < min_ppl:\n",
    "        # If perplexity is below the target, increase the temperature\n",
    "        adjustment = learning_rate * abs(lower_diff / min_ppl)\n",
    "        min_temperature = min(min_temperature + adjustment, max_temperature)\n",
    "        max_temperature = min(max_temperature + adjustment, 2 * max_temperature)  # prevent excessive growth\n",
    "    elif mean_ppl > max_ppl:\n",
    "        # If perplexity is above the target, decrease the temperature\n",
    "        adjustment = learning_rate * abs(upper_diff / max_ppl)\n",
    "        max_temperature = max(max_temperature - adjustment, min_temperature)\n",
    "        min_temperature = max(min_temperature - adjustment, 0.5 * min_temperature)  # prevent negative temps\n",
    "\n",
    "    # Print the new temperature range\n",
    "    print(f\"New temperature range: {min_temperature:.4f} - {max_temperature:.4f}\")\n",
    "\n",
    "    return min_temperature, max_temperature\n",
    "\n",
    "def generate_synthetic_canaries_ppl(model: AutoModelForCausalLM, tokenizer: AutoTokenizer, \n",
    "                                n_canaries: int, canary_length: int,\n",
    "                                prompt: str, min_ppl: float, max_ppl: float,\n",
    "                                min_temperature: float, max_temperature: float,\n",
    "                                batch_size: int, device: torch.device):\n",
    "\n",
    "    canaries = [] \n",
    "    inputs = tokenizer([prompt] * batch_size, return_tensors=\"pt\").to(device)\n",
    "\n",
    "    total_samples = 0\n",
    "    step = 0\n",
    "    duplicates = 0\n",
    "\n",
    "    while len(canaries) < n_canaries:\n",
    "        if step > 0 and step % 5 == 0:\n",
    "            samples = len(canaries)\n",
    "            print(\n",
    "                f\"Step: {step} | total: {total_samples} | accepted: {samples} | duplicates: {duplicates}\"\n",
    "            )\n",
    "            if len(canaries) > 0:\n",
    "                print('The last canary was: ')\n",
    "                print(canaries[-1])\n",
    "\n",
    "        # sample one temperature for the entire batch\n",
    "        temperature = (min_temperature + max_temperature) / 2 #random.uniform(min_temperature, max_temperature)\n",
    "\n",
    "        generated_ids = model.generate(\n",
    "            **inputs,\n",
    "            max_length=canary_length * 2, # we define canary length in words, so we need to generate a bit more\n",
    "            do_sample=True,\n",
    "            temperature=temperature,\n",
    "            top_p=1.0,\n",
    "            top_k=0,\n",
    "            pad_token_id=tokenizer.eos_token_id \n",
    "        )\n",
    "\n",
    "        # now we have to get the generated text \n",
    "        generated_text = tokenizer.batch_decode(generated_ids[:, 1:])\n",
    "        valid_text = []\n",
    "        for text in generated_text:\n",
    "            # remove prompt and everything from eos token\n",
    "            text = text.replace(prompt, '')\n",
    "            text = text.split(tokenizer.eos_token)[0]\n",
    "            text_split = text.split()\n",
    "            n_words = len(text_split)\n",
    "            if is_duplicate(text, canaries):\n",
    "                duplicates += 1\n",
    "                continue\n",
    "            if n_words >= canary_length:\n",
    "                valid_text.append(\" \".join(text_split[:canary_length]))\n",
    "\n",
    "        # now we have to compute the perplexity of these selected texts\n",
    "        valid_text_dataset = datasets.Dataset.from_dict({'prompt':[prompt] * len(valid_text), 'completion': valid_text})\n",
    "        tokenized_valid_text = valid_text_dataset.map(lambda x: main_preprocess_function(x, tokenizer=tokenizer), batched=True, \n",
    "                                                      num_proc=10, desc=\"tokenizing dataset\", \n",
    "                                                      remove_columns=valid_text_dataset.column_names)\n",
    "        all_ppls = compute_perplexity_batch(model, tokenized_valid_text, device)\n",
    "        \n",
    "        for idx, text in enumerate(valid_text):\n",
    "            ppl = all_ppls[idx]\n",
    "            print(temperature, ppl, text)\n",
    "            if ppl >= min_ppl and ppl <= max_ppl and len(canaries) < n_canaries:\n",
    "                canaries.append(text)\n",
    "\n",
    "        print(f\"Found {len(canaries)} canaries - continuing...\")\n",
    "        total_samples += batch_size\n",
    "        step += 1\n",
    "\n",
    "        # optimize temperature for next batch\n",
    "        min_temperature, max_temperature =  update_temperature(all_ppls, min_temperature, max_temperature, min_ppl, max_ppl, learning_rate=0.1)\n",
    "\n",
    "    return canaries\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import Counter\n",
    "\n",
    "TEMPERATURES = [0.5, 0.75, 1.0, 1.25]\n",
    "N_CANARIES = 50\n",
    "\n",
    "temp_to_data = {}\n",
    "model.eval()\n",
    "\n",
    "# get all the same prompts as the test dataset\n",
    "\n",
    "prep_in_canaries = prep_dataset(in_canaries.select(list(range(N_CANARIES))))\n",
    "all_prompts = prep_in_canaries['prompt']\n",
    "\n",
    "# get the counts per unique prompt\n",
    "prompt_counts = Counter(all_prompts)\n",
    "\n",
    "# the function is designed to get canaries of a certain target perplexity\n",
    "# but let's make both min and max temperature the same and the perplexity range very big\n",
    "\n",
    "for temp in TEMPERATURES:\n",
    "    all_prompts, all_synthetic_canaries = [], []\n",
    "\n",
    "    for prompt in prompt_counts.keys():\n",
    "        n_for_this_prompt = prompt_counts[prompt]\n",
    "\n",
    "        synthetic_canaries = generate_synthetic_canaries_ppl(model=model, tokenizer=tokenizer, \n",
    "                                    n_canaries=n_for_this_prompt, canary_length=canary_length,\n",
    "                                    prompt=prompt,  min_ppl=0, max_ppl=100000,\n",
    "                                    min_temperature=temp, max_temperature=temp,\n",
    "                                    batch_size=8, device='cuda:0')\n",
    "        \n",
    "        all_synthetic_canaries.extend(synthetic_canaries)\n",
    "        all_prompts.extend([prompt] * n_for_this_prompt)\n",
    "\n",
    "    synthetic_dataset = datasets.Dataset.from_dict(\n",
    "        {'prompt': all_prompts, \n",
    "        'completion': all_synthetic_canaries})\n",
    "\n",
    "    temp_to_data[temp] = synthetic_dataset\n",
    "\n",
    "synthetic_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "temp_to_ppl = {}\n",
    "\n",
    "for temp in TEMPERATURES:\n",
    "    data = temp_to_data[temp]\n",
    "    data_tokenized =  data.map(\n",
    "        main_preprocess_function, batched=True, num_proc=1, desc=\"tokenizing dataset\", \n",
    "            remove_columns=data.column_names\n",
    "    ) \n",
    "    batch_size = 5\n",
    "    n_batches = int(len(data) / batch_size)\n",
    "    perplexities = []\n",
    "\n",
    "    for i in tqdm(range(n_batches)):\n",
    "        ppl_batch = compute_perplexity_batch(model=model, \n",
    "                                            batch=data_tokenized[i * batch_size: (i+1) * batch_size], \n",
    "                                            device=torch.device('cuda:0'))\n",
    "        perplexities.extend(ppl_batch)\n",
    "    temp_to_ppl[temp] = perplexities"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "bins = np.histogram(np.hstack([perplexities_in_canaries] + [temp_to_ppl[temp] for temp in TEMPERATURES]), bins = 500)[1]\n",
    "\n",
    "plt.figure(figsize = (6, 6))\n",
    "plt.hist(perplexities_in_canaries, bins = bins, alpha = 0.5, density = True, label = 'In distribution canaries')\n",
    "for temp in TEMPERATURES:\n",
    "    plt.hist(temp_to_ppl[temp], bins = bins, alpha = 0.5, density = True, label = f'Synthetic canaries - temp={temp}')\n",
    "plt.xlim(0, 1000)\n",
    "plt.xlabel('Mistral 7B perplexity computed using prompt', fontsize = 15)\n",
    "plt.ylabel('Frequency', fontsize = 15)\n",
    "plt.legend(fontsize = 15)\n",
    "plt.grid()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernel_info": {
   "name": "python38-azureml"
  },
  "kernelspec": {
   "display_name": "privacy-estimates",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  },
  "microsoft": {
   "host": {
    "AzureML": {
     "notebookHasBeenCompleted": true
    }
   },
   "ms_spell_check": {
    "ms_spell_check_language": "en"
   }
  },
  "nteract": {
   "version": "nteract-front-end@1.0.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
