{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "ef5659a1",
   "metadata": {},
   "source": [
    "This notebook can be used to plot a scaling law for all experiments conducted with _Empirical_SL_CS.ipynb_ or _Empirical_SL_CS.py_\n",
    "\n",
    "It collects pre-computed performance metrics from all directories starting with E001..., E002..., ...."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba1f50d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import glob\n",
    "import pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5545be46",
   "metadata": {},
   "outputs": [],
   "source": [
    "def lin_fit(x,y,start_x,end_x):\n",
    "    # train_size at which power-law regime begins\n",
    "    start_ind = np.where(x==start_x)[0][0]\n",
    "    # train_size at which power-law regime ends\n",
    "    end_ind = np.where(x==end_x)[0][0]\n",
    "\n",
    "    # find linear fit \n",
    "    y_fit = np.log10(y[start_ind:end_ind+1])\n",
    "    x_fit = np.vstack((np.log10(x[start_ind:end_ind+1]),np.ones(y_fit.shape))).T\n",
    "\n",
    "    linfit_w = np.linalg.inv(x_fit.T@x_fit) @ x_fit.T@y_fit\n",
    "\n",
    "    # power-law: R = beta * N**alpha\n",
    "    beta = 10**linfit_w[1]\n",
    "    alpha = linfit_w[0]\n",
    "    print(beta,alpha)\n",
    "    linfit = beta * x**alpha\n",
    "    return linfit, alpha, beta\n",
    "\n",
    "    \n",
    "\n",
    "def plot_trainsize_scaling(ax,best_perf_dict,train_examples_tags,dist_shifts,fontsize,linfit_bounds,colors,\n",
    "                          ylim=None,xlim=None):\n",
    "    \n",
    "    \n",
    "    for j,dist_shift in enumerate(dist_shifts):  \n",
    "        x_points = []\n",
    "        y_points = []\n",
    "        x_line = []\n",
    "        y_line = []\n",
    "        for tag in train_examples_tags:\n",
    "            x_line.append(int(tag[1:-1]))\n",
    "            y_line.append(best_perf_dict[tag][dist_shift]['max'])\n",
    "\n",
    "            for i in range(len(best_perf_dict[tag][dist_shift]['all'])):\n",
    "                x_points.append(int(tag[1:-1]))\n",
    "                y_points.append(best_perf_dict[tag][dist_shift]['all'][i])\n",
    "        linfits = []\n",
    "        alphas = []\n",
    "        if linfit_bounds:\n",
    "            if not isinstance(linfit_bounds[0],list):\n",
    "                linfit, alpha, _ = lin_fit(np.array(x_line),np.array(y_line),linfit_bounds[0],linfit_bounds[1])\n",
    "            else:\n",
    "                for linfit_bound in linfit_bounds:\n",
    "                    linfit, alpha, _ = lin_fit(np.array(x_line),np.array(y_line),linfit_bound[0],linfit_bound[1])\n",
    "                    linfits.append(linfit)\n",
    "                    alphas.append(alpha)\n",
    "        else:\n",
    "            alpha=0\n",
    "\n",
    "        if dist_shift == 'test_':\n",
    "            label = r\"ImgNet Test\"\n",
    "            #label = 'test'\n",
    "        else:\n",
    "            label = r\"%s: $\\alpha={%.4f}$\"%(dist_shift,np.round(alpha,4))\n",
    "            #label = 'test'\n",
    "\n",
    "        ax.plot(x_line,y_line,label=label,color=colors[j])\n",
    "        if linfit_bounds:\n",
    "            if linfits:\n",
    "                for k,(linfit,alpha) in enumerate(zip(linfits,alphas)):\n",
    "                    label = r\"$\\alpha={%.4f}$\"%(np.round(alpha,4))\n",
    "                    ax.plot(x_line,linfit,linestyle='--',label=label,color=colors[j+k+1])\n",
    "            else:\n",
    "                label = r\"$\\alpha={%.4f}$\"%(np.round(alpha,4))\n",
    "                ax.plot(x_line,linfit,linestyle='--',label=label,color=colors[j])\n",
    "        ax.scatter(x_points,y_points,color=colors[j])\n",
    "\n",
    "    \n",
    "    \n",
    "    ax.legend(fontsize=fontsize-3)\n",
    "    ax.set_xlabel(\"Number of exsamples in the training set $N$\", fontsize=fontsize)\n",
    "    ax.set_ylabel(\"SSIM\", fontsize=fontsize)\n",
    "    ax.tick_params(axis='both', which='major', labelsize=fontsize)\n",
    "    ax.tick_params(axis='both', which='minor', labelsize=fontsize-2)\n",
    "    ax.set_xscale('log')\n",
    "    if ylim:\n",
    "        ax.set_ylim(ylim)\n",
    "    if xlim:\n",
    "        ax.set_xlim(xlim)\n",
    "    ax.grid(True)\n",
    "    \n",
    "def plot_parameter_scaling(ax,perf_dict,train_examples_tags,dist_shifts,fontsize,colors,ylim=None,xlim=None):\n",
    "    layers_channels = ['l4c16', 'l4c32', 'l4c48', 'l4c64', 'l4c96','l4c112','l4c128','l4c144','l4c160','l4c176','l4c192','l4c208','l4c224','l4c256']\n",
    "    num_parameters = list(np.array([2,  8, 18, 31, 70, 95,124,157,193,234,279,327,380,496])*1e6)\n",
    "\n",
    "    for j,dist_shift in enumerate(dist_shifts):   \n",
    "        \n",
    "        for c,tag in enumerate(train_examples_tags):\n",
    "            num_train = int(tag[1:-1])\n",
    "            num_train = int(tag[1:-1])\n",
    "            x_points = []\n",
    "            y_points = []\n",
    "            x_line = []\n",
    "            y_line = []\n",
    "            experiments = perf_dict[tag].keys()\n",
    "\n",
    "            for exp in experiments:\n",
    "                use_best_or_last = 'best'\n",
    "                eind = exp.find('_l')+1\n",
    "                # Get the number of parameters\n",
    "                lc_tag = exp[eind:exp.find('_bs')]\n",
    "                for k,lc in enumerate(layers_channels):\n",
    "                    if lc == lc_tag:\n",
    "                        num_par = num_parameters[k]\n",
    "                # Get the number of runs for this parameter count\n",
    "                num_runs = len(perf_dict[tag][exp][dist_shift]['best'])\n",
    "\n",
    "                x_line.append(num_par)\n",
    "                y_line.append(np.max(perf_dict[tag][exp][dist_shift][use_best_or_last]))\n",
    "                        \n",
    "                for i in range(num_runs):\n",
    "                    y_points.append(perf_dict[tag][exp][dist_shift][use_best_or_last][i])\n",
    "                    x_points.append(num_par)           \n",
    "            \n",
    "            ax.scatter(x_points,y_points,color=colors[c])\n",
    "            label = r\"$N%i$\"%(num_train)\n",
    "            ax.plot(x_line,y_line,label=label,color=colors[c])\n",
    "        \n",
    "            \n",
    "\n",
    "    ax.legend(fontsize=fontsize-4)\n",
    "    ax.set_xlabel(\"Number of network parameters $P$\", fontsize=fontsize)\n",
    "    ax.tick_params(axis='both', which='major', labelsize=fontsize)\n",
    "    ax.tick_params(axis='both', which='minor', labelsize=fontsize-2)\n",
    "    ax.set_xscale('log')\n",
    "    if ylim:\n",
    "        ax.set_ylim(ylim)\n",
    "    if xlim:\n",
    "        ax.set_xlim(xlim)\n",
    "    ax.grid(True)\n",
    "    \n",
    "def generate_performance_dicts(dist_shifts,train_examples_tags,dist_exps_list):\n",
    "    if \"val_\" in dist_shifts:\n",
    "        pass\n",
    "    else:\n",
    "        raise ValueError(\"val_ metrics must be included to determine best experiment per setup\") \n",
    "\n",
    "\n",
    "    distinct_nums = []\n",
    "    exp_num_to_train_size = {}\n",
    "    for tag in train_examples_tags:\n",
    "        exp_num_to_train_size[tag] = []\n",
    "        for dist_exp in dist_exps_list:\n",
    "            if tag in dist_exp:\n",
    "                if not any(distinct_num in dist_exp for distinct_num in distinct_nums):\n",
    "                    exp_num_to_train_size[tag].append(dist_exp)\n",
    "                    distinct_nums.append(dist_exp)\n",
    "\n",
    "    all_exps_list = glob.glob('E*')\n",
    "\n",
    "    perf_dict = {}\n",
    "    for tag in train_examples_tags:\n",
    "        perf_dict[tag] = {}\n",
    "        for exp_num in exp_num_to_train_size[tag]:\n",
    "            perf_dict[tag][exp_num] = {}\n",
    "            for dist_shift in dist_shifts:\n",
    "                perf_dict[tag][exp_num][dist_shift] = {}\n",
    "                for ckpt in ['best','last']:\n",
    "                    perf_dict[tag][exp_num][dist_shift][ckpt] = []\n",
    "\n",
    "    for exp in all_exps_list:\n",
    "        metrics_list = glob.glob(exp+'/log_files/metrics*')\n",
    "\n",
    "        for dist_exp in dist_exps_list:\n",
    "            if exp==dist_exp: \n",
    "                if 'run' in dist_exp:\n",
    "                    eind = dist_exp.find('_run')\n",
    "                    exp_num = dist_exp[0:eind]\n",
    "                else:\n",
    "                    exp_num = dist_exp\n",
    "                for tag in train_examples_tags: \n",
    "                    if tag in exp:\n",
    "                        for dist_shift in dist_shifts: \n",
    "                            ckpt = 'best'\n",
    "                            for metric in metrics_list:\n",
    "                                if dist_shift in metric and ckpt in metric: # find the correct metric file        \n",
    "                                    perf_dict[tag][exp_num][dist_shift][ckpt].append(pickle.load( open( metric, \"rb\" ) ).means()['SSIM'])\n",
    "\n",
    "    # Get best mean/std or median performance per trainset size\n",
    "    print('Mean/std performance:')\n",
    "    best_perf_dict = {}\n",
    "    for tag in train_examples_tags:\n",
    "        best_perf_dict[tag] = {}\n",
    "        for dist_shift in dist_shifts:\n",
    "            best_perf_dict[tag][dist_shift] = {}\n",
    "            best_perf_dict[tag][dist_shift]['mean'] = 0\n",
    "            best_perf_dict[tag][dist_shift]['max'] = 0\n",
    "            best_perf_dict[tag][dist_shift]['std'] = 0\n",
    "            best_perf_dict[tag][dist_shift]['median'] = 0\n",
    "            best_perf_dict[tag][dist_shift]['all'] = []\n",
    "\n",
    "\n",
    "    for dist_shift in dist_shifts:\n",
    "        print('\\n')\n",
    "        for tag in train_examples_tags:\n",
    "            best_per_exp_num = []\n",
    "            exp_nums = []\n",
    "            print('''{} {} all experiments:'''.format(dist_shift,tag))\n",
    "            for exp_num in exp_num_to_train_size[tag]:\n",
    "                use_best_or_last = 'best'\n",
    "\n",
    "                exp_nums.append(exp_num)\n",
    "                best_per_exp_num.append(np.max(perf_dict[tag][exp_num][\"val_\"][use_best_or_last])) #only use val metric to compare experiments\n",
    "                \n",
    "                print_all_psnr = [np.round(tt,4) for tt in perf_dict[tag][exp_num][dist_shift][use_best_or_last]]\n",
    "                print('''{} with SSIM mean {} max {} std {} all {}\\n'''.format(exp_num,\n",
    "                                                                            np.round(np.mean(perf_dict[tag][exp_num][dist_shift][use_best_or_last]),4),\n",
    "                                                                             np.round(np.max(perf_dict[tag][exp_num][dist_shift][use_best_or_last]),4),\n",
    "                                                                            np.round(np.std(perf_dict[tag][exp_num][dist_shift][use_best_or_last]),4),\n",
    "                                                                            print_all_psnr                                                  \n",
    "                                                                            ))\n",
    "\n",
    "            ind = np.where(best_per_exp_num==np.max(best_per_exp_num))[0][0]\n",
    "            best_exp_num = exp_nums[ind]\n",
    "                \n",
    "            best_perf_dict[tag][dist_shift]['mean'] = np.mean(perf_dict[tag][best_exp_num][dist_shift][use_best_or_last])\n",
    "            best_perf_dict[tag][dist_shift]['std'] = np.std(perf_dict[tag][best_exp_num][dist_shift][use_best_or_last])\n",
    "            best_perf_dict[tag][dist_shift]['median'] = np.median(perf_dict[tag][best_exp_num][dist_shift][use_best_or_last])\n",
    "            best_perf_dict[tag][dist_shift]['all'] = perf_dict[tag][best_exp_num][dist_shift][use_best_or_last]   \n",
    "            best_perf_dict[tag][dist_shift]['max'] = np.max(perf_dict[tag][best_exp_num][dist_shift][use_best_or_last])   \n",
    "            print('best experiment: {} with SSIM mean {} max {} std {} \\n'.format(best_exp_num,\n",
    "                                                                                        np.round(best_perf_dict[tag][dist_shift]['mean'],4),\n",
    "                                                                                  np.round(best_perf_dict[tag][dist_shift]['max'],4),\n",
    "                                                                                       np.round(best_perf_dict[tag][dist_shift]['std'],4)))\n",
    "    return best_perf_dict, perf_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eca6c154",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "077a73ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "dist_exps_list = glob.glob('E0*')\n",
    "dist_exps_list.sort()\n",
    "train_examples_tags = []\n",
    "for dist_exps in dist_exps_list:\n",
    "    tag = dist_exps[dist_exps.find('t'):dist_exps.find('_l')+1]\n",
    "    if tag not in train_examples_tags:\n",
    "        train_examples_tags.append(tag)\n",
    "        \n",
    "dist_shifts = [\"val_\",\"test_\"]\n",
    "best_perf_dict, perf_dict = generate_performance_dicts(dist_shifts,train_examples_tags,dist_exps_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3db22bcf",
   "metadata": {},
   "outputs": [],
   "source": [
    "######\n",
    "# Performance as function of number of training examples\n",
    "######\n",
    "\n",
    "\n",
    "fig = plt.figure(figsize=(20,7))\n",
    "fontsize = 22\n",
    "ax1 = fig.add_subplot(121)\n",
    "colors = ['b','r','k','g','m','b','r','k','g','m']\n",
    "dist_shifts = ['test_']\n",
    "\n",
    "ylim = []\n",
    "\n",
    "# Specify from which training set size to which training set size to fit a linear power law\n",
    "# e.g. [300,3000] or a list of start and end points to get several power laws, e.g. [[300,3000],[3000,10000]]\n",
    "linfit_bounds = [] \n",
    "\n",
    "plot_trainsize_scaling(ax1,best_perf_dict,train_examples_tags,dist_shifts,fontsize,linfit_bounds=linfit_bounds,\n",
    "                          colors=colors, ylim=ylim,xlim=None)\n",
    "\n",
    "\n",
    "######\n",
    "# Performance as function of number of network parameters\n",
    "######\n",
    "ax2 = fig.add_subplot(122)\n",
    "plot_parameter_scaling(ax2,perf_dict,train_examples_tags,dist_shifts,fontsize,colors,ylim=ylim,xlim=None)\n",
    "\n",
    "plt.savefig(\"Empirical_SL_CS.png\",dpi=150)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1bb05835",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
