{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "gzXEp7L7_mxK"
   },
   "source": [
    "# FinGPT Test: Financial Headline Analysis\n",
    "\n",
    "This notebook demonstrates how to test FinGPT on the Financial Headline Analysis dataset."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ZuQCI9X0_mxL"
   },
   "source": [
    "## 1. Install Dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "vlxuSeEl_mxL"
   },
   "outputs": [],
   "source": [
    "!pip install transformers==4.32.0 peft==0.5.0 datasets accelerate bitsandbytes sentencepiece tqdm scikit-learn pandas matplotlib seaborn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "WmCujtNe_mxM"
   },
   "source": [
    "## 2. Clone the FinGPT Repository"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "qBePNiVP_mxM"
   },
   "outputs": [],
   "source": [
    "!git clone https://github.com/AI4Finance-Foundation/FinGPT.git\n",
    "%cd FinGPT"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "qNBwYi0n_mxM"
   },
   "source": [
    "## 3. Download the Financial Headline Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "id": "KvcoJS50_mxN"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading Financial Headline dataset...\n",
      "Saving dataset to fingpt/FinGPT_Benchmark/data/fingpt-headline-instruct\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a80433f09ed54741add420604b0ac245",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/82161 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6db799f941e74ac9be9d8d4cfa1a4a6f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/20547 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset download complete!\n"
     ]
    }
   ],
   "source": [
    "import datasets\n",
    "from pathlib import Path\n",
    "\n",
    "data_dir = Path('./fingpt/FinGPT_Benchmark/data')\n",
    "data_dir.mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "print(\"Downloading Financial Headline dataset...\")\n",
    "try:\n",
    "    dataset = datasets.load_dataset(\"FinGPT/fingpt-headline\")\n",
    "\n",
    "    # Save the dataset to disk\n",
    "    save_path = str(data_dir / \"fingpt-headline-instruct\")\n",
    "    print(f\"Saving dataset to {save_path}\")\n",
    "    dataset.save_to_disk(save_path)\n",
    "    print(\"Dataset download complete!\")\n",
    "except Exception as e:\n",
    "    print(f\"Error loading dataset: {e}\")\n",
    "    print(\"You may need to manually download the Headline dataset and place it in the fingpt/FinGPT_Benchmark/data directory.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "SjkYtGkQ_mxN"
   },
   "source": [
    "## 4. Testing Module for Headline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Zr0Klkp8_mxN"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overwriting fingpt/FinGPT_Benchmark/benchmarks/headline.py\n"
     ]
    }
   ],
   "source": [
    "%%writefile fingpt/FinGPT_Benchmark/benchmarks/headline.py\n",
    "from sklearn.metrics import accuracy_score, f1_score, classification_report\n",
    "from datasets import load_dataset, load_from_disk\n",
    "from tqdm import tqdm\n",
    "import datasets\n",
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "from functools import partial\n",
    "from pathlib import Path\n",
    "\n",
    "import sys\n",
    "sys.path.append('../')\n",
    "\n",
    "def binary2multi(dataset):\n",
    "    pred, label = [], []\n",
    "    tmp_pred, tmp_label = [], []\n",
    "    for i, row in dataset.iterrows():\n",
    "        tmp_pred.append(row['pred'])\n",
    "        tmp_label.append(row['label'])\n",
    "        if (i + 1) % 9 == 0:\n",
    "            pred.append(tmp_pred)\n",
    "            label.append(tmp_label)\n",
    "            tmp_pred, tmp_label = [], []\n",
    "    return pred, label\n",
    "\n",
    "\n",
    "def map_output(feature):\n",
    "    pred = 1 if 'yes' in feature['out_text'].lower() else 0\n",
    "    label = 1 if 'yes' in feature['output'].lower() else 0\n",
    "    return {'label': label, 'pred': pred}\n",
    "\n",
    "\n",
    "def test_mapping(args, example):\n",
    "    prompt = f\"Instruction: {example['instruction']}\\nInput: {example['input']}\\nAnswer: \"\n",
    "    return {\"prompt\": prompt}\n",
    "\n",
    "\n",
    "def test_headline(args, model, tokenizer):\n",
    "    print(\"Loading Financial Headline dataset...\")\n",
    "    # dataset = load_from_disk('../data/fingpt-headline')['test']\n",
    "    dataset = load_from_disk(Path(__file__).parent.parent / 'data/fingpt-headline-instruct')['test']\n",
    "    dataset = dataset.map(partial(test_mapping, args), load_from_cache_file=False)\n",
    "\n",
    "    def collate_fn(batch):\n",
    "        inputs = tokenizer(\n",
    "            [f[\"prompt\"] for f in batch], return_tensors='pt',\n",
    "            padding=True, max_length=args.max_length,\n",
    "            return_token_type_ids=False\n",
    "        )\n",
    "        return inputs\n",
    "\n",
    "    dataloader = DataLoader(dataset, batch_size=args.batch_size, collate_fn=collate_fn, shuffle=False)\n",
    "\n",
    "    print(f\"Running inference on {len(dataset)} examples with batch size {args.batch_size}...\")\n",
    "    out_text_list = []\n",
    "    log_interval = max(1, len(dataloader) // 5)\n",
    "\n",
    "    for idx, inputs in enumerate(tqdm(dataloader)):\n",
    "        inputs = {key: value.to(model.device) for key, value in inputs.items()}\n",
    "        res = model.generate(**inputs, max_length=args.max_length, eos_token_id=tokenizer.eos_token_id)\n",
    "        res_sentences = [tokenizer.decode(i, skip_special_tokens=True) for i in res]\n",
    "        if (idx + 1) % log_interval == 0:\n",
    "            tqdm.write(f'Example {idx}: {res_sentences[0]}')\n",
    "        out_text = [o.split(\"Answer: \")[1] if \"Answer: \" in o else o for o in res_sentences]\n",
    "        out_text_list += out_text\n",
    "        torch.cuda.empty_cache()\n",
    "\n",
    "    print(\"Processing results...\")\n",
    "    dataset = dataset.add_column(\"out_text\", out_text_list)\n",
    "    dataset = dataset.map(map_output, load_from_cache_file=False)\n",
    "    dataset = dataset.to_pandas()\n",
    "\n",
    "    dataset.to_csv('headline_results.csv', index=False)\n",
    "    print(\"Results saved to headline_results.csv\")\n",
    "\n",
    "    acc = accuracy_score(dataset[\"label\"], dataset[\"pred\"])\n",
    "    f1 = f1_score(dataset[\"label\"], dataset[\"pred\"], average=\"binary\")\n",
    "\n",
    "    pred, label = binary2multi(dataset)\n",
    "\n",
    "    print(f\"\\n|| Accuracy: {acc:.4f} || F1 (binary): {f1:.4f} ||\\n\")\n",
    "\n",
    "    category_names = [\n",
    "        'price or not', 'price up', 'price stable',\n",
    "        'price down', 'price past', 'price future',\n",
    "        'event past', 'event future', 'asset comp'\n",
    "    ]\n",
    "    print(classification_report(label, pred, digits=4, target_names=category_names))\n",
    "\n",
    "    return dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "L5DzncH3_mxN"
   },
   "source": [
    "## 5. Update Benchmarking Script"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "id": "49pCYdPa_mxO"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overwriting fingpt/FinGPT_Benchmark/benchmarks/benchmarks.py\n"
     ]
    }
   ],
   "source": [
    "%%writefile fingpt/FinGPT_Benchmark/benchmarks/benchmarks.py\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
    "from peft import PeftModel, get_peft_model, LoraConfig, TaskType\n",
    "import torch\n",
    "import argparse\n",
    "\n",
    "from headline import test_headline\n",
    "\n",
    "import sys\n",
    "sys.path.append('../')\n",
    "from utils import *\n",
    "\n",
    "def main(args):\n",
    "    if args.from_remote:\n",
    "        model_name = parse_model_name(args.base_model, args.from_remote)\n",
    "    else:\n",
    "        model_name = '../' + parse_model_name(args.base_model)\n",
    "\n",
    "    print(f\"Loading base model: {model_name}\")\n",
    "    model = AutoModelForCausalLM.from_pretrained(\n",
    "        model_name, trust_remote_code=True,\n",
    "        # load_in_8bit=True\n",
    "        device_map=\"auto\",\n",
    "        # fp16=True\n",
    "    )\n",
    "    model.model_parallel = True\n",
    "\n",
    "    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n",
    "\n",
    "    # tokenizer.pad_token_id = tokenizer.eos_token_id\n",
    "\n",
    "    tokenizer.padding_side = \"left\"\n",
    "    if args.base_model == 'qwen':\n",
    "        tokenizer.eos_token_id = tokenizer.convert_tokens_to_ids('<|endoftext|>')\n",
    "        tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids('<|extra_0|>')\n",
    "    if not tokenizer.pad_token or tokenizer.pad_token_id == tokenizer.eos_token_id:\n",
    "        tokenizer.add_special_tokens({'pad_token': '[PAD]'})\n",
    "        model.resize_token_embeddings(len(tokenizer))\n",
    "\n",
    "    print(f'pad: {tokenizer.pad_token_id}, eos: {tokenizer.eos_token_id}')\n",
    "\n",
    "    print(f\"Loading FinGPT adapter: {args.peft_model}\")\n",
    "    model = PeftModel.from_pretrained(model, args.peft_model)\n",
    "    model = model.eval()\n",
    "\n",
    "    with torch.no_grad():\n",
    "        for data in args.dataset.split(','):\n",
    "            if data == 'headline':\n",
    "                test_headline(args, model, tokenizer)\n",
    "            else:\n",
    "                raise ValueError(f'Undefined dataset: {data}')\n",
    "\n",
    "    print('Evaluation Ends.')\n",
    "\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "\n",
    "    parser = argparse.ArgumentParser()\n",
    "    parser.add_argument(\"--dataset\", required=True, type=str)\n",
    "    parser.add_argument(\"--base_model\", required=True, type=str, choices=['chatglm2', 'llama2', 'llama2-13b', 'llama2-13b-nr', 'baichuan', 'falcon', 'internlm', 'qwen', 'mpt', 'bloom'])\n",
    "    parser.add_argument(\"--peft_model\", required=True, type=str)\n",
    "    parser.add_argument(\"--max_length\", default=512, type=int)\n",
    "    parser.add_argument(\"--batch_size\", default=4, type=int, help=\"The train batch size per device\")\n",
    "    parser.add_argument(\"--instruct_template\", default='default')\n",
    "    parser.add_argument(\"--from_remote\", default=False, type=bool)\n",
    "\n",
    "    args = parser.parse_args()\n",
    "\n",
    "    print(args.base_model)\n",
    "    print(args.peft_model)\n",
    "\n",
    "    main(args)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ne8oJBxV_mxO"
   },
   "source": [
    "## 6. Create Utils Module\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "id": "D9-3XebT_mxO"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overwriting fingpt/FinGPT_Benchmark/utils.py\n"
     ]
    }
   ],
   "source": [
    "%%writefile fingpt/FinGPT_Benchmark/utils.py\n",
    "def parse_model_name(base_model, from_remote=False):\n",
    "    model_map = {\n",
    "        'chatglm2': 'THUDM/chatglm2-6b',\n",
    "        'llama2': 'meta-llama/Llama-2-7b-hf',\n",
    "        'llama2-13b': 'meta-llama/Llama-2-13b-hf',\n",
    "        'llama2-13b-nr': 'NousResearch/Llama-2-13b-hf',\n",
    "        'baichuan': 'baichuan-inc/Baichuan-7B',\n",
    "        'falcon': 'tiiuae/falcon-7b',\n",
    "        'internlm': 'internlm/internlm-7b',\n",
    "        'qwen': 'Qwen/Qwen-7B',\n",
    "        'mpt': 'mosaicml/mpt-7b',\n",
    "        'bloom': 'bigscience/bloom-7b1',\n",
    "    }\n",
    "    if base_model not in model_map:\n",
    "        raise ValueError(f\"Unknown base model: {base_model}\")\n",
    "    return model_map[base_model]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "id": "MKGku6zYBHS1"
   },
   "outputs": [],
   "source": [
    "from huggingface_hub import login\n",
    "login(token=\"token\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "HDRhXEXZ_mxO"
   },
   "source": [
    "## 7. Run the Financial Headline Benchmark Test\n",
    "\n",
    "Now run the benchmark test."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "3l-4zlmr_mxO",
    "outputId": "33d370e8-1155-41f6-f622-c861b4736cf5"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/workspace/FinLoRA/test/fingpt_tests/FinGPT/FinGPT/fingpt/FinGPT_Benchmark/benchmarks\n",
      "The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.\n",
      "0it [00:00, ?it/s]\n",
      "/usr/local/lib/python3.11/dist-packages/transformers/utils/generic.py:260: FutureWarning: `torch.utils._pytree._register_pytree_node` is deprecated. Please use `torch.utils._pytree.register_pytree_node` instead.\n",
      "  torch.utils._pytree._register_pytree_node(\n",
      "llama2\n",
      "fingpt/fingpt-mt_llama2-7b_lora\n",
      "Loading base model: meta-llama/Llama-2-7b-hf\n",
      "/usr/local/lib/python3.11/dist-packages/huggingface_hub/file_download.py:943: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
      "  warnings.warn(\n",
      "Loading checkpoint shards: 100%|██████████████████| 2/2 [00:07<00:00,  3.88s/it]\n",
      "/usr/local/lib/python3.11/dist-packages/huggingface_hub/file_download.py:943: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
      "  warnings.warn(\n",
      "Using pad_token, but it is not set yet.\n",
      "You are resizing the embedding layer without providing a `pad_to_multiple_of` parameter. This means that the new embeding dimension will be 32001. This might induce some performance reduction as *Tensor Cores* will not be available. For more details  about this, or help on choosing the correct value for resizing, refer to this guide: https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc\n",
      "pad: 32000, eos: 2\n",
      "Loading FinGPT adapter: fingpt/fingpt-mt_llama2-7b_lora\n",
      "adapter_config.json: 100%|█████████████████████| 460/460 [00:00<00:00, 2.47MB/s]\n",
      "adapter_model.bin: 100%|████████████████████| 12.6M/12.6M [00:00<00:00, 269MB/s]\n",
      "Loading Financial Headline dataset...\n",
      "Map: 100%|██████████████████████| 20547/20547 [00:00<00:00, 25403.67 examples/s]\n",
      "Running inference on 20547 examples with batch size 4...\n",
      "  0%|                                                  | 0/5137 [00:00<?, ?it/s]/usr/local/lib/python3.11/dist-packages/transformers/tokenization_utils_base.py:2436: UserWarning: `max_length` is ignored when `padding`=`True` and there is no truncation strategy. To pad to max length, use `padding='max_length'`.\n",
      "  warnings.warn(\n",
      "Example 1026: Instruction: Does the news headline talk about price? Please choose an answer from {Yes/No}.\n",
      "Input: Gold prices to trade lower today: Angel Commodities\n",
      "Answer:  No\n",
      "Example 2053: Instruction: Does the news headline talk about price in the past? Please choose an answer from {Yes/No}.\n",
      "Input: Gold remains lower after third-quarter GDP update\n",
      "Answer:  Yes\n",
      "Example 3080: Instruction: Does the news headline compare gold with any other asset? Please choose an answer from {Yes/No}.\n",
      "Input: expect gold prices to trade higher: angel\n",
      "Answer:  No\n",
      "Example 4107: Instruction: Does the news headline talk about price going down? Please choose an answer from {Yes/No}.\n",
      "Input: gold price falls as dollar strengthens\n",
      "Answer:  Yes\n",
      "Example 5134: Instruction: Does the news headline talk about a general event (apart from prices) in the future? Please choose an answer from {Yes/No}.\n",
      "Input: august gold up $7.60 at $878.80 an ounce on nymex\n",
      "Answer:  No\n",
      "100%|███████████████████████████████████████| 5137/5137 [24:09<00:00,  3.54it/s]\n",
      "Processing results...\n",
      "Map: 100%|██████████████████████| 20547/20547 [00:01<00:00, 20453.93 examples/s]\n",
      "Results saved to headline_results.csv\n",
      "\n",
      "|| Accuracy: 0.9696 || F1 (binary): 0.9333 ||\n",
      "\n",
      "/usr/local/lib/python3.11/dist-packages/sklearn/metrics/_classification.py:1509: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
      "/usr/local/lib/python3.11/dist-packages/sklearn/metrics/_classification.py:1509: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in samples with no predicted labels. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "price or not     0.9324    0.6699    0.7797       103\n",
      "    price up     0.9334    0.9264    0.9299       938\n",
      "price stable     0.9059    0.7476    0.8191       103\n",
      "  price down     0.9370    0.9160    0.9264       845\n",
      "  price past     0.9718    0.9650    0.9684      1857\n",
      "price future     0.9206    0.6591    0.7682        88\n",
      "  event past     0.7799    0.8913    0.8319       322\n",
      "event future     0.0000    0.0000    0.0000        16\n",
      "  asset comp     0.9888    0.9821    0.9854       448\n",
      "\n",
      "   micro avg     0.9418    0.9250    0.9333      4720\n",
      "   macro avg     0.8189    0.7508    0.7788      4720\n",
      "weighted avg     0.9399    0.9250    0.9311      4720\n",
      " samples avg     0.9257    0.9207    0.9180      4720\n",
      "\n",
      "Evaluation Ends.\n"
     ]
    }
   ],
   "source": [
    "%cd fingpt/FinGPT_Benchmark/benchmarks\n",
    "\n",
    "base_model = 'llama2'\n",
    "# The FinGPT adapter model\n",
    "peft_model = 'FinGPT/fingpt-mt_llama2-7b_lora'\n",
    "batch_size = 4\n",
    "max_length = 512\n",
    "\n",
    "!python benchmarks.py --dataset headline --base_model {base_model} --peft_model {peft_model} --batch_size {batch_size} --max_length {max_length} --from_remote True"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "A100",
   "machine_shape": "hm",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
