{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Generalization\n",
    "\n",
    "In this notebook, we simply measure the generalization ability of our various models. We use mdoels trained with ASVspoof data and test their performance on FakeAVCeleb."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from azureml.fsspec import AzureMachineLearningFileSystem\n",
    "from transformers import AutoModelForAudioClassification, AutoFeatureExtractor\n",
    "import os\n",
    "from transformer import FakeAVCeleb\n",
    "import torch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## AST"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "root_dir = \"azureml:\"\n",
    "fs = AzureMachineLearningFileSystem(root_dir)\n",
    "model_path = \"checkpoint/\"\n",
    "feature_extractor = AutoFeatureExtractor.from_pretrained(\"MIT/ast-finetuned-audioset-10-10-0.4593\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## WAV2VEC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_path = \"/home/azureuser/checkpoint/\"\n",
    "feature_extractor = AutoFeatureExtractor.from_pretrained(\"facebook/wav2vec2-base\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_duration = 6"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def preprocess_function(examples):\n",
    "        audio_arrays = [x[\"array\"][:int(feature_extractor.sampling_rate * max_duration)] for x in examples[\"audio\"]]\n",
    "        inputs = feature_extractor(\n",
    "            audio_arrays, \n",
    "            sampling_rate=feature_extractor.sampling_rate, \n",
    "            max_length=None, \n",
    "            truncation=False, \n",
    "        )\n",
    "        return inputs    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformer import ASVSpoofDataset\n",
    "avdata = ASVSpoofDataset(max_size=10).load_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = FakeAVCeleb(max_size=1000).load_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "encoded_dataset = dataset.map(preprocess_function, remove_columns=[\"audio\", \"filename\"], batched=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "labels = dataset.features[\"label\"].names\n",
    "num_labels = len(labels)\n",
    "\n",
    "label2id, id2label = dict(), dict()\n",
    "label2id['C'] = '0'\n",
    "id2label['0'] = 'C'\n",
    "\n",
    "label2id['D'] = '1'\n",
    "id2label['1'] = 'D'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "av_label2id, av_id2label = dict(), dict()\n",
    "av_label2id['bonafide'] = '0'\n",
    "av_id2label['0'] = 'bonafide'\n",
    "\n",
    "av_label2id['spoof'] = '1'\n",
    "av_id2label['1'] = 'spoof'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "local_dir = model_path # \"../temp_model\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a local temporary directory to download the model\n",
    "os.makedirs(local_dir, exist_ok=True)\n",
    "\n",
    "# Download the model files\n",
    "for file in fs.ls(model_path, detail=False, recursive=True):\n",
    "    if fs.isfile(file):  # Only process files, not directories\n",
    "        file_name = os.path.basename(file)\n",
    "        fs.get(file, local_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "loaded_model = AutoModelForAudioClassification.from_pretrained(\n",
    "                local_dir, \n",
    "                num_labels=num_labels,\n",
    "                label2id=label2id,\n",
    "                id2label=id2label,\n",
    "                ignore_mismatched_sizes=True,\n",
    "            )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "\n",
    "y_true = []\n",
    "y_pred = []\n",
    "idxs = random.sample(range(1, 3000), 1000)\n",
    "\n",
    "for idx in idxs:\n",
    "    inputs = feature_extractor(\n",
    "            dataset[idx]['audio'][\"array\"], \n",
    "            sampling_rate=feature_extractor.sampling_rate, \n",
    "            return_tensors=\"pt\"\n",
    "        ) # .to(\"cuda:0\")\n",
    "\n",
    "\n",
    "    with torch.no_grad():\n",
    "        logits = loaded_model(**inputs).logits\n",
    "        predicted_class_ids = torch.argmax(logits).item()\n",
    "\n",
    "    y_true.append(1 - dataset[idx]['label'])\n",
    "    y_pred.append(predicted_class_ids)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import classification_report\n",
    "y_true = [1-y for y in y_true]\n",
    "print(classification_report(y_true, y_pred))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### GBDT Comparsion"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pickle import load\n",
    "import sklearn\n",
    "\n",
    "with open(\"booster_audio_len_6_max_depth_8_n_est_400.pkl\", \"rb\") as f:\n",
    "    booster = load(f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load FakeAVCeleb Data for Booster"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "sys.path.insert(0, '../') \n",
    "\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import librosa\n",
    "from azureml.fsspec import AzureMachineLearningFileSystem\n",
    "from xgbooster.generate import Features\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "\n",
    "slice_size = 6\n",
    "feature_generator = Features()\n",
    "train_only = True\n",
    "\n",
    "def load_fakeavceleb_data():\n",
    "        train_dir = \"azureml://\"\n",
    "        fs = AzureMachineLearningFileSystem(train_dir)\n",
    "\n",
    "        train_features = []\n",
    "        metadata_file = f\"{train_dir}/metadata.csv\"\n",
    "        metadata = pd.read_csv(metadata_file)\n",
    "        filenames = [f\"{train_dir}/{file}\" for file in metadata['new_filename'].to_list()]\n",
    "        labels = metadata['category'].to_list()\n",
    "\n",
    "        for file in filenames:\n",
    "            with fs.open(file, 'r') as f:\n",
    "\n",
    "                segment, sr = librosa.load(f)\n",
    "                if slice_size != None:\n",
    "                    segment = segment[:int(sr*slice_size)]\n",
    "\n",
    "                train_features.append(feature_generator.make_features(segment, sr))\n",
    "\n",
    "        X_train = np.array(train_features)\n",
    "        y_train = np.array(labels)\n",
    "        print(\"loaded train audio, y_train contains {} samples\".format(len(y_train)))\n",
    "\n",
    "        if not train_only:\n",
    "            X_train, X_test, y_train, y_test = train_test_split(np.array(X_train), np.array(y_train), test_size=0.33, random_state=0)\n",
    "        else:\n",
    "            print(\"skipping loading test audio\")\n",
    "        return X_train, y_train\n",
    "\n",
    "X_train, y_train = load_fakeavceleb_data()\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_pred = booster.predict(X_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trans = {'spoof':'D', 'bonafide':'C'}\n",
    "converted_y_pred = list(map(lambda x: trans[x], list(y_pred)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import classification_report\n",
    "\n",
    "print(classification_report(y_train, converted_y_pred))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py38",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
