{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gzXEp7L7_mxK"
      },
      "source": [
        "# FinGPT Test: News with GPT Instructions (NWGI)\n",
        "\n",
        "This notebook demonstrates how to test FinGPT on the News with GPT Instructions sentiment analysis dataset."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZuQCI9X0_mxL"
      },
      "source": [
        "## 1. Install Dependencies"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vlxuSeEl_mxL"
      },
      "outputs": [],
      "source": [
        "!pip install transformers==4.32.0 peft==0.5.0 datasets accelerate bitsandbytes sentencepiece tqdm scikit-learn pandas matplotlib seaborn"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WmCujtNe_mxM"
      },
      "source": [
        "## 2. Clone the FinGPT Repository"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qBePNiVP_mxM"
      },
      "outputs": [],
      "source": [
        "!git clone https://github.com/AI4Finance-Foundation/FinGPT.git\n",
        "%cd FinGPT"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qNBwYi0n_mxM"
      },
      "source": [
        "## 3. Download the NWGI Dataset"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KvcoJS50_mxN"
      },
      "outputs": [],
      "source": [
        "import datasets\n",
        "from pathlib import Path\n",
        "data_dir = Path('./fingpt/FinGPT_Benchmark/data')\n",
        "data_dir.mkdir(parents=True, exist_ok=True)\n",
        "\n",
        "print(\"Downloading News with GPT Instructions dataset...\")\n",
        "dataset = datasets.load_dataset('oliverwang15/news_with_gpt_instructions')\n",
        "\n",
        "# Save the dataset to disk\n",
        "save_path = str(data_dir / \"news_with_gpt_instructions\")\n",
        "print(f\"Saving dataset to {save_path}\")\n",
        "dataset.save_to_disk(save_path)\n",
        "print(\"Dataset download complete!\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SjkYtGkQ_mxN"
      },
      "source": [
        "## 4. Testing Module for NWGI"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Zr0Klkp8_mxN"
      },
      "outputs": [],
      "source": [
        "%%writefile fingpt/FinGPT_Benchmark/benchmarks/nwgi.py\n",
        "import warnings\n",
        "warnings.filterwarnings(\"ignore\")\n",
        "\n",
        "from sklearn.metrics import accuracy_score, f1_score\n",
        "from datasets import load_dataset, load_from_disk\n",
        "from tqdm import tqdm\n",
        "import datasets\n",
        "import torch\n",
        "from pathlib import Path\n",
        "\n",
        "dic = {\n",
        "    'strong negative':\"negative\",\n",
        "    'moderately negative':\"negative\",\n",
        "    'mildly negative':\"neutral\",\n",
        "    'strong positive':\"positive\",\n",
        "    'moderately positive':\"positive\",\n",
        "    'mildly positive':'neutral',\n",
        "    'neutral':'neutral',\n",
        "}\n",
        "\n",
        "def format_example(example: dict) -> dict:\n",
        "    context = f\"Instruction: {example['instruction']}\\n\"\n",
        "    if example.get(\"input\"):\n",
        "        context += f\"Input: {example['input']}\\n\"\n",
        "    context += \"Answer: \"\n",
        "    target = example[\"output\"]\n",
        "    return {\"context\": context, \"target\": target}\n",
        "\n",
        "def change_target(x):\n",
        "    if 'positive' in x or 'Positive' in x:\n",
        "        return 'positive'\n",
        "    elif 'negative' in x or 'Negative' in x:\n",
        "        return 'negative'\n",
        "    else:\n",
        "        return 'neutral'\n",
        "\n",
        "def test_nwgi(args, model, tokenizer, prompt_fun=None):\n",
        "    batch_size = args.batch_size\n",
        "    # dataset = load_dataset('oliverwang15/news_with_gpt_instructions')\n",
        "    dataset = load_from_disk(Path(__file__).parent.parent / 'data/news_with_gpt_instructions/')\n",
        "    dataset = dataset['train']\n",
        "    dataset = dataset.to_pandas()\n",
        "    dataset['output'] = dataset['label'].apply(lambda x:dic[x])\n",
        "\n",
        "    if prompt_fun is None:\n",
        "        dataset[\"instruction\"] = \"What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\"\n",
        "    else:\n",
        "        dataset[\"instruction\"] = dataset.apply(prompt_fun, axis = 1)\n",
        "    dataset[\"input\"] = dataset[\"news\"]\n",
        "\n",
        "    dataset = dataset[['input', 'output', 'instruction']]\n",
        "    dataset[[\"context\",\"target\"]] = dataset.apply(format_example, axis = 1, result_type=\"expand\")\n",
        "\n",
        "    print(f\"\\n\\nPrompt example:\\n{dataset['context'][0]}\\n\\n\")\n",
        "\n",
        "    context = dataset['context'].tolist()\n",
        "\n",
        "    total_steps = dataset.shape[0]//batch_size + 1\n",
        "    print(f\"Total len: {len(context)}. Batchsize: {batch_size}. Total steps: {total_steps}\")\n",
        "\n",
        "    out_text_list = []\n",
        "    for i in tqdm(range(total_steps)):\n",
        "        tmp_context = context[i* batch_size:(i+1)* batch_size]\n",
        "        if not tmp_context:\n",
        "            continue\n",
        "        tokens = tokenizer(tmp_context, return_tensors='pt', padding=True, max_length=512, return_token_type_ids=False)\n",
        "        for k in tokens.keys():\n",
        "            tokens[k] = tokens[k].cuda()\n",
        "        res = model.generate(**tokens, max_length=512, eos_token_id=tokenizer.eos_token_id)\n",
        "        res_sentences = [tokenizer.decode(i, skip_special_tokens=True) for i in res]\n",
        "        if i % 20 == 0:\n",
        "            tqdm.write(f'Example {i}: {res_sentences[0]}')\n",
        "        out_text = [o.split(\"Answer: \")[1] if \"Answer: \" in o else o for o in res_sentences]\n",
        "        out_text_list += out_text\n",
        "        torch.cuda.empty_cache()\n",
        "\n",
        "    dataset[\"out_text\"] = out_text_list\n",
        "    dataset[\"new_target\"] = dataset[\"target\"].apply(change_target)\n",
        "    dataset[\"new_out\"] = dataset[\"out_text\"].apply(change_target)\n",
        "\n",
        "    acc = accuracy_score(dataset[\"new_target\"], dataset[\"new_out\"])\n",
        "    f1_macro = f1_score(dataset[\"new_target\"], dataset[\"new_out\"], average = \"macro\")\n",
        "    f1_micro = f1_score(dataset[\"new_target\"], dataset[\"new_out\"], average = \"micro\")\n",
        "    f1_weighted = f1_score(dataset[\"new_target\"], dataset[\"new_out\"], average = \"weighted\")\n",
        "\n",
        "    print(f\"Acc: {acc}. F1 macro: {f1_macro}. F1 micro: {f1_micro}. F1 weighted (BloombergGPT): {f1_weighted}. \")\n",
        "\n",
        "    # Save results to CSV\n",
        "    dataset.to_csv('nwgi_results.csv', index=False)\n",
        "    print(\"Results saved to nwgi_results.csv\")\n",
        "\n",
        "    return dataset"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "L5DzncH3_mxN"
      },
      "source": [
        "## 5. Update Benchmarking Script"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "49pCYdPa_mxO"
      },
      "outputs": [],
      "source": [
        "%%writefile fingpt/FinGPT_Benchmark/benchmarks/benchmarks.py\n",
        "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
        "from peft import PeftModel, get_peft_model, LoraConfig, TaskType  # 0.4.0\n",
        "import torch\n",
        "import argparse\n",
        "\n",
        "from nwgi import test_nwgi\n",
        "\n",
        "import sys\n",
        "sys.path.append('../')\n",
        "from utils import *\n",
        "\n",
        "def main(args):\n",
        "    if args.from_remote:\n",
        "        model_name = parse_model_name(args.base_model, args.from_remote)\n",
        "    else:\n",
        "        model_name = '../' + parse_model_name(args.base_model)\n",
        "\n",
        "    model = AutoModelForCausalLM.from_pretrained(\n",
        "        model_name, trust_remote_code=True,\n",
        "        # load_in_8bit=True\n",
        "        device_map=\"auto\",\n",
        "        # fp16=True\n",
        "    )\n",
        "    model.model_parallel = True\n",
        "\n",
        "    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n",
        "\n",
        "    # tokenizer.pad_token_id = tokenizer.eos_token_id\n",
        "\n",
        "    tokenizer.padding_side = \"left\"\n",
        "    if args.base_model == 'qwen':\n",
        "        tokenizer.eos_token_id = tokenizer.convert_tokens_to_ids('<|endoftext|>')\n",
        "        tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids('<|extra_0|>')\n",
        "    if not tokenizer.pad_token or tokenizer.pad_token_id == tokenizer.eos_token_id:\n",
        "        tokenizer.add_special_tokens({'pad_token': '[PAD]'})\n",
        "        model.resize_token_embeddings(len(tokenizer))\n",
        "\n",
        "    print(f'pad: {tokenizer.pad_token_id}, eos: {tokenizer.eos_token_id}')\n",
        "\n",
        "    model = PeftModel.from_pretrained(model, args.peft_model)\n",
        "    model = model.eval()\n",
        "\n",
        "    with torch.no_grad():\n",
        "        for data in args.dataset.split(','):\n",
        "            if data == 'nwgi':\n",
        "                test_nwgi(args, model, tokenizer)\n",
        "            else:\n",
        "                raise ValueError('undefined dataset.')\n",
        "\n",
        "    print('Evaluation Ends.')\n",
        "\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "\n",
        "    parser = argparse.ArgumentParser()\n",
        "    parser.add_argument(\"--dataset\", required=True, type=str)\n",
        "    parser.add_argument(\"--base_model\", required=True, type=str, choices=['chatglm2', 'llama2', 'llama2-13b', 'llama2-13b-nr', 'baichuan', 'falcon', 'internlm', 'qwen', 'mpt', 'bloom'])\n",
        "    parser.add_argument(\"--peft_model\", required=True, type=str)\n",
        "    parser.add_argument(\"--max_length\", default=512, type=int)\n",
        "    parser.add_argument(\"--batch_size\", default=4, type=int, help=\"The train batch size per device\")\n",
        "    parser.add_argument(\"--instruct_template\", default='default')\n",
        "    parser.add_argument(\"--from_remote\", default=False, type=bool)\n",
        "\n",
        "    args = parser.parse_args()\n",
        "\n",
        "    print(args.base_model)\n",
        "    print(args.peft_model)\n",
        "\n",
        "    main(args)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ne8oJBxV_mxO"
      },
      "source": [
        "## 6. Create Utils Module"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "D9-3XebT_mxO"
      },
      "outputs": [],
      "source": [
        "%%writefile fingpt/FinGPT_Benchmark/utils.py\n",
        "def parse_model_name(base_model, from_remote=False):\n",
        "    model_map = {\n",
        "        'chatglm2': 'THUDM/chatglm2-6b',\n",
        "        'llama2': 'meta-llama/Llama-2-7b-hf',\n",
        "        'llama2-13b': 'meta-llama/Llama-2-13b-hf',\n",
        "        'llama2-13b-nr': 'NousResearch/Llama-2-13b-hf',\n",
        "        'baichuan': 'baichuan-inc/Baichuan-7B',\n",
        "        'falcon': 'tiiuae/falcon-7b',\n",
        "        'internlm': 'internlm/internlm-7b',\n",
        "        'qwen': 'Qwen/Qwen-7B',\n",
        "        'mpt': 'mosaicml/mpt-7b',\n",
        "        'bloom': 'bigscience/bloom-7b1',\n",
        "    }\n",
        "    if base_model not in model_map:\n",
        "        raise ValueError(f\"Unknown base model: {base_model}\")\n",
        "    return model_map[base_model]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "MKGku6zYBHS1"
      },
      "outputs": [],
      "source": [
        "from huggingface_hub import login\n",
        "login(token=\"token\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HDRhXEXZ_mxO"
      },
      "source": [
        "## 7. Run the NWGI Benchmark Test"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "3l-4zlmr_mxO",
        "outputId": "7077cfa3-65d5-4d42-df78-8d5fd1a1a042"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Example 3520: Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
            "Input: Focus on the business, not the stock split.\n",
            "Answer:  neutral\n",
            "Example 3540: Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
            "Input: Long-term investors looking for top-notch growth stocks may want to consider these seven world-class companies as core holdings right now. The post 7 Great Growth Stocks to Add to Your January Buy List appeared first on InvestorPlace.\n",
            "Answer:  positive\n",
            "Example 3560: Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
            "Input: Texas Instruments reported solid Q2 results on Tuesday, beating analyst expectations by a significant margin. Cash flow decreased substantially compared to the same period last year. There is an explanation for that.\n",
            "Answer:  negative\n",
            "Example 3580: Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
            "Input: Former United Airlines CEO Oscar Munoz joins Yahoo Finance Live to discuss airline earnings, the travel recovery, unruly flight passengers, M&A activity among air carriers, and crew shortages in the industry.\n",
            "Answer:  neutral\n",
            "Example 3600: Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
            "Input: When you think of traditional challenges that Apple often finds itself facing, areas such as innovation and employee relations come to mind, but lately it's encountering a new one. Recently Apple has turned its streaming division into a diamond in the rough and racked up multiple Emmy and Oscar wins, which has raised the bar for the division.\n",
            "Answer:  positive\n",
            "Example 3620: Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
            "Input: Ike Boruchow, Wells Fargo senior retail analyst, joins 'Squawk on the Street' to discuss what it'll take for Boruchow to get more constructive on Lululemon's rating, how the holidays will fare for Lululemon, and more.\n",
            "Answer:  neutral\n",
            "Example 3640: Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
            "Input: German engineering and technology firm Siemens has teamed up with chip maker Nvidia to launch a metaverse where companies can reduce the cost of running their plants and accelerate new production designs. As Reuters reported Wednesday (June 29), the agreement is the foundation of Siemens Xcelerator, the company's new open, cloud-based digital platform, aimed at […]\n",
            "Answer:  positive\n",
            "Example 3660: Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
            "Input: The Communications Workers of America is escalating its complaints to the National Labor Relations Board.\n",
            "Answer:  negative\n",
            "Example 3680: Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
            "Input: Amazon.com, Inc. is the largest ecommerce company in the world and continues to build out its Prime ecosystem. Its cloud business, AWS, is the market leader and continues to grow at a rapid pace, while driving the majority of Amazon's profits.\n",
            "Answer:  positive\n",
            "Example 3700: Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
            "Input: Alphabet Inc's Google was fined 9.36 billion Indian rupees ($113.04 million) on Tuesday as India concluded yet another antitrust probe this month, finding the U.S. tech firm guilty of abusing its market position to promote its payments app and in-app payment system.\n",
            "Answer:  negative\n",
            "Example 3720: Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
            "Input: Nvidia (NVDA) delivered earnings and revenue surprises of -8.93% and 0.03%, respectively, for the quarter ended July 2022. Do the numbers hold clues to what lies ahead for the stock?\n",
            "Answer:  neutral\n",
            "Example 3740: Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
            "Input: Tech reporter Allie Garfinkle looks at the sports streaming landscape after the NBA announces a new partnership with Microsoft.\n",
            "Answer:  positive\n",
            "Example 3760: Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
            "Input: The Zacks Retail Pharmacy and Drugstore industry players like CVS, WBA and RAD are likely to gain from the growing demand for digital healthcare support. However, reimbursement challenges are hurting overall industry health.\n",
            "Answer:  neutral\n",
            "Example 3780: Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
            "Input: It's been three months since shares of e-commerce giant Amazon (NASDAQ: AMZN) printed fresh all time highs, a long time for investors who were getting used to it being a daily thing for much of last year.\n",
            "Answer:  neutral\n",
            "Example 3800: Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
            "Input: A positive note from an analyst sent this stock higher today.\n",
            "Answer:  positive\n",
            "Example 3820: Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
            "Input: Foxconn's Zhengzhou facility may have ended its closed-loop arrangement due to COVID lockdowns. However, China is now facing a surge in COVID cases. Hence, Apple's COVID nightmare in China has likely not ended. Further disruption could impact its iPhone Pro series production significantly.\n",
            "Answer:  negative\n",
            "Example 3840: Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
            "Input: Google agreed to a $391.5 million settlement with 40 states over location tracking, Oregon Attorney General Ellen Rosenblum announced Monday.\n",
            "Answer:  neutral\n",
            "Example 3860: Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
            "Input: Workers at a Starbucks store in Knoxville, Tennessee voted 8-7 to form a union on Tuesday, becoming the first store in the South to unionize, a spokesperson for SB Workers United said.\n",
            "Answer:  positive\n",
            "Example 3880: Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
            "Input: FORT WORTH, Texas, May 20, 2022 (GLOBE NEWSWIRE) -- American Airlines Group Inc. (NASDAQ: AAL) Chief Commercial Officer Vasu Raja will present at the 2022 Wolfe Research Global Transportation & Industrials Conference on Thursday, May 26, at 12:55 p.m. CT.\n",
            "Answer:  neutral\n",
            "Example 3900: Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
            "Input: Microsoft will power Netflix's ads, but the competition will be stiff.\n",
            "Answer:  neutral\n",
            "Example 3920: Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
            "Input: A bear market isn't enough to chase billionaire hedge fund manager Ken Griffin of Citadel to the sideline.\n",
            "Answer:  neutral\n",
            "Example 3940: Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
            "Input: The company was hit hard by disappointing Q3 guidance.\n",
            "Answer:  negative\n",
            "Example 3960: Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
            "Input: Software growth stocks are still looking for a catalyst as the sector lags the S&P 500. Estimates could be revised down in the next quarterly earnings reports.\n",
            "Answer:  negative\n",
            "Example 3980: Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
            "Input: There are opportunities to consider buying film companies like Lions Gate. The advertising arm for Netflix has tremendous potential as advertisers now have a new way to reach the demographic of adults aged 18 to 49.\n",
            "Answer:  positive\n",
            "Example 4000: Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
            "Input: Tech stocks are falling, but there are many winners in today's market.\n",
            "Answer:  neutral\n",
            "Example 4020: Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
            "Input: It's been over two months since Netflix, Inc. (NASDAQ:NFLX) sold off post-earnings and sent shockwaves through Wall Street.\n",
            "Answer:  negative\n",
            "Example 4040: Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
            "Input: For years there has been speculation that Amazon might be broken into pieces. One reason is that Amazon is worth more than the sum of its pieces to shareholders.\n",
            "Answer:  neutral\n",
            "100% 4047/4047 [29:08<00:00,  2.31it/s]\n",
            "Acc: 0.6557093425605537. F1 macro: 0.6615313458649182. F1 micro: 0.6557093425605537. F1 weighted (BloombergGPT): 0.6513421774743056. \n",
            "Results saved to nwgi_results.csv\n",
            "Evaluation Ends.\n"
          ]
        }
      ],
      "source": [
        "%cd /content/FinGPT/fingpt/FinGPT_Benchmark/benchmarks\n",
        "\n",
        "base_model = 'llama2'\n",
        "# The FinGPT adapter model\n",
        "peft_model = 'FinGPT/fingpt-mt_llama2-7b_lora'\n",
        "batch_size = 4\n",
        "max_length = 512\n",
        "\n",
        "!python benchmarks.py --dataset nwgi --base_model {base_model} --peft_model {peft_model} --batch_size {batch_size} --max_length {max_length} --from_remote True"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "A100",
      "machine_shape": "hm",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.10.12"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
