{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"},"accelerator":"GPU","gpuClass":"standard"},"cells":[{"cell_type":"markdown","source":["In this notebook, you will be able to train the BERT classifiers on both the DWMW17 and FDCL18 datasets. These datasets can be obtained by using the twitter API and your secret key. Or it can obtained by wget on the given URLs (however, we suggest using the first method).\n","\n","At the end of this notebook, we are saving the models as \"davidson_bert\" and \"founta_bert\". We also provide these models in the supplementary, which are ready to be used and do not need to be trained from scratch.\n","\n","References we used:\n","1. https://www.tensorflow.org/text/tutorials/classify_text_with_bert\n","2. https://tinyurl.com/bert-github-source-code\n","3. https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/4"],"metadata":{"id":"WcbvKRRRrMSG"}},{"cell_type":"code","source":["#!wget  https://www.dropbox.com/sh/4mapojr85a6sc76/AACD5J8nKg1mOImMhKCE06Zma/hatespeech_text_label_vote_RESTRICTED_100K.csv?dl=0 -O fdcl18.csv\n","#!wget https://github.com/t-davidson/hate-speech-and-offensive-language/raw/master/data/labeled_data.csv -O Dwmw17.csv"],"metadata":{"id":"LrcnMReMrOSW"},"execution_count":null,"outputs":[]},{"cell_type":"code","execution_count":null,"metadata":{"id":"iTFJ2PQGSGuj"},"outputs":[],"source":["import pandas as pd\n","import os\n","\n","df = pd.read_csv(\"/content/fdcl18.csv\",delimiter=\"\\t\")\n","df.columns=[\"tweet\", \"label\", \"count\"]\n","cur = os.getcwd()\n","train = os.path.join(cur,\"train_fdcl\")\n","test = os.path.join(cur,\"test_fdcl\")\n","os.mkdir(train)\n","os.mkdir(test)\n","labels = df[\"label\"].unique()\n","\n","for lab in labels:\n","    dir_train = os.path.join(train, lab)\n","    os.mkdir(dir_train)\n","    dir_test = os.path.join(test, lab)\n","    os.mkdir(dir_test)\n","\n","train_ = 70000\n","\n","\n","dic = {\"abusive\":0, \"hateful\":0, \"normal\":0, \"spam\":0}\n","for i in range(train_,99995):\n","    tweet_label = df[\"label\"].iloc[i]\n","    dic[tweet_label]+=1\n","    tweet = df[\"tweet\"].iloc[i]\n","    with open(os.path.join(test, tweet_label) + \"/\" + tweet_label + \"_\" + str(dic[tweet_label])+\".txt\", \"w\", encoding=\"utf-8\") as f:\n","        f.write(tweet)\n","dic = {\"abusive\":0, \"hateful\":0, \"normal\":0, \"spam\":0}\n","for i in range(train_):\n","    tweet_label = df[\"label\"].iloc[i]\n","    dic[tweet_label]+=1\n","    tweet = df[\"tweet\"].iloc[i]\n","    with open(os.path.join(train, tweet_label) + \"/\" + tweet_label + \"_\" + str(dic[tweet_label])+\".txt\", \"w\", encoding=\"utf-8\") as f:\n","        f.write(tweet)"]},{"cell_type":"code","source":["!pip install -q -U \"tensorflow-text==2.8.*\"\n","!pip install -q tf-models-official==2.7.0"],"metadata":{"id":"7fSTJJWNUDYU"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["import os\n","import shutil\n","\n","import tensorflow as tf\n","import tensorflow_hub as hub\n","import tensorflow_text as text\n","from official.nlp import optimization  # to create AdamW optimizer\n","\n","import matplotlib.pyplot as plt\n","\n","tf.get_logger().setLevel('ERROR')"],"metadata":{"id":"yWcHfKLTUQqJ"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["AUTOTUNE = tf.data.AUTOTUNE\n","batch_size = 32\n","seed = 42\n","\n","raw_train_ds = tf.keras.utils.text_dataset_from_directory(\n","    '/content/train_fdcl',\n","    batch_size=batch_size,\n","    validation_split=0.2,\n","    subset='training',\n","    seed=seed)\n","\n","class_names = raw_train_ds.class_names\n","print(class_names)\n","train_ds = raw_train_ds.cache().prefetch(buffer_size=AUTOTUNE)\n","\n","val_ds = tf.keras.utils.text_dataset_from_directory(\n","    '/content/train_fdcl',\n","    batch_size=batch_size,\n","    validation_split=0.2,\n","    subset='validation',\n","    seed=seed)\n","\n","val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)\n","\n","test_ds = tf.keras.utils.text_dataset_from_directory(\n","    '/content/test_fdcl',\n","    batch_size=batch_size)\n","\n","test_ds = test_ds.cache().prefetch(buffer_size=AUTOTUNE)"],"metadata":{"id":"ZsPYzhfiUdiC"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["tfhub_handle_encoder = \"https://tfhub.dev/tensorflow/bert_multi_cased_L-12_H-768_A-12/3\"\n","tfhub_handle_preprocess = \"https://tfhub.dev/tensorflow/bert_multi_cased_preprocess/3\""],"metadata":{"id":"XKXhhiGTFT1Y"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@title Choose a BERT model to fine-tune\n","\n","bert_model_name = 'bert_en_uncased_L-12_H-768_A-12'  #@param [\"bert_en_uncased_L-12_H-768_A-12\", \"bert_en_cased_L-12_H-768_A-12\", \"bert_multi_cased_L-12_H-768_A-12\", \"small_bert/bert_en_uncased_L-2_H-128_A-2\", \"small_bert/bert_en_uncased_L-2_H-256_A-4\", \"small_bert/bert_en_uncased_L-2_H-512_A-8\", \"small_bert/bert_en_uncased_L-2_H-768_A-12\", \"small_bert/bert_en_uncased_L-4_H-128_A-2\", \"small_bert/bert_en_uncased_L-4_H-256_A-4\", \"small_bert/bert_en_uncased_L-4_H-512_A-8\", \"small_bert/bert_en_uncased_L-4_H-768_A-12\", \"small_bert/bert_en_uncased_L-6_H-128_A-2\", \"small_bert/bert_en_uncased_L-6_H-256_A-4\", \"small_bert/bert_en_uncased_L-6_H-512_A-8\", \"small_bert/bert_en_uncased_L-6_H-768_A-12\", \"small_bert/bert_en_uncased_L-8_H-128_A-2\", \"small_bert/bert_en_uncased_L-8_H-256_A-4\", \"small_bert/bert_en_uncased_L-8_H-512_A-8\", \"small_bert/bert_en_uncased_L-8_H-768_A-12\", \"small_bert/bert_en_uncased_L-10_H-128_A-2\", \"small_bert/bert_en_uncased_L-10_H-256_A-4\", \"small_bert/bert_en_uncased_L-10_H-512_A-8\", \"small_bert/bert_en_uncased_L-10_H-768_A-12\", \"small_bert/bert_en_uncased_L-12_H-128_A-2\", \"small_bert/bert_en_uncased_L-12_H-256_A-4\", \"small_bert/bert_en_uncased_L-12_H-512_A-8\", \"small_bert/bert_en_uncased_L-12_H-768_A-12\", \"albert_en_base\", \"electra_small\", \"electra_base\", \"experts_pubmed\", \"experts_wiki_books\", \"talking-heads_base\"]\n","\n","map_name_to_handle = {\n","    'bert_en_uncased_L-12_H-768_A-12':\n","        'https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/3',\n","    'bert_en_cased_L-12_H-768_A-12':\n","        'https://tfhub.dev/tensorflow/bert_en_cased_L-12_H-768_A-12/3',\n","    'bert_multi_cased_L-12_H-768_A-12':\n","        'https://tfhub.dev/tensorflow/bert_multi_cased_L-12_H-768_A-12/3',\n","    'small_bert/bert_en_uncased_L-2_H-128_A-2':\n","        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-2_H-128_A-2/1',\n","    'small_bert/bert_en_uncased_L-2_H-256_A-4':\n","        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-2_H-256_A-4/1',\n","    'small_bert/bert_en_uncased_L-2_H-512_A-8':\n","        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-2_H-512_A-8/1',\n","    'small_bert/bert_en_uncased_L-2_H-768_A-12':\n","        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-2_H-768_A-12/1',\n","    'small_bert/bert_en_uncased_L-4_H-128_A-2':\n","        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-128_A-2/1',\n","    'small_bert/bert_en_uncased_L-4_H-256_A-4':\n","        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-256_A-4/1',\n","    'small_bert/bert_en_uncased_L-4_H-512_A-8':\n","        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-512_A-8/1',\n","    'small_bert/bert_en_uncased_L-4_H-768_A-12':\n","        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-768_A-12/1',\n","    'small_bert/bert_en_uncased_L-6_H-128_A-2':\n","        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-6_H-128_A-2/1',\n","    'small_bert/bert_en_uncased_L-6_H-256_A-4':\n","        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-6_H-256_A-4/1',\n","    'small_bert/bert_en_uncased_L-6_H-512_A-8':\n","        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-6_H-512_A-8/1',\n","    'small_bert/bert_en_uncased_L-6_H-768_A-12':\n","        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-6_H-768_A-12/1',\n","    'small_bert/bert_en_uncased_L-8_H-128_A-2':\n","        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-8_H-128_A-2/1',\n","    'small_bert/bert_en_uncased_L-8_H-256_A-4':\n","        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-8_H-256_A-4/1',\n","    'small_bert/bert_en_uncased_L-8_H-512_A-8':\n","        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-8_H-512_A-8/1',\n","    'small_bert/bert_en_uncased_L-8_H-768_A-12':\n","        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-8_H-768_A-12/1',\n","    'small_bert/bert_en_uncased_L-10_H-128_A-2':\n","        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-10_H-128_A-2/1',\n","    'small_bert/bert_en_uncased_L-10_H-256_A-4':\n","        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-10_H-256_A-4/1',\n","    'small_bert/bert_en_uncased_L-10_H-512_A-8':\n","        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-10_H-512_A-8/1',\n","    'small_bert/bert_en_uncased_L-10_H-768_A-12':\n","        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-10_H-768_A-12/1',\n","    'small_bert/bert_en_uncased_L-12_H-128_A-2':\n","        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-12_H-128_A-2/1',\n","    'small_bert/bert_en_uncased_L-12_H-256_A-4':\n","        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-12_H-256_A-4/1',\n","    'small_bert/bert_en_uncased_L-12_H-512_A-8':\n","        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-12_H-512_A-8/1',\n","    'small_bert/bert_en_uncased_L-12_H-768_A-12':\n","        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-12_H-768_A-12/1',\n","    'albert_en_base':\n","        'https://tfhub.dev/tensorflow/albert_en_base/2',\n","    'electra_small':\n","        'https://tfhub.dev/google/electra_small/2',\n","    'electra_base':\n","        'https://tfhub.dev/google/electra_base/2',\n","    'experts_pubmed':\n","        'https://tfhub.dev/google/experts/bert/pubmed/2',\n","    'experts_wiki_books':\n","        'https://tfhub.dev/google/experts/bert/wiki_books/2',\n","    'talking-heads_base':\n","        'https://tfhub.dev/tensorflow/talkheads_ggelu_bert_en_base/1',\n","}\n","\n","map_model_to_preprocess = {\n","    'bert_en_uncased_L-12_H-768_A-12':\n","        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',\n","    'bert_en_cased_L-12_H-768_A-12':\n","        'https://tfhub.dev/tensorflow/bert_en_cased_preprocess/3',\n","    'small_bert/bert_en_uncased_L-2_H-128_A-2':\n","        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',\n","    'small_bert/bert_en_uncased_L-2_H-256_A-4':\n","        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',\n","    'small_bert/bert_en_uncased_L-2_H-512_A-8':\n","        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',\n","    'small_bert/bert_en_uncased_L-2_H-768_A-12':\n","        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',\n","    'small_bert/bert_en_uncased_L-4_H-128_A-2':\n","        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',\n","    'small_bert/bert_en_uncased_L-4_H-256_A-4':\n","        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',\n","    'small_bert/bert_en_uncased_L-4_H-512_A-8':\n","        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',\n","    'small_bert/bert_en_uncased_L-4_H-768_A-12':\n","        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',\n","    'small_bert/bert_en_uncased_L-6_H-128_A-2':\n","        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',\n","    'small_bert/bert_en_uncased_L-6_H-256_A-4':\n","        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',\n","    'small_bert/bert_en_uncased_L-6_H-512_A-8':\n","        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',\n","    'small_bert/bert_en_uncased_L-6_H-768_A-12':\n","        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',\n","    'small_bert/bert_en_uncased_L-8_H-128_A-2':\n","        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',\n","    'small_bert/bert_en_uncased_L-8_H-256_A-4':\n","        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',\n","    'small_bert/bert_en_uncased_L-8_H-512_A-8':\n","        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',\n","    'small_bert/bert_en_uncased_L-8_H-768_A-12':\n","        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',\n","    'small_bert/bert_en_uncased_L-10_H-128_A-2':\n","        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',\n","    'small_bert/bert_en_uncased_L-10_H-256_A-4':\n","        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',\n","    'small_bert/bert_en_uncased_L-10_H-512_A-8':\n","        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',\n","    'small_bert/bert_en_uncased_L-10_H-768_A-12':\n","        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',\n","    'small_bert/bert_en_uncased_L-12_H-128_A-2':\n","        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',\n","    'small_bert/bert_en_uncased_L-12_H-256_A-4':\n","        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',\n","    'small_bert/bert_en_uncased_L-12_H-512_A-8':\n","        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',\n","    'small_bert/bert_en_uncased_L-12_H-768_A-12':\n","        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',\n","    'bert_multi_cased_L-12_H-768_A-12':\n","        'https://tfhub.dev/tensorflow/bert_multi_cased_preprocess/3',\n","    'albert_en_base':\n","        'https://tfhub.dev/tensorflow/albert_en_preprocess/3',\n","    'electra_small':\n","        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',\n","    'electra_base':\n","        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',\n","    'experts_pubmed':\n","        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',\n","    'experts_wiki_books':\n","        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',\n","    'talking-heads_base':\n","        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',\n","}\n","\n","tfhub_handle_encoder = map_name_to_handle[bert_model_name]\n","tfhub_handle_preprocess = map_model_to_preprocess[bert_model_name]\n","\n","print(f'BERT model selected           : {tfhub_handle_encoder}')\n","print(f'Preprocess model auto-selected: {tfhub_handle_preprocess}')"],"metadata":{"cellView":"form","id":"y1cq9UrjWIf3"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["bert_preprocess_model = hub.KerasLayer(tfhub_handle_preprocess)"],"metadata":{"id":"m3bL8Z88WcpF"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["bert_model = hub.KerasLayer(tfhub_handle_encoder)"],"metadata":{"id":"As6ARz1jWqd_"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def build_classifier_model():\n","  text_input = tf.keras.layers.Input(shape=(), dtype=tf.string, name='text')\n","  preprocessing_layer = hub.KerasLayer(tfhub_handle_preprocess, name='preprocessing')\n","  encoder_inputs = preprocessing_layer(text_input)\n","  encoder = hub.KerasLayer(tfhub_handle_encoder, trainable=True, name='BERT_encoder')\n","  outputs = encoder(encoder_inputs)\n","  net = outputs['pooled_output']\n","  net = tf.keras.layers.Dropout(0.1)(net)\n","  net = tf.keras.layers.Dense(4, activation=tf.keras.activations.softmax, name='classifier')(net)\n","  return tf.keras.Model(text_input, net)"],"metadata":{"id":"y1d9QXzbW-tt"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["classifier_model = build_classifier_model()"],"metadata":{"id":"gICdFdlkXaql"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["tf.keras.utils.plot_model(classifier_model)"],"metadata":{"id":"ioPyykG1XhNd"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["loss = tf.keras.losses.sparse_categorical_crossentropy\n","metrics = [\"accuracy\"]"],"metadata":{"id":"_wdq8pL-XwHm"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["epochs = 30\n","steps_per_epoch = tf.data.experimental.cardinality(train_ds).numpy()\n","num_train_steps = steps_per_epoch * epochs\n","num_warmup_steps = int(0.1*num_train_steps)\n","\n","init_lr = 3e-5\n","optimizer = optimization.create_optimizer(init_lr=init_lr,\n","                                          num_train_steps=num_train_steps,\n","                                          num_warmup_steps=num_warmup_steps,\n","                                          optimizer_type='adamw')"],"metadata":{"id":"ldH7CvpUYDdU"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["classifier_model.compile(optimizer=optimizer,\n","                         loss=loss,\n","                         metrics=metrics)"],"metadata":{"id":"bF6aded4YH4N"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["print(f'Training model with {tfhub_handle_encoder}')\n","history = classifier_model.fit(x=train_ds,\n","                               validation_data=val_ds,\n","                               epochs=epochs)"],"metadata":{"id":"L3Xubi1XYKQD"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["dataset_name = 'faunta'\n","saved_model_path = './{}_bert'.format(dataset_name.replace('/', '_'))\n","\n","classifier_model.save(saved_model_path, include_optimizer=False)"],"metadata":{"id":"OGieaexvYNGM"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["#Training BERT on Dwmw dataset"],"metadata":{"id":"TyBGXW7LGLz6"}},{"cell_type":"code","source":["data = pd.read_csv(\"/content/Dwmw17.csv\")\n","cur = os.getcwd()\n","train = os.path.join(cur,\"train_dwmw\")\n","test = os.path.join(cur,\"test_dwmw\")\n","os.mkdir(train)\n","os.mkdir(test)\n","labels = ['hate_speech', 'offensive_language', 'neither']\n","\n","for lab in labels:\n","    dir_train = os.path.join(train, lab)\n","    os.mkdir(dir_train)\n","    dir_test = os.path.join(test, lab)\n","    os.mkdir(dir_test)\n","\n","train_ = int(len(data)*0.75)\n","\n","tweet_type=['hate_speech', 'offensive_language', 'neither']\n","dic = {\"hate_speech\":0, \"offensive_language\":0, \"neither\":0}\n","for i in range(train_,):\n","    tweet_label = tweet_type[data[\"class\"].iloc[i]]\n","    dic[tweet_label] += 1\n","    #print(dic)\n","    tweet = data[\"tweet\"].iloc[i]\n","    with open(os.path.join(train, tweet_label) + \"/\" + tweet_label + \"_\" + str(dic[tweet_label])+\".txt\", \"w\", encoding=\"utf-8\") as f:\n","        f.write(tweet)\n","dic = {\"hate_speech\":0, \"offensive_language\":0, \"neither\":0}\n","for i in range(train_,len(data)):\n","    tweet_label = tweet_type[data[\"class\"].iloc[i]]\n","    dic[tweet_label]+=1\n","    tweet = data[\"tweet\"].iloc[i]\n","    with open(os.path.join(test, tweet_label) + \"/\" + tweet_label + \"_\" + str(dic[tweet_label])+\".txt\", \"w\", encoding=\"utf-8\") as f:\n","        f.write(tweet)"],"metadata":{"id":"VraPBCCryNVK"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["AUTOTUNE = tf.data.AUTOTUNE\n","batch_size = 32\n","seed = 42\n","\n","raw_train_ds = tf.keras.utils.text_dataset_from_directory(\n","    '/content/train_dwmw',\n","    batch_size=batch_size,\n","    validation_split=0.2,\n","    subset='training',\n","    seed=seed)\n","\n","class_names = raw_train_ds.class_names\n","print(class_names)\n","train_ds = raw_train_ds.cache().prefetch(buffer_size=AUTOTUNE)\n","\n","val_ds = tf.keras.utils.text_dataset_from_directory(\n","    '/content/train_dwmw',\n","    batch_size=batch_size,\n","    validation_split=0.2,\n","    subset='validation',\n","    seed=seed)\n","\n","val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)\n","\n","test_ds = tf.keras.utils.text_dataset_from_directory(\n","    '/content/test_dwmw',\n","    batch_size=batch_size)\n","\n","test_ds = test_ds.cache().prefetch(buffer_size=AUTOTUNE)"],"metadata":{"id":"klnvfpIsHIcc"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["tfhub_handle_encoder = \"https://tfhub.dev/tensorflow/bert_multi_cased_L-12_H-768_A-12/3\"\n","tfhub_handle_preprocess = \"https://tfhub.dev/tensorflow/bert_multi_cased_preprocess/3\""],"metadata":{"id":"udBIruIQJLDc"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["bert_preprocess_model = hub.KerasLayer(tfhub_handle_preprocess)\n","bert_model = hub.KerasLayer(tfhub_handle_encoder)"],"metadata":{"id":"dV3_dhE9JR7x"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def build_classifier_model():\n","  text_input = tf.keras.layers.Input(shape=(), dtype=tf.string, name='text')\n","  preprocessing_layer = hub.KerasLayer(tfhub_handle_preprocess, name='preprocessing')\n","  encoder_inputs = preprocessing_layer(text_input)\n","  encoder = hub.KerasLayer(tfhub_handle_encoder, trainable=True, name='BERT_encoder')\n","  outputs = encoder(encoder_inputs)\n","  net = outputs['pooled_output']\n","  net = tf.keras.layers.Dropout(0.1)(net)\n","  net = tf.keras.layers.Dense(3, activation=tf.keras.activations.softmax, name='classifier')(net)\n","  return tf.keras.Model(text_input, net)"],"metadata":{"id":"FaXlTw51JckN"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["classifier_model_2 = build_classifier_model()\n","loss = tf.keras.losses.sparse_categorical_crossentropy\n","metrics = [\"accuracy\"]"],"metadata":{"id":"HYAbuBrSJhKt"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["epochs = 30\n","steps_per_epoch = tf.data.experimental.cardinality(train_ds).numpy()\n","num_train_steps = steps_per_epoch * epochs\n","num_warmup_steps = int(0.1*num_train_steps)\n","\n","init_lr = 3e-5\n","optimizer = optimization.create_optimizer(init_lr=init_lr,\n","                                          num_train_steps=num_train_steps,\n","                                          num_warmup_steps=num_warmup_steps,\n","                                          optimizer_type='adamw')"],"metadata":{"id":"C_c4KvINJn51"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["classifier_model_2.compile(optimizer=optimizer,\n","                         loss=loss,\n","                         metrics=metrics)"],"metadata":{"id":"JSzHutlOJtig"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["print(f'Training model with {tfhub_handle_encoder}')\n","history = classifier_model_2.fit(x=train_ds,\n","                               validation_data=val_ds,\n","                               epochs=epochs)"],"metadata":{"id":"bN6Ie0RiJzm9"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["dataset_name = 'davidson'\n","saved_model_path = './{}_bert'.format(dataset_name.replace('/', '_'))\n","\n","classifier_model_2.save(saved_model_path, include_optimizer=False)"],"metadata":{"id":"B1VXpZoPJ0U3"},"execution_count":null,"outputs":[]}]}