{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c100e354",
   "metadata": {},
   "outputs": [],
   "source": [
    "def set_latex():\n",
    "    for i in range(2):\n",
    "        import matplotlib\n",
    "        import matplotlib.pyplot as plt\n",
    "\n",
    "        plt.rc('text', usetex=True)\n",
    "        plt.rc('font', family='serif')\n",
    "\n",
    "        plt.style.use(\"default\")\n",
    "        plt.rcParams[\"font.size\"]=15\n",
    "\n",
    "        plt.rcParams['font.family'] = 'Times New Roman'\n",
    "        plt.rcParams['mathtext.fontset'] = 'stix'\n",
    "\n",
    "        try:\n",
    "            del matplotlib.font_manager.weight_dict['roman']\n",
    "            matplotlib.font_manager._rebuild()\n",
    "        except:\n",
    "            pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "09644e57",
   "metadata": {},
   "outputs": [],
   "source": [
    "import itertools\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "from tqdm.notebook import tqdm\n",
    "\n",
    "import math\n",
    "import matplotlib.cm as cm\n",
    "import os\n",
    "from typing import Dict, Tuple, List\n",
    "import pickle\n",
    "import warnings\n",
    "\n",
    "from MKLpy.algorithms import EasyMKL\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da0b171b",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.style.use(\"default\")\n",
    "plt.rcParams[\"font.size\"]=15"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a950719d",
   "metadata": {},
   "outputs": [],
   "source": [
    "set_latex()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "889274f2",
   "metadata": {},
   "source": [
    "## Load dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "79cd286f",
   "metadata": {},
   "source": [
    "For downloading dataset, see https://github.com/LeoYu/neural-tangent-kernel-UCI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4dfe4f30",
   "metadata": {},
   "outputs": [],
   "source": [
    "DATA_DIR = os.path.join(\"./data/\")\n",
    "\n",
    "def get_datasize(dic: Dict) -> Tuple[int, int, int, int]:\n",
    "    c = int(dic[\"n_clases=\"])\n",
    "    d = int(dic[\"n_entradas=\"])\n",
    "    n_train_val = int(dic[\"n_patrons1=\"])\n",
    "    if \"n_patrons2=\" in dic:\n",
    "        n_test = int(dic[\"n_patrons2=\"])\n",
    "    else:\n",
    "        n_test = 0\n",
    "    n_tot = n_train_val + n_test\n",
    "    return n_tot, n_train_val, n_test, d,  c\n",
    "\n",
    "\n",
    "def load_data(dic: Dict) -> Tuple[np.array, np.array]:\n",
    "    f = open(os.path.join(DATA_DIR, dic[\"dataset\"], dic[\"fich1=\"]), \"r\").readlines()[1:]\n",
    "    X = np.asarray(list(map(lambda x: list(map(float, x.split()[1:-1])), f)))\n",
    "    y = np.asarray(list(map(lambda x: int(x.split()[-1]), f)))\n",
    "    return X, y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c58c2916",
   "metadata": {},
   "outputs": [],
   "source": [
    "MAX_TOT = 1000\n",
    "MAX_FEATURES = 10\n",
    "MAX_CLASSES = 2\n",
    "\n",
    "datasets = []\n",
    "\n",
    "n_dataset = 0\n",
    "for idx, dataset in enumerate(sorted(os.listdir(DATA_DIR))): \n",
    "    if not os.path.isfile(os.path.join(DATA_DIR, dataset, f\"{dataset}.txt\")):\n",
    "        continue\n",
    "        \n",
    "    if dataset==\"tic-tac-toe\": # remove tic-tac-toe\n",
    "        continue\n",
    "        \n",
    "    # load configuration\n",
    "    dic = dict()\n",
    "    dic[\"dataset\"] = dataset\n",
    "\n",
    "    for k, v in map(\n",
    "        lambda x: x.split(),\n",
    "        open(os.path.join(DATA_DIR, dataset, f\"{dataset}.txt\"), \"r\").readlines(),\n",
    "    ):\n",
    "        dic[k] = v\n",
    "\n",
    "    # Check skip or not\n",
    "    n_tot, n_train_val, n_test, n_feature, n_class = get_datasize(dic)\n",
    "    if (n_tot > MAX_TOT) or (n_test > 0) or (n_feature >  MAX_FEATURES) or (n_class > MAX_CLASSES):\n",
    "        continue\n",
    "    else:\n",
    "        print(f\"-----{idx}, {dataset}, {n_tot}, {n_feature}, {n_class}-----\")\n",
    "        n_dataset += 1\n",
    "\n",
    "    # load dataset\n",
    "    X, y = load_data(dic)\n",
    "    fold = list(\n",
    "        map(\n",
    "            lambda x: list(map(int, x.split())),\n",
    "            open(\n",
    "                os.path.join(DATA_DIR, dic[\"dataset\"], \"conxuntos_kfold.dat\"), \"r\"\n",
    "            ).readlines(),\n",
    "        )\n",
    "    )\n",
    "    datasets.append({\"X\": X, \"y\": y, \"name\": dic[\"dataset\"], \"fold\": fold})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c13b3fc9",
   "metadata": {},
   "source": [
    "## Kernels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "788c7c13",
   "metadata": {},
   "outputs": [],
   "source": [
    "def calc_tau(alpha: float, S: np.array, diag_i: np.array, diag_j: np.array) -> np.array:\n",
    "    tau = 1 / 4 + 1 / (2 * math.pi) * np.arcsin(\n",
    "        ((alpha ** 2) * S)\n",
    "        / (np.sqrt(((alpha ** 2) * diag_i + 0.5) * ((alpha ** 2) * diag_j + 0.5)))\n",
    "    )\n",
    "    return tau\n",
    "\n",
    "\n",
    "def calc_tau_dot(\n",
    "    alpha: float, S: np.array, diag_i: np.array, diag_j: np.array\n",
    ") -> np.array:\n",
    "    tau_dot = (\n",
    "        (alpha ** 2)\n",
    "        / (math.pi)\n",
    "        * 1\n",
    "        / np.sqrt(\n",
    "            (2 * (alpha ** 2) * diag_i + 1) * (2 * (alpha ** 2) * diag_j + 1)\n",
    "            - (4 * (alpha ** 4) * (S ** 2))\n",
    "        )\n",
    "    )\n",
    "    return tau_dot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b1672f7b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def hard_kernel(X: np.array, alpha: float, beta: float, finetune: bool, rulelist: list):    \n",
    "    S_list = []\n",
    "    tau_list = []\n",
    "    tau_dot_list = []\n",
    "\n",
    "    for feature_index in range(len(X[0])):\n",
    "        S = np.outer(X[:, feature_index], X[:, feature_index].T) + beta**2\n",
    "        S_all = np.matmul(X, X.T) + beta**2\n",
    "        if finetune:\n",
    "            S_list.append(S_all)\n",
    "        else:\n",
    "            S_list.append(S)\n",
    "\n",
    "        _diag = [S[i, i] for i in range(len(S))]\n",
    "        diag_i = np.array(_diag * len(_diag)).reshape(len(_diag), len(_diag))\n",
    "        diag_j = diag_i.transpose()\n",
    "        tau_list.append(calc_tau(alpha, S, diag_i, diag_j))\n",
    "        tau_dot_list.append(calc_tau_dot(alpha, S, diag_i, diag_j))\n",
    "        \n",
    "    K = np.zeros((X.shape[0], X.shape[0]))\n",
    "    \n",
    "    H = np.zeros_like(S_list[0])\n",
    "    for rules in tqdm(rulelist, leave=False):\n",
    "        # Internal nodes\n",
    "        for i, s in enumerate(rules):\n",
    "            ts = rules[0:i]+rules[i+1:]\n",
    "            _H_nodes = S_list[s]* tau_dot_list[s]\n",
    "            for t in ts:\n",
    "                _H_nodes *= tau_list[t]\n",
    "            K+= _H_nodes * (2**len(rules))\n",
    "        _H_leaves = np.ones_like(K)\n",
    "        \n",
    "        # Leaves\n",
    "        for tau in [tau_list[i] for i in rules]:\n",
    "            _H_leaves *= tau\n",
    "        K += _H_leaves * (2**len(rules))\n",
    "    \n",
    "    return K/len(rulelist) # normalize "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ffedfdda",
   "metadata": {},
   "source": [
    "## Experiment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cef60394",
   "metadata": {},
   "outputs": [],
   "source": [
    "def extract_kernels(X, alpha, beta, degree):\n",
    "    assert degree in (1, 2, 3)\n",
    "    patterns = list(itertools.combinations(np.arange(X.shape[1]), 1))\n",
    "\n",
    "    if degree>=2:\n",
    "        patterns.extend(list(itertools.combinations(np.arange(X.shape[1]), 2)))\n",
    "        \n",
    "    if degree>=3:\n",
    "        patterns.extend(list(itertools.combinations(np.arange(X.shape[1]), 3)))\n",
    "\n",
    "    patterns = [list(l) for l in patterns]\n",
    "    patterns = [[pattern] for pattern in patterns]\n",
    "    \n",
    "    kernels_aaa = []\n",
    "    kernels_aai = []\n",
    "\n",
    "    for pattern in tqdm(patterns, leave=False):\n",
    "        kernels_aaa.append(hard_kernel(X, alpha=alpha, beta=beta, finetune=False, rulelist=pattern))\n",
    "        kernels_aai.append(hard_kernel(X, alpha=alpha, beta=beta, finetune=True, rulelist=pattern))  \n",
    "        \n",
    "    return kernels_aaa, kernels_aai, patterns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bababb91",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_weights(ker_matrix_aaa, ker_matrix_aai, patterns, name):\n",
    "    x = range(len(ker_matrix_aaa.weights))\n",
    "    plt.bar(x, ker_matrix_aaa.weights, alpha=0.5, label=\"AAA\")\n",
    "    plt.bar(x, ker_matrix_aai.weights, alpha=0.5, label=\"AAI\")\n",
    "    plt.xticks(\n",
    "        x,\n",
    "        [str(sorted(set(i[0]))).replace(\"[\", \"{\").replace(\"]\", \"}\") for i in patterns],\n",
    "        rotation=75,\n",
    "        fontsize=10\n",
    "    )\n",
    "\n",
    "    plt.grid(linestyle=\"dotted\")\n",
    "    plt.ylabel(\"Weight\")\n",
    "    x = range(len(ker_matrix_aai.weights))\n",
    "    plt.grid(linestyle=\"dotted\")\n",
    "    plt.ylabel(\"Weight\")\n",
    "    plt.title(name)\n",
    "    plt.legend(loc='upper left')\n",
    "    plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e99b9f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def kl_divergence(p, q):\n",
    "    return np.sum(p * np.log(p / q))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0cdba73",
   "metadata": {},
   "outputs": [],
   "source": [
    "alpha = 2.0\n",
    "beta = 0.5\n",
    "degree = 2\n",
    "\n",
    "kl_div_aaa = []\n",
    "kl_div_aai = []\n",
    "\n",
    "plt.figure(figsize=(20, 25))\n",
    "for i, data in enumerate(datasets):\n",
    "    X = data[\"X\"]\n",
    "    y = data[\"y\"]\n",
    "    name = data[\"name\"]\n",
    "\n",
    "    kernels_aaa, kernels_aai, patterns = extract_kernels(X, alpha=alpha, beta=beta, degree=degree)\n",
    "\n",
    "    mkl = EasyMKL()\n",
    "    ker_matrix_aaa = mkl.combine_kernels(kernels_aaa, y)\n",
    "    ker_matrix_aai = mkl.combine_kernels(kernels_aai, y)\n",
    "\n",
    "    kl_div_aaa.append(kl_divergence(np.array(ker_matrix_aaa.weights), np.ones_like(np.array(ker_matrix_aaa.weights))/len(np.array(ker_matrix_aaa.weights))))\n",
    "    kl_div_aai.append(kl_divergence(np.array(ker_matrix_aai.weights), np.ones_like(np.array(ker_matrix_aai.weights))/len(np.array(ker_matrix_aai.weights))))\n",
    "    \n",
    "    plt.subplot(14,1, i+1)\n",
    "    plot_weights(ker_matrix_aaa, ker_matrix_aai, patterns, name)\n",
    "plt.savefig(f\"./figures/mkl_weight.pdf\", bbox_inches=\"tight\", pad_inches=0.10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3158e5e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(6,4))\n",
    "plt.scatter(kl_div_aaa, kl_div_aai)\n",
    "plt.plot([0, 0.8], [0, 0.8], linestyle=\"dotted\", color=\"black\")\n",
    "plt.grid(linestyle=\"dotted\")\n",
    "plt.title(\"Kullback–Leibler divergence between obtained weights by MKL and uniform weights\")\n",
    "plt.xlabel(\"AAA\")\n",
    "plt.ylabel(\"AAI\")\n",
    "plt.savefig(f\"./figures/mkl_kldiv.pdf\", bbox_inches=\"tight\", pad_inches=0.10)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bafadb2d",
   "metadata": {},
   "source": [
    "## Generalization Performance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c016ede",
   "metadata": {},
   "outputs": [],
   "source": [
    "def soft_kernel(X: np.array, depth: int, alpha: float, beta: float):\n",
    "    K = np.zeros((depth, X.shape[0], X.shape[0]))\n",
    "    S = np.matmul(X, X.T) + beta**2\n",
    "    _diag = [S[i, i] for i in range(len(S))]\n",
    "    diag_i = np.array(_diag * len(_diag)).reshape(len(_diag), len(_diag))\n",
    "    diag_j = diag_i.transpose()\n",
    "\n",
    "    tau = calc_tau(alpha, S, diag_i, diag_j)\n",
    "    tau_dot = calc_tau_dot(alpha, S, diag_i, diag_j)\n",
    "\n",
    "    for i, depth in enumerate((range(1, depth + 1, 1))):\n",
    "        H = (2 * S * (2 ** (depth - 1)) * depth * tau_dot * tau ** (depth - 1)) + (\n",
    "            (2 ** depth) * (tau ** depth)\n",
    "        )\n",
    "        K[depth - 1] = H\n",
    "\n",
    "    return K[::-1][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "323e5f00",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.svm import SVC\n",
    "from sklearn.metrics import accuracy_score, roc_auc_score\n",
    "\n",
    "def svm(kernels, y, weights, reg, train_index, test_index):\n",
    "    model= SVC(kernel=\"precomputed\", C=1.0, probability=True)\n",
    "\n",
    "    K = np.zeros_like(kernels[0])\n",
    "    for j in range(len(weights)):\n",
    "        K+=kernels[j]*weights[j]\n",
    "    \n",
    "    K_train= K[train_index][:, train_index]\n",
    "    K_test = K[test_index][:, train_index]\n",
    "\n",
    "    y_train = y[train_index]\n",
    "    y_test = y[test_index]\n",
    "\n",
    "    model.fit(K_train, y_train)\n",
    "    test_pred = model.predict(K_test)\n",
    "    test_pred_proba = model.predict_proba(K_test)[:, 1]\n",
    "    \n",
    "    accuracy = accuracy_score(y_test, test_pred)\n",
    "    auc = roc_auc_score(y_test, test_pred_proba)\n",
    "    \n",
    "    return accuracy, auc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39e3432a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier\n",
    "\n",
    "def rf_benchmark(X: np.array, y: np.array, train_index: list, test_index: list, max_depth: int, n_estimators: int) -> Tuple[float, List[float]]:\n",
    "    model = RandomForestClassifier(max_depth=max_depth, n_estimators=n_estimators)\n",
    "    model.fit(X[train_index], y[train_index])\n",
    "    test_pred = model.predict(X[test_index])\n",
    "    test_pred_proba = model.predict_proba(X[test_index])[:, 1]\n",
    "    \n",
    "    accuracy = accuracy_score(y[test_index], test_pred)\n",
    "    auc = roc_auc_score(y[test_index], test_pred_proba)\n",
    "    \n",
    "    return accuracy, auc\n",
    "\n",
    "def gbdt_benchmark(X: np.array, y: np.array, train_index: list, test_index: list, max_depth: int, n_estimators: int) -> Tuple[float, List[float]]:\n",
    "    model = GradientBoostingClassifier(max_depth=max_depth, n_estimators=n_estimators)\n",
    "    model.fit(X[train_index], y[train_index])\n",
    "    test_pred = model.predict(X[test_index])\n",
    "    test_pred_proba = model.predict_proba(X[test_index])[:, 1]\n",
    "    \n",
    "    accuracy = accuracy_score(y[test_index], test_pred)\n",
    "    auc = roc_auc_score(y[test_index], test_pred_proba)\n",
    "    \n",
    "    return accuracy, auc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2bd64f67",
   "metadata": {},
   "outputs": [],
   "source": [
    "def benchmark(kernels_aaa, kernels_aai,kernel_soft,  y, train_index, test_index, alpha, beta, repeat):\n",
    "    acc_dict, auc_dict = {}, {}\n",
    "    \n",
    "    acc_dict[\"alpha\"] = alpha\n",
    "    acc_dict[\"beta\"] = beta\n",
    "    acc_dict[\"repeat\"] = repeat\n",
    "    auc_dict[\"alpha\"] = alpha\n",
    "    auc_dict[\"beta\"] = beta\n",
    "    auc_dict[\"repeat\"] = repeat\n",
    "    \n",
    "    # AAA\n",
    "    acc_dict[\"aaa_mkl\"], auc_dict[\"aaa_mkl\"] = svm(kernels_aaa, y, np.array(ker_matrix_aaa.weights), 1.0, train_index, test_index)\n",
    "    acc_dict[\"aaa_benchmark\"], auc_dict[\"aaa_benchmark\"] = svm(kernels_aaa, y, np.ones_like(ker_matrix_aaa.weights)/len(ker_matrix_aaa.weights), 1.0, train_index, test_index)\n",
    "\n",
    "    # AAI\n",
    "    acc_dict[\"aai_mkl\"], auc_dict[\"aai_mkl\"] = svm(kernels_aai, y, np.array(ker_matrix_aai.weights), 1.0, train_index, test_index)\n",
    "    acc_dict[\"aai_benchmark\"], auc_dict[\"aai_benchmark\"] = svm(kernels_aai, y, np.ones_like(ker_matrix_aai.weights)/len(ker_matrix_aaa.weights), 1.0, train_index, test_index)\n",
    "\n",
    "    # Soft\n",
    "    acc_dict[\"soft\"], auc_dict[\"soft\"] = svm([kernel_soft] * len(kernels_aaa), y, np.ones_like(ker_matrix_aaa.weights)/len(ker_matrix_aaa.weights), 1.0, train_index, test_index)\n",
    "\n",
    "    # RF\n",
    "    acc_dict[\"rf3\"], auc_dict[\"rf3\"] = rf_benchmark(X, y, train_index, test_index, max_depth=3, n_estimators=1000)\n",
    "    acc_dict[\"rf5\"], auc_dict[\"rf5\"] = rf_benchmark(X, y, train_index, test_index, max_depth=5, n_estimators=1000)\n",
    "    acc_dict[\"rf7\"], auc_dict[\"rf7\"] = rf_benchmark(X, y, train_index, test_index, max_depth=7, n_estimators=1000)\n",
    "    acc_dict[\"rfmax\"], auc_dict[\"rfmax\"] = rf_benchmark(X, y, train_index, test_index, max_depth=None, n_estimators=1000)\n",
    "\n",
    "    # GBDT\n",
    "    acc_dict[\"gbdt3\"], auc_dict[\"gbdt3\"] = rf_benchmark(X, y, train_index, test_index, max_depth=3, n_estimators=1000)\n",
    "    acc_dict[\"gbdt5\"], auc_dict[\"gbdt5\"] = rf_benchmark(X, y, train_index, test_index, max_depth=5, n_estimators=1000)\n",
    "    acc_dict[\"gbdt7\"], auc_dict[\"gbdt7\"] = rf_benchmark(X, y, train_index, test_index, max_depth=7, n_estimators=1000)\n",
    "    acc_dict[\"gbdtmax\"], auc_dict[\"gbdtmax\"] = rf_benchmark(X, y, train_index, test_index, max_depth=None, n_estimators=1000)\n",
    "\n",
    "    return acc_dict, auc_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ac3e061",
   "metadata": {},
   "outputs": [],
   "source": [
    "degree = 2\n",
    "\n",
    "acc_dicts_multiple, auc_dicts_multiple = [], []\n",
    "\n",
    "if False:\n",
    "    for i, data in enumerate(tqdm(datasets, leave=False)):\n",
    "        X = data[\"X\"]\n",
    "        y = data[\"y\"]\n",
    "        fold = data[\"fold\"]\n",
    "        acc_dicts, auc_dicts = [], []\n",
    "\n",
    "        for alpha in tqdm([0.5, 1.0, 2.0, 4.0], leave=False):\n",
    "            for beta in tqdm([0.1, 0.5, 1.0], leave=False):\n",
    "                kernel_soft = soft_kernel(X, depth=degree, alpha=alpha, beta=beta)\n",
    "                kernels_aaa, kernels_aai, patterns = extract_kernels(X, alpha=alpha, beta=beta, degree=degree)\n",
    "\n",
    "                for repeat in tqdm(range(4), leave=False):\n",
    "                    test_index, train_index = fold[repeat * 2], fold[repeat * 2 + 1]\n",
    "                    assert len(test_index) > len(train_index)\n",
    "                    mkl = EasyMKL()\n",
    "\n",
    "                    train_kernels_aaa = [i[train_index][:, train_index] for i in kernels_aaa]\n",
    "                    train_kernels_aai = [i[train_index][:, train_index] for i in kernels_aai]\n",
    "                    ker_matrix_aaa = mkl.combine_kernels(train_kernels_aaa, y[train_index])\n",
    "                    ker_matrix_aai = mkl.combine_kernels(train_kernels_aai, y[train_index])                    \n",
    "\n",
    "                    acc_dict, auc_dict = benchmark(kernels_aaa, kernels_aai, kernel_soft, y, train_index, test_index, alpha, beta, repeat)\n",
    "\n",
    "                    acc_dicts.append(acc_dict)\n",
    "                    auc_dicts.append(auc_dict)\n",
    "        acc_dicts_multiple.append(acc_dicts)\n",
    "        auc_dicts_multiple.append(auc_dicts)\n",
    "\n",
    "    with open('acc_dicts_multiple.pkl', 'wb') as file:\n",
    "        pickle.dump(acc_dicts_multiple, file)\n",
    "    with open('auc_dicts_multiple.pkl', 'wb') as file:\n",
    "        pickle.dump(auc_dicts_multiple, file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "544b387a",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('acc_dicts_multiple.pkl', 'rb') as file:\n",
    "    acc_dicts_multiple = pickle.load(file)\n",
    "with open('auc_dicts_multiple.pkl', 'rb') as file:\n",
    "    auc_dicts_multiple = pickle.load(file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "013fd17a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "186a9ab9",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "def plot_result(beta):\n",
    "    plt.figure(figsize=(19,6))\n",
    "    for i in range(len(acc_dicts_multiple)):\n",
    "        plt.subplot(2,7,i+1)\n",
    "        df = pd.DataFrame(acc_dicts_multiple[i])\n",
    "\n",
    "        _df = df[df[\"beta\"]==beta].groupby(by=[\"alpha\", \"beta\"]).mean()[\n",
    "            [\"aaa_mkl\", \"aaa_benchmark\", \"aai_mkl\", \"aai_benchmark\", \"soft\", \"rf3\", \"rf5\", \"rf7\", \"gbdt3\", \"gbdt5\", \"gbdt7\"]\n",
    "        ].reset_index()\n",
    "\n",
    "        x = range(4)\n",
    "\n",
    "        _df[\"aaa_mkl\"].plot(label=\"AAA (MKL)\", color=\"red\", linestyle=\"solid\", marker=\"o\")\n",
    "        _df[\"aaa_benchmark\"].plot(label=\"AAA (Benchmark)\", color=\"red\", linestyle=\"dotted\", marker=\"v\" )\n",
    "        _df[\"aai_mkl\"].plot(label=\"AAI (MKL)\", color=\"blue\", linestyle=\"solid\", marker=\"o\")\n",
    "        _df[\"aai_benchmark\"].plot(label=\"AAI (Benchmark)\", color=\"blue\", linestyle=\"dotted\", marker=\"v\")\n",
    "        _df[\"soft\"].plot(label=\"Oblique\", color=\"black\", marker=\"s\")\n",
    "\n",
    "        rf3_mean = df.groupby(by=[\"alpha\", \"beta\"]).mean().mean()[\"rf3\"]\n",
    "        rf3_std  = df.groupby(by=[\"alpha\", \"beta\"]).mean().std()[\"rf3\"]\n",
    "        rf5_mean = df.groupby(by=[\"alpha\", \"beta\"]).mean().mean()[\"rf5\"]\n",
    "        rf5_std  = df.groupby(by=[\"alpha\", \"beta\"]).mean().std()[\"rf5\"]\n",
    "        rf7_mean = df.groupby(by=[\"alpha\", \"beta\"]).mean().mean()[\"rf7\"]\n",
    "        rf7_std  = df.groupby(by=[\"alpha\", \"beta\"]).mean().std()[\"rf7\"]\n",
    "        rfmax_mean = df.groupby(by=[\"alpha\", \"beta\"]).mean().mean()[\"rfmax\"]\n",
    "        rfmax_std  = df.groupby(by=[\"alpha\", \"beta\"]).mean().std()[\"rfmax\"]\n",
    "\n",
    "        gbdt3_mean = df.groupby(by=[\"alpha\", \"beta\"]).mean().mean()[\"gbdt3\"]\n",
    "        gbdt3_std  = df.groupby(by=[\"alpha\", \"beta\"]).mean().std()[\"gbdt3\"]\n",
    "        gbdt5_mean = df.groupby(by=[\"alpha\", \"beta\"]).mean().mean()[\"gbdt5\"]\n",
    "        gbdt5_std  = df.groupby(by=[\"alpha\", \"beta\"]).mean().std()[\"gbdt5\"]\n",
    "        gbdt7_mean = df.groupby(by=[\"alpha\", \"beta\"]).mean().mean()[\"gbdt7\"]\n",
    "        gbdt7_std  = df.groupby(by=[\"alpha\", \"beta\"]).mean().std()[\"gbdt7\"]\n",
    "        gbdtmax_mean = df.groupby(by=[\"alpha\", \"beta\"]).mean().mean()[\"gbdtmax\"]\n",
    "        gbdtmax_std  = df.groupby(by=[\"alpha\", \"beta\"]).mean().std()[\"gbdtmax\"]\n",
    "\n",
    "        plt.plot(x, [rf3_mean]*len(x), color=\"green\", linestyle=\"dashdot\", alpha=0.7, label=\"Random Forest\", linewidth=1.0)\n",
    "        plt.fill_between(x, rf3_mean-rf3_std, rf3_mean+rf3_std, color='green', alpha=0.1)\n",
    "\n",
    "        # plt.plot(x, [gbdt3_mean]*len(x), color=\"orange\", linestyle=\"dashdot\", alpha=0.7, label=\"Gradient Boosting\", linewidth=1.0)\n",
    "        # plt.fill_between(x, gbdt3_mean-gbdt3_std, gbdt3_mean+gbdt3_std, color='green', alpha=0.1)\n",
    "\n",
    "        plt.title(datasets[i][\"name\"])\n",
    "        if i>=7:\n",
    "            plt.xlabel(\"$\\\\alpha$\")\n",
    "            plt.xticks([0, 1, 2, 3], [0.5, 1.0, 2.0, 4.0])\n",
    "        else:\n",
    "            plt.xticks([0, 1, 2, 3], [])\n",
    "        if i%7 ==0:\n",
    "            plt.ylabel(\"Accuracy\")\n",
    "        plt.grid(linestyle=\"dotted\")\n",
    "\n",
    "    plt.figlegend(labels=[\"AAA (MKL)\", \"AAA (Benchmark)\", \"AAI (MKL)\", \"AAI (Benchmark)\", \"Oblique\", \"Random Forest\"],\n",
    "                  loc=\"lower center\", \n",
    "                  ncol=6,\n",
    "                  bbox_to_anchor=(0.5, -0.05))\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.suptitle(f\"$\\\\beta$={beta}\", y=1.05)\n",
    "    plt.savefig(f\"./figures/multidata_metrics_{beta}.pdf\", bbox_inches=\"tight\", pad_inches=0.10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12a48e09",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_result(beta=0.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca5d10c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_result(beta=0.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b63c589",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_result(beta=1.0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "caaa6204",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "celltoolbar": "Tags",
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
