{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b25536c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from pathlib import Path\n",
    "import json\n",
    "import matplotlib.pyplot as plt\n",
    "import spacy\n",
    "from tqdm import tqdm\n",
    "spacy.prefer_gpu()\n",
    "\n",
    "root_dir = Path('/data/FOLDER/FOLDER/data/MM-IMDB/mmimdb')\n",
    "\n",
    "nlp = spacy.load(\"en_core_web_trf\")\n",
    "import en_core_web_trf\n",
    "nlp = en_core_web_trf.load()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b3eb8bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "def extract_nouns(x):\n",
    "    return [i.text.lower().strip() for i in x if i.pos_ == 'NOUN']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f6746a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "raw = []\n",
    "keep_keys = ['genres', 'title', 'plot', 'plot outline']\n",
    "for i in root_dir.glob('**/*.json'):\n",
    "    if i.name == 'split.json':\n",
    "        continue\n",
    "    data = json.load(i.open('r'))\n",
    "    raw.append({\n",
    "        k: data[k]\n",
    "        for k in keep_keys if k in data\n",
    "    })\n",
    "    raw[-1]['path'] = str(i.with_suffix('.jpeg'))\n",
    "    raw[-1]['id'] = i.name[:-5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f2d8504",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame(raw)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f9f1efb",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f772e77",
   "metadata": {},
   "outputs": [],
   "source": [
    "split = json.load((root_dir/'split.json').open('r'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5eb7a1b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "df['split'] = None\n",
    "df.loc[df.id.isin(split['train']), 'split'] = 'train'\n",
    "df.loc[df.id.isin(split['dev']), 'split'] = 'val'\n",
    "df.loc[df.id.isin(split['test']), 'split'] = 'test'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "282fd42c",
   "metadata": {},
   "outputs": [],
   "source": [
    "pd.isnull(df).sum(axis = 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb1e8b6a",
   "metadata": {},
   "outputs": [],
   "source": [
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2089842",
   "metadata": {},
   "outputs": [],
   "source": [
    "df['plot outline'] = df['plot outline'].fillna('')\n",
    "df['plot'] = df['plot'].apply(lambda x: '\\n'.join(x))\n",
    "df['all_text'] = df.apply(lambda x: x['plot outline'] + '\\n' + x['plot'], axis = 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1398240",
   "metadata": {},
   "outputs": [],
   "source": [
    "labels = np.unique([j for i in df['genres'].values for j in i])\n",
    "label_mapping = {i: c for c, i in enumerate(labels)}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6929f6c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "df['cat_labels'] = df['genres'].apply(lambda x: [label_mapping[i] for i in x])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64033473",
   "metadata": {},
   "outputs": [],
   "source": [
    "df['spacy_doc'] = [d for d in tqdm(nlp.pipe(df['plot outline'], n_process = 1), total = len(df))]\n",
    "df['nouns'] = df['spacy_doc'].apply(extract_nouns)\n",
    "noun_vocab = tuple(set([j for i in df['nouns'] for j in i]))\n",
    "noun_vocab_mapping = {i: c for c, i in enumerate(noun_vocab)}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d841ee3",
   "metadata": {},
   "outputs": [],
   "source": [
    "df['nouns_int'] = df['nouns'].apply(lambda x: [noun_vocab_mapping[i] for i in x])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec99b575",
   "metadata": {},
   "outputs": [],
   "source": [
    "df['sentence'] = df['plot outline']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24ae4706",
   "metadata": {},
   "outputs": [],
   "source": [
    "df.drop(columns = ['spacy_doc']).to_pickle(root_dir/'multimodal_mislabel_split.pkl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2cb33679",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
