{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "gzXEp7L7_mxK"
   },
   "source": [
    "# FinGPT Test: Financial Phrasebank (FPB)\n",
    "\n",
    "This notebook demonstrates how to test FinGPT models on the Financial Phrasebank (FPB) sentiment analysis dataset."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ZuQCI9X0_mxL"
   },
   "source": [
    "## 1. Install Dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "vlxuSeEl_mxL"
   },
   "outputs": [],
   "source": [
    "!pip install transformers==4.32.0 peft==0.5.0 datasets accelerate bitsandbytes sentencepiece tqdm scikit-learn pandas matplotlib seaborn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "WmCujtNe_mxM"
   },
   "source": [
    "## 2. Clone the FinGPT Repository"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "qBePNiVP_mxM"
   },
   "outputs": [],
   "source": [
    "!git clone https://github.com/AI4Finance-Foundation/FinGPT.git\n",
    "%cd FinGPT"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "iVxmcTW7_mxM"
   },
   "source": [
    "## 3. Create Sentiment Templates File\n",
    "\n",
    "This is needed for the multiple-template testing."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "id": "Iyr3_OF8_mxM"
   },
   "outputs": [],
   "source": [
    "!mkdir -p fingpt/FinGPT_Benchmark/benchmarks\n",
    "\n",
    "templates = \"\"\"What is the sentiment of this {type}?\n",
    "Determine the sentiment of this {type}.\n",
    "How would you describe the sentiment of this {type}?\n",
    "Is the sentiment of this {type} positive or negative?\n",
    "Analyze the sentiment of this {type}.\n",
    "What's the sentiment of this {type}?\"\"\"\n",
    "\n",
    "with open('fingpt/FinGPT_Benchmark/benchmarks/sentiment_templates.txt', 'w') as f:\n",
    "    f.write(templates)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "qNBwYi0n_mxM"
   },
   "source": [
    "## 4. Download the FPB Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "id": "KvcoJS50_mxN"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading Financial Phrasebank dataset...\n",
      "Saving dataset to fingpt/FinGPT_Benchmark/data/financial_phrasebank-sentences_50agree\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1109cb453cf74cfd9d893be86f3e06f8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/4846 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset download complete!\n"
     ]
    }
   ],
   "source": [
    "import datasets\n",
    "from pathlib import Path\n",
    "\n",
    "# Create the data directory if it doesn't exist\n",
    "data_dir = Path('./fingpt/FinGPT_Benchmark/data')\n",
    "data_dir.mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "# Download FPB dataset\n",
    "print(\"Downloading Financial Phrasebank dataset...\")\n",
    "dataset = datasets.load_dataset(\"financial_phrasebank\", \"sentences_50agree\")\n",
    "\n",
    "# Save the dataset to disk\n",
    "save_path = str(data_dir / \"financial_phrasebank-sentences_50agree\")\n",
    "print(f\"Saving dataset to {save_path}\")\n",
    "dataset.save_to_disk(save_path)\n",
    "print(\"Dataset download complete!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "SjkYtGkQ_mxN"
   },
   "source": [
    "## 5. Create Testing Module\n",
    "\n",
    "Let's implement the testing functions as defined in the FinGPT repository."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "Zr0Klkp8_mxN",
    "outputId": "8e4f6a9f-5677-495b-b354-93e086473ce9"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overwriting fingpt/FinGPT_Benchmark/benchmarks/fpb.py\n"
     ]
    }
   ],
   "source": [
    "%%writefile fingpt/FinGPT_Benchmark/benchmarks/fpb.py\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "from sklearn.metrics import accuracy_score, f1_score\n",
    "from datasets import load_dataset, load_from_disk, Dataset\n",
    "from tqdm import tqdm\n",
    "import datasets\n",
    "import torch\n",
    "\n",
    "from torch.utils.data import DataLoader\n",
    "from functools import partial\n",
    "from pathlib import Path\n",
    "\n",
    "dic = {\n",
    "        0:\"negative\",\n",
    "        1:'neutral',\n",
    "        2:'positive',\n",
    "    }\n",
    "\n",
    "with open(Path(__file__).parent / 'sentiment_templates.txt') as f:\n",
    "    templates = [l.strip() for l in f.readlines()]\n",
    "\n",
    "\n",
    "def format_example(example: dict) -> dict:\n",
    "    context = f\"Instruction: {example['instruction']}\\n\"\n",
    "    if example.get(\"input\"):\n",
    "        context += f\"Input: {example['input']}\\n\"\n",
    "    context += \"Answer: \"\n",
    "    target = example[\"output\"]\n",
    "    return {\"context\": context, \"target\": target}\n",
    "\n",
    "def change_target(x):\n",
    "    if 'positive' in x or 'Positive' in x:\n",
    "        return 'positive'\n",
    "    elif 'negative' in x or 'Negative' in x:\n",
    "        return 'negative'\n",
    "    else:\n",
    "        return 'neutral'\n",
    "\n",
    "\n",
    "def vote_output(x):\n",
    "    output_dict = {'positive': 0, 'negative': 0, 'neutral': 0}\n",
    "    for i in range(len(templates)):\n",
    "        pred = change_target(x[f'out_text_{i}'].lower())\n",
    "        output_dict[pred] += 1\n",
    "    if output_dict['positive'] > output_dict['negative']:\n",
    "        return 'positive'\n",
    "    elif output_dict['negative'] > output_dict['positive']:\n",
    "        return 'negative'\n",
    "    else:\n",
    "        return 'neutral'\n",
    "\n",
    "def test_fpb(args, model, tokenizer, prompt_fun=None):\n",
    "    batch_size = args.batch_size\n",
    "    # instructions = load_dataset(\"financial_phrasebank\", \"sentences_50agree\")\n",
    "    instructions = load_from_disk(Path(__file__).parent.parent / \"data/financial_phrasebank-sentences_50agree/\")\n",
    "    instructions = instructions[\"train\"]\n",
    "    instructions = instructions.train_test_split(seed = 42)['test']\n",
    "    instructions = instructions.to_pandas()\n",
    "    instructions.columns = [\"input\", \"output\"]\n",
    "    instructions[\"output\"] = instructions[\"output\"].apply(lambda x:dic[x])\n",
    "\n",
    "    if prompt_fun is None:\n",
    "        instructions[\"instruction\"] = \"What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\"\n",
    "    else:\n",
    "        instructions[\"instruction\"] = instructions.apply(prompt_fun, axis = 1)\n",
    "\n",
    "    instructions[[\"context\",\"target\"]] = instructions.apply(format_example, axis = 1, result_type=\"expand\")\n",
    "\n",
    "    # print example\n",
    "    print(f\"\\n\\nPrompt example:\\n{instructions['context'][0]}\\n\\n\")\n",
    "\n",
    "\n",
    "    context = instructions['context'].tolist()\n",
    "\n",
    "    total_steps = instructions.shape[0]//batch_size + 1\n",
    "    print(f\"Total len: {len(context)}. Batchsize: {batch_size}. Total steps: {total_steps}\")\n",
    "\n",
    "\n",
    "    out_text_list = []\n",
    "    for i in tqdm(range(total_steps)):\n",
    "        tmp_context = context[i* batch_size:(i+1)* batch_size]\n",
    "        if not tmp_context:  # Skip empty batches\n",
    "            continue\n",
    "        tokens = tokenizer(tmp_context, return_tensors='pt', padding=True, max_length=512, return_token_type_ids=False)\n",
    "        for k in tokens.keys():\n",
    "            tokens[k] = tokens[k].cuda()\n",
    "        res = model.generate(**tokens, max_length=512, eos_token_id=tokenizer.eos_token_id)\n",
    "        res_sentences = [tokenizer.decode(i, skip_special_tokens=True) for i in res]\n",
    "        # print(f'{i}: {res_sentences[0]}')\n",
    "        out_text = [o.split(\"Answer: \")[1] if \"Answer: \" in o else o for o in res_sentences]\n",
    "        out_text_list += out_text\n",
    "        torch.cuda.empty_cache()\n",
    "\n",
    "    instructions[\"out_text\"] = out_text_list\n",
    "    instructions[\"new_target\"] = instructions[\"target\"].apply(change_target)\n",
    "    instructions[\"new_out\"] = instructions[\"out_text\"].apply(change_target)\n",
    "\n",
    "    acc = accuracy_score(instructions[\"new_target\"], instructions[\"new_out\"])\n",
    "    f1_macro = f1_score(instructions[\"new_target\"], instructions[\"new_out\"], average = \"macro\")\n",
    "    f1_micro = f1_score(instructions[\"new_target\"], instructions[\"new_out\"], average = \"micro\")\n",
    "    f1_weighted = f1_score(instructions[\"new_target\"], instructions[\"new_out\"], average = \"weighted\")\n",
    "\n",
    "    print(f\"Acc: {acc}. F1 macro: {f1_macro}. F1 micro: {f1_micro}. F1 weighted (BloombergGPT): {f1_weighted}. \")\n",
    "\n",
    "    return instructions\n",
    "\n",
    "\n",
    "def test_fpb_mlt(args, model, tokenizer):\n",
    "    batch_size = args.batch_size\n",
    "    # dataset = load_dataset(\"financial_phrasebank\", \"sentences_50agree\")\n",
    "    dataset = load_from_disk(Path(__file__).parent.parent / 'data/financial_phrasebank-sentences_50agree/')\n",
    "    dataset = dataset[\"train\"]#.select(range(300))\n",
    "    dataset = dataset.train_test_split(seed=42)['test']\n",
    "    dataset = dataset.to_pandas()\n",
    "    dataset.columns = [\"input\", \"output\"]\n",
    "    dataset[\"output\"] = dataset[\"output\"].apply(lambda x: dic[x])\n",
    "    dataset[\"text_type\"] = dataset.apply(lambda x: 'news', axis=1)\n",
    "\n",
    "    dataset[\"output\"] = dataset[\"output\"].apply(change_target)\n",
    "    dataset = dataset[dataset[\"output\"] != 'neutral']\n",
    "\n",
    "    out_texts_list = [[] for _ in range(len(templates))]\n",
    "\n",
    "    def collate_fn(batch):\n",
    "        inputs = tokenizer(\n",
    "            [f[\"context\"] for f in batch], return_tensors='pt',\n",
    "            padding=True, max_length=args.max_length,\n",
    "            return_token_type_ids=False\n",
    "        )\n",
    "        return inputs\n",
    "\n",
    "    for i, template in enumerate(templates):\n",
    "        dataset_temp = dataset[['input', 'output', \"text_type\"]].copy()\n",
    "        dataset_temp[\"instruction\"] = dataset_temp['text_type'].apply(lambda x: template.format(type=x) + \"\\nOptions: positive, negative\")\n",
    "        # dataset[\"instruction\"] = dataset['text_type'].apply(lambda x: template.format(type=x) + \"\\nOptions: negative, positive\")\n",
    "        dataset_temp[[\"context\", \"target\"]] = dataset_temp.apply(format_example, axis=1, result_type=\"expand\")\n",
    "\n",
    "        dataloader = DataLoader(Dataset.from_pandas(dataset_temp), batch_size=args.batch_size, collate_fn=collate_fn, shuffle=False)\n",
    "\n",
    "        log_interval = len(dataloader) // 5\n",
    "\n",
    "        for idx, inputs in enumerate(tqdm(dataloader)):\n",
    "            inputs = {key: value.to(model.device) for key, value in inputs.items()}\n",
    "            res = model.generate(**inputs, do_sample=False, max_length=args.max_length, eos_token_id=tokenizer.eos_token_id, max_new_tokens=10)\n",
    "            res_sentences = [tokenizer.decode(i, skip_special_tokens=True) for i in res]\n",
    "            tqdm.write(f'{idx}: {res_sentences[0]}')\n",
    "            # if (idx + 1) % log_interval == 0:\n",
    "            #     tqdm.write(f'{idx}: {res_sentences[0]}')\n",
    "            out_text = [o.split(\"Answer: \")[1] if \"Answer: \" in o else o for o in res_sentences]\n",
    "            out_texts_list[i] += out_text\n",
    "            torch.cuda.empty_cache()\n",
    "\n",
    "    for i in range(len(templates)):\n",
    "        dataset[f\"out_text_{i}\"] = out_texts_list[i]\n",
    "        dataset[f\"out_text_{i}\"] = dataset[f\"out_text_{i}\"].apply(change_target)\n",
    "\n",
    "    dataset[\"new_out\"] = dataset.apply(vote_output, axis=1, result_type=\"expand\")\n",
    "    dataset.to_csv('tmp.csv')\n",
    "\n",
    "    for k in [f\"out_text_{i}\" for i in range(len(templates))] + [\"new_out\"]:\n",
    "\n",
    "        acc = accuracy_score(dataset[\"target\"], dataset[k])\n",
    "        f1_macro = f1_score(dataset[\"target\"], dataset[k], average=\"macro\")\n",
    "        f1_micro = f1_score(dataset[\"target\"], dataset[k], average=\"micro\")\n",
    "        f1_weighted = f1_score(dataset[\"target\"], dataset[k], average=\"weighted\")\n",
    "\n",
    "        print(f\"Acc: {acc}. F1 macro: {f1_macro}. F1 micro: {f1_micro}. F1 weighted (BloombergGPT): {f1_weighted}. \")\n",
    "\n",
    "    return dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "L5DzncH3_mxN"
   },
   "source": [
    "## 6. Create Benchmarks Runner Script"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "49pCYdPa_mxO",
    "outputId": "64c0e60f-4b0d-49b3-fe71-9e86bc537bf9"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overwriting fingpt/FinGPT_Benchmark/benchmarks/benchmarks.py\n"
     ]
    }
   ],
   "source": [
    "%%writefile fingpt/FinGPT_Benchmark/benchmarks/benchmarks.py\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
    "from peft import PeftModel, get_peft_model, LoraConfig, TaskType  # 0.4.0\n",
    "import torch\n",
    "import argparse\n",
    "\n",
    "from fpb import test_fpb, test_fpb_mlt\n",
    "\n",
    "import sys\n",
    "sys.path.append('../')\n",
    "from utils import *\n",
    "\n",
    "def main(args):\n",
    "    if args.from_remote:\n",
    "        model_name = parse_model_name(args.base_model, args.from_remote)\n",
    "    else:\n",
    "        model_name = '../' + parse_model_name(args.base_model)\n",
    "\n",
    "    model = AutoModelForCausalLM.from_pretrained(\n",
    "        model_name, trust_remote_code=True,\n",
    "        # load_in_8bit=True\n",
    "        device_map=\"auto\",\n",
    "        # fp16=True\n",
    "    )\n",
    "    model.model_parallel = True\n",
    "\n",
    "    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n",
    "\n",
    "    # tokenizer.pad_token_id = tokenizer.eos_token_id\n",
    "\n",
    "    tokenizer.padding_side = \"left\"\n",
    "    if args.base_model == 'qwen':\n",
    "        tokenizer.eos_token_id = tokenizer.convert_tokens_to_ids('<|endoftext|>')\n",
    "        tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids('<|extra_0|>')\n",
    "    if not tokenizer.pad_token or tokenizer.pad_token_id == tokenizer.eos_token_id:\n",
    "        tokenizer.add_special_tokens({'pad_token': '[PAD]'})\n",
    "        model.resize_token_embeddings(len(tokenizer))\n",
    "\n",
    "    print(f'pad: {tokenizer.pad_token_id}, eos: {tokenizer.eos_token_id}')\n",
    "\n",
    "    model = PeftModel.from_pretrained(model, args.peft_model)\n",
    "    model = model.eval()\n",
    "\n",
    "    with torch.no_grad():\n",
    "        for data in args.dataset.split(','):\n",
    "            if data == 'fpb':\n",
    "                test_fpb(args, model, tokenizer)\n",
    "            elif data == 'fpb_mlt':\n",
    "                test_fpb_mlt(args, model, tokenizer)\n",
    "            else:\n",
    "                raise ValueError('undefined dataset.')\n",
    "\n",
    "    print('Evaluation Ends.')\n",
    "\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "\n",
    "    parser = argparse.ArgumentParser()\n",
    "    parser.add_argument(\"--dataset\", required=True, type=str)\n",
    "    parser.add_argument(\"--base_model\", required=True, type=str, choices=['chatglm2', 'llama2', 'llama2-13b', 'llama2-13b-nr', 'baichuan', 'falcon', 'internlm', 'qwen', 'mpt', 'bloom'])\n",
    "    parser.add_argument(\"--peft_model\", required=True, type=str)\n",
    "    parser.add_argument(\"--max_length\", default=512, type=int)\n",
    "    parser.add_argument(\"--batch_size\", default=4, type=int, help=\"The train batch size per device\")\n",
    "    parser.add_argument(\"--instruct_template\", default='default')\n",
    "    parser.add_argument(\"--from_remote\", default=False, type=bool)\n",
    "\n",
    "    args = parser.parse_args()\n",
    "\n",
    "    print(args.base_model)\n",
    "    print(args.peft_model)\n",
    "\n",
    "    main(args)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ne8oJBxV_mxO"
   },
   "source": [
    "## 7. Create Utils Module"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "D9-3XebT_mxO",
    "outputId": "1c61b6ef-0351-4d40-dec1-432342a6bb79"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overwriting fingpt/FinGPT_Benchmark/utils.py\n"
     ]
    }
   ],
   "source": [
    "%%writefile fingpt/FinGPT_Benchmark/utils.py\n",
    "def parse_model_name(base_model, from_remote=False):\n",
    "    model_map = {\n",
    "        'chatglm2': 'THUDM/chatglm2-6b',\n",
    "        'llama2': 'meta-llama/Llama-2-7b-hf',\n",
    "        'llama2-13b': 'meta-llama/Llama-2-13b-hf',\n",
    "        'llama2-13b-nr': 'NousResearch/Llama-2-13b-hf',\n",
    "        'baichuan': 'baichuan-inc/Baichuan-7B',\n",
    "        'falcon': 'tiiuae/falcon-7b',\n",
    "        'internlm': 'internlm/internlm-7b',\n",
    "        'qwen': 'Qwen/Qwen-7B',\n",
    "        'mpt': 'mosaicml/mpt-7b',\n",
    "        'bloom': 'bigscience/bloom-7b1',\n",
    "    }\n",
    "    if base_model not in model_map:\n",
    "        raise ValueError(f\"Unknown base model: {base_model}\")\n",
    "    return model_map[base_model]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "id": "MKGku6zYBHS1"
   },
   "outputs": [],
   "source": [
    "from huggingface_hub import login\n",
    "login(token=\"token\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "HDRhXEXZ_mxO"
   },
   "source": [
    "## 8. Run the FPB Benchmark Test\n",
    "\n",
    "Now that we have set up all the necessary files, let's run the benchmark test."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "3l-4zlmr_mxO",
    "outputId": "fdf3f68e-286f-4dd8-a9e0-4c49a658ee9a"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/workspace/FinLoRA/test/fingpt_tests/FinGPT/fingpt/FinGPT_Benchmark/benchmarks/FinGPT/fingpt/FinGPT_Benchmark/benchmarks\n",
      "/usr/local/lib/python3.11/dist-packages/transformers/utils/generic.py:260: FutureWarning: `torch.utils._pytree._register_pytree_node` is deprecated. Please use `torch.utils._pytree.register_pytree_node` instead.\n",
      "  torch.utils._pytree._register_pytree_node(\n",
      "llama2\n",
      "FinGPT/fingpt-mt_llama2-7b_lora\n",
      "Loading checkpoint shards: 100%|██████████████████| 2/2 [00:07<00:00,  3.98s/it]\n",
      "Using pad_token, but it is not set yet.\n",
      "You are resizing the embedding layer without providing a `pad_to_multiple_of` parameter. This means that the new embeding dimension will be 32001. This might induce some performance reduction as *Tensor Cores* will not be available. For more details  about this, or help on choosing the correct value for resizing, refer to this guide: https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc\n",
      "pad: 32000, eos: 2\n",
      "\n",
      "\n",
      "Prompt example:\n",
      "Instruction: What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\n",
      "Input: L&T has also made a commitment to redeem the remaining shares by the end of 2011 .\n",
      "Answer: \n",
      "\n",
      "\n",
      "Total len: 1212. Batchsize: 4. Total steps: 304\n",
      "100%|█████████████████████████████████████████| 304/304 [02:01<00:00,  2.51it/s]\n",
      "Acc: 0.8646864686468647. F1 macro: 0.8482702552620766. F1 micro: 0.8646864686468647. F1 weighted (BloombergGPT): 0.8631274316532289. \n",
      "Evaluation Ends.\n"
     ]
    }
   ],
   "source": [
    "# Run the FPB benchmark test using the pre-trained FinGPT model\n",
    "\n",
    "# Change to the benchmarks directory\n",
    "%cd fingpt/FinGPT_Benchmark/benchmarks\n",
    "\n",
    "# You can modify these parameters based on your needs\n",
    "base_model = 'llama2'  # Options: chatglm2, llama2, falcon, etc.\n",
    "peft_model = 'FinGPT/fingpt-mt_llama2-7b_lora'  # The FinGPT adapter model\n",
    "batch_size = 4\n",
    "max_length = 512\n",
    "\n",
    "# Single template test\n",
    "!python benchmarks.py --dataset fpb --base_model {base_model} --peft_model {peft_model} --batch_size {batch_size} --max_length {max_length} --from_remote True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "gTqF-nXM_mxO"
   },
   "source": [
    "## 9. Run Multiple-Template Test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "E678J9LF_mxO"
   },
   "outputs": [],
   "source": [
    "# Run the multiple-template FPB test\n",
    "!python benchmarks.py --dataset fpb_mlt --base_model {base_model} --peft_model {peft_model} --batch_size {batch_size} --max_length {max_length} --from_remote True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "uopNBXrc_mxO"
   },
   "source": [
    "## 10. Analyze the Results\n",
    "\n",
    "The benchmark test will output accuracy and F1 scores."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "d6FdBOBA_mxO"
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from sklearn.metrics import accuracy_score, f1_score, classification_report\n",
    "\n",
    "results = pd.read_csv('tmp.csv')\n",
    "\n",
    "target = results['target']  # True labels\n",
    "predictions = results['new_out']  # Predicted labels\n",
    "\n",
    "# Calculate accuracy and F1 scores\n",
    "accuracy = accuracy_score(target, predictions)\n",
    "f1_macro = f1_score(target, predictions, average='macro')\n",
    "f1_micro = f1_score(target, predictions, average='micro')\n",
    "f1_weighted = f1_score(target, predictions, average='weighted')\n",
    "\n",
    "# Print the metrics\n",
    "print(f\"Accuracy: {accuracy:.4f}\")\n",
    "print(f\"F1 Macro: {f1_macro:.4f}\")\n",
    "print(f\"F1 Micro: {f1_micro:.4f}\")\n",
    "print(f\"F1 Weighted (BloombergGPT): {f1_weighted:.4f}\")\n",
    "\n",
    "print(\"\\nDetailed Classification Report:\")\n",
    "print(classification_report(target, predictions))\n",
    "\n",
    "if 'out_text_0' in results.columns:\n",
    "    print(\"\\nResults by template:\")\n",
    "    template_columns = [col for col in results.columns if col.startswith('out_text_')]\n",
    "    for i, col in enumerate(template_columns):\n",
    "        accuracy = accuracy_score(target, results[col])\n",
    "        f1_macro = f1_score(target, results[col], average='macro')\n",
    "        f1_weighted = f1_score(target, results[col], average='weighted')\n",
    "        print(f\"Template {i}: Accuracy: {accuracy:.4f}, F1 Macro: {f1_macro:.4f}, F1 Weighted: {f1_weighted:.4f}\")"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "A100",
   "machine_shape": "hm",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
