{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "gzXEp7L7_mxK"
   },
   "source": [
    "# FinGPT Test: FiQA Sentiment Analysis\n",
    "\n",
    "This notebook demonstrates how to test FinGPT on the FiQA sentiment analysis dataset."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ZuQCI9X0_mxL"
   },
   "source": [
    "## 1. Install Dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "vlxuSeEl_mxL"
   },
   "outputs": [],
   "source": [
    "!pip install transformers==4.32.0 peft==0.5.0 datasets accelerate bitsandbytes sentencepiece tqdm scikit-learn pandas matplotlib seaborn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "WmCujtNe_mxM"
   },
   "source": [
    "## 2. Clone the FinGPT Repository"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "qBePNiVP_mxM"
   },
   "outputs": [],
   "source": [
    "!git clone https://github.com/AI4Finance-Foundation/FinGPT.git\n",
    "%cd FinGPT"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "iVxmcTW7_mxM"
   },
   "source": [
    "## 3. Create Sentiment Templates File"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "id": "Iyr3_OF8_mxM"
   },
   "outputs": [],
   "source": [
    "!mkdir -p fingpt/FinGPT_Benchmark/benchmarks\n",
    "\n",
    "templates = \"\"\"What is the sentiment of this {type}?\n",
    "Determine the sentiment of this {type}.\n",
    "How would you describe the sentiment of this {type}?\n",
    "Is the sentiment of this {type} positive or negative?\n",
    "Analyze the sentiment of this {type}.\n",
    "What's the sentiment of this {type}?\"\"\"\n",
    "\n",
    "with open('fingpt/FinGPT_Benchmark/benchmarks/sentiment_templates.txt', 'w') as f:\n",
    "    f.write(templates)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "qNBwYi0n_mxM"
   },
   "source": [
    "## 4. Download the FiQA Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "id": "KvcoJS50_mxN"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading FiQA dataset...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "863b457d73624d109cb3ad2bde9989cf",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "train.csv: 0.00B [00:00, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "198eecf402d14e138a1142b4cf75f432",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "validation.csv: 0.00B [00:00, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "27fe1cabcf584d27abeb44e5cf122c36",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "test.csv: 0.00B [00:00, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "97aa7f507bc14bf7993d451bb42eedc1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating train split:   0%|          | 0/961 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e45463c5f73848bba37b903a9295826c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating validation split:   0%|          | 0/102 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "deeb0395f8a2476f9649a6a0b7a24689",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating test split:   0%|          | 0/150 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saving dataset to fingpt/FinGPT_Benchmark/data/fiqa-2018\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d731be5e96d04c1b9e4ced843d677c98",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/961 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "aebb982d6f32423f8cede677e863fb44",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/102 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "da364743b1144be5868cc8180593c488",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/150 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset download complete!\n"
     ]
    }
   ],
   "source": [
    "import datasets\n",
    "from pathlib import Path\n",
    "\n",
    "data_dir = Path('./fingpt/FinGPT_Benchmark/data')\n",
    "data_dir.mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "print(\"Downloading FiQA dataset...\")\n",
    "dataset = datasets.load_dataset('pauri32/fiqa-2018')\n",
    "\n",
    "# Save the dataset to disk\n",
    "save_path = str(data_dir / \"fiqa-2018\")\n",
    "print(f\"Saving dataset to {save_path}\")\n",
    "dataset.save_to_disk(save_path)\n",
    "print(\"Dataset download complete!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "SjkYtGkQ_mxN"
   },
   "source": [
    "## 5. Testing Module of FiQA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "id": "Zr0Klkp8_mxN"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overwriting fingpt/FinGPT_Benchmark/benchmarks/fiqa.py\n"
     ]
    }
   ],
   "source": [
    "%%writefile fingpt/FinGPT_Benchmark/benchmarks/fiqa.py\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "from sklearn.metrics import accuracy_score, f1_score\n",
    "from datasets import load_dataset, load_from_disk, Dataset\n",
    "from tqdm import tqdm\n",
    "import datasets\n",
    "import torch\n",
    "\n",
    "from torch.utils.data import DataLoader\n",
    "from functools import partial\n",
    "from pathlib import Path\n",
    "\n",
    "with open(Path(__file__).parent / 'sentiment_templates.txt') as f:\n",
    "    templates = [l.strip() for l in f.readlines()]\n",
    "\n",
    "def format_example(example: dict) -> dict:\n",
    "    context = f\"Instruction: {example['instruction']}\\n\"\n",
    "    if example.get(\"input\"):\n",
    "        context += f\"Input: {example['input']}\\n\"\n",
    "    context += \"Answer: \"\n",
    "    target = example[\"output\"]\n",
    "    return {\"context\": context, \"target\": target}\n",
    "\n",
    "def add_instructions(x):\n",
    "    if x.format == \"post\":\n",
    "        return \"What is the sentiment of this tweet? Please choose an answer from {negative/neutral/positive}.\"\n",
    "    else:\n",
    "        return \"What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\"\n",
    "\n",
    "def make_label(x):\n",
    "    if x < - 0.1: return \"negative\"\n",
    "    elif x >=-0.1 and x < 0.1: return \"neutral\"\n",
    "    elif x >= 0.1: return \"positive\"\n",
    "\n",
    "def change_target(x):\n",
    "    if 'positive' in x or 'Positive' in x:\n",
    "        return 'positive'\n",
    "    elif 'negative' in x or 'Negative' in x:\n",
    "        return 'negative'\n",
    "    else:\n",
    "        return 'neutral'\n",
    "\n",
    "def vote_output(x):\n",
    "    output_dict = {'positive': 0, 'negative': 0, 'neutral': 0}\n",
    "    for i in range(len(templates)):\n",
    "        pred = change_target(x[f'out_text_{i}'].lower())\n",
    "        output_dict[pred] += 1\n",
    "    if output_dict['positive'] > output_dict['negative']:\n",
    "        return 'positive'\n",
    "    elif output_dict['negative'] > output_dict['positive']:\n",
    "        return 'negative'\n",
    "    else:\n",
    "        return 'neutral'\n",
    "\n",
    "def test_fiqa(args, model, tokenizer, prompt_fun=add_instructions):\n",
    "    batch_size = args.batch_size\n",
    "    # dataset = load_dataset('pauri32/fiqa-2018')\n",
    "    dataset = load_from_disk(Path(__file__).parent.parent / 'data/fiqa-2018/')\n",
    "    dataset = datasets.concatenate_datasets([dataset[\"train\"], dataset[\"validation\"] ,dataset[\"test\"] ])\n",
    "    dataset = dataset.train_test_split(0.226, seed = 42)['test']\n",
    "    dataset = dataset.to_pandas()\n",
    "    dataset[\"output\"] = dataset.sentiment_score.apply(make_label)\n",
    "    if prompt_fun is None:\n",
    "        dataset[\"instruction\"] = \"What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\"\n",
    "    else:\n",
    "        dataset[\"instruction\"] = dataset.apply(prompt_fun, axis = 1)\n",
    "\n",
    "    dataset = dataset[['sentence', 'output',\"instruction\"]]\n",
    "    dataset.columns = [\"input\", \"output\",\"instruction\"]\n",
    "    dataset[[\"context\",\"target\"]] = dataset.apply(format_example, axis=1, result_type=\"expand\")\n",
    "\n",
    "    print(f\"\\n\\nPrompt example:\\n{dataset['context'][0]}\\n\\n\")\n",
    "\n",
    "    context = dataset['context'].tolist()\n",
    "    total_steps = dataset.shape[0]//batch_size + 1\n",
    "    print(f\"Total len: {len(context)}. Batchsize: {batch_size}. Total steps: {total_steps}\")\n",
    "\n",
    "    out_text_list = []\n",
    "\n",
    "    for i in tqdm(range(total_steps)):\n",
    "        tmp_context = context[i* batch_size:(i+1)* batch_size]\n",
    "        if not tmp_context:\n",
    "            continue\n",
    "        tokens = tokenizer(tmp_context, return_tensors='pt', padding=True, max_length=512, return_token_type_ids=False)\n",
    "        # tokens.pop('token_type_ids')\n",
    "        for k in tokens.keys():\n",
    "            tokens[k] = tokens[k].cuda()\n",
    "\n",
    "        res = model.generate(**tokens, max_length=512, eos_token_id=tokenizer.eos_token_id)\n",
    "        res_sentences = [tokenizer.decode(i, skip_special_tokens=True) for i in res]\n",
    "        tqdm.write(f'{i}: {res_sentences[0]}')\n",
    "        out_text = [o.split(\"Answer: \")[1] if \"Answer: \" in o else o for o in res_sentences]\n",
    "        out_text_list += out_text\n",
    "        torch.cuda.empty_cache()\n",
    "\n",
    "    dataset[\"out_text\"] = out_text_list\n",
    "    dataset[\"new_target\"] = dataset[\"target\"].apply(change_target)\n",
    "    dataset[\"new_out\"] = dataset[\"out_text\"].apply(change_target)\n",
    "\n",
    "    acc = accuracy_score(dataset[\"new_target\"], dataset[\"new_out\"])\n",
    "    f1_macro = f1_score(dataset[\"new_target\"], dataset[\"new_out\"], average = \"macro\")\n",
    "    f1_micro = f1_score(dataset[\"new_target\"], dataset[\"new_out\"], average = \"micro\")\n",
    "    f1_weighted = f1_score(dataset[\"new_target\"], dataset[\"new_out\"], average = \"weighted\")\n",
    "\n",
    "    print(f\"Acc: {acc}. F1 macro: {f1_macro}. F1 micro: {f1_micro}. F1 weighted (BloombergGPT): {f1_weighted}. \")\n",
    "\n",
    "    dataset.to_csv('fiqa_results.csv', index=False)\n",
    "    print(\"Results saved to fiqa_results.csv\")\n",
    "\n",
    "    return dataset\n",
    "\n",
    "\n",
    "def test_fiqa_mlt(args, model, tokenizer):\n",
    "    batch_size = args.batch_size\n",
    "    # dataset = load_dataset('pauri32/fiqa-2018')\n",
    "    dataset = load_from_disk(Path(__file__).parent.parent / 'data/fiqa-2018/')\n",
    "    dataset = datasets.concatenate_datasets([dataset[\"train\"], dataset[\"validation\"] ,dataset[\"test\"] ])\n",
    "    dataset = dataset.train_test_split(0.226, seed=42)['test']\n",
    "    dataset = dataset.to_pandas()\n",
    "    dataset[\"output\"] = dataset.sentiment_score.apply(make_label)\n",
    "    dataset[\"text_type\"] = dataset.apply(lambda x: 'tweet' if x.format == \"post\" else 'news', axis=1)\n",
    "    dataset = dataset[['sentence', 'output', \"text_type\"]]\n",
    "    dataset.columns = [\"input\", \"output\", \"text_type\"]\n",
    "\n",
    "    dataset[\"output\"] = dataset[\"output\"].apply(change_target)\n",
    "    dataset = dataset[dataset[\"output\"] != 'neutral']\n",
    "\n",
    "    out_texts_list = [[] for _ in range(len(templates))]\n",
    "\n",
    "    def collate_fn(batch):\n",
    "        inputs = tokenizer(\n",
    "            [f[\"context\"] for f in batch], return_tensors='pt',\n",
    "            padding=True, max_length=args.max_length,\n",
    "            return_token_type_ids=False\n",
    "        )\n",
    "        return inputs\n",
    "\n",
    "    for i, template in enumerate(templates):\n",
    "        print(f\"\\nTesting with template {i+1}/{len(templates)}: '{template}'\")\n",
    "        dataset_temp = dataset[['input', 'output', \"text_type\"]].copy()\n",
    "        dataset_temp[\"instruction\"] = dataset_temp['text_type'].apply(lambda x: template.format(type=x) + \"\\nOptions: positive, negative\")\n",
    "        dataset_temp[[\"context\", \"target\"]] = dataset_temp.apply(format_example, axis=1, result_type=\"expand\")\n",
    "\n",
    "        dataloader = DataLoader(Dataset.from_pandas(dataset_temp), batch_size=args.batch_size, collate_fn=collate_fn, shuffle=False)\n",
    "\n",
    "        log_interval = max(1, len(dataloader) // 5)\n",
    "\n",
    "        for idx, inputs in enumerate(tqdm(dataloader)):\n",
    "            inputs = {key: value.to(model.device) for key, value in inputs.items()}\n",
    "            res = model.generate(**inputs, do_sample=False, max_length=args.max_length, eos_token_id=tokenizer.eos_token_id)\n",
    "            res_sentences = [tokenizer.decode(i, skip_special_tokens=True) for i in res]\n",
    "            if idx % log_interval == 0:\n",
    "                tqdm.write(f'Template {i+1}, batch {idx}: {res_sentences[0]}')\n",
    "            out_text = [o.split(\"Answer: \")[1] if \"Answer: \" in o else o for o in res_sentences]\n",
    "            out_texts_list[i] += out_text\n",
    "            torch.cuda.empty_cache()\n",
    "\n",
    "    original_dataset = dataset.copy()\n",
    "    for i in range(len(templates)):\n",
    "        original_dataset[f\"out_text_{i}\"] = out_texts_list[i]\n",
    "        original_dataset[f\"out_text_{i}\"] = original_dataset[f\"out_text_{i}\"].apply(change_target)\n",
    "\n",
    "    original_dataset[\"new_out\"] = original_dataset.apply(vote_output, axis=1, result_type=\"expand\")\n",
    "    original_dataset.to_csv('fiqa_mlt_results.csv', index=False)\n",
    "    print(\"Results saved to fiqa_mlt_results.csv\")\n",
    "\n",
    "    for k in [f\"out_text_{i}\" for i in range(len(templates))] + [\"new_out\"]:\n",
    "        template_name = \"Ensemble (Voting)\" if k == \"new_out\" else f\"Template {k.split('_')[-1]}\"\n",
    "        acc = accuracy_score(original_dataset[\"output\"], original_dataset[k])\n",
    "        f1_macro = f1_score(original_dataset[\"output\"], original_dataset[k], average=\"macro\")\n",
    "        f1_micro = f1_score(original_dataset[\"output\"], original_dataset[k], average=\"micro\")\n",
    "        f1_weighted = f1_score(original_dataset[\"output\"], original_dataset[k], average=\"weighted\")\n",
    "\n",
    "        print(f\"{template_name}: Acc: {acc:.4f}. F1 macro: {f1_macro:.4f}. F1 micro: {f1_micro:.4f}. F1 weighted: {f1_weighted:.4f}\")\n",
    "\n",
    "    return original_dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "L5DzncH3_mxN"
   },
   "source": [
    "## 6. Create Benchmarks Runner Script"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "id": "49pCYdPa_mxO"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overwriting fingpt/FinGPT_Benchmark/benchmarks/benchmarks.py\n"
     ]
    }
   ],
   "source": [
    "%%writefile fingpt/FinGPT_Benchmark/benchmarks/benchmarks.py\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
    "from peft import PeftModel, get_peft_model, LoraConfig, TaskType\n",
    "import torch\n",
    "import argparse\n",
    "\n",
    "from fiqa import test_fiqa, test_fiqa_mlt\n",
    "\n",
    "import sys\n",
    "sys.path.append('../')\n",
    "from utils import *\n",
    "\n",
    "def main(args):\n",
    "    if args.from_remote:\n",
    "        model_name = parse_model_name(args.base_model, args.from_remote)\n",
    "    else:\n",
    "        model_name = '../' + parse_model_name(args.base_model)\n",
    "\n",
    "    model = AutoModelForCausalLM.from_pretrained(\n",
    "        model_name, trust_remote_code=True,\n",
    "        # load_in_8bit=True\n",
    "        device_map=\"auto\",\n",
    "        # fp16=True\n",
    "    )\n",
    "    model.model_parallel = True\n",
    "\n",
    "    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n",
    "\n",
    "    # tokenizer.pad_token_id = tokenizer.eos_token_id\n",
    "\n",
    "    tokenizer.padding_side = \"left\"\n",
    "    if args.base_model == 'qwen':\n",
    "        tokenizer.eos_token_id = tokenizer.convert_tokens_to_ids('<|endoftext|>')\n",
    "        tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids('<|extra_0|>')\n",
    "    if not tokenizer.pad_token or tokenizer.pad_token_id == tokenizer.eos_token_id:\n",
    "        tokenizer.add_special_tokens({'pad_token': '[PAD]'})\n",
    "        model.resize_token_embeddings(len(tokenizer))\n",
    "\n",
    "    print(f'pad: {tokenizer.pad_token_id}, eos: {tokenizer.eos_token_id}')\n",
    "\n",
    "    model = PeftModel.from_pretrained(model, args.peft_model)\n",
    "    model = model.eval()\n",
    "\n",
    "    with torch.no_grad():\n",
    "        for data in args.dataset.split(','):\n",
    "            if data == 'fiqa':\n",
    "                test_fiqa(args, model, tokenizer)\n",
    "            elif data == 'fiqa_mlt':\n",
    "                test_fiqa_mlt(args, model, tokenizer)\n",
    "            else:\n",
    "                raise ValueError('undefined dataset.')\n",
    "\n",
    "    print('Evaluation Ends.')\n",
    "\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "\n",
    "    parser = argparse.ArgumentParser()\n",
    "    parser.add_argument(\"--dataset\", required=True, type=str)\n",
    "    parser.add_argument(\"--base_model\", required=True, type=str, choices=['chatglm2', 'llama2', 'llama2-13b', 'llama2-13b-nr', 'baichuan', 'falcon', 'internlm', 'qwen', 'mpt', 'bloom'])\n",
    "    parser.add_argument(\"--peft_model\", required=True, type=str)\n",
    "    parser.add_argument(\"--max_length\", default=512, type=int)\n",
    "    parser.add_argument(\"--batch_size\", default=4, type=int, help=\"The train batch size per device\")\n",
    "    parser.add_argument(\"--instruct_template\", default='default')\n",
    "    parser.add_argument(\"--from_remote\", default=False, type=bool)\n",
    "\n",
    "    args = parser.parse_args()\n",
    "\n",
    "    print(args.base_model)\n",
    "    print(args.peft_model)\n",
    "\n",
    "    main(args)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ne8oJBxV_mxO"
   },
   "source": [
    "## 7. Create Utils Module"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "id": "D9-3XebT_mxO"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overwriting fingpt/FinGPT_Benchmark/utils.py\n"
     ]
    }
   ],
   "source": [
    "%%writefile fingpt/FinGPT_Benchmark/utils.py\n",
    "def parse_model_name(base_model, from_remote=False):\n",
    "    model_map = {\n",
    "        'chatglm2': 'THUDM/chatglm2-6b',\n",
    "        'llama2': 'meta-llama/Llama-2-7b-hf',\n",
    "        'llama2-13b': 'meta-llama/Llama-2-13b-hf',\n",
    "        'llama2-13b-nr': 'NousResearch/Llama-2-13b-hf',\n",
    "        'baichuan': 'baichuan-inc/Baichuan-7B',\n",
    "        'falcon': 'tiiuae/falcon-7b',\n",
    "        'internlm': 'internlm/internlm-7b',\n",
    "        'qwen': 'Qwen/Qwen-7B',\n",
    "        'mpt': 'mosaicml/mpt-7b',\n",
    "        'bloom': 'bigscience/bloom-7b1',\n",
    "    }\n",
    "    if base_model not in model_map:\n",
    "        raise ValueError(f\"Unknown base model: {base_model}\")\n",
    "    return model_map[base_model]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "id": "0SjqFgc4HrFu"
   },
   "outputs": [],
   "source": [
    "from huggingface_hub import login\n",
    "login(token=\"token\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "HDRhXEXZ_mxO"
   },
   "source": [
    "## 8. Run the FiQA Benchmark Test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "3l-4zlmr_mxO",
    "outputId": "ce4929a5-3bac-4a51-82fc-ec414afdf5bb"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/workspace/FinLoRA/test/fingpt_tests/FinGPT/fingpt/FinGPT_Benchmark/benchmarks\n",
      "/usr/local/lib/python3.11/dist-packages/transformers/utils/generic.py:260: FutureWarning: `torch.utils._pytree._register_pytree_node` is deprecated. Please use `torch.utils._pytree.register_pytree_node` instead.\n",
      "  torch.utils._pytree._register_pytree_node(\n",
      "llama2\n",
      "FinGPT/fingpt-mt_llama2-7b_lora\n",
      "Loading checkpoint shards: 100%|██████████████████| 2/2 [00:09<00:00,  4.77s/it]\n",
      "Using pad_token, but it is not set yet.\n",
      "You are resizing the embedding layer without providing a `pad_to_multiple_of` parameter. This means that the new embeding dimension will be 32001. This might induce some performance reduction as *Tensor Cores* will not be available. For more details  about this, or help on choosing the correct value for resizing, refer to this guide: https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc\n",
      "pad: 32000, eos: 2\n",
      "\n",
      "\n",
      "Prompt example:\n",
      "Instruction: What is the sentiment of this tweet? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: This $BBBY stock options trade would have more than doubled your money https://t.co/Oa0loiRIJL via @TheStreet\n",
      "Answer: \n",
      "\n",
      "\n",
      "Total len: 275. Batchsize: 4. Total steps: 69\n",
      "0: Instruction: What is the sentiment of this tweet? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: This $BBBY stock options trade would have more than doubled your money https://t.co/Oa0loiRIJL via @TheStreet\n",
      "Answer:  positive\n",
      "1: Instruction: What is the sentiment of this tweet? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: Daily Mail owner considering Yahoo bid $yhoo ,up 2,05% https://t.co/extZr1riyP\n",
      "Answer:  positive\n",
      "2: Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: Ocwen Reaches Settlement With California Regulator\n",
      "Answer:  positive\n",
      "3: Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: Berkshire Hathaway names Kara Raiguel to lead General Re unit\n",
      "Answer:  neutral\n",
      "4: Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: Irish housebuilder Cairn Homes plans London listing\n",
      "Answer:  neutral\n",
      "5: Instruction: What is the sentiment of this tweet? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: $NFLX small pos, short 180 wkly puts.\n",
      "Answer:  negative\n",
      "6: Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: UPDATE 1-BP shareholders back more disclosure on climate change risks\n",
      "Answer:  neutral\n",
      "7: Instruction: What is the sentiment of this tweet? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: $NIHD insiders got this one wrong. Looking for bottom. Rsi under 30\n",
      "Answer:  neutral\n",
      "8: Instruction: What is the sentiment of this tweet? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: Starbucks' Digital Strategy To Drive Significant Growth With Customer Loyalty $SBUX https://t.co/Xk6lZ3UI3K\n",
      "Answer:  positive\n",
      "9: Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: Saudi Aramco, Shell plan to break up Motiva, divide up assets\n",
      "Answer:  neutral\n",
      "10: Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: Britain's FTSE outperforms Europe, Royal Mail and Tesco rise\n",
      "Answer:  positive\n",
      "11: Instruction: What is the sentiment of this tweet? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: $NFLX VISION : short term consolidation then movement higher http://stks.co/j05uu\n",
      "Answer:  positive\n",
      "12: Instruction: What is the sentiment of this tweet? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: $NUGT longer term bullish\n",
      "Answer:  positive\n",
      "13: Instruction: What is the sentiment of this tweet? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: Short interest increases yet again http://stks.co/e19h via @ryandetrick $SPY\n",
      "Answer:  negative\n",
      "14: Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: Shell and BG Shareholders to Vote on Deal at End of January\n",
      "Answer:  neutral\n",
      "15: Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: Glencore slumps 30 percent as debt fears grow\n",
      "Answer:  negative\n",
      "16: Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: Is It Worth Investing In Tesco PLC And Prudential plc Now?\n",
      "Answer:  neutral\n",
      "17: Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: Morrisons and Debenhams surprise City with Christmas bounce back\n",
      "Answer:  positive\n",
      "18: Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: METALS-Zinc surges 12 pct after Glencore cuts output, fuelling metals rally\n",
      "Answer:  positive\n",
      "19: Instruction: What is the sentiment of this tweet? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: Double bottom with handle buy point of 56.49 $WMT http://chart.ly/ml857v3\n",
      "Answer:  positive\n",
      "20: Instruction: What is the sentiment of this tweet? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: longed $AMZN 300 @ 189.82\n",
      "Answer:  positive\n",
      "21: Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: Ackman, in email, says supports Valeant CEO Pearson\n",
      "Answer:  positive\n",
      "22: Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: AstraZeneca share price: Acerta deal pays off with orphan drug status\n",
      "Answer:  positive\n",
      "23: Instruction: What is the sentiment of this tweet? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: $BAC $ADSK $NFLX long this morning\n",
      "Answer:  positive\n",
      "24: Instruction: What is the sentiment of this tweet? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: $TSLA not dipping - testing the 200 day ma https://t.co/jEPLmZQKGW\n",
      "Answer:  neutral\n",
      "25: Instruction: What is the sentiment of this tweet? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: Long $TSLA short $MBLY https://t.co/jSpUSzo6na\n",
      "Answer:  negative\n",
      "26: Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: Morning Agenda: Shire's Deal for NPS\n",
      "Answer:  positive\n",
      "27: Instruction: What is the sentiment of this tweet? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: $SPY bull move ended waiting for next setup\n",
      "Answer:  negative\n",
      "28: Instruction: What is the sentiment of this tweet? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: $HYG Potential continuation Uptrend on a 60'-15'-4' charts http://stks.co/t1Qp6\n",
      "Answer:  positive\n",
      "29: Instruction: What is the sentiment of this tweet? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: Costco: A Premier Retail Dividend Play https://t.co/J3UhTs022M $COST\n",
      "Answer:  positive\n",
      "30: Instruction: What is the sentiment of this tweet? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: I'm liking the price action in $SWKS, currently @21.80 ; my target: 24.00+ before year end.\n",
      "Answer:  positive\n",
      "31: Instruction: What is the sentiment of this tweet? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: $UUP at major resistance right now. Can go much higher if it can break above. May take several attempts - few days to a week\n",
      "Answer:  positive\n",
      "32: Instruction: What is the sentiment of this tweet? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: $FB still a dog going much lower this week\n",
      "Answer:  negative\n",
      "33: Instruction: What is the sentiment of this tweet? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: RT @stockdemons think shorting here into close might be play, $SPY - high beta reversing..looks worth a shot\n",
      "Answer:  negative\n",
      "34: Instruction: What is the sentiment of this tweet? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: $TSLA steaming up again, this stock is relentless at the moment\n",
      "Answer:  positive\n",
      "35: Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: AstraZeneca Explores Potential Deal With Acerta for Cancer Drug\n",
      "Answer:  positive\n",
      "36: Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: Credit Suisse poaches Prudential's Thiam for Asian push\n",
      "Answer:  neutral\n",
      "37: Instruction: What is the sentiment of this tweet? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: $AAPL nibbling on a small long via 525.. buying off 13min charts..\n",
      "Answer:  positive\n",
      "38: Instruction: What is the sentiment of this tweet? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: $NUGT Gold above 1400...wow\n",
      "Answer:  positive\n",
      "39: Instruction: What is the sentiment of this tweet? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: New recommendation from Carl Kirst of BMO Capital Markets for $WMB is BUY.Price target is $62:http://stks.co/t0S0r\n",
      "Answer:  positive\n",
      "40: Instruction: What is the sentiment of this tweet? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: RT @StockTwits RT @fallondpicks Breadth Consolidates: After weeks of steady gains,advances in mkt breadth slowe... http://stks.co/2TrG $QQQ\n",
      "Answer:  negative\n",
      "41: Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: FTSE 100 falls as China devaluation hits Burberry, mining stocks\n",
      "Answer:  negative\n",
      "42: Instruction: What is the sentiment of this tweet? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: $PCLN Trying to break daily trend line! Big move could happen https://t.co/gY0aDb2jsQ https://t.co/KPBBCgQ2xy\n",
      "Answer:  neutral\n",
      "43: Instruction: What is the sentiment of this tweet? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: $EXPE working on the second leg of its reversal run as it breaks another down trend and continues higher. https://t.co/76DSJysyR0\n",
      "Answer:  positive\n",
      "44: Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: Hammerson, JV Partner secure ownership of Ireland's Dundrum - Quick Facts\n",
      "Answer:  positive\n",
      "45: Instruction: What is the sentiment of this tweet? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: Surprising to see $JWN and $SKS sales numbers still holding up so well. I guess the high end might not be a concern after all.\n",
      "Answer:  positive\n",
      "46: Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: Goldman Sachs, Barclays, HSBC downplay Brexit threat\n",
      "Answer:  neutral\n",
      "47: Instruction: What is the sentiment of this tweet? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: $ISRG PT raised to $700 from $640 at Leerink - keeps Outperform rated\n",
      "Answer:  positive\n",
      "48: Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: RBS Pays $1.7 Billion to Scrap U.K. Treasury's Dividend Rights\n",
      "Answer:  negative\n",
      "49: Instruction: What is the sentiment of this tweet? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: $GRPN might be selling off ahead of $P earnings...\n",
      "Answer:  negative\n",
      "50: Instruction: What is the sentiment of this tweet? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: $renn topping tail 5 min. chart at 7.31 short from here.\n",
      "Answer:  neutral\n",
      "51: Instruction: What is the sentiment of this tweet? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: $ELN longs...congrats...another HOD\n",
      "Answer:  positive\n",
      "52: Instruction: What is the sentiment of this tweet? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: $ENB.CA {Head&Shoulders} bullish reversal setup and breakout. Oil and Gas Pipleline stock. $USO\n",
      "Answer:  positive\n",
      "53: Instruction: What is the sentiment of this tweet? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: MT @TheAcsMan Amazing seeing everyone suddenly touting $MSFT. Long been favorite covered call & double dip dividend play.\n",
      "Answer:  positive\n",
      "54: Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: Valeant Said to Name New CEO With Pearson Still Hospitalized\n",
      "Answer:  neutral\n",
      "55: Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: Crown Castle buys Tower Development Corp for $461 million\n",
      "Answer:  positive\n",
      "56: Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: Barclays Bonds Rise as Lender Cuts Dividends to Shore Up Capital\n",
      "Answer:  positive\n",
      "57: Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: SAB's Chairman Digs In With Board Divided on InBev Offer\n",
      "Answer:  neutral\n",
      "58: Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: UPDATE 1-Lloyds to cut 945 jobs as part of 3-year restructuring plan\n",
      "Answer:  negative\n",
      "59: Instruction: What is the sentiment of this tweet? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: potential defect with third-row seat belts. Tesla Motors recalls 2,700 Model X SUVs $TSLA https://t.co/YVYDncdNdi\n",
      "Answer:  negative\n",
      "60: Instruction: What is the sentiment of this tweet? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: @technik I still have a smaller $GILD pos. Being very tender with cash due to volatility. So reluctant to keep many big positions.\n",
      "Answer:  neutral\n",
      "61: Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: UK housing market steadies after Brexit dip, Persimmon says\n",
      "Answer:  positive\n",
      "62: Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: MarketsBP promotes upstream boss to deputy CEO\n",
      "Answer:  neutral\n",
      "63: Instruction: What is the sentiment of this tweet? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: Locked in some $FB puts for nice gain\n",
      "Answer:  positive\n",
      "64: Instruction: What is the sentiment of this tweet? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: $CNP Sell Short Position on CNP,...Closed for Profit $ 59,367.00 (7.75%) http://stks.co/ghjc\n",
      "Answer:  positive\n",
      "65: Instruction: What is the sentiment of this tweet? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: Exact (EXAS) Flagged As Strong On High Volume $EXAS http://stks.co/r26Ra\n",
      "Answer:  positive\n",
      "66: Instruction: What is the sentiment of this tweet? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: $GOOGL is a short below 740 into the upper BB and is overbought\n",
      "Answer:  negative\n",
      "67: Instruction: What is the sentiment of this tweet? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: Exact (EXAS) Flagged As Strong On High Volume $EXAS http://stks.co/r26Ra\n",
      "Answer:  positive\n",
      "68: Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: CompaniesUnilever sales lifted by ice cream in soft economy\n",
      "Answer:  positive\n",
      "100%|███████████████████████████████████████████| 69/69 [00:24<00:00,  2.80it/s]\n",
      "Acc: 0.8109090909090909. F1 macro: 0.7262509528096003. F1 micro: 0.8109090909090909. F1 weighted (BloombergGPT): 0.8292671229987187. \n",
      "Results saved to fiqa_results.csv\n",
      "Evaluation Ends.\n"
     ]
    }
   ],
   "source": [
    "%cd fingpt/FinGPT_Benchmark/benchmarks\n",
    "\n",
    "base_model = 'llama2'\n",
    "# The FinGPT adapter model\n",
    "peft_model = 'FinGPT/fingpt-mt_llama2-7b_lora'\n",
    "batch_size = 4\n",
    "max_length = 512\n",
    "\n",
    "!python benchmarks.py --dataset fiqa --base_model {base_model} --peft_model {peft_model} --batch_size {batch_size} --max_length {max_length} --from_remote True"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "A100",
   "machine_shape": "hm",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
