{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ca198c30",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from tqdm.notebook import tqdm\n",
    "import pickle\n",
    "import glob\n",
    "import traceback\n",
    "\n",
    "from scipy import stats as sc_stats\n",
    "from scipy.special import gamma,loggamma\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics import mean_squared_error\n",
    "from sklearn.metrics import mean_absolute_percentage_error\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "dd8ad989",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-05-15 10:00:34.619533: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 FMA\n",
      "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
     ]
    }
   ],
   "source": [
    "import tensorflow as tf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "e9ec7453",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import os\n",
    "from os.path import dirname\n",
    "sys.path.append(dirname(\"../../\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "551b9722",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.edl import dense_layers,dense_loss\n",
    "from src.weibull_edl import loss_and_layers\n",
    "from src.exp_utils import synthetic"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "931113c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d8788f3e",
   "metadata": {},
   "source": [
    "### experiment variable"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "c7b0cd96",
   "metadata": {},
   "outputs": [],
   "source": [
    "EXP_NAME = \"exp3\" #experiment name\n",
    "N_EPS = 5 #number of eps values to try\n",
    "N_REGUL = 10 #number of regularisation values to try"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "e27e4e79",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "../../exp_results/exp3 is not empty, Removing files\n"
     ]
    }
   ],
   "source": [
    "base_dir = f\"../../exp_results/{EXP_NAME}\"\n",
    "\n",
    "if len(os.listdir(base_dir)) == 0:\n",
    "    print(f\"{base_dir} is empty\")\n",
    "else:    \n",
    "    print(f\"{base_dir} is not empty, Removing files\")\n",
    "    [os.remove(f) for f in glob.glob(base_dir+\"/*\")]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "027dde0a",
   "metadata": {},
   "source": [
    "### Helper funcs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c6ed12c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_results(ax, mu, var, model_type, n): \n",
    "    ax.plot(x_test,y_test,label=\"True Test\",alpha = 1.0,color=\"black\")\n",
    "    ax.scatter(x_train,y_train,label=\"Train\")\n",
    "    ax.plot(x_test.reshape(n), mu, zorder=3, label=\"Mean Pred\",color=\"red\", alpha = 0.5)\n",
    "    if model_type == \"proposed\":\n",
    "        a = 0.5\n",
    "        b = 1\n",
    "        var_check = np.sqrt(var)\n",
    "    else:\n",
    "        a = 2\n",
    "        b = 8\n",
    "        var_check = var\n",
    "    ax.fill_between(x=x_test.reshape(n),\\\n",
    "                 y1=(mu - a * var_check).reshape(n), \\\n",
    "                 y2=(mu + a * var_check).reshape(n),\\\n",
    "                 label=f\"{a} std PI\",color=\"grey\",alpha=0.7)\n",
    "    ax.fill_between(x=x_test.reshape(n),\\\n",
    "                 y1=(mu - b * var_check).reshape(n), \\\n",
    "                 y2=(mu + b * var_check).reshape(n),\\\n",
    "                 label=f\"{b} std PI\",color=\"pink\",alpha=0.2)\n",
    "#     ax.set_ylim(-10,200)\n",
    "    ax.legend()\n",
    "    ax.set_title(f\"{model_type} Model\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b2754319",
   "metadata": {},
   "source": [
    "### Experiment for synthetic data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "56166814",
   "metadata": {},
   "outputs": [],
   "source": [
    "results_dict={}\n",
    "model_dict={}\n",
    "data_dict={}\n",
    "\n",
    "eps_list = np.linspace(0.2,0.4,N_EPS)\n",
    "for eps in eps_list:\n",
    "    #a.gen synthetic data\n",
    "    x_train, y_train = synthetic.gen_data_weibulll(-4, 4, 3400, eps)\n",
    "    x_test, y_test = synthetic.gen_data_weibulll(-5, 5, 3400, eps, train=False)\n",
    "    data_dict[eps] = (x_train, y_train, x_test, y_test)\n",
    "    \n",
    "    #b.get k from train data\n",
    "    rv = sc_stats.weibull_min.fit(y_train, floc=0.0)\n",
    "    k=float(rv[0])\n",
    "    print (f\"For eps={eps}, fitted k ={k}\")\n",
    "    \n",
    "    #c.plot data and save fig\n",
    "    fig,ax = plt.subplots(1,1)\n",
    "    ax.scatter(x_train,y_train,label=\"Train\")\n",
    "    ax.plot(x_test,y_test,label=\"Test\",alpha = 0.5, color=\"black\")\n",
    "    ax.legend()\n",
    "    ax.set_title(f\"$y \\sim x^2 + {eps:.3f}*weibull(shape=1.6)$\")\n",
    "    fig.savefig(base_dir+f\"/data_gen_eps={eps}.png\")\n",
    "    plt.close()\n",
    "    \n",
    "    #d.initialise and run benchmark model\n",
    "    results_benchmark_eps = {}\n",
    "    for c_i in np.logspace(-6,-1,N_REGUL):\n",
    "        try:\n",
    "            print (f\"Fitting for c={c_i}\")\n",
    "            print(\"fitting benchmark\")\n",
    "            #run benchmark model\n",
    "            mu_i, var_i, y_pred_train_i, y_pred_test_i,\\\n",
    "            benchmark_model_i, hist_i = synthetic.results_benchmark_model(c_i,x_train,y_train,x_test)\n",
    "            a,b = synthetic.metrics_benchmark(y_train,y_pred_train_i)\n",
    "            results_dict[(eps,c_i,\"benchmark\",\"train\")] = {\n",
    "                \"mse\":a, \"nll\":b, \"loss\": hist_i.history[\"loss\"][-1],\n",
    "            }\n",
    "            c,d = synthetic.metrics_benchmark(y_test,y_pred_test_i)\n",
    "            results_dict[(eps,c_i,\"benchmark\",\"test\")] = {\n",
    "                \"mse\":c, \"nll\":d\n",
    "            }\n",
    "            model_dict[(eps,c_i,\"benchmark\")] = benchmark_model_i\n",
    "            \n",
    "            print(\"fitting proposed\")\n",
    "            #run proposed model\n",
    "            mu_prop_i, var_prop_i, y_pred_train_prop_i,\\\n",
    "            y_pred_test_prop_i, proposed_model_i, hist_prop_i = synthetic.results_weibull_model(c_i,x_train,y_train,x_test,k,0)\n",
    "            a1,b1 = synthetic.metrics_proposed(y_train,y_pred_train_prop_i,k)\n",
    "            results_dict[(eps,c_i,\"proposed\",\"train\")] = {\n",
    "                \"mse\":a1, \"nll\":b1, \"loss\": hist_prop_i.history[\"loss\"][-1],\n",
    "            }\n",
    "            c1,d1 = synthetic.metrics_proposed(y_test,y_pred_test_prop_i,k)\n",
    "            results_dict[(eps,c_i,\"proposed\",\"test\")] = {\n",
    "                \"mse\":c1, \"nll\":d1\n",
    "            }\n",
    "            model_dict[(eps,c_i,\"proposed\")] = proposed_model_i\n",
    "            \n",
    "            #plot the results\n",
    "            fig,ax = plt.subplots(1,2,figsize=(12,6))\n",
    "            plot_results(ax[0],mu_i,var_i,\"benchmark\",n=3400)\n",
    "            plot_results(ax[1],mu_prop_i,var_prop_i,\"proposed\",n=3400)\n",
    "            fig.suptitle(f\"eps={eps}, c= {c_i}\")\n",
    "            fig.savefig(base_dir+f\"/results_eps={eps},c={c_i}.png\")\n",
    "            plt.close()\n",
    "        except Exception as e:\n",
    "            print (f\"Error for eps={eps}, c= {c_i}\")\n",
    "            traceback.print_exc()\n",
    "        print()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b0f8b4df",
   "metadata": {},
   "source": [
    "### Save results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf4b2be3",
   "metadata": {},
   "outputs": [],
   "source": [
    "result_df = pd.DataFrame.from_dict(results_dict,orient=\"index\")\n",
    "result_df[['eps','c',\"model_type\"]] = pd.DataFrame(result_df.index.tolist(), index= result_df.index)\n",
    "result_df = result_df.rename(columns={0:\"mse\",1:\"NLL\"})\n",
    "result_df.to_csv(base_dir+\"/mse_nll_results.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d453a30",
   "metadata": {},
   "outputs": [],
   "source": [
    "exp_metadata = {\"data\":data_dict, \"models\": model_dict}\n",
    "\n",
    "with open(base_dir + f'/exp_metadata.pickle', 'wb') as handle:\n",
    "    pickle.dump(exp_metadata, handle)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9e0881ff",
   "metadata": {},
   "source": [
    "### Junk below"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "79c658ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "benchmark = result_df[result_df.model_type==\"benchmark\"]\n",
    "benchmark[\"rnk_\"] = benchmark.groupby([\"eps\",\"model_type\"])[\"mse\"].rank()\n",
    "best_benchmark = benchmark[benchmark[\"rnk_\"]==1]\n",
    "\n",
    "proposed = result_df[result_df.model_type==\"proposed\"]\n",
    "proposed[\"rnk_\"] = proposed.groupby([\"eps\",\"model_type\"])[\"mse\"].rank()\n",
    "best_proposed = proposed[proposed[\"rnk_\"]==1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3361520d",
   "metadata": {},
   "outputs": [],
   "source": [
    "best_results = best_benchmark.merge(best_proposed,on=\"eps\",suffixes=(\"_bench\",\"_prop\"))\n",
    "best_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "625ed2d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "best_results[[\"eps\",\"mse_bench\",\"mse_prop\",\"NLL_bench\",\"NLL_prop\"]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "77260ab5",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c5a9fa53",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c0611e39",
   "metadata": {},
   "outputs": [],
   "source": [
    "proposed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0099f711",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ddc59462",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "317939ec",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23cc61c7",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fd979e39",
   "metadata": {},
   "outputs": [],
   "source": [
    "benchmark[benchmark[\"rnk_\"]==1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "275b4b98",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb48b8c1",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb10af03",
   "metadata": {},
   "outputs": [],
   "source": [
    "# mu_prop_i, var_prop_i,\\\n",
    "#             y_pred_prop_i, proposed_model_i = synthetic.results_weibull_model(c_i,x_train,y_train,x_test,k,1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "961d1ef9",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ef76083",
   "metadata": {},
   "outputs": [],
   "source": [
    "# synthetic.results_weibull_model(c_i,x_train,y_train,x_test,k,1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4495ffe1",
   "metadata": {},
   "source": [
    "from scipy.special import loggamma"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fca4f7f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def my_func(x):\n",
    "    a = gamma(x-(2/1.3))\n",
    "    b = gamma(x)\n",
    "    c = gamma(x-(1/1.3))\n",
    "    d = gamma(x)\n",
    "    return (a/b)-(c/d)**2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "60150e55",
   "metadata": {},
   "outputs": [],
   "source": [
    "x = np.linspace(2.0,10,100)\n",
    "plt.plot(x,my_func(x))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "70ca73a4",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f4c554f3",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e3157b2",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b8174eb3",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
