{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from importlib import reload\n",
    "import watermarking.watermark_processor\n",
    "\n",
    "watermarking.watermark_processor = reload(watermarking.watermark_processor)\n",
    "from watermarking.watermark_processor import RepetitionPenaltyLogitsProcessor\n",
    "from transformers import LogitsProcessorList, MinLengthLogitsProcessor, LogitsProcessor\n",
    "from watermarking.utils.text_tools import truncate\n",
    "from watermarking.utils.load_local import load_local_model_or_tokenizer\n",
    "from datasets import load_dataset, load_from_disk"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n",
     "is_executing": true
    }
   },
   "outputs": [],
   "source": [
    "tokenizer = load_local_model_or_tokenizer(\"facebook/opt-1.3b\", 'tokenizer')\n",
    "model = load_local_model_or_tokenizer(\"facebook/opt-1.3b\", 'model')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n",
     "is_executing": true
    }
   },
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
    "lm_tokenizer = AutoTokenizer.from_pretrained('gpt2')\n",
    "lm_model = AutoModelForCausalLM.from_pretrained('gpt2').to('cuda:0')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "model = model.to('cuda:0')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n",
     "is_executing": true
    }
   },
   "outputs": [],
   "source": [
    "c4_sliced_and_filted = load_from_disk('./c4-train.00000-of-00512_sliced')\n",
    "c4_sliced_and_filted = c4_sliced_and_filted['train'].shuffle(seed=42).select(\n",
    "    range(100))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 396,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# select a prompt\n",
    "\n",
    "sample_idx = 98\n",
    "input_text = c4_sliced_and_filted[sample_idx]['text']\n",
    "tokenized_input = tokenizer(input_text, return_tensors='pt').to(model.device)\n",
    "tokenized_input = truncate(tokenized_input, max_length=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "min_length_processor = MinLengthLogitsProcessor(min_length=1000,\n",
    "                                                eos_token_id=tokenizer.eos_token_id)\n",
    "repetition_processor = RepetitionPenaltyLogitsProcessor(penalty=1.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 253,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from importlib import reload\n",
    "import watermarking.watermark_processors.message_models.lm_message_model\n",
    "watermarking.watermark_processors.message_models.lm_message_model = reload(watermarking.watermark_processors.message_models.lm_message_model)\n",
    "import watermarking.watermark_processors.message_model_processor\n",
    "watermarking.watermark_processors.message_model_processor = reload(watermarking.watermark_processors.message_model_processor)\n",
    "\n",
    "\n",
    "from watermarking.watermark_processors.message_models.lm_message_model import LMMessageModel\n",
    "from watermarking.watermark_processors.message_model_processor import WmProcessorMessageModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 449,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "lm_message_model = LMMessageModel(tokenizer=tokenizer,lm_model=model,lm_tokenizer=tokenizer,\n",
    "    delta = 1.1, lm_prefix_len=10, lm_topk=-1, message_code_len = 20,random_permutation_num=50)\n",
    "wm_precessor_message_model = WmProcessorMessageModel(message_model=lm_message_model,tokenizer=tokenizer,\n",
    "    encode_ratio=5,max_confidence_lbd=0.5,strategy='max_confidence', message=[42,34,54,665,226,329])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 450,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "start_length = tokenized_input['input_ids'].shape[-1]\n",
    "wm_precessor_message_model.start_length = start_length\n",
    "output_tokens = model.generate(**tokenized_input, max_new_tokens=200, num_beams=4,\n",
    "                               logits_processor=LogitsProcessorList(\n",
    "                                   [min_length_processor, repetition_processor,\n",
    "                                    wm_precessor_message_model]))\n",
    "output_text = tokenizer.decode(output_tokens[0][tokenized_input['input_ids'].shape[-1]:],\n",
    "                               skip_special_tokens=True)\n",
    "prefix_and_output_text = tokenizer.decode(output_tokens[0], skip_special_tokens=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n",
     "is_executing": true
    }
   },
   "outputs": [],
   "source": [
    "log_probs = wm_precessor_message_model.decode(output_text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n",
     "is_executing": true
    }
   },
   "outputs": [],
   "source": [
    "log_probs[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 431,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import watermarking.watermark_processors.random_processor\n",
    "watermarking.watermark_processors.random_processor = reload(watermarking.watermark_processors.random_processor)\n",
    "\n",
    "from watermarking.watermark_processors.random_processor import WmProcessorRandom\n",
    "random_processor = WmProcessorRandom(message=[42,34,5465,34], tokenizer=tokenizer,delta=1.5,message_code_len = 20,\n",
    "     top_k=100, encode_ratio = 5.\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 456,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "random_processor.start_length = tokenized_input['input_ids'].shape[-1]\n",
    "t_output_tokens = model.generate(**tokenized_input, max_new_tokens=200, num_beams=4,\n",
    "                               logits_processor=LogitsProcessorList(\n",
    "                                   [min_length_processor, repetition_processor,\n",
    "                                    random_processor]))\n",
    "t_output_text = tokenizer.decode(t_output_tokens[0][tokenized_input['input_ids'].shape[-1]:],\n",
    "                               skip_special_tokens=True)\n",
    "t_prefix_and_output_text = tokenizer.decode(t_output_tokens[0], skip_special_tokens=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n",
     "is_executing": true
    }
   },
   "outputs": [],
   "source": [
    "t_log_probs = random_processor.decode(t_output_text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n",
     "is_executing": true
    }
   },
   "outputs": [],
   "source": [
    "t_log_probs[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 430,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "no_wm_output_tokens = model.generate(**tokenized_input, max_new_tokens=200, num_beams=4,\n",
    "                                     logits_processor=LogitsProcessorList(\n",
    "                                       [min_length_processor, repetition_processor]))\n",
    "no_wm_output_text = tokenizer.decode(no_wm_output_tokens[0][tokenized_input['input_ids'].shape[-1]:],\n",
    "                                     skip_special_tokens=True)\n",
    "no_wm_prefix_and_output_text = tokenizer.decode(no_wm_output_tokens[0], skip_special_tokens=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n",
     "is_executing": true
    }
   },
   "outputs": [],
   "source": [
    "oracle_tokenizer = load_local_model_or_tokenizer('facebook/opt-2.7b', 'tokenizer')\n",
    "oracle_model = load_local_model_or_tokenizer('facebook/opt-2.7b', 'model')\n",
    "oracle_model = oracle_model.to('cuda:0')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n",
     "is_executing": true
    }
   },
   "outputs": [],
   "source": [
    "from watermarking.experiments.watermark import compute_ppl_single\n",
    "\n",
    "loss, ppl = compute_ppl_single(prefix_and_output_text=prefix_and_output_text,\n",
    "                               oracle_model_name='facebook/opt-2.7b',\n",
    "                               output_text=output_text,\n",
    "                               oracle_model=oracle_model, oracle_tokenizer=oracle_tokenizer)\n",
    "print(loss, ppl)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n",
     "is_executing": true
    }
   },
   "outputs": [],
   "source": [
    "from watermarking.experiments.watermark import compute_ppl_single\n",
    "\n",
    "loss, ppl = compute_ppl_single(prefix_and_output_text=t_prefix_and_output_text,\n",
    "                               oracle_model_name='facebook/opt-2.7b',\n",
    "                               output_text=t_output_text,\n",
    "                               oracle_model=oracle_model, oracle_tokenizer=oracle_tokenizer)\n",
    "print(loss, ppl)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n",
     "is_executing": true
    }
   },
   "outputs": [],
   "source": [
    "loss, ppl = compute_ppl_single(prefix_and_output_text=no_wm_prefix_and_output_text,\n",
    "                               oracle_model_name='facebook/opt-2.7b',\n",
    "                               output_text=no_wm_output_text,\n",
    "                               oracle_model=oracle_model, oracle_tokenizer=oracle_tokenizer)\n",
    "print(loss, ppl)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10.10 ('wm')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.10"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "69f52fabb15766d39c6bf90ba53c555c905cb082f5a671ecb5c4487727b3f015"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}