{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import datasets\n",
    "from privacy_estimates.experiments.aml import JobList\n",
    "from sklearn.metrics import roc_curve, roc_auc_score, auc\n",
    "from datetime import datetime"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_performance(scores_members, scores_non_members):\n",
    "    mia_performance = {}\n",
    "    \n",
    "    member_vals = [val for val in scores_members]\n",
    "    non_member_vals = [val for val in scores_non_members]\n",
    "    mia_performance['auc'] = roc_auc_score([1]*len(member_vals) + [0]*len(non_member_vals), member_vals + non_member_vals)\n",
    "    fpr, tpr, thresholds = roc_curve([1]*len(member_vals) + [0]*len(non_member_vals), member_vals + non_member_vals)\n",
    "    for target_fpr in (0.01, 0.05, 0.1):\n",
    "        mia_performance[f'tpr_at_{target_fpr}'] = np.interp(target_fpr, fpr, tpr)\n",
    "    print(f\"AUC: {mia_performance['auc']}, TPR@0.01: {mia_performance['tpr_at_0.01']}, TPR@0.05: {mia_performance['tpr_at_0.05']}, TPR@0.1: {mia_performance['tpr_at_0.1']}\")\n",
    "\n",
    "    # also add the curves\n",
    "    mia_performance['fpr'] = fpr\n",
    "    mia_performance['tpr'] = tpr\n",
    "\n",
    "    return mia_performance\n",
    "\n",
    "def compute_performance_from_url(url, job_name = None):\n",
    "    jobs = JobList.from_urls([url])\n",
    "    if job_name is None:\n",
    "        job_name = str(datetime.now())\n",
    "    \n",
    "    if not os.path.exists(f'./mia_results/{job_name}'):\n",
    "        test = jobs[0].get_node('estimate_privacy').download_input('scores', f'./mia_results/{job_name}/scores')\n",
    "        test = jobs[0].get_node('estimate_privacy').download_input('challenge_bits', f'./mia_results/{job_name}/challenge_bits')\n",
    "    scores = datasets.load_from_disk(f'./mia_results/{job_name}/scores')\n",
    "    bits = datasets.load_from_disk(f'./mia_results/{job_name}/challenge_bits')\n",
    "    \n",
    "    membership_scores = np.array([k['score'] for k in scores])\n",
    "    membership_labels = np.array([k['challenge_bit'] for k in bits])\n",
    "    members = membership_scores[membership_labels == 1]\n",
    "    non_members = membership_scores[membership_labels == 0]\n",
    "    return compute_performance(members, non_members)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# this is all for syntheticcanary_uniformlabel with n_rep = 12\n",
    "all_urls = {\n",
    "    'sst2': {\n",
    "        'synthetic_2gram': 'https://ml.azure.com/runs/quirky_battery_tvw7wcm4wh?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourcegroups/PPML/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47'\n",
    "    },\n",
    "    'agnews': {\n",
    "        'synthetic_2gram': 'https://ml.azure.com/runs/quirky_insect_7scfdk028l?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourcegroups/PPML/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47',\n",
    "    },\n",
    "    'sst2-best': {\n",
    "        'synthetic_2gram': 'https://ml.azure.com/experiments/id/cbd45cd3-4fd8-4922-b82a-124527cc98ee/runs/gifted_jewel_w50pc8cd89?wsid=/subscriptions/acc09744-1ee3-4242-b375-93421c63af0c/resourceGroups/PPML/providers/Microsoft.MachineLearningServices/workspaces/M365Research-PPML-EUS&tid=72f988bf-86f1-41af-91ab-2d7cd011db47#'\n",
    "    },\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "DATASET = 'sst2-best'\n",
    "text_key = 'sentence'\n",
    "synthetic_job_name = all_urls[DATASET]['synthetic_2gram']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "jobs = JobList.from_urls([synthetic_job_name])\n",
    "synthetic_job = jobs[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Clean up before proceeding.**\n",
    "\n",
    "Before you run the below, I recommend running './clean_for_interpet.sh' in the notebooks directory. This makes sure nothing remains from the previous run and everything can be downloaded for again for the right jobs. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## (1) Let's first get the canaries!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get the canaries\n",
    "test = jobs[0].get_node('add_index_to_dataset_2').get_node('append_column_incrementing').download_input('data', 'canaries_synthetic')\n",
    "canaries = datasets.load_from_disk('canaries_synthetic')\n",
    "canaries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# for all reference models, get the canary data that was IN\n",
    "\n",
    "ref_model_to_in_data = {}\n",
    "\n",
    "for i in range(1, 5):\n",
    "    if i == 4:\n",
    "        test = jobs[0].get_node('compute_shadow_model_statistics').get_node('train_shadow_models').get_node('train_final_model_group').get_node(f'train_model_and_predict').get_node('filter_in_data').download_output('filtered', f'in_data_{i}')\n",
    "    else:\n",
    "         jobs[0].get_node('compute_shadow_model_statistics').get_node('train_shadow_models').get_node('train_final_model_group').get_node(f'train_model_and_predict_{i}').get_node('filter_in_data').download_output('filtered', f'in_data_{i}')\n",
    "    in_data_for_model = datasets.load_from_disk(f'in_data_{i}')\n",
    "    ref_model_to_in_data[i] = in_data_for_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# for all reference models, get the generated synthetic data\n",
    "\n",
    "ref_model_to_synthetic_data = {}\n",
    "\n",
    "for i in range(1, 5):\n",
    "    if i == 4:\n",
    "        test = jobs[0].get_node('compute_shadow_model_statistics').get_node('train_shadow_models').get_node('train_final_model_group').get_node('train_model_and_predict').get_node('inference_in').get_node('preprocess_synthetic').download_output('prep_synthetic_data_path', f'synthetic_data_{i}')\n",
    "    else:\n",
    "        test =  jobs[0].get_node('compute_shadow_model_statistics').get_node('train_shadow_models').get_node('train_final_model_group').get_node(f'train_model_and_predict_{i}').get_node('inference_in').get_node('preprocess_synthetic').download_output('prep_synthetic_data_path', f'synthetic_data_{i}')\n",
    "    synthetic_data = datasets.Dataset.from_json(f'synthetic_data_{i}/prep_synthetic_data_path')\n",
    "    ref_model_to_synthetic_data[i] = synthetic_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "canaries[i]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# let's pick a certain canary\n",
    "\n",
    "\n",
    "idx = 0\n",
    "canary = canaries[idx][text_key]\n",
    "print(\"canary: \", canary)\n",
    "\n",
    "for i in range(1, 5):\n",
    "   in_data_for_model = ref_model_to_in_data[i]\n",
    "   if canary in in_data_for_model[text_key]:\n",
    "      print(\"Canary was IN model \", i)\n",
    "   else:\n",
    "      print(\"Canary was OUT model \", i)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# actually we want to look at the most vulnerable when it comes to the target model/RMIA scores\n",
    "\n",
    "test = jobs[0].get_node('estimate_privacy').download_input('scores', './mia_results/scores_synthetic/scores')\n",
    "scores_synthetic = datasets.load_from_disk('./mia_results/scores_synthetic/scores')\n",
    "\n",
    "test = jobs[0].get_node('estimate_privacy').download_input('challenge_bits', './mia_results/challenge_bits_synthetic/challenge_bits')\n",
    "challenge_bits_synthetic = datasets.load_from_disk('./mia_results/challenge_bits_synthetic/challenge_bits')\n",
    "\n",
    "test =  jobs[0].get_node('train_many_models').get_node('train_final_model_group').get_node('train_model_and_predict').get_node('data_for_model').download_input('in_out_data', 'in_out_data_synthetic')\n",
    "in_out_data_synthetic = datasets.load_from_disk('in_out_data_synthetic')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# also got the target model synthetic data\n",
    "test = jobs[0].get_node('train_many_models').get_node('train_final_model_group').get_node('train_model_and_predict').get_node('inference_in').get_node('preprocess_synthetic').download_output('prep_synthetic_data_path', f'synthetic_data_target')\n",
    "target_synthetic_data = datasets.Dataset.from_json(f'synthetic_data_target/prep_synthetic_data_path')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "membership_scores_synthetic = [k['score'] for k in scores_synthetic]\n",
    "membership_labels_synthetic = [k['challenge_bit'] for k in challenge_bits_synthetic]\n",
    "\n",
    "top_n = 5\n",
    "\n",
    "high_to_low_indices = np.argsort(membership_scores_synthetic)[::-1]\n",
    "\n",
    "for idx in high_to_low_indices[:top_n]:\n",
    "    print(f\"Top synthetic score: {membership_scores_synthetic[idx]}, challenge bit: {membership_labels_synthetic[idx]}\")\n",
    "    print(in_out_data_synthetic[int(idx)][text_key])\n",
    "    print('---')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "most_vulnerable_canary = in_out_data_synthetic[int(high_to_low_indices[0])][text_key]\n",
    "print(\"most vulnerable canary: \")\n",
    "print(most_vulnerable_canary)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "What do we now want to do? \n",
    "\n",
    "- Run through all models\n",
    "- Train an n-gram model on the corresponding synthetic data\n",
    "- Get the log likelihood of the canary for that n-gram model\n",
    "- And also check what is 'extracted' from the sequence\n",
    "- This enables us to compare IN and OUT for this sequence"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import collections\n",
    "from tqdm import tqdm\n",
    "\n",
    "def generate_ngrams(text, n):\n",
    "    \"\"\"\n",
    "    Generate n-grams from the input text.\n",
    "    \"\"\"\n",
    "    tokens = text.split()\n",
    "    ngrams = zip(*[tokens[i:] for i in range(n)])\n",
    "    return [' '.join(ngram) for ngram in ngrams]\n",
    "\n",
    "def train_ngram_model(all_text, n, smoothing=1):\n",
    "    \"\"\"\n",
    "    Train an n-gram model from the given text using Laplace smoothing.\n",
    "    \"\"\"\n",
    "    all_ngrams = []\n",
    "    vocabulary = set()\n",
    "\n",
    "    for text in tqdm(all_text):\n",
    "        words = text.split()\n",
    "        vocabulary.update(words)\n",
    "        ngrams = generate_ngrams(text, n)\n",
    "        all_ngrams.extend(ngrams)\n",
    "\n",
    "    ngram_counts = collections.Counter(all_ngrams)\n",
    "    total_ngrams = sum(ngram_counts.values()) + smoothing * len(vocabulary) ** n\n",
    "\n",
    "    # Convert counts to probabilities with smoothing\n",
    "    ngram_probabilities = {\n",
    "        ngram: (count + smoothing) / total_ngrams\n",
    "        for ngram, count in ngram_counts.items()\n",
    "    }\n",
    "\n",
    "    return ngram_probabilities, len(vocabulary), total_ngrams\n",
    "\n",
    "def inference(ngram_model, text, n, vocabulary_size, total_ngrams, smoothing=1):\n",
    "    \"\"\"\n",
    "    Compute the loss of the n-gram model on a given piece of text.\n",
    "    The loss is the average negative log likelihood of the n-grams in the text.\n",
    "    \"\"\"\n",
    "    ngrams = generate_ngrams(text, n)\n",
    "    log_likelihood = 0\n",
    "    count = 0\n",
    "    n_gram_counts = []\n",
    "\n",
    "    for ngram in ngrams:\n",
    "        if ngram in ngram_model:\n",
    "            prob = ngram_model[ngram]\n",
    "            n_gram_counts.append((ngram, prob, prob * total_ngrams))\n",
    "        else:\n",
    "            # Apply smoothing for unseen n-grams\n",
    "            prob = smoothing / (sum(ngram_model.values()) + smoothing * vocabulary_size ** n)\n",
    "        log_likelihood += np.log(prob)\n",
    "        count += 1\n",
    "\n",
    "    loss = -log_likelihood / count if count > 0 else float('inf')\n",
    "\n",
    "    return loss, n_gram_counts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(canary)\n",
    "\n",
    "import difflib\n",
    "\n",
    "def longest_overlapping_substring(target, seq):\n",
    "    max_overlap = \"\"\n",
    "    \n",
    "    s = difflib.SequenceMatcher(None, target, seq, autojunk=False)\n",
    "    match = s.find_longest_match(0, len(target), 0, len(seq))\n",
    "    if match.size > len(max_overlap):\n",
    "        max_overlap = target[match.a: match.a + match.size]\n",
    "    \n",
    "    return max_overlap\n",
    "\n",
    "for text in ref_model_to_synthetic_data[1][text_key]:\n",
    "    if 'Barcelona striker' in text:\n",
    "        print(longest_overlapping_substring(text, canary))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "canary = most_vulnerable_canary\n",
    "print(\"canary: \", canary)\n",
    "\n",
    "# first for the reference models\n",
    "for i in range(1, 5):\n",
    "   in_data_for_model = ref_model_to_in_data[i]\n",
    "   if canary in in_data_for_model[text_key]:\n",
    "      print(\"Canary was IN model \", i)\n",
    "   else:\n",
    "      print(\"Canary was OUT model \", i)\n",
    "   synthetic_data = ref_model_to_synthetic_data[i]\n",
    "   # train 2-gram model\n",
    "   ngram_probabilities, vocab_size, total_ngrams = train_ngram_model(synthetic_data[text_key], 2)\n",
    "   ngram_loss, n_gram_counts = inference(ngram_probabilities, canary, 2, vocab_size, total_ngrams)\n",
    "   #n_gram_counts.sort(key=lambda x: x[2], reverse=False)\n",
    "   print(f\"Loss: {ngram_loss}\")\n",
    "   print(n_gram_counts)\n",
    "   print('----')\n",
    "\n",
    "# finally for the target model\n",
    "print(\"Membership label for canary: \", membership_labels_synthetic[int(high_to_low_indices[0])])\n",
    "\n",
    "# train 2-gram model\n",
    "ngram_probabilities, vocab_size, total_ngrams = train_ngram_model(target_synthetic_data[text_key], 2)\n",
    "ngram_loss, n_gram_counts = inference(ngram_probabilities, canary, 2, vocab_size, total_ngrams)\n",
    "\n",
    "print(f\"Loss: {ngram_loss}\")\n",
    "print(n_gram_counts)\n",
    "print('----')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "WHat about the largest overlapping substring? "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import difflib\n",
    "from tqdm import tqdm\n",
    "\n",
    "def rank_sentences_by_overlap(target_sentence, sentence_list):\n",
    "    overlap_list = []\n",
    "    \n",
    "    for i in tqdm(range(len(sentence_list))):\n",
    "        sentence = sentence_list[i]\n",
    "        # Initialize SequenceMatcher\n",
    "        s = difflib.SequenceMatcher(None, target_sentence, sentence, autojunk=False)\n",
    "        \n",
    "        # Find the longest matching block\n",
    "        match = s.find_longest_match(0, len(target_sentence), 0, len(sentence))\n",
    "        max_overlap = target_sentence[match.a: match.a + match.size]\n",
    "        \n",
    "        # Check the length of the matching block\n",
    "        overlap_length = match.size\n",
    "        \n",
    "        # Add the overlap length and sentence to the list\n",
    "        overlap_list.append((overlap_length, max_overlap, sentence, i))\n",
    "    \n",
    "    # Sort the list by overlap length in descending order\n",
    "    overlap_list.sort(key=lambda x: x[0], reverse=True)\n",
    "    \n",
    "    return overlap_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "canary = most_vulnerable_canary\n",
    "print(\"canary: \", canary)\n",
    "\n",
    "for i in range(1, 5):\n",
    "   in_data_for_model = ref_model_to_in_data[i]\n",
    "   if canary in in_data_for_model[text_key]:\n",
    "      print(\"Canary was IN model \", i)\n",
    "   else:\n",
    "      print(\"Canary was OUT model \", i)\n",
    "   synthetic_data = ref_model_to_synthetic_data[i]\n",
    "   ranked_sentences = rank_sentences_by_overlap(canary, synthetic_data[text_key])\n",
    "   print(\"Ranked sentences by decreasing max overlap:\")\n",
    "   for overlap_length, overlap, sentence, label in ranked_sentences[:3]:\n",
    "      print(f\"Overlap length: {overlap_length}, overlap: {overlap}, Sentence: {sentence}, Label: {label}\")\n",
    "   print('----')\n",
    "\n",
    "# finally for the target model\n",
    "print(\"Membership label for canary: \", membership_labels_synthetic[int(high_to_low_indices[0])])\n",
    "ranked_sentences = rank_sentences_by_overlap(canary, target_synthetic_data[text_key])\n",
    "print(\"Ranked sentences by decreasing max overlap:\")\n",
    "for overlap_length, overlap, sentence, label in ranked_sentences[:3]:\n",
    "    print(f\"Overlap length: {overlap_length}, overlap: {overlap}, Sentence: {sentence}, Label: {label}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# and what about the canary with the lowest RMIA score?\n",
    "\n",
    "lowest_RMIA_canary = in_out_data_synthetic[int(high_to_low_indices[-1])][text_key]\n",
    "print(\"lowest_RMIA_canary: \")\n",
    "print(lowest_RMIA_canary)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "canary = lowest_RMIA_canary\n",
    "print(\"canary: \", canary)\n",
    "\n",
    "# first for the reference models\n",
    "for i in range(1, 5):\n",
    "   in_data_for_model = ref_model_to_in_data[i]\n",
    "   if canary in in_data_for_model[text_key]:\n",
    "      print(\"Canary was IN model \", i)\n",
    "   else:\n",
    "      print(\"Canary was OUT model \", i)\n",
    "   synthetic_data = ref_model_to_synthetic_data[i]\n",
    "   # train 2-gram model\n",
    "   ngram_probabilities, vocab_size, total_ngrams = train_ngram_model(synthetic_data[text_key], 2)\n",
    "   ngram_loss, n_gram_counts = inference(ngram_probabilities, canary, 2, vocab_size, total_ngrams)\n",
    "   #n_gram_counts.sort(key=lambda x: x[2], reverse=False)\n",
    "   print(f\"Loss: {ngram_loss}\")\n",
    "   print(n_gram_counts)\n",
    "   print('----')\n",
    "\n",
    "# finally for the target model\n",
    "print(\"Membership label for canary: \", membership_labels_synthetic[int(high_to_low_indices[-1])])\n",
    "\n",
    "# train 2-gram model\n",
    "ngram_probabilities, vocab_size, total_ngrams = train_ngram_model(target_synthetic_data[text_key], 2)\n",
    "ngram_loss, n_gram_counts = inference(ngram_probabilities, canary, 2, vocab_size, total_ngrams)\n",
    "\n",
    "print(f\"Loss: {ngram_loss}\")\n",
    "print(n_gram_counts)\n",
    "print('----')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "canary = lowest_RMIA_canary\n",
    "print(\"canary: \", canary)\n",
    "\n",
    "for i in range(1, 5):\n",
    "   in_data_for_model = ref_model_to_in_data[i]\n",
    "   if canary in in_data_for_model[text_key]:\n",
    "      print(\"Canary was IN model \", i)\n",
    "   else:\n",
    "      print(\"Canary was OUT model \", i)\n",
    "   synthetic_data = ref_model_to_synthetic_data[i]\n",
    "   ranked_sentences = rank_sentences_by_overlap(canary, synthetic_data[text_key])\n",
    "   print(\"Ranked sentences by decreasing max overlap:\")\n",
    "   for overlap_length, overlap, sentence, label in ranked_sentences[:3]:\n",
    "      print(f\"Overlap length: {overlap_length}, overlap: {overlap}, Sentence: {sentence}, Label: {label}\")\n",
    "   print('----')\n",
    "\n",
    "# finally for the target model\n",
    "print(\"Membership label for canary: \", membership_labels_synthetic[int(high_to_low_indices[-1])])\n",
    "ranked_sentences = rank_sentences_by_overlap(canary, target_synthetic_data[text_key])\n",
    "print(\"Ranked sentences by decreasing max overlap:\")\n",
    "for overlap_length, overlap, sentence, label in ranked_sentences[:3]:\n",
    "    print(f\"Overlap length: {overlap_length}, overlap: {overlap}, Sentence: {sentence}, Label: {label}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "privacy-estimates",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
