{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c100e354",
   "metadata": {},
   "outputs": [],
   "source": [
    "def set_latex():\n",
    "    for i in range(2):\n",
    "        import matplotlib\n",
    "        import matplotlib.pyplot as plt\n",
    "\n",
    "        plt.rc('text', usetex=True)\n",
    "        plt.rc('font', family='serif')\n",
    "\n",
    "        plt.style.use(\"default\")\n",
    "        plt.rcParams[\"font.size\"]=15\n",
    "\n",
    "        plt.rcParams['font.family'] = 'Times New Roman'\n",
    "        plt.rcParams['mathtext.fontset'] = 'stix'\n",
    "\n",
    "        try:\n",
    "            del matplotlib.font_manager.weight_dict['roman']\n",
    "            matplotlib.font_manager._rebuild()\n",
    "        except:\n",
    "            pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "09644e57",
   "metadata": {},
   "outputs": [],
   "source": [
    "import itertools\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "from tqdm.notebook import tqdm\n",
    "\n",
    "import math\n",
    "import matplotlib.cm as cm\n",
    "import os\n",
    "from typing import Dict, Tuple, List\n",
    "import pickle\n",
    "import warnings\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da0b171b",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.style.use(\"default\")\n",
    "plt.rcParams[\"font.size\"]=15"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a950719d",
   "metadata": {},
   "outputs": [],
   "source": [
    "set_latex()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "889274f2",
   "metadata": {},
   "source": [
    "## Load dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "27bb4869",
   "metadata": {},
   "source": [
    "For downloading dataset, see https://github.com/LeoYu/neural-tangent-kernel-UCI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4dfe4f30",
   "metadata": {},
   "outputs": [],
   "source": [
    "DATA_DIR = os.path.join(\"./data/\")\n",
    "\n",
    "def get_datasize(dic: Dict) -> Tuple[int, int, int, int]:\n",
    "    c = int(dic[\"n_clases=\"])\n",
    "    d = int(dic[\"n_entradas=\"])\n",
    "    n_train_val = int(dic[\"n_patrons1=\"])\n",
    "    if \"n_patrons2=\" in dic:\n",
    "        n_test = int(dic[\"n_patrons2=\"])\n",
    "    else:\n",
    "        n_test = 0\n",
    "    n_tot = n_train_val + n_test\n",
    "    return n_tot, n_train_val, n_test, d,  c\n",
    "\n",
    "\n",
    "def load_data(dic: Dict) -> Tuple[np.array, np.array]:\n",
    "    f = open(os.path.join(DATA_DIR, dic[\"dataset\"], dic[\"fich1=\"]), \"r\").readlines()[1:]\n",
    "    X = np.asarray(list(map(lambda x: list(map(float, x.split()[1:-1])), f)))\n",
    "    y = np.asarray(list(map(lambda x: int(x.split()[-1]), f)))\n",
    "    return X, y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c58c2916",
   "metadata": {},
   "outputs": [],
   "source": [
    "MAX_TOT = 1000\n",
    "MAX_FEATURES = 10\n",
    "MAX_CLASSES = 2\n",
    "\n",
    "datasets = []\n",
    "\n",
    "n_dataset = 0\n",
    "for idx, dataset in enumerate(sorted(os.listdir(DATA_DIR))): \n",
    "    if not os.path.isfile(os.path.join(DATA_DIR, dataset, f\"{dataset}.txt\")):\n",
    "        continue\n",
    "\n",
    "    # load configuration\n",
    "    dic = dict()\n",
    "    dic[\"dataset\"] = dataset\n",
    "    if dic[\"dataset\"]!=\"tic-tac-toe\": # use only tic-tac-toe\n",
    "        continue\n",
    "\n",
    "    for k, v in map(\n",
    "        lambda x: x.split(),\n",
    "        open(os.path.join(DATA_DIR, dataset, f\"{dataset}.txt\"), \"r\").readlines(),\n",
    "    ):\n",
    "        dic[k] = v\n",
    "\n",
    "    # Check skip or not\n",
    "    n_tot, n_train_val, n_test, n_feature, n_class = get_datasize(dic)\n",
    "    if (n_tot > MAX_TOT) or (n_test > 0) or (n_feature >  MAX_FEATURES) or (n_class > MAX_CLASSES):\n",
    "        continue\n",
    "    else:\n",
    "        print(f\"-----{idx}, {dataset}, {n_tot}, {n_feature}, {n_class}-----\")\n",
    "        n_dataset += 1\n",
    "\n",
    "    # load dataset\n",
    "    X, y = load_data(dic)\n",
    "    fold = list(\n",
    "        map(\n",
    "            lambda x: list(map(int, x.split())),\n",
    "            open(\n",
    "                os.path.join(DATA_DIR, dic[\"dataset\"], \"conxuntos_kfold.dat\"), \"r\"\n",
    "            ).readlines(),\n",
    "        )\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c13b3fc9",
   "metadata": {},
   "source": [
    "## Kernels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "788c7c13",
   "metadata": {},
   "outputs": [],
   "source": [
    "def calc_tau(alpha: float, S: np.array, diag_i: np.array, diag_j: np.array) -> np.array:\n",
    "    tau = 1 / 4 + 1 / (2 * math.pi) * np.arcsin(\n",
    "        ((alpha ** 2) * S)\n",
    "        / (np.sqrt(((alpha ** 2) * diag_i + 0.5) * ((alpha ** 2) * diag_j + 0.5)))\n",
    "    )\n",
    "    return tau\n",
    "\n",
    "\n",
    "def calc_tau_dot(\n",
    "    alpha: float, S: np.array, diag_i: np.array, diag_j: np.array\n",
    ") -> np.array:\n",
    "    tau_dot = (\n",
    "        (alpha ** 2)\n",
    "        / (math.pi)\n",
    "        * 1\n",
    "        / np.sqrt(\n",
    "            (2 * (alpha ** 2) * diag_i + 1) * (2 * (alpha ** 2) * diag_j + 1)\n",
    "            - (4 * (alpha ** 4) * (S ** 2))\n",
    "        )\n",
    "    )\n",
    "    return tau_dot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b1672f7b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def hard_kernel(X: np.array, alpha: float, beta: float, finetune: bool, rulelist: list):    \n",
    "    S_list = []\n",
    "    tau_list = []\n",
    "    tau_dot_list = []\n",
    "\n",
    "    for feature_index in range(len(X[0])):\n",
    "        S = np.outer(X[:, feature_index], X[:, feature_index].T) + beta**2\n",
    "        S_all = np.matmul(X, X.T) + beta**2\n",
    "        if finetune:\n",
    "            S_list.append(S_all)\n",
    "        else:\n",
    "            S_list.append(S)\n",
    "\n",
    "        _diag = [S[i, i] for i in range(len(S))]\n",
    "        diag_i = np.array(_diag * len(_diag)).reshape(len(_diag), len(_diag))\n",
    "        diag_j = diag_i.transpose()\n",
    "        tau_list.append(calc_tau(alpha, S, diag_i, diag_j))\n",
    "        tau_dot_list.append(calc_tau_dot(alpha, S, diag_i, diag_j))\n",
    "        \n",
    "    K = np.zeros((X.shape[0], X.shape[0]))\n",
    "    \n",
    "    H = np.zeros_like(S_list[0])\n",
    "    for rules in tqdm(rulelist, leave=False):\n",
    "        # Internal nodes\n",
    "        for i, s in enumerate(rules):\n",
    "            ts = rules[0:i]+rules[i+1:]\n",
    "            _H_nodes = S_list[s]* tau_dot_list[s]\n",
    "            for t in ts:\n",
    "                _H_nodes *= tau_list[t]\n",
    "            K+= _H_nodes * (2**len(rules))\n",
    "        _H_leaves = np.ones_like(K)\n",
    "        \n",
    "        # Leaves\n",
    "        for tau in [tau_list[i] for i in rules]:\n",
    "            _H_leaves *= tau\n",
    "        K += _H_leaves * (2**len(rules))\n",
    "    \n",
    "    return K/len(rulelist) # normalize "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3951fdd0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def soft_kernel(X: np.array, depth: int, alpha: float, beta: float):\n",
    "    K = np.zeros((depth, X.shape[0], X.shape[0]))\n",
    "    S = np.matmul(X, X.T) + beta**2\n",
    "    _diag = [S[i, i] for i in range(len(S))]\n",
    "    diag_i = np.array(_diag * len(_diag)).reshape(len(_diag), len(_diag))\n",
    "    diag_j = diag_i.transpose()\n",
    "\n",
    "    tau = calc_tau(alpha, S, diag_i, diag_j)\n",
    "    tau_dot = calc_tau_dot(alpha, S, diag_i, diag_j)\n",
    "\n",
    "    for i, depth in enumerate((range(1, depth + 1, 1))):\n",
    "        H = (2 * S * (2 ** (depth - 1)) * depth * tau_dot * tau ** (depth - 1)) + (\n",
    "            (2 ** depth) * (tau ** depth)\n",
    "        )\n",
    "        K[depth - 1] = H\n",
    "\n",
    "    return K[::-1][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cef60394",
   "metadata": {},
   "outputs": [],
   "source": [
    "def extract_kernels(X, alpha, beta, degree):\n",
    "    assert degree in (1, 2, 3)\n",
    "    patterns = list(itertools.combinations(np.arange(X.shape[1]), 1))\n",
    "\n",
    "    if degree>=2:\n",
    "        patterns.extend(list(itertools.combinations(np.arange(X.shape[1]), 2)))\n",
    "        \n",
    "    if degree>=3:\n",
    "        patterns.extend(list(itertools.combinations(np.arange(X.shape[1]), 3)))\n",
    "\n",
    "    patterns = [list(l) for l in patterns]\n",
    "    patterns = [[pattern] for pattern in patterns]\n",
    "    \n",
    "    kernels_aaa = []\n",
    "    kernels_aai = []\n",
    "\n",
    "    for pattern in tqdm(patterns, leave=False):\n",
    "        kernels_aaa.append(hard_kernel(X, alpha=alpha, beta=beta, finetune=False, rulelist=pattern))\n",
    "        kernels_aai.append(hard_kernel(X, alpha=alpha, beta=beta, finetune=True, rulelist=pattern))  \n",
    "        \n",
    "    return kernels_aaa, kernels_aai, patterns"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8e233068",
   "metadata": {},
   "source": [
    "## MKL"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a6b3ca5d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from MKLpy.algorithms import EasyMKL"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1d4fa74",
   "metadata": {},
   "outputs": [],
   "source": [
    "kernels_aaa, kernels_aai, patterns = extract_kernels(X, alpha=2.0, beta=0.5, degree=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f7d2e646",
   "metadata": {},
   "outputs": [],
   "source": [
    "mkl = EasyMKL()\n",
    "ker_matrix_aaa_full = mkl.combine_kernels(kernels_aaa, y)\n",
    "ker_matrix_aai_full = mkl.combine_kernels(kernels_aai, y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bababb91",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(25,3))\n",
    "x = range(len(ker_matrix_aaa_full.weights))\n",
    "plt.bar(x, ker_matrix_aaa_full.weights, alpha=0.5, label=\"AAA\")\n",
    "plt.bar(x, ker_matrix_aai_full.weights, alpha=0.5, label=\"AAI\")\n",
    "plt.xticks(\n",
    "    x,\n",
    "    [str(sorted(set(i[0]))).replace(\"[\", \"{\").replace(\"]\", \"}\") for i in patterns],\n",
    "    rotation=75,\n",
    "    fontsize=10\n",
    ")\n",
    "plt.axvline(45, color=\"red\", linestyle=\"dashed\", linewidth=1)\n",
    "plt.axvline(60, color=\"red\", linestyle=\"dashed\", linewidth=1)\n",
    "plt.axvline(66, color=\"red\", linestyle=\"dashed\", linewidth=1)\n",
    "plt.axvline(86, color=\"red\", linestyle=\"dashed\", linewidth=1)\n",
    "plt.axvline(100, color=\"red\", linestyle=\"dashed\", linewidth=1)\n",
    "plt.axvline(105, color=\"red\", linestyle=\"dashed\", linewidth=1)\n",
    "plt.axvline(109, color=\"red\", linestyle=\"dashed\", linewidth=1)\n",
    "plt.axvline(128, color=\"red\", linestyle=\"dashed\", linewidth=1)\n",
    "\n",
    "plt.grid(linestyle=\"dotted\")\n",
    "plt.ylabel(\"Weight\")\n",
    "x = range(len(ker_matrix_aai_full.weights))\n",
    "plt.grid(linestyle=\"dotted\")\n",
    "plt.xlabel(\"Feature Combination\")\n",
    "plt.ylabel(\"Weight\")\n",
    "plt.legend()\n",
    "plt.savefig(\"./figures/tictactoe_weight.pdf\", bbox_inches=\"tight\", pad_inches=0.10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9271ba0e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_optimal_weight(size):\n",
    "    optimal = np.zeros_like(range(size))/1.\n",
    "    optimal[45] = 1\n",
    "    optimal[60] = 1\n",
    "    optimal[66] = 1\n",
    "    optimal[86] = 1\n",
    "    optimal[100] = 1\n",
    "    optimal[105] = 1\n",
    "    optimal[109] = 1\n",
    "    optimal[128] = 1\n",
    "    optimal/=sum(optimal)\n",
    "    return optimal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "148db565",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.svm import SVC\n",
    "from sklearn.metrics import accuracy_score, roc_auc_score\n",
    "\n",
    "def svm(kernels, y, weights, reg, train_index, test_index):\n",
    "    model= SVC(kernel=\"precomputed\", C=1.0, probability=True)\n",
    "\n",
    "    K = np.zeros_like(kernels[0])\n",
    "    for j in range(len(weights)):\n",
    "        K+=kernels[j]*weights[j]\n",
    "    \n",
    "    K_train= K[train_index][:, train_index]\n",
    "    K_test = K[test_index][:, train_index]\n",
    "\n",
    "    y_train = y[train_index]\n",
    "    y_test = y[test_index]\n",
    "\n",
    "    model.fit(K_train, y_train)\n",
    "    test_pred = model.predict(K_test)\n",
    "    test_pred_proba = model.predict_proba(K_test)[:, 1]\n",
    "    \n",
    "    accuracy = accuracy_score(y_test, test_pred)\n",
    "    auc = roc_auc_score(y_test, test_pred_proba)\n",
    "    \n",
    "    return accuracy, auc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ce6459c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier\n",
    "\n",
    "def rf_benchmark(X: np.array, y: np.array, train_index: list, test_index: list, max_depth: int, n_estimators: int) -> Tuple[float, List[float]]:\n",
    "    model = RandomForestClassifier(max_depth=max_depth, n_estimators=n_estimators)\n",
    "    model.fit(X[train_index], y[train_index])\n",
    "    test_pred = model.predict(X[test_index])\n",
    "    test_pred_proba = model.predict_proba(X[test_index])[:, 1]\n",
    "    \n",
    "    accuracy = accuracy_score(y[test_index], test_pred)\n",
    "    auc = roc_auc_score(y[test_index], test_pred_proba)\n",
    "    \n",
    "    return accuracy, auc\n",
    "\n",
    "def gbdt_benchmark(X: np.array, y: np.array, train_index: list, test_index: list, max_depth: int, n_estimators: int) -> Tuple[float, List[float]]:\n",
    "    model = GradientBoostingClassifier(max_depth=max_depth, n_estimators=n_estimators)\n",
    "    model.fit(X[train_index], y[train_index])\n",
    "    test_pred = model.predict(X[test_index])\n",
    "    test_pred_proba = model.predict_proba(X[test_index])[:, 1]\n",
    "    \n",
    "    accuracy = accuracy_score(y[test_index], test_pred)\n",
    "    auc = roc_auc_score(y[test_index], test_pred_proba)\n",
    "    \n",
    "    return accuracy, auc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2278ed3f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def benchmark(kernels_aaa, kernels_aai,kernel_soft,  y, train_index, test_index, alpha, beta, repeat, optimal):\n",
    "    acc_dict, auc_dict = {}, {}\n",
    "    \n",
    "    acc_dict[\"alpha\"] = alpha\n",
    "    acc_dict[\"beta\"] = beta\n",
    "    acc_dict[\"repeat\"] = repeat\n",
    "    auc_dict[\"alpha\"] = alpha\n",
    "    auc_dict[\"beta\"] = beta\n",
    "    auc_dict[\"repeat\"] = repeat\n",
    "    \n",
    "    # AAA\n",
    "    acc_dict[\"aaa_mkl\"], auc_dict[\"aaa_mkl\"] = svm(kernels_aaa, y, np.array(ker_matrix_aaa.weights), 1.0, train_index, test_index)\n",
    "    acc_dict[\"aaa_optimal\"], auc_dict[\"aaa_optimal\"] = svm(kernels_aaa, y, optimal, 1.0, train_index, test_index)\n",
    "    acc_dict[\"aaa_benchmark\"], auc_dict[\"aaa_benchmark\"] = svm(kernels_aaa, y, np.ones_like(ker_matrix_aaa.weights)/len(ker_matrix_aaa.weights), 1.0, train_index, test_index)\n",
    "\n",
    "    # AAI\n",
    "    acc_dict[\"aai_mkl\"], auc_dict[\"aai_mkl\"] = svm(kernels_aai, y, np.array(ker_matrix_aai.weights), 1.0, train_index, test_index)\n",
    "    acc_dict[\"aai_optimal\"], auc_dict[\"aai_optimal\"] = svm(kernels_aai, y, optimal, 1.0, train_index, test_index)\n",
    "    acc_dict[\"aai_benchmark\"], auc_dict[\"aai_benchmark\"] = svm(kernels_aai, y, np.ones_like(ker_matrix_aai.weights)/len(ker_matrix_aaa.weights), 1.0, train_index, test_index)\n",
    "\n",
    "    # Soft\n",
    "    acc_dict[\"soft\"], auc_dict[\"soft\"] = svm([kernel_soft] * len(kernels_aaa), y, np.ones_like(ker_matrix_aaa.weights)/len(ker_matrix_aaa.weights), 1.0, train_index, test_index)\n",
    "\n",
    "    # RF\n",
    "    acc_dict[\"rf3\"], auc_dict[\"rf3\"] = rf_benchmark(X, y, train_index, test_index, max_depth=3, n_estimators=1000)\n",
    "    acc_dict[\"rf5\"], auc_dict[\"rf5\"] = rf_benchmark(X, y, train_index, test_index, max_depth=5, n_estimators=1000)\n",
    "    acc_dict[\"rf7\"], auc_dict[\"rf7\"] = rf_benchmark(X, y, train_index, test_index, max_depth=7, n_estimators=1000)\n",
    "    acc_dict[\"rfmax\"], auc_dict[\"rfmax\"] = rf_benchmark(X, y, train_index, test_index, max_depth=None, n_estimators=1000)\n",
    "\n",
    "    # GBDT\n",
    "    acc_dict[\"gbdt3\"], auc_dict[\"gbdt3\"] = rf_benchmark(X, y, train_index, test_index, max_depth=3, n_estimators=1000)\n",
    "    acc_dict[\"gbdt5\"], auc_dict[\"gbdt5\"] = rf_benchmark(X, y, train_index, test_index, max_depth=5, n_estimators=1000)\n",
    "    acc_dict[\"gbdt7\"], auc_dict[\"gbdt7\"] = rf_benchmark(X, y, train_index, test_index, max_depth=7, n_estimators=1000)\n",
    "    acc_dict[\"gbdtmax\"], auc_dict[\"gbdtmax\"] = rf_benchmark(X, y, train_index, test_index, max_depth=None, n_estimators=1000)\n",
    "\n",
    "    \n",
    "    return acc_dict, auc_dict"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cb003fca",
   "metadata": {},
   "source": [
    "## GridSearch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "41abc3a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "degree = 3\n",
    "\n",
    "acc_dicts, auc_dicts = [], []\n",
    "\n",
    "if False:\n",
    "    for alpha in tqdm([0.5, 1.0, 2.0, 4.0], leave=False):\n",
    "        for beta in tqdm([0.1, 0.5, 1.0], leave=False):\n",
    "            kernel_soft = soft_kernel(X, depth=degree, alpha=alpha, beta=beta)\n",
    "            kernels_aaa, kernels_aai, patterns = extract_kernels(X, alpha=alpha, beta=beta, degree=degree)\n",
    "\n",
    "            for repeat in tqdm(range(4), leave=False):\n",
    "                test_index, train_index = fold[repeat * 2], fold[repeat * 2 + 1]\n",
    "                assert len(test_index) > len(train_index)\n",
    "                mkl = EasyMKL()\n",
    "\n",
    "                train_kernels_aaa = [i[train_index][:, train_index] for i in kernels_aaa]\n",
    "                train_kernels_aai = [i[train_index][:, train_index] for i in kernels_aai]\n",
    "                ker_matrix_aaa = mkl.combine_kernels(train_kernels_aaa, y[train_index])\n",
    "                ker_matrix_aai = mkl.combine_kernels(train_kernels_aai, y[train_index])                    \n",
    "\n",
    "                optimal = get_optimal_weight(len(ker_matrix_aaa.weights))\n",
    "\n",
    "                acc_dict, auc_dict = benchmark(kernels_aaa, kernels_aai, kernel_soft, y, train_index, test_index, alpha, beta, repeat, optimal)\n",
    "\n",
    "                acc_dicts.append(acc_dict)\n",
    "                auc_dicts.append(auc_dict)\n",
    "\n",
    "    with open('acc_dicts.pkl', 'wb') as file:\n",
    "        pickle.dump(acc_dicts, file)\n",
    "    with open('auc_dicts.pkl', 'wb') as file:\n",
    "        pickle.dump(auc_dicts, file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "965089d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('acc_dicts.pkl', 'rb') as file:\n",
    "    acc_dicts= pickle.load(file)\n",
    "with open('auc_dicts.pkl', 'rb') as file:\n",
    "    auc_dicts= pickle.load(file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a18411ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9565653d",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame(acc_dicts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37ba6fda",
   "metadata": {},
   "outputs": [],
   "source": [
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3928ad8",
   "metadata": {},
   "outputs": [],
   "source": [
    "df.groupby(by=[\"alpha\", \"beta\"]).mean()[\n",
    "    [\"aaa_mkl\", \"aaa_optimal\", \"aaa_benchmark\", \"aai_mkl\", \"aai_optimal\", \"aai_benchmark\", \"soft\"]\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a9dbe9dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "beta = 0.5\n",
    "\n",
    "_df = df[df[\"beta\"]==beta].groupby(by=[\"alpha\", \"beta\"]).mean()[\n",
    "    [\"aaa_mkl\", \"aaa_optimal\", \"aaa_benchmark\", \"aai_mkl\", \"aai_optimal\", \"aai_benchmark\", \"soft\", \"rf3\", \"rf5\", \"rf7\", \"gbdt3\", \"gbdt5\", \"gbdt7\"]\n",
    "].reset_index()\n",
    "\n",
    "x = range(4)\n",
    "\n",
    "plt.figure(figsize=(7,4))\n",
    "_df[\"aaa_mkl\"].plot(label=\"AAA (MKL)\", color=\"red\", linestyle=\"solid\", marker=\"o\")\n",
    "_df[\"aaa_optimal\"].plot(label=\"AAA (Optimal)\", color=\"red\", linestyle=\"dashed\", marker=\"^\")\n",
    "_df[\"aaa_benchmark\"].plot(label=\"AAA (Benchmark)\", color=\"red\", linestyle=\"dotted\", marker=\"v\" )\n",
    "_df[\"aai_mkl\"].plot(label=\"AAI (MKL)\", color=\"blue\", linestyle=\"solid\", marker=\"o\")\n",
    "_df[\"aai_optimal\"].plot(label=\"AAI (Optimal)\", color=\"blue\", linestyle=\"dashed\", marker=\"^\")\n",
    "_df[\"aai_benchmark\"].plot(label=\"AAI (Benchmark)\", color=\"blue\", linestyle=\"dotted\", marker=\"v\")\n",
    "_df[\"soft\"].plot(label=\"Oblique\", color=\"black\", marker=\"s\")\n",
    "\n",
    "rf3_mean = df.groupby(by=[\"alpha\", \"beta\"]).mean().mean()[\"rf3\"]\n",
    "rf3_std  = df.groupby(by=[\"alpha\", \"beta\"]).mean().std()[\"rf3\"]\n",
    "rf5_mean = df.groupby(by=[\"alpha\", \"beta\"]).mean().mean()[\"rf5\"]\n",
    "rf5_std  = df.groupby(by=[\"alpha\", \"beta\"]).mean().std()[\"rf5\"]\n",
    "rf7_mean = df.groupby(by=[\"alpha\", \"beta\"]).mean().mean()[\"rf7\"]\n",
    "rf7_std  = df.groupby(by=[\"alpha\", \"beta\"]).mean().std()[\"rf7\"]\n",
    "rfmax_mean = df.groupby(by=[\"alpha\", \"beta\"]).mean().mean()[\"rfmax\"]\n",
    "rfmax_std  = df.groupby(by=[\"alpha\", \"beta\"]).mean().std()[\"rfmax\"]\n",
    "\n",
    "gbdt3_mean = df.groupby(by=[\"alpha\", \"beta\"]).mean().mean()[\"gbdt3\"]\n",
    "gbdt3_std  = df.groupby(by=[\"alpha\", \"beta\"]).mean().std()[\"gbdt3\"]\n",
    "gbdt5_mean = df.groupby(by=[\"alpha\", \"beta\"]).mean().mean()[\"gbdt5\"]\n",
    "gbdt5_std  = df.groupby(by=[\"alpha\", \"beta\"]).mean().std()[\"gbdt5\"]\n",
    "gbdt7_mean = df.groupby(by=[\"alpha\", \"beta\"]).mean().mean()[\"gbdt7\"]\n",
    "gbdt7_std  = df.groupby(by=[\"alpha\", \"beta\"]).mean().std()[\"gbdt7\"]\n",
    "gbdtmax_mean = df.groupby(by=[\"alpha\", \"beta\"]).mean().mean()[\"gbdtmax\"]\n",
    "gbdtmax_std  = df.groupby(by=[\"alpha\", \"beta\"]).mean().std()[\"gbdtmax\"]\n",
    "\n",
    "plt.plot(x, [rf3_mean]*len(x), color=\"green\", linestyle=\"dashdot\", alpha=0.7, label=\"Random Forest\", linewidth=1.0)\n",
    "plt.fill_between(x, rf3_mean-rf3_std, rf3_mean+rf3_std, color='green', alpha=0.1)\n",
    "\n",
    "# plt.plot(x, [gbdt3_mean]*len(x), color=\"orange\", linestyle=\"dashdot\", alpha=0.7, label=\"Gradient Boosting\", linewidth=1.0)\n",
    "# plt.fill_between(x, gbdt3_mean-gbdt3_std, gbdt3_mean+gbdt3_std, color='green', alpha=0.1)\n",
    "\n",
    "plt.xticks([0, 1, 2, 3], [0.5, 1.0, 2.0, 4.0])\n",
    "plt.xlabel(\"$\\\\alpha$\")\n",
    "plt.ylabel(\"Accuracy\")\n",
    "plt.grid(linestyle=\"dotted\")\n",
    "plt.legend(loc=\"upper left\", bbox_to_anchor=(1,0.95))\n",
    "plt.tight_layout()\n",
    "plt.savefig(f\"./figures/tictactoe_metrics.pdf\", bbox_inches=\"tight\", pad_inches=0.10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "96d7eb90",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(20,6))\n",
    "for i, beta in enumerate([0.1, 0.5, 1.0]):\n",
    "    plt.subplot(1,3, i+1)\n",
    "    _df = df[df[\"beta\"]==beta].groupby(by=[\"alpha\", \"beta\"]).mean()[\n",
    "        [\"aaa_mkl\", \"aaa_optimal\", \"aaa_benchmark\", \"aai_mkl\", \"aai_optimal\", \"aai_benchmark\", \"soft\", \"rf3\", \"rf5\", \"rf7\", \"gbdt3\", \"gbdt5\", \"gbdt7\"]\n",
    "    ].reset_index()\n",
    "\n",
    "    x = range(4)\n",
    "\n",
    "    _df[\"aaa_mkl\"].plot(label=\"AAA (MKL)\", color=\"red\", linestyle=\"solid\", marker=\"o\")\n",
    "    _df[\"aaa_optimal\"].plot(label=\"AAA (Optimal)\", color=\"red\", linestyle=\"dashed\", marker=\"^\")\n",
    "    _df[\"aaa_benchmark\"].plot(label=\"AAA (Benchmark)\", color=\"red\", linestyle=\"dotted\", marker=\"v\" )\n",
    "    _df[\"aai_mkl\"].plot(label=\"AAI (MKL)\", color=\"blue\", linestyle=\"solid\", marker=\"o\")\n",
    "    _df[\"aai_optimal\"].plot(label=\"AAI (Optimal)\", color=\"blue\", linestyle=\"dashed\", marker=\"^\")\n",
    "    _df[\"aai_benchmark\"].plot(label=\"AAI (Benchmark)\", color=\"blue\", linestyle=\"dotted\", marker=\"v\")\n",
    "    _df[\"soft\"].plot(label=\"Oblique\", color=\"black\", marker=\"s\")\n",
    "\n",
    "    rf3_mean = df.groupby(by=[\"alpha\", \"beta\"]).mean().mean()[\"rf3\"]\n",
    "    rf3_std  = df.groupby(by=[\"alpha\", \"beta\"]).mean().std()[\"rf3\"]\n",
    "    rf5_mean = df.groupby(by=[\"alpha\", \"beta\"]).mean().mean()[\"rf5\"]\n",
    "    rf5_std  = df.groupby(by=[\"alpha\", \"beta\"]).mean().std()[\"rf5\"]\n",
    "    rf7_mean = df.groupby(by=[\"alpha\", \"beta\"]).mean().mean()[\"rf7\"]\n",
    "    rf7_std  = df.groupby(by=[\"alpha\", \"beta\"]).mean().std()[\"rf7\"]\n",
    "    rfmax_mean = df.groupby(by=[\"alpha\", \"beta\"]).mean().mean()[\"rfmax\"]\n",
    "    rfmax_std  = df.groupby(by=[\"alpha\", \"beta\"]).mean().std()[\"rfmax\"]\n",
    "\n",
    "    gbdt3_mean = df.groupby(by=[\"alpha\", \"beta\"]).mean().mean()[\"gbdt3\"]\n",
    "    gbdt3_std  = df.groupby(by=[\"alpha\", \"beta\"]).mean().std()[\"gbdt3\"]\n",
    "    gbdt5_mean = df.groupby(by=[\"alpha\", \"beta\"]).mean().mean()[\"gbdt5\"]\n",
    "    gbdt5_std  = df.groupby(by=[\"alpha\", \"beta\"]).mean().std()[\"gbdt5\"]\n",
    "    gbdt7_mean = df.groupby(by=[\"alpha\", \"beta\"]).mean().mean()[\"gbdt7\"]\n",
    "    gbdt7_std  = df.groupby(by=[\"alpha\", \"beta\"]).mean().std()[\"gbdt7\"]\n",
    "    gbdtmax_mean = df.groupby(by=[\"alpha\", \"beta\"]).mean().mean()[\"gbdtmax\"]\n",
    "    gbdtmax_std  = df.groupby(by=[\"alpha\", \"beta\"]).mean().std()[\"gbdtmax\"]\n",
    "\n",
    "    plt.plot(x, [rf3_mean]*len(x), color=\"green\", linestyle=\"dashdot\", alpha=0.7, label=\"Random Forest\", linewidth=1.0)\n",
    "    plt.fill_between(x, rf3_mean-rf3_std, rf3_mean+rf3_std, color='green', alpha=0.1)\n",
    "\n",
    "    # plt.plot(x, [gbdt3_mean]*len(x), color=\"orange\", linestyle=\"dashdot\", alpha=0.7, label=\"Gradient Boosting\", linewidth=1.0)\n",
    "    # plt.fill_between(x, gbdt3_mean-gbdt3_std, gbdt3_mean+gbdt3_std, color='green', alpha=0.1)\n",
    "\n",
    "    plt.xticks([0, 1, 2, 3], [0.5, 1.0, 2.0, 4.0])\n",
    "    plt.xlabel(\"$\\\\alpha$\")\n",
    "    if beta==0.1:\n",
    "        plt.ylabel(\"Accuracy\")\n",
    "    plt.grid(linestyle=\"dotted\")\n",
    "    plt.title(f\"$\\\\beta$={beta}\")\n",
    "\n",
    "plt.figlegend(labels=[\"AAA (MKL)\", \"AAA (Benchmark)\", \"AAA (Optimal)\", \"AAI (MKL)\", \"AAI (Benchmark)\", \"AAI (Optimal)\", \"Oblique\", \"Random Forest\"],\n",
    "    loc=\"lower center\", \n",
    "    ncol=8,\n",
    "    bbox_to_anchor=(0.5, -0.1)\n",
    ")\n",
    "plt.tight_layout()\n",
    "plt.savefig(f\"./figures/tictactoe_metrics_beta.pdf\", bbox_inches=\"tight\", pad_inches=0.10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db1955ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "rf_means = [rf3_mean, rf5_mean, rf7_mean, rfmax_mean]\n",
    "gbdt_means = [gbdt3_mean, gbdt5_mean, gbdt7_mean, gbdtmax_mean]\n",
    "rf_stds = [rf3_std, rf5_std, rf7_std, rfmax_std]\n",
    "gbdt_stds = [gbdt3_std, gbdt5_std, gbdt7_std, gbdtmax_std]\n",
    "\n",
    "x_pos_rf = np.arange(len(rf_means))\n",
    "x_pos_gbdt = x_pos_rf + 0.4\n",
    "\n",
    "plt.bar(x_pos_rf, rf_means, yerr=rf_stds, capsize=10, color='skyblue', width=0.4, label='Random Forest')\n",
    "\n",
    "plt.bar(x_pos_gbdt, gbdt_means, yerr=gbdt_stds, capsize=10, color='lightgreen', width=0.4, label='Gradient Boosting Decision Tree')\n",
    "\n",
    "plt.xlabel('max_depth')\n",
    "plt.ylabel('Accuracy')\n",
    "plt.ylim([0.7, 1.0])\n",
    "\n",
    "plt.xticks(x_pos_rf + 0.2, ['3', '5', '7', 'inf'])\n",
    "\n",
    "plt.hlines(\n",
    "    _df[\"aaa_benchmark\"].max(), \n",
    "    xmin=-0.4,\n",
    "    xmax=3.8, \n",
    "    color=\"red\", \n",
    "    linestyle=\"dotted\", \n",
    "    alpha=0.7, \n",
    "    label=\"AAA (Benchmark)\",\n",
    "    linewidth=2.0\n",
    ")\n",
    "\n",
    "plt.hlines(\n",
    "    _df[\"aai_benchmark\"].max(), \n",
    "    xmin=-0.4,\n",
    "    xmax=3.8, \n",
    "    color=\"blue\", \n",
    "    linestyle=\"dotted\", \n",
    "    alpha=0.7,\n",
    "    label=\"AAI (Benchmark)\",\n",
    "    linewidth=2.0\n",
    ")\n",
    "\n",
    "plt.xlim([-0.4, 3.8])\n",
    "\n",
    "plt.legend()\n",
    "\n",
    "plt.savefig(\"./figures/rf_gbdt_performance.pdf\", bbox_inches=\"tight\", pad_inches=0.10)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fc99dab3",
   "metadata": {},
   "source": [
    "## Hard Splitting Performance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb63a04b",
   "metadata": {},
   "outputs": [],
   "source": [
    "acc_dicts_hard, auc_dicts_hard = [], []\n",
    "for alpha in tqdm([1e0, 1e1, 1e2, 1e3], leave=False):\n",
    "    kernel_soft = soft_kernel(X, depth=3, alpha=alpha, beta=0.5)\n",
    "    kernels_aaa, kernels_aai, patterns = extract_kernels(X, alpha=alpha, beta=0.5, degree=3)\n",
    "    for repeat in tqdm(range(4), leave=False):\n",
    "        test_index, train_index = fold[repeat * 2], fold[repeat * 2 + 1]\n",
    "        assert len(test_index) > len(train_index)\n",
    "        mkl = EasyMKL()\n",
    "\n",
    "        train_kernels_aaa = [i[train_index][:, train_index] for i in kernels_aaa]\n",
    "        train_kernels_aai = [i[train_index][:, train_index] for i in kernels_aai]\n",
    "        ker_matrix_aaa = mkl.combine_kernels(train_kernels_aaa, y[train_index])\n",
    "        ker_matrix_aai = mkl.combine_kernels(train_kernels_aai, y[train_index])\n",
    "\n",
    "        optimal = get_optimal_weight(len(ker_matrix_aaa.weights))\n",
    "\n",
    "        acc_dict_hard, auc_dict_hard = benchmark(kernels_aaa, kernels_aai, kernel_soft, y, train_index, test_index, alpha, beta, repeat, optimal)\n",
    "        acc_dicts_hard.append(acc_dict_hard)\n",
    "        auc_dicts_hard.append(auc_dict_hard)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "116a0294",
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy import special\n",
    "\n",
    "fig = plt.figure(figsize=(10, 4))\n",
    "\n",
    "plt.subplot(1,2,1)\n",
    "df = pd.DataFrame(acc_dicts_hard)\n",
    "results = df.groupby(by=[\"alpha\", \"beta\"]).mean()[\"aaa_benchmark\"].values\n",
    "labels = ['$10^0$', '$10^1$', '$10^2$', '$10^3$']\n",
    "\n",
    "plt.bar(labels, results)\n",
    "x_range = np.linspace(-0.5, len(labels) - 0.5, 100)\n",
    "plt.plot(x_range, [rf3_mean] * len(x_range), color=\"green\", linestyle=\"dashdot\", alpha=0.7, label=\"Random Forest\", linewidth=1.0)\n",
    "plt.fill_between(x_range, [rf3_mean-rf3_std] * len(x_range), [rf3_mean+rf3_std] * len(x_range), color='green', alpha=0.1)\n",
    "\n",
    "plt.xlabel(\"$\\\\alpha$\")\n",
    "plt.ylabel('Accuracy')\n",
    "plt.title(\"AAA (Benchmark)\")\n",
    "plt.ylim([0.5, 1.0])\n",
    "plt.xlim(-0.5, len(labels) - 0.5)\n",
    "plt.legend(loc='lower right')\n",
    "\n",
    "plt.subplot(1,2,2)\n",
    "alpha_values = [1e0, 1e1, 1e2, 1e3]\n",
    "colors = ['blue', 'red', 'green', 'purple']\n",
    "labels = ['$\\\\alpha=10^0$', '$\\\\alpha=10^1$', '$\\\\alpha=10^2$', '$\\\\alpha=10^3$']\n",
    "\n",
    "for alpha, color, label in zip(alpha_values, colors, labels):\n",
    "    x = np.linspace(-0.5, 0.5, 100000)\n",
    "    plt.plot(x, 0.5 * special.erf(alpha * x) + 0.5, color=color, label=label)\n",
    "\n",
    "plt.xlabel('$c$')\n",
    "plt.ylabel('$\\sigma(c)$')\n",
    "plt.grid(linestyle=\"dotted\")\n",
    "\n",
    "plt.legend()\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"./figures/large_alpha.pdf\", bbox_inches=\"tight\", pad_inches=0.10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a208b54",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "celltoolbar": "Tags",
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
