{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "635a1ed9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Data from sdam-only_main-n150.json:\n",
      "MSE Mean (DNN): 14.37, Std: 1.03\n",
      "MSE Mean (SDAM): 2.84, Std: 3.82\n",
      "-----------------------------------------------------------\n",
      "Data from sdam-inter_no_overlap-n150.json:\n",
      "MSE Mean (DNN): 5.78, Std: 0.43\n",
      "MSE Mean (SDAM): 1.44, Std: 1.27\n",
      "-----------------------------------------------------------\n",
      "Data from sdam-inter_mild_overlap-n150.json:\n",
      "MSE Mean (DNN): 7.11, Std: 0.51\n",
      "MSE Mean (SDAM): 1.48, Std: 1.31\n",
      "-----------------------------------------------------------\n",
      "Data from sdam-inter_strong_overlap-n150.json:\n",
      "MSE Mean (DNN): 5.90, Std: 0.43\n",
      "MSE Mean (SDAM): 1.21, Std: 1.23\n",
      "-----------------------------------------------------------\n",
      "Data from sdam-only_inter-n150.json:\n",
      "MSE Mean (DNN): 1.05, Std: 0.11\n",
      "MSE Mean (SDAM): 0.31, Std: 0.20\n",
      "-----------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "import json\n",
    "import numpy as np\n",
    "# Define the path to your JSON file\n",
    "file_path = '/home/users/yhung7/SDAM/src/output/'  # Replace with the actual path to your file\n",
    "file_type = ['sdam-only_main-n150', 'sdam-inter_no_overlap-n150',\n",
    "              'sdam-inter_mild_overlap-n150', 'sdam-inter_strong_overlap-n150', 'sdam-only_inter-n150']\n",
    "\n",
    "# Open the JSON file in read mode ('r')\n",
    "for file in file_type:\n",
    "    with open(file_path+file+'.json', 'r') as f:\n",
    "        # Use json.load() to parse the JSON data from the file object\n",
    "        data = json.load(f)\n",
    "\n",
    "    print(f\"Data from {file}.json:\")\n",
    "    print(f\"MSE Mean (DNN): {np.mean(data['MSE_DNN']):.2f}, Std: {np.std(data['MSE_DNN']):.2f}\")\n",
    "    print(f\"MSE Mean (SDAM): {np.mean(data['MSE_ADB']):.2f}, Std: {np.std(data['MSE_ADB']):.2f}\")\n",
    "    #print(f\"Runtime Mean (DNN): {np.mean(data['Runtime_DNN'])}, Std: {np.std(data['Runtime_DNN'])}\")\n",
    "    #print(f\"Runtime Mean (SDAM): {np.mean(data['Runtime_ADN'])}, Std: {np.std(data['Runtime_ADN'])}\")\n",
    "\n",
    "    print('-----------------------------------------------------------')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "52abd602-5014-4a95-aa5d-fdcd3d9c3e92",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b498eb4",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "3e5fb26b",
   "metadata": {},
   "source": [
    "## Comparison"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "25e12f8d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "from pygam import LinearGAM, s, te\n",
    "from sklearn.metrics import mean_squared_error\n",
    "import fastsparsegams"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb3d263a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "only_main_data 5.845926563792651 0.5191938219646623\n",
      "weak_main_data 3.2287610940876745 0.2478477695514835\n",
      "inter_no_overlap_data 3.5505507515420813 0.2458181083672007\n",
      "inter_mild_overlap_data 3.5346354127574027 0.27443186104794975\n",
      "inter_strong_overlap_data 3.6535376179392345 0.2734541231983192\n",
      "only_inter_data 0.6264446516415378 0.0458019011782489\n"
     ]
    }
   ],
   "source": [
    "name = ['only_main_data', 'weak_main_data', 'inter_no_overlap_data', 'inter_mild_overlap_data', 'inter_strong_overlap_data', 'only_inter_data']\n",
    "\n",
    "for n in name:\n",
    "    _dict = torch.load('/Users/a080528/Downloads/data/'+ n + '.pt', weights_only= True)\n",
    "    X_train = np.array(_dict['X_train'])\n",
    "    y_train = np.array(_dict['y_train'])\n",
    "    X_test = np.array(_dict['X_test'])\n",
    "    y_test = np.array(_dict['y_test'])\n",
    "\n",
    "    mse = []\n",
    "    for r in range(X_train.shape[0]):\n",
    "\n",
    "        fit_model = fastsparsegams.fit(X_train[r].astype(np.float64), y_train[r].astype(np.float64), penalty=\"L0\", max_support_size=20)\n",
    "        y_pred = (fit_model.predict(x=X_test[r], lambda_0=0.032715, gamma=0))\n",
    "\n",
    "        mse.append(mean_squared_error(y_pred, y_test[r]))\n",
    "\n",
    "    print(n, np.mean(mse), np.std(mse))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8840c798",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "7b21dd93-1ecb-452b-aede-8d20e5609052",
   "metadata": {},
   "source": [
    "## Variable Cover"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "id": "8a1b1743",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Case: only_main\n",
      "main coverage 100/100\n"
     ]
    }
   ],
   "source": [
    "import re\n",
    "import ast\n",
    "\n",
    "file_path = '/home/users/yhung7/SDAM/src/scripts/'  # Replace with the actual path to your file\n",
    "file_type = ['only_main']\n",
    "true_m = [[0], [1], [2]]\n",
    "\n",
    "for file in file_type:\n",
    "    with open(file_path+file+'.txt', 'r') as f:\n",
    "    \n",
    "        lines = f.readlines() # Reads all lines into a list\n",
    "        \n",
    "    count_list_m = []\n",
    "    count_list_i = []\n",
    "    t = 0 \n",
    "    for l in lines:\n",
    "        if 'Active' in l:\n",
    "            t+= 1\n",
    "            start_index = l.find(\"[[\")\n",
    "            array_string = l[start_index:].strip()\n",
    "            result_array = ast.literal_eval(array_string)\n",
    "\n",
    "            count_m = 0\n",
    "            for element1 in true_m:\n",
    "                for element2 in result_array:\n",
    "                    if element1 == element2:\n",
    "                        count_m += 1\n",
    "\n",
    "            count_list_m.append(count_m == len(true_m))\n",
    "            \n",
    "print(f\"Case: {file_type[0]}\\nmain coverage {sum(count_list_m)}/{t}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "id": "d28a15f7-2f86-4b5c-ad45-d5dd7de855c6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Case: inter_no_overlap\n",
      "main coverage 100/100\n",
      "interaction coverage 74/100\n"
     ]
    }
   ],
   "source": [
    "file_type = ['inter_no_overlap']\n",
    "true_m = [[0], [1], [2]]\n",
    "true_i = [[3, 4]]\n",
    "\n",
    "for file in file_type:\n",
    "    with open(file_path+file+'.txt', 'r') as f:\n",
    "    \n",
    "        lines = f.readlines() # Reads all lines into a list\n",
    "        \n",
    "    count_list_m = []\n",
    "    count_list_i = []\n",
    "    t = 0 \n",
    "    for l in lines:\n",
    "        if 'Active' in l:\n",
    "            t+= 1\n",
    "            start_index = l.find(\"[[\")\n",
    "            array_string = l[start_index:].strip()\n",
    "            result_array = ast.literal_eval(array_string)\n",
    "\n",
    "            count_m = 0\n",
    "            for element1 in true_m:\n",
    "                for element2 in result_array:\n",
    "                    if element1 == element2:\n",
    "                        count_m += 1\n",
    "\n",
    "            count_i = 0\n",
    "            for element1 in true_i:\n",
    "                for element2 in result_array:\n",
    "                    if element1 == element2:\n",
    "                        count_i += 1  \n",
    "                        \n",
    "            count_list_m.append(count_m == len(true_m))\n",
    "            count_list_i.append(count_i == len(true_i))\n",
    "\n",
    "print(f\"Case: {file_type[0]}\\nmain coverage {sum(count_list_m)}/{t}\")\n",
    "print(f\"interaction coverage {sum(count_list_i)}/{t}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "id": "084d8f23-0c19-402c-8955-4100b961e2a8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Case: inter_mild_overlap\n",
      "main coverage 100/100\n",
      "interaction coverage 85/100\n"
     ]
    }
   ],
   "source": [
    "file_type = ['inter_mild_overlap']\n",
    "true_m = [[0], [1], [2]]\n",
    "true_i = [[2, 3]]\n",
    "\n",
    "for file in file_type:\n",
    "    with open(file_path+file+'.txt', 'r') as f:\n",
    "    \n",
    "        lines = f.readlines() # Reads all lines into a list\n",
    "        \n",
    "    count_list_m = []\n",
    "    count_list_i = []\n",
    "    t = 0 \n",
    "    for l in lines:\n",
    "        if 'Active' in l:\n",
    "            t+= 1\n",
    "            start_index = l.find(\"[[\")\n",
    "            array_string = l[start_index:].strip()\n",
    "            result_array = ast.literal_eval(array_string)\n",
    "\n",
    "            count_m = 0\n",
    "            for element1 in true_m:\n",
    "                for element2 in result_array:\n",
    "                    if element1 == element2:\n",
    "                        count_m += 1\n",
    "\n",
    "            count_i = 0\n",
    "            for element1 in true_i:\n",
    "                for element2 in result_array:\n",
    "                    if element1 == element2:\n",
    "                        count_i += 1  \n",
    "                        \n",
    "            count_list_m.append(count_m == len(true_m))\n",
    "            count_list_i.append(count_i == len(true_i))\n",
    "\n",
    "print(f\"Case: {file_type[0]}\\nmain coverage {sum(count_list_m)}/{t}\")\n",
    "print(f\"interaction coverage {sum(count_list_i)}/{t}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "id": "eb3d9530-fe38-4c39-9bcb-66adc480330a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Case: inter_strong_overlap\n",
      "main coverage 100/100\n",
      "interaction coverage 94/100\n"
     ]
    }
   ],
   "source": [
    "file_type = ['inter_strong_overlap']\n",
    "true_m = [[0], [1], [2]]\n",
    "true_i = [[1, 2]]\n",
    "\n",
    "for file in file_type:\n",
    "    with open(file_path+file+'.txt', 'r') as f:\n",
    "    \n",
    "        lines = f.readlines() # Reads all lines into a list\n",
    "        \n",
    "    count_list_m = []\n",
    "    count_list_i = []\n",
    "    t = 0 \n",
    "    for l in lines:\n",
    "        if 'Active' in l:\n",
    "            t+= 1\n",
    "            start_index = l.find(\"[[\")\n",
    "            array_string = l[start_index:].strip()\n",
    "            result_array = ast.literal_eval(array_string)\n",
    "\n",
    "            count_m = 0\n",
    "            for element1 in true_m:\n",
    "                for element2 in result_array:\n",
    "                    if element1 == element2:\n",
    "                        count_m += 1\n",
    "\n",
    "            count_i = 0\n",
    "            for element1 in true_i:\n",
    "                for element2 in result_array:\n",
    "                    if element1 == element2:\n",
    "                        count_i += 1  \n",
    "                        \n",
    "            count_list_m.append(count_m == len(true_m))\n",
    "            count_list_i.append(count_i == len(true_i))\n",
    "\n",
    "print(f\"Case: {file_type[0]}\\nmain coverage {sum(count_list_m)}/{t}\")\n",
    "print(f\"interaction coverage {sum(count_list_i)}/{t}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 97,
   "id": "0239c222-3010-4eee-99d7-cd1cf1d46235",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Case: weak_main\n",
      "strong main coverage 100/100\n",
      "Weak main coverage 0/100\n"
     ]
    }
   ],
   "source": [
    "file_type = ['weak_main']\n",
    "true_m = [[0], [1], [2], [3]]\n",
    "\n",
    "for file in file_type:\n",
    "    with open(file_path+file+'.txt', 'r') as f:\n",
    "    \n",
    "        lines = f.readlines() # Reads all lines into a list\n",
    "        \n",
    "    count_list_sm = []\n",
    "    count_list_wm = []\n",
    "    t = 0 \n",
    "    for l in lines:\n",
    "        if 'Active' in l:\n",
    "            t+= 1\n",
    "            start_index = l.find(\"[[\")\n",
    "            array_string = l[start_index:].strip()\n",
    "            result_array = ast.literal_eval(array_string)\n",
    "\n",
    "            count_m = 0\n",
    "            for element1 in true_m:\n",
    "                for element2 in result_array:\n",
    "                    if element1 == element2:\n",
    "                        count_m += 1\n",
    "\n",
    "            \n",
    "            count_list_wm.append(count_m == len(true_m))\n",
    "            count_list_sm.append(count_m == (len(true_m)-1))\n",
    "\n",
    "print(f\"Case: {file_type[0]}\\nstrong main coverage {sum(count_list_sm)}/{t}\")\n",
    "print(f\"Weak main coverage {sum(count_list_wm)}/{t}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 98,
   "id": "9c46c3b3-916c-4917-9418-5a9a600b444c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Case: only_inter\n",
      "interaction coverage 95/100\n"
     ]
    }
   ],
   "source": [
    "file_type = ['only_inter']\n",
    "true_i = [[0, 1], [2, 3]]\n",
    "\n",
    "for file in file_type:\n",
    "    with open(file_path+file+'.txt', 'r') as f:\n",
    "    \n",
    "        lines = f.readlines() # Reads all lines into a list\n",
    "        \n",
    "    count_list_i = []\n",
    "    t = 0 \n",
    "    for l in lines:\n",
    "        if 'Active' in l:\n",
    "            t+= 1\n",
    "            start_index = l.find(\"[[\")\n",
    "            array_string = l[start_index:].strip()\n",
    "            result_array = ast.literal_eval(array_string)\n",
    "\n",
    "            count_i = 0\n",
    "            for element1 in true_i:\n",
    "                for element2 in result_array:\n",
    "                    if element1 == element2:\n",
    "                        count_i += 1  \n",
    "                        \n",
    "            count_list_i.append(count_i == len(true_i))\n",
    "\n",
    "print(f\"Case: {file_type[0]}\")\n",
    "print(f\"interaction coverage {sum(count_list_i)}/{t}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a980fb88-26fe-4952-9a18-d8ae76857052",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
