{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9D3qOolmj0tN"
      },
      "source": [
        "# FinGPT Test: Named Entity Recognition (NER)\n",
        "\n",
        "This notebook demonstrates how to test FinGPT models on the Named Entity Recognition (NER) dataset."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-xPEjCL9j0tP"
      },
      "source": [
        "## 1. Install Dependencies"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BhuB4Oo1j0tP",
        "vscode": {
          "languageId": "python"
        }
      },
      "outputs": [],
      "source": [
        "!pip install transformers==4.32.0 peft==0.5.0 datasets accelerate bitsandbytes sentencepiece tqdm scikit-learn pandas matplotlib seaborn"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lRZO3EHyj0tQ"
      },
      "source": [
        "## 2. Import Libraries"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "id": "fZKQafyMj0tQ",
        "vscode": {
          "languageId": "python"
        }
      },
      "outputs": [],
      "source": [
        "import os\n",
        "import csv\n",
        "import torch\n",
        "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
        "from peft import PeftModel, PeftConfig\n",
        "from tqdm import tqdm\n",
        "import pandas as pd"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qAtDJPcFj0tQ"
      },
      "source": [
        "## 3. Login to Hugging Face"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pCB5uU8pj0tR",
        "vscode": {
          "languageId": "python"
        }
      },
      "outputs": [],
      "source": [
        "from huggingface_hub import login\n",
        "login(token=\"token\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nlDnMgq3ycfs"
      },
      "source": [
        "## 4. Download NER dataset"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qMvAkV2Bybyf",
        "vscode": {
          "languageId": "python"
        }
      },
      "outputs": [],
      "source": [
        "!wget -O /content/test.parquet https://huggingface.co/datasets/FinGPT/fingpt-ner-cls/resolve/main/data/test-00000-of-00001-71355aae60cb2b0b.parquet"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "U9aD95_HzEhV",
        "vscode": {
          "languageId": "python"
        }
      },
      "outputs": [],
      "source": [
        "df = pd.read_parquet(\"test.parquet\")\n",
        "df.to_csv(\"ner.csv\", index=False)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "L-sBYctIj0tR"
      },
      "source": [
        "## 5. Load Models"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vcwjyNMvj0tR",
        "vscode": {
          "languageId": "python"
        }
      },
      "outputs": [],
      "source": [
        "import os\n",
        "import csv\n",
        "import torch\n",
        "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
        "from peft import PeftModel, PeftConfig\n",
        "from tqdm import tqdm\n",
        "import pandas as pd\n",
        "\n",
        "batch_size = 4\n",
        "max_length = 512\n",
        "\n",
        "def parse_model_name(base_model, from_remote=False):\n",
        "    model_map = {\n",
        "        'chatglm2': 'THUDM/chatglm2-6b',\n",
        "        'llama2': 'meta-llama/Llama-2-7b-hf',\n",
        "        'llama2-13b': 'meta-llama/Llama-2-13b-hf',\n",
        "        'llama2-13b-nr': 'NousResearch/Llama-2-13b-hf',\n",
        "        'baichuan': 'baichuan-inc/Baichuan-7B',\n",
        "        'falcon': 'tiiuae/falcon-7b',\n",
        "        'internlm': 'internlm/internlm-7b',\n",
        "        'qwen': 'Qwen/Qwen-7B',\n",
        "        'mpt': 'mosaicml/mpt-7b',\n",
        "        'bloom': 'bigscience/bloom-7b1',\n",
        "    }\n",
        "    if base_model not in model_map:\n",
        "        raise ValueError(f\"Unknown base model: {base_model}\")\n",
        "    return model_map[base_model]\n",
        "\n",
        "def load_model_and_tokenizer(base_model, peft_model, from_remote=True):\n",
        "    if from_remote:\n",
        "        model_name = parse_model_name(base_model)\n",
        "    else:\n",
        "        model_name = '../' + parse_model_name(base_model)\n",
        "\n",
        "    model = AutoModelForCausalLM.from_pretrained(\n",
        "        model_name, trust_remote_code=True,\n",
        "        # load_in_8bit=True\n",
        "        device_map=\"auto\",\n",
        "        # fp16=True\n",
        "    )\n",
        "    model.model_parallel = True\n",
        "    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n",
        "\n",
        "    # tokenizer.pad_token_id = tokenizer.eos_token_id\n",
        "    tokenizer.padding_side = \"left\"\n",
        "    if base_model == 'qwen':\n",
        "        tokenizer.eos_token_id = tokenizer.convert_tokens_to_ids('<|endoftext|>')\n",
        "        tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids('<|extra_0|>')\n",
        "    if not tokenizer.pad_token or tokenizer.pad_token_id == tokenizer.eos_token_id:\n",
        "        tokenizer.add_special_tokens({'pad_token': '[PAD]'})\n",
        "        model.resize_token_embeddings(len(tokenizer))\n",
        "\n",
        "    print(f'pad: {tokenizer.pad_token_id}, eos: {tokenizer.eos_token_id}')\n",
        "\n",
        "    model = PeftModel.from_pretrained(model, peft_model)\n",
        "    model = model.eval()\n",
        "\n",
        "    return model, tokenizer\n",
        "\n",
        "base_model = 'llama2'\n",
        "peft_model = 'FinGPT/fingpt-mt_llama2-7b_lora'\n",
        "model, tokenizer = load_model_and_tokenizer(base_model, peft_model, from_remote=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6Z98KGEOj0tR"
      },
      "source": [
        "## 6. Define NER Function"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_cleGl44j0tR",
        "vscode": {
          "languageId": "python"
        }
      },
      "outputs": [],
      "source": [
        "def get_entity_response(context, instruction):\n",
        "    prompt = f\"Instruction: {instruction}\\nInput: {context}\\nAnswer: \"\n",
        "\n",
        "    tokens = tokenizer(prompt, return_tensors='pt', max_length=512, truncation=True, padding=True)\n",
        "    tokens = {k: v.to(model.device) for k, v in tokens.items()}\n",
        "\n",
        "    with torch.no_grad():\n",
        "        outputs = model.generate(**tokens, max_length=512, eos_token_id=tokenizer.eos_token_id)\n",
        "    decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)\n",
        "\n",
        "    # Split on \"Answer: \"\n",
        "    if \"Answer: \" in decoded:\n",
        "        answer = decoded.split(\"Answer: \")[1].strip()\n",
        "    else:\n",
        "        answer = decoded\n",
        "\n",
        "    return answer\n",
        "\n",
        "input_file = 'ner.csv'\n",
        "output_file = 'fingpt_ner.csv'"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0sACQqJXj0tR"
      },
      "source": [
        "## 7. Process the Dataset"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lVtJ5lAsj0tR",
        "vscode": {
          "languageId": "python"
        }
      },
      "outputs": [],
      "source": [
        "with open(input_file, 'r', encoding='utf-8') as infile:\n",
        "    reader = csv.DictReader(infile)\n",
        "    all_rows = list(reader)\n",
        "    fieldnames = reader.fieldnames\n",
        "    if 'model answer' not in fieldnames:\n",
        "        fieldnames.append('model answer')\n",
        "\n",
        "with open(output_file, 'w', newline='', encoding='utf-8') as outfile:\n",
        "    writer = csv.DictWriter(outfile, fieldnames=fieldnames)\n",
        "    writer.writeheader()\n",
        "\n",
        "    for row in tqdm(all_rows, desc=\"Processing rows\"):\n",
        "        context = row.get('input', '')\n",
        "        instruction = row.get('instruction', '')\n",
        "\n",
        "        try:\n",
        "            row['model answer'] = get_entity_response(context, instruction)\n",
        "        except Exception as e:\n",
        "            print(f\"Error: {e}\")\n",
        "            row['model answer'] = 'ERROR'\n",
        "\n",
        "        writer.writerow(row)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SPPSITDDzYyu"
      },
      "source": [
        "## 8. Evaluation"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eTzu2Nz4qKOf"
      },
      "source": [
        "We then manually clean the file in Google Sheets to clean cases like \"a person\", \"an organization\", \"industry, organization\" to be the standard \"location, person, organization\" options."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "R_me-dQHrH_1"
      },
      "source": [
        "We run the following code to label each correct answer with 1 in a scoring column and 0 for each incorrect answer in a scoring column."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VAergpeHnlZu",
        "vscode": {
          "languageId": "python"
        }
      },
      "outputs": [],
      "source": [
        "\"\"\"\n",
        "function checkAnswers() {\n",
        "  // Open the active spreadsheet\n",
        "  const sheet = SpreadsheetApp.getActiveSpreadsheet().getActiveSheet();\n",
        "\n",
        "  const dataRange = sheet.getDataRange();\n",
        "  const data = dataRange.getValues();\n",
        "\n",
        "  const headers = data[0];\n",
        "  const answersIndex = headers.indexOf('answer');\n",
        "  const modelAnswersIndex = headers.indexOf('model answer');\n",
        "\n",
        "  if (answersIndex === -1 || modelAnswersIndex === -1) {\n",
        "    SpreadsheetApp.getUi().alert('Error: \"answers\" or \"model answers\" column not found.');\n",
        "    return;\n",
        "  }\n",
        "\n",
        "  const scoringIndex = headers.indexOf('scoring');\n",
        "  if (scoringIndex === -1) {\n",
        "    sheet.getRange(1, headers.length + 1).setValue('scoring');\n",
        "  }\n",
        "\n",
        "  for (let i = 1; i < data.length; i++) {\n",
        "    const answer = data[i][answersIndex].toLowerCase();\n",
        "    const modelAnswer = data[i][modelAnswersIndex].toLowerCase();\n",
        "\n",
        "    const scoringValue = (answer === modelAnswer) ? '1' : '0';\n",
        "    sheet.getRange(i + 1, headers.length + 1).setValue(scoringValue);\n",
        "\n",
        "    if (scoringValue === '0') {\n",
        "      sheet.getRange(i + 1, headers.length + 1).setBackground('yellow');\n",
        "    } else {\n",
        "      sheet.getRange(i + 1, headers.length + 1).setBackground(null);\n",
        "    }\n",
        "  }\n",
        "\n",
        "  SpreadsheetApp.getUi().alert('Scoring completed!');\n",
        "}\n",
        "\"\"\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "n4qwpmqnrL_f"
      },
      "source": [
        "We then compute the average of the scoring column to get the accuracy. This works since the scoring can only be 0 or 1."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QJiLKDLOrr3I"
      },
      "source": [
        "To get the Weighted F1 Score, I run the following Google Script."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Eriw_Z9UokeE",
        "vscode": {
          "languageId": "python"
        }
      },
      "outputs": [],
      "source": [
        "\"\"\"\n",
        "function computeWeightedF1() {\n",
        "  const sheet = SpreadsheetApp.getActiveSpreadsheet().getActiveSheet();\n",
        "  const dataRange = sheet.getDataRange();\n",
        "  const data = dataRange.getValues();\n",
        "\n",
        "  const headers = data[0];\n",
        "  const answerIndex = headers.indexOf('answer');\n",
        "  const modelAnswerIndex = headers.indexOf('model answer');\n",
        "\n",
        "  if (answerIndex === -1 || modelAnswerIndex === -1) {\n",
        "    SpreadsheetApp.getUi().alert('Error: \"answer\" or \"model answer\" column not found.');\n",
        "    return;\n",
        "  }\n",
        "\n",
        "  // Holds stats for each class.\n",
        "  // true positives (TP), false positives (FP), false negatives (FN), and the count (number of gold instances)\n",
        "  const stats = {};\n",
        "\n",
        "  for (let i = 1; i < data.length; i++) {\n",
        "    const trueLabel = data[i][answerIndex];\n",
        "    const predLabel = data[i][modelAnswerIndex];\n",
        "\n",
        "    if (trueLabel === \"\" || trueLabel === undefined) continue;\n",
        "\n",
        "    if (!stats[trueLabel]) {\n",
        "      stats[trueLabel] = { TP: 0, FP: 0, FN: 0, count: 0 };\n",
        "    }\n",
        "\n",
        "    stats[trueLabel].count++;\n",
        "\n",
        "    if (!stats[predLabel]) {\n",
        "      stats[predLabel] = { TP: 0, FP: 0, FN: 0, count: 0 };\n",
        "    }\n",
        "\n",
        "    if (trueLabel === predLabel) {\n",
        "      stats[trueLabel].TP++;\n",
        "    } else {\n",
        "      stats[trueLabel].FN++;\n",
        "      stats[predLabel].FP++;\n",
        "    }\n",
        "  }\n",
        "\n",
        "  let weightedF1Sum = 0;\n",
        "  let weightSum = 0;\n",
        "\n",
        "  for (let label in stats) {\n",
        "    if (stats[label].count > 0) {\n",
        "      const TP = stats[label].TP;\n",
        "      const FP = stats[label].FP;\n",
        "      const FN = stats[label].FN;\n",
        "\n",
        "      const precision = (TP + FP) > 0 ? TP / (TP + FP) : 0;\n",
        "      const recall = (TP + FN) > 0 ? TP / (TP + FN) : 0;\n",
        "      const f1 = (precision + recall) > 0 ? (2 * precision * recall) / (precision + recall) : 0;\n",
        "\n",
        "      weightedF1Sum += stats[label].count * f1;\n",
        "      weightSum += stats[label].count;\n",
        "    }\n",
        "  }\n",
        "\n",
        "  const weightedF1 = weightSum > 0 ? weightedF1Sum / weightSum : 0;\n",
        "\n",
        "  SpreadsheetApp.getUi().alert('Weighted F1 Score: ' + weightedF1);\n",
        "}\n",
        "\"\"\""
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "A100",
      "machine_shape": "hm",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Mojo",
      "language": "mojo",
      "name": "mojo-jupyter-kernel"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "mojo",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.10"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
