{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d436ab2-4e3a-42d0-a9dd-735ac2e03568",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import os\n",
    "import pickle\n",
    "import string\n",
    "\n",
    "from functools import cache\n",
    "from itertools import chain\n",
    "\n",
    "import datasets\n",
    "import guidance\n",
    "import nltk\n",
    "import spacy\n",
    "import tqdm\n",
    "import wandb\n",
    "\n",
    "from datasets import load_dataset\n",
    "from frozendict import frozendict\n",
    "from guidance import assistant, gen, user\n",
    "from guidance.models import OpenAI, Transformers\n",
    "from nltk.sentiment.vader import SentimentIntensityAnalyzer\n",
    "from tqdm import tqdm_notebook\n",
    "from wandb import Artifact"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5155e5b8-d618-48cb-96a5-e13a4e384d4e",
   "metadata": {},
   "source": [
    "### Download lexicon and setup OpenAI and spacy models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "291cc1a6-5ed8-4a2b-b6d3-e448b5814fe7",
   "metadata": {},
   "outputs": [],
   "source": [
    "os.environ['WANDB_API_KEY'] = \"YOUR_KEY_HERE\"\n",
    "os.environ['OPENAI_API_KEY'] = \"YOUR_KEY_HERE\"\n",
    "imdb_docs_filename = \"imdb_docs.pkl\"\n",
    "wandb_project_name = \"utility_reconstruction\"\n",
    "debug=False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80dfbd3d-1510-4d2d-a47b-55aa04fcfb30",
   "metadata": {},
   "outputs": [],
   "source": [
    "run = wandb.init(project=wandb_project_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6c2862d-62e6-4cf7-b6b7-c156a63b1e15",
   "metadata": {},
   "outputs": [],
   "source": [
    "openai_model = OpenAI(\"gpt-4\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f633adc-00a3-490e-b81a-a34ec64d4b01",
   "metadata": {},
   "outputs": [],
   "source": [
    "! python -m spacy download en_core_web_md\n",
    "nltk.download('vader_lexicon')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "312d1ffa-e66d-4430-9605-107db0295cbe",
   "metadata": {},
   "outputs": [],
   "source": [
    "triplets = []\n",
    "\n",
    "intensity_analyzer = SentimentIntensityAnalyzer()\n",
    "lexicon = intensity_analyzer.lexicon\n",
    "\n",
    "nlp = spacy.load(\"en_core_web_md\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6c5f42b0-8a52-4976-beb2-0ef5c78f0c7b",
   "metadata": {},
   "source": [
    "### Extract topmost positive words from lexicon."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8078380b-a63e-44c5-b9c6-e1fcc49103b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_percentiles(dictionary, desired_percentile = 25):\n",
    "    # Step 1: Extract values from the dictionary\n",
    "    items = list(dictionary.items())\n",
    "\n",
    "    # Step 2: Sort the values\n",
    "    sorted_values = sorted(items, key = lambda x: x[1])\n",
    "\n",
    "    # Step 3: Calculate indices for top 25% and bottom 25%\n",
    "    total_count = len(sorted_values)\n",
    "\n",
    "    percentile = desired_percentile / 100\n",
    "    top_percentile = 1 - percentile\n",
    "    top_index = int(top_percentile * total_count)\n",
    "\n",
    "    bottom_index = int(percentile * total_count)\n",
    "\n",
    "    # Step 4: Retrieve values at calculated indices\n",
    "    positive_lexicon = frozendict(sorted_values[top_index:])\n",
    "    negative_lexicon = frozendict(sorted_values[:bottom_index])\n",
    "\n",
    "    return positive_lexicon, negative_lexicon"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c56e49aa-0b65-42f1-98c6-5c5361eb45c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "topmost_positive_lexicon, topmost_negative_lexicon = get_percentiles(lexicon)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7fd6c6d1-7897-424a-b836-8739bb2f9574",
   "metadata": {},
   "outputs": [],
   "source": [
    "topmost_positive_words = frozenset(topmost_positive_lexicon.keys())\n",
    "topmost_negative_words = frozenset(topmost_negative_lexicon.keys())\n",
    "\n",
    "all_positive_words = frozenset([key for key, value in lexicon.items() if value > 0])\n",
    "all_negative_words = frozenset([key for key, value in lexicon.items() if value < 0])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4e179316-be21-4310-aa87-9af12adbf88e",
   "metadata": {},
   "source": [
    "### Compute spacy documents."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a98c40f3-6b91-4270-87ad-a4829031a78b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def strip_punctuation(s):\n",
    "    return s.translate(str.maketrans('', '', string.punctuation))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ecd4bb64-8a4b-4207-b10e-2654a0347b00",
   "metadata": {},
   "outputs": [],
   "source": [
    "def extract_top_sentences_and_tokens(spacy_doc, num_sentences=1):\n",
    "    sents = list(spacy_doc.sents)[:num_sentences]\n",
    "    tokens = list(chain(*sents))\n",
    "    tokens = [strip_punctuation(str(token)).strip().lower() for token in tokens]\n",
    "\n",
    "    sents = [str(sent) for sent in sents]\n",
    "    return \" \".join(sents), tokens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb641be9-e1b7-43bd-9edf-5648c0f08e79",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_spacy_docs():\n",
    "    imdb_dataset = load_dataset('imdb', split='train')\n",
    "    imdb_texts = [item['text'] for item in imdb_dataset]\n",
    "    all_imdb_docs = []\n",
    "    for text in tqdm_notebook(imdb_texts):\n",
    "        all_imdb_docs.append(nlp(text))\n",
    "    return all_imdb_docs\n",
    "\n",
    "def load_spacy_docs(filename: str):\n",
    "    print(f'Loading spacy docs from {filename}')\n",
    "    with open(filename, 'rb') as file:\n",
    "        return pickle.load(file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a507bd7b-6672-48ef-a78c-b269bb53fae5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_or_compute_spacy_docs(filename=imdb_docs_filename):\n",
    "    try:\n",
    "        all_imdb_docs = load_spacy_docs(imdb_docs_filename)\n",
    "    except Exception as e:\n",
    "        print(f'Caught exception {e}, recomputing spacy docs.')\n",
    "        all_imdb_docs = compute_spacy_docs()\n",
    "\n",
    "        with open(filename, 'wb') as file:\n",
    "            pickle.dump(all_imdb_docs)\n",
    "    return all_imdb_docs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f363d775-1dc3-4248-859b-638c834c8383",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_imdb_docs = load_or_compute_spacy_docs(filename=imdb_docs_filename)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "09519534-430a-4649-894c-57429398d8ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(all_imdb_docs)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a622994c-39b4-4eeb-b0ac-65a979b1181a",
   "metadata": {},
   "source": [
    "### Construct sentiment laden examples picking where the first sentence is positive only."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58cc73c4-f4c8-4b16-815b-577798f72d48",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_sentiment_laden_examples(spacy_doc):\n",
    "    \"\"\"\n",
    "    Gives score from vader lexicon of number of positive and negative words\n",
    "    \"\"\"\n",
    "    sentences, tokens = extract_top_sentences_and_tokens(spacy_doc)\n",
    "    target_sets = {}\n",
    "    target_sets['positive'] = topmost_positive_words\n",
    "    target_sets['negative'] = topmost_negative_words\n",
    "\n",
    "    hit_tokens = {}\n",
    "    score = {}\n",
    "    for target in ['negative', 'positive']:\n",
    "        hit_tokens[target] = frozenset(tokens).intersection(target_sets[target])\n",
    "        score[target] = len(hit_tokens[target])\n",
    "    \n",
    "    return frozendict(score), frozendict(hit_tokens), sentences"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3404ae23-56fa-4ae4-9e0c-12632c761995",
   "metadata": {},
   "outputs": [],
   "source": [
    "spacy_doc = all_imdb_docs[78]\n",
    "get_sentiment_laden_examples(spacy_doc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "616bfca8-d5b0-4c08-b2e4-e700499363d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_score_tokens_and_sentences = [get_sentiment_laden_examples(spacy_doc) for spacy_doc in all_imdb_docs]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6fa0aa10-8e33-4961-8e3e-47c516341774",
   "metadata": {},
   "outputs": [],
   "source": [
    "positive_examples = []\n",
    "for score_tokens_and_sentences in all_score_tokens_and_sentences:\n",
    "    score = score_tokens_and_sentences[0]\n",
    "    positive = score['positive']\n",
    "    negative = score['negative']\n",
    "    if positive > 0 and negative == 0:\n",
    "        positive_examples.append(score_tokens_and_sentences)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b743629-86e7-4829-bcda-f3d61820ef7f",
   "metadata": {},
   "outputs": [],
   "source": [
    "selected_positive_examples = positive_examples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ab20aae-7108-4c79-98ac-f48ebe508399",
   "metadata": {},
   "outputs": [],
   "source": [
    "selected_positive_examples[:2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "776db144-b668-4e5c-bacd-b5508df697df",
   "metadata": {},
   "outputs": [],
   "source": [
    "triplet = selected_positive_examples[0]\n",
    "triplet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87c22828-72e8-49e7-babb-06591bbf2bbe",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_result_from_guidance(prompt, openai_model):\n",
    "    with user():\n",
    "        lm = openai_model + prompt\n",
    "\n",
    "    with assistant():\n",
    "        lm += gen('answer')\n",
    "\n",
    "    return lm['answer']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "969d875b-aec0-4842-b2c0-0903c7802013",
   "metadata": {},
   "outputs": [],
   "source": [
    "def select_prompt(triplet, openai_model, debug=False):\n",
    "    text = triplet[2]\n",
    "    positive_words = list(triplet[1]['positive'])\n",
    "\n",
    "    prompt = f\"\"\"You are a helpful agent. Given a text and a collection C of positive words in the text.\n",
    "    first I want you to find the sentiment of this text. Then replace each positive word with\n",
    "    a negative sentiment word, such that the final sentiment becomes negative. Only alter the positive\n",
    "    words, no other text should be changed. Give your response as parseable json strictly in the output format, \n",
    "    including the input sentiment, output sentiment , modified text and new words (as a dictionary of old to replaced words) :\n",
    "    {{\"input_sentiment\": sentiment_value, \"output_sentiment\": new_sentiment_value, \"modified_text\": new_text, \"new_words\": new_words_dict}}.\n",
    "    Sentiment values may be \"positive\" or \"negative\"\n",
    "    Your input is: \n",
    "    {{\"text\": {text}, \"positive_words\": {str(positive_words)}}}\n",
    "    \"\"\"\n",
    "    if debug:\n",
    "        print(prompt)\n",
    "    return prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "74f767fd-437e-44fa-92f8-6292e43010ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "@cache\n",
    "def filter_and_modify_text(triplet):\n",
    "    \"\"\"\n",
    "    Returns a contrastive pair of form \"input_text\", \"output_text\", \"positive_words\", \"new_words\".\n",
    "    Where positive words are replaced with new words using ChatGPT.\n",
    "    We only include pairs where sentiment is modified from positive to negative, and the json can be parsed successfully.\n",
    "    \"\"\"\n",
    "    try:\n",
    "        prompt = select_prompt(triplet, openai_model)\n",
    "        input_text = triplet[2]\n",
    "        positive_words = list(triplet[1]['positive'])\n",
    "\n",
    "        result = json.loads(get_result_from_guidance(prompt, openai_model))\n",
    "\n",
    "        output_text = None\n",
    "        if result.get('input_sentiment', 'negative') == 'positive' and result.get('output_sentiment', 'positive') == 'negative':\n",
    "            output_text = result.get('modified_text', None)\n",
    "            new_words = result.get('new_words', None)\n",
    "\n",
    "        if input_text and output_text and positive_words and new_words:\n",
    "            return {\n",
    "                \"input_text\": input_text,\n",
    "                \"output_text\": output_text,\n",
    "                \"positive_words\": positive_words,\n",
    "                \"new_words\": new_words\n",
    "            }\n",
    "        else:\n",
    "            return None\n",
    "    \n",
    "    except Exception as e:\n",
    "        print(f'Error {e} processing triplet {triplet}. Returning none')\n",
    "        return None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "95353ccc-a4ee-4cb3-8239-b374d5bb63c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "result = filter_and_modify_text(triplet)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "869df5be-03bb-470f-ac01-7ef7d2e88dbc",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f'Computing for full dataset of {len(positive_examples)}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca435b5b-796e-4b42-abbc-c461526d80d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "final_pairs = []\n",
    "if debug:\n",
    "    print(f'Restricting to 5 examples for debugging')\n",
    "    selected_positive_examples = positive_examples[:5]\n",
    "else:\n",
    "    print(f'Computing for full dataset of {len(positive_examples)}')\n",
    "    selected_positive_examples = positive_examples\n",
    "\n",
    "for index, example in enumerate(tqdm_notebook(selected_positive_examples)):\n",
    "    result = filter_and_modify_text(example)\n",
    "    if index%5 == 0:\n",
    "        # Progress tracker, essentially.\n",
    "        wandb.log({\"Index\": index})\n",
    "    if result:\n",
    "        final_pairs.append(result)\n",
    "\n",
    "print(f'Generated {len(final_pairs)} pairs from {len(selected_positive_examples)} inputs')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "397c43b6-e3f2-4b51-b403-1521f1f9bb9c",
   "metadata": {},
   "outputs": [],
   "source": [
    "final_pairs[:3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2d4a5f1-62e2-4384-8f7f-6b67c37dae41",
   "metadata": {},
   "outputs": [],
   "source": [
    "pairs_savefile = \"final_triplets.json\"\n",
    "with open(pairs_savefile, \"w\") as f_out:\n",
    "    json.dump(final_pairs, f_out)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c078177-945c-430a-9f2a-5ebbf7923cc4",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(pairs_savefile, \"r\") as f_in:\n",
    "    loaded_pairs = json.load(f_in)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7db24d25-d370-41ff-8db3-4f6b1a8cd700",
   "metadata": {},
   "outputs": [],
   "source": [
    "def save_pairs_to_wandb(loaded_pairs):\n",
    "    my_artifact = wandb.Artifact(\"contrastive_sentiment_pairs\", type=\"data\")\n",
    "    \n",
    "    # Add the list to the artifact\n",
    "    my_artifact.add_file(local_path=pairs_savefile, name=\"contrastive_sentiment_pairs\")\n",
    "\n",
    "    metadata_dict = {\n",
    "        \"description\": \"Contrastive pairs from IMDB\",\n",
    "        \"source\": \"Generated by my script\",\n",
    "        \"num_pairs\": len(loaded_pairs),\n",
    "        \"sources\": len(selected_positive_examples),\n",
    "        \"split\": \"train\"\n",
    "    }\n",
    "\n",
    "    my_artifact.metadata.update(metadata_dict)\n",
    "\n",
    "    # Log the artifact to the run\n",
    "    wandb.log_artifact(my_artifact)\n",
    "\n",
    "save_pairs_to_wandb(loaded_pairs)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
