{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gzXEp7L7_mxK"
      },
      "source": [
        "# FinGPT Test: Twitter Financial News Sentiment (TFNS)\n",
        "\n",
        "This notebook demonstrates how to test FinGPTon the Twitter Financial News Sentiment dataset."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZuQCI9X0_mxL"
      },
      "source": [
        "## 1. Install Dependencies"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vlxuSeEl_mxL"
      },
      "outputs": [],
      "source": [
        "!pip install transformers==4.32.0 peft==0.5.0 datasets accelerate bitsandbytes sentencepiece tqdm scikit-learn pandas matplotlib seaborn"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WmCujtNe_mxM"
      },
      "source": [
        "## 2. Clone the FinGPT Repository"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qBePNiVP_mxM"
      },
      "outputs": [],
      "source": [
        "!git clone https://github.com/AI4Finance-Foundation/FinGPT.git\n",
        "%cd FinGPT"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qNBwYi0n_mxM"
      },
      "source": [
        "## 3. Download the TFNS Dataset"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KvcoJS50_mxN"
      },
      "outputs": [],
      "source": [
        "import datasets\n",
        "from pathlib import Path\n",
        "\n",
        "data_dir = Path('./fingpt/FinGPT_Benchmark/data')\n",
        "data_dir.mkdir(parents=True, exist_ok=True)\n",
        "\n",
        "print(\"Downloading Twitter Financial News Sentiment dataset...\")\n",
        "dataset = datasets.load_dataset('zeroshot/twitter-financial-news-sentiment')\n",
        "\n",
        "# Save the dataset to disk\n",
        "save_path = str(data_dir / \"twitter-financial-news-sentiment\")\n",
        "print(f\"Saving dataset to {save_path}\")\n",
        "dataset.save_to_disk(save_path)\n",
        "print(\"Dataset download complete!\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SjkYtGkQ_mxN"
      },
      "source": [
        "## 4. Testing Module for TFNS"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Zr0Klkp8_mxN"
      },
      "outputs": [],
      "source": [
        "%%writefile fingpt/FinGPT_Benchmark/benchmarks/tfns.py\n",
        "import warnings\n",
        "warnings.filterwarnings(\"ignore\")\n",
        "\n",
        "from sklearn.metrics import accuracy_score, f1_score\n",
        "from datasets import load_dataset, load_from_disk\n",
        "from tqdm import tqdm\n",
        "import datasets\n",
        "import torch\n",
        "from pathlib import Path\n",
        "\n",
        "dic = {\n",
        "    0:\"negative\",\n",
        "    1:'positive',\n",
        "    2:'neutral',\n",
        "}\n",
        "\n",
        "def format_example(example: dict) -> dict:\n",
        "    context = f\"Instruction: {example['instruction']}\\n\"\n",
        "    if example.get(\"input\"):\n",
        "        context += f\"Input: {example['input']}\\n\"\n",
        "    context += \"Answer: \"\n",
        "    target = example[\"output\"]\n",
        "    return {\"context\": context, \"target\": target}\n",
        "\n",
        "def change_target(x):\n",
        "    if 'positive' in x or 'Positive' in x:\n",
        "        return 'positive'\n",
        "    elif 'negative' in x or 'Negative' in x:\n",
        "        return 'negative'\n",
        "    else:\n",
        "        return 'neutral'\n",
        "\n",
        "def test_tfns(args, model, tokenizer, prompt_fun=None):\n",
        "    batch_size = args.batch_size\n",
        "    # dataset = load_dataset('zeroshot/twitter-financial-news-sentiment')\n",
        "    dataset = load_from_disk(Path(__file__).parent.parent / 'data/twitter-financial-news-sentiment')\n",
        "    dataset = dataset['validation']\n",
        "    dataset = dataset.to_pandas()\n",
        "    dataset['label'] = dataset['label'].apply(lambda x:dic[x])\n",
        "\n",
        "    if prompt_fun is None:\n",
        "        dataset[\"instruction\"] = 'What is the sentiment of this tweet? Please choose an answer from {negative/neutral/positive}.'\n",
        "    else:\n",
        "        dataset[\"instruction\"] = dataset.apply(prompt_fun, axis = 1)\n",
        "\n",
        "    dataset.columns = ['input', 'output', 'instruction']\n",
        "    dataset[[\"context\",\"target\"]] = dataset.apply(format_example, axis = 1, result_type=\"expand\")\n",
        "\n",
        "    print(f\"\\n\\nPrompt example:\\n{dataset['context'][0]}\\n\\n\")\n",
        "\n",
        "    context = dataset['context'].tolist()\n",
        "\n",
        "    total_steps = dataset.shape[0]//batch_size + 1\n",
        "    print(f\"Total len: {len(context)}. Batchsize: {batch_size}. Total steps: {total_steps}\")\n",
        "\n",
        "\n",
        "    out_text_list = []\n",
        "    for i in tqdm(range(total_steps)):\n",
        "        tmp_context = context[i* batch_size:(i+1)* batch_size]\n",
        "        if not tmp_context:\n",
        "            continue\n",
        "        tokens = tokenizer(tmp_context, return_tensors='pt', padding=True, max_length=512, return_token_type_ids=False)\n",
        "        # tokens.pop('token_type_ids')\n",
        "        for k in tokens.keys():\n",
        "            tokens[k] = tokens[k].cuda()\n",
        "        res = model.generate(**tokens, max_length=512, eos_token_id=tokenizer.eos_token_id)\n",
        "        res_sentences = [tokenizer.decode(i, skip_special_tokens=True) for i in res]\n",
        "        if i % 20 == 0:\n",
        "            tqdm.write(f'Example {i}: {res_sentences[0]}')\n",
        "        out_text = [o.split(\"Answer: \")[1] if \"Answer: \" in o else o for o in res_sentences]\n",
        "        out_text_list += out_text\n",
        "        torch.cuda.empty_cache()\n",
        "\n",
        "    dataset[\"out_text\"] = out_text_list\n",
        "    dataset[\"new_target\"] = dataset[\"target\"].apply(change_target)\n",
        "    dataset[\"new_out\"] = dataset[\"out_text\"].apply(change_target)\n",
        "\n",
        "    acc = accuracy_score(dataset[\"new_target\"], dataset[\"new_out\"])\n",
        "    f1_macro = f1_score(dataset[\"new_target\"], dataset[\"new_out\"], average = \"macro\")\n",
        "    f1_micro = f1_score(dataset[\"new_target\"], dataset[\"new_out\"], average = \"micro\")\n",
        "    f1_weighted = f1_score(dataset[\"new_target\"], dataset[\"new_out\"], average = \"weighted\")\n",
        "\n",
        "    print(f\"Acc: {acc}. F1 macro: {f1_macro}. F1 micro: {f1_micro}. F1 weighted (BloombergGPT): {f1_weighted}. \")\n",
        "\n",
        "    dataset.to_csv('tfns_results.csv', index=False)\n",
        "    print(\"Results saved to tfns_results.csv\")\n",
        "\n",
        "    return dataset"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "L5DzncH3_mxN"
      },
      "source": [
        "## 5. Create Benchmarking Script"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "49pCYdPa_mxO"
      },
      "outputs": [],
      "source": [
        "%%writefile fingpt/FinGPT_Benchmark/benchmarks/benchmarks.py\n",
        "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
        "from peft import PeftModel, get_peft_model, LoraConfig, TaskType\n",
        "import torch\n",
        "import argparse\n",
        "\n",
        "from tfns import test_tfns\n",
        "\n",
        "import sys\n",
        "sys.path.append('../')\n",
        "from utils import *\n",
        "\n",
        "def main(args):\n",
        "    if args.from_remote:\n",
        "        model_name = parse_model_name(args.base_model, args.from_remote)\n",
        "    else:\n",
        "        model_name = '../' + parse_model_name(args.base_model)\n",
        "\n",
        "    model = AutoModelForCausalLM.from_pretrained(\n",
        "        model_name, trust_remote_code=True,\n",
        "        # load_in_8bit=True\n",
        "        device_map=\"auto\",\n",
        "        # fp16=True\n",
        "    )\n",
        "    model.model_parallel = True\n",
        "\n",
        "    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n",
        "\n",
        "    # tokenizer.pad_token_id = tokenizer.eos_token_id\n",
        "\n",
        "    tokenizer.padding_side = \"left\"\n",
        "    if args.base_model == 'qwen':\n",
        "        tokenizer.eos_token_id = tokenizer.convert_tokens_to_ids('<|endoftext|>')\n",
        "        tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids('<|extra_0|>')\n",
        "    if not tokenizer.pad_token or tokenizer.pad_token_id == tokenizer.eos_token_id:\n",
        "        tokenizer.add_special_tokens({'pad_token': '[PAD]'})\n",
        "        model.resize_token_embeddings(len(tokenizer))\n",
        "\n",
        "    print(f'pad: {tokenizer.pad_token_id}, eos: {tokenizer.eos_token_id}')\n",
        "\n",
        "    model = PeftModel.from_pretrained(model, args.peft_model)\n",
        "    model = model.eval()\n",
        "\n",
        "    with torch.no_grad():\n",
        "        for data in args.dataset.split(','):\n",
        "            if data == 'tfns':\n",
        "                test_tfns(args, model, tokenizer)\n",
        "            else:\n",
        "                raise ValueError('undefined dataset.')\n",
        "\n",
        "    print('Evaluation Ends.')\n",
        "\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "\n",
        "    parser = argparse.ArgumentParser()\n",
        "    parser.add_argument(\"--dataset\", required=True, type=str)\n",
        "    parser.add_argument(\"--base_model\", required=True, type=str, choices=['chatglm2', 'llama2', 'llama2-13b', 'llama2-13b-nr', 'baichuan', 'falcon', 'internlm', 'qwen', 'mpt', 'bloom'])\n",
        "    parser.add_argument(\"--peft_model\", required=True, type=str)\n",
        "    parser.add_argument(\"--max_length\", default=512, type=int)\n",
        "    parser.add_argument(\"--batch_size\", default=4, type=int, help=\"The train batch size per device\")\n",
        "    parser.add_argument(\"--instruct_template\", default='default')\n",
        "    parser.add_argument(\"--from_remote\", default=False, type=bool)\n",
        "\n",
        "    args = parser.parse_args()\n",
        "\n",
        "    print(args.base_model)\n",
        "    print(args.peft_model)\n",
        "\n",
        "    main(args)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ne8oJBxV_mxO"
      },
      "source": [
        "## 6. Create Utils Module"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "D9-3XebT_mxO"
      },
      "outputs": [],
      "source": [
        "%%writefile fingpt/FinGPT_Benchmark/utils.py\n",
        "def parse_model_name(base_model, from_remote=False):\n",
        "    model_map = {\n",
        "        'chatglm2': 'THUDM/chatglm2-6b',\n",
        "        'llama2': 'meta-llama/Llama-2-7b-hf',\n",
        "        'llama2-13b': 'meta-llama/Llama-2-13b-hf',\n",
        "        'llama2-13b-nr': 'NousResearch/Llama-2-13b-hf',\n",
        "        'baichuan': 'baichuan-inc/Baichuan-7B',\n",
        "        'falcon': 'tiiuae/falcon-7b',\n",
        "        'internlm': 'internlm/internlm-7b',\n",
        "        'qwen': 'Qwen/Qwen-7B',\n",
        "        'mpt': 'mosaicml/mpt-7b',\n",
        "        'bloom': 'bigscience/bloom-7b1',\n",
        "    }\n",
        "    if base_model not in model_map:\n",
        "        raise ValueError(f\"Unknown base model: {base_model}\")\n",
        "    return model_map[base_model]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "MKGku6zYBHS1"
      },
      "outputs": [],
      "source": [
        "from huggingface_hub import login\n",
        "login(token=\"token\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HDRhXEXZ_mxO"
      },
      "source": [
        "## 7. Run the TFNS Benchmark Test\n",
        "\n",
        "Now run the benchmark test."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "3l-4zlmr_mxO",
        "outputId": "0a4565c1-e105-4693-b530-77c6349300ef"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Example 420: Instruction: What is the sentiment of this tweet? Please choose an answer from {negative/neutral/positive}.\n",
            "Input: Nasdaq-100 futures fall 20 points, or 0.2%\n",
            "Answer:  negative\n",
            "Example 440: Instruction: What is the sentiment of this tweet? Please choose an answer from {negative/neutral/positive}.\n",
            "Input: Market Recap: Friday, Dec. 13\n",
            "Answer:  neutral\n",
            "Example 460: Instruction: What is the sentiment of this tweet? Please choose an answer from {negative/neutral/positive}.\n",
            "Input: Tory manifesto leaves door open to no-deal Brexit https://t.co/Qca426gTj9\n",
            "Answer:  neutral\n",
            "Example 480: Instruction: What is the sentiment of this tweet? Please choose an answer from {negative/neutral/positive}.\n",
            "Input: Companies Like Juggernaut Exploration (CVE:JUGR) Can Be Considered Quite Risky\n",
            "Answer:  neutral\n",
            "Example 500: Instruction: What is the sentiment of this tweet? Please choose an answer from {negative/neutral/positive}.\n",
            "Input: If You Had Bought Century Communities (NYSE:CCS) Stock Five Years Ago, You Could Pocket A 61% Gain Today\n",
            "Answer:  positive\n",
            "Example 520: Instruction: What is the sentiment of this tweet? Please choose an answer from {negative/neutral/positive}.\n",
            "Input: 3 Stocks to Enhance the Quality of Your Portfolio\n",
            "Answer:  neutral\n",
            "Example 540: Instruction: What is the sentiment of this tweet? Please choose an answer from {negative/neutral/positive}.\n",
            "Input: Is Autolus Therapeutics plc (AUTL) A Good Stock To Buy?\n",
            "Answer:  neutral\n",
            "Example 560: Instruction: What is the sentiment of this tweet? Please choose an answer from {negative/neutral/positive}.\n",
            "Input: $LNN - Lindsay -2% as COVID-19 impact 'remains uncertain' https://t.co/X5HppJ6i0k\n",
            "Answer:  negative\n",
            "Example 580: Instruction: What is the sentiment of this tweet? Please choose an answer from {negative/neutral/positive}.\n",
            "Input: $SRNE (+36.3% pre) Sorrento Therapeutics shares soar 25% after it rejects offer from two companies - MW https://t.co/5WB1nxwKVn\n",
            "Answer:  positive\n",
            "100% 598/598 [03:28<00:00,  2.87it/s]\n",
            "Acc: 0.8844221105527639. F1 macro: 0.8522607292145157. F1 micro: 0.8844221105527639. F1 weighted (BloombergGPT): 0.8835713702366452. \n",
            "Results saved to tfns_results.csv\n",
            "Evaluation Ends.\n"
          ]
        }
      ],
      "source": [
        "%cd /content/FinGPT/fingpt/FinGPT_Benchmark/benchmarks\n",
        "\n",
        "base_model = 'llama2'\n",
        "# The FinGPT adapter model\n",
        "peft_model = 'FinGPT/fingpt-mt_llama2-7b_lora'\n",
        "batch_size = 4\n",
        "max_length = 512\n",
        "\n",
        "!python benchmarks.py --dataset tfns --base_model {base_model} --peft_model {peft_model} --batch_size {batch_size} --max_length {max_length} --from_remote True"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "A100",
      "machine_shape": "hm",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.10.12"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
