{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c76b8620",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import glob\n",
    "import warnings\n",
    "import h5py as h5\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from copy import deepcopy\n",
    "from tqdm.auto import tqdm\n",
    "from IPython.core.display import display, HTML\n",
    "\n",
    "from gensit.config import Config\n",
    "from gensit.inputs import Inputs\n",
    "from gensit.outputs import Outputs\n",
    "from gensit.utils.misc_utils import *\n",
    "from gensit.utils.math_utils import *\n",
    "\n",
    "from gensit.utils.probability_utils import *\n",
    "from gensit.contingency_table import instantiate_ct\n",
    "from gensit.contingency_table.ContingencyTable_MCMC import ContingencyTableMarkovChainMonteCarlo\n",
    "\n",
    "from gensit.config import Config\n",
    "from gensit.inputs import Inputs\n",
    "from gensit.utils.misc_utils import *\n",
    "from gensit.static.plot_variables import *\n",
    "from gensit.static.global_variables import *\n",
    "from gensit.outputs import Outputs,OutputSummary\n",
    "\n",
    "\n",
    "# LaTeX font configuration\n",
    "mpl.rcParams.update(LATEX_RC_PARAMETERS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b718b815",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "# AUTO RELOAD EXTERNAL MODULES\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8785d438",
   "metadata": {},
   "source": [
    "# Plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ddac5e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "datapath = \"../data/outputs/cambridge_work_commuter_lsoas_to_msoas/exp1/paper_figures/figure2/cumulative_srmse_and_cp_by_method_label_title&sigma_marker_sigma_markersize_table_coverage_probability_size_linewidth_1.0_colour_title_opacity_1.0_hatchopacity_1.0\"\n",
    "srmseoutputpath =\"../data/outputs/cambridge_work_commuter_lsoas_to_msoas/exp1/paper_figures/figure2_v2/cumulative_srmse_by_method_label_title&sigma_marker_sigma_markersize_table_coverage_probability_size_linewidth_1.0_colour_title_opacity_1.0_hatchopacity_1.0\"\n",
    "cpoutputpath =\"../data/outputs/cambridge_work_commuter_lsoas_to_msoas/exp1/paper_figures/figure2_v2/cumulative_cp_by_method_label_title&sigma_marker_sigma_markersize_table_coverage_probability_size_linewidth_1.0_colour_title_opacity_1.0_hatchopacity_1.0\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63679bfc",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = read_json(datapath+'_data.json')\n",
    "settings = read_json(datapath+'_settings.json')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "315a4f84",
   "metadata": {},
   "outputs": [],
   "source": [
    "slice_vals = ['$\\\\mytablecolsums$, $\\\\sigma = 0.014$',\n",
    "'$\\\\mytablecolsums,\\\\mytablerowsums,\\\\mytablecells{_2}$, $\\\\sigma = 0.014$']\n",
    "slice_key = 'label'\n",
    "slice_index = []\n",
    "for i,v in enumerate(data[slice_key]):\n",
    "    if v in slice_vals and ('Disjoint' not in data['label'][i]):\n",
    "        slice_index.append(i)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a9b595bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# index2 = []\n",
    "# for i,v in enumerate(data['label']):\n",
    "#     if v == '$\\\\mathbf{T}_{+\\\\cdot},\\\\mathbf{T}_{\\\\cdot +}$, $\\\\sigma = 0.014$' and data['x'][i][0] == '\\zachosframeworktag':\n",
    "#         index2.append(i)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20a41ca0",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_slice = deepcopy(data)\n",
    "for k in [j for j in data.keys() if j not in ['outputs','x_group','y_group','z_group','annotate','hatch']]:\n",
    "    data_slice[k] = np.array(data_slice[k])[slice_index].tolist()\n",
    "data_slice['newlabel'] = np.array(data_slice['x'])[:,0].tolist()\n",
    "data_slice['newx'] = np.array(data_slice['x'])[:,1].tolist()\n",
    "for k,v in data_slice.items():\n",
    "    print(k,np.shape(v))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b56b102",
   "metadata": {},
   "outputs": [],
   "source": [
    "fontsize = 30"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69de2600",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig,ax = plt.subplots(figsize=settings['figure_size'])\n",
    "cs = ['#8ebeda', '#8ebeda', '#a6c858', '#a6c858']#'#8ebeda' '#a6c858', '#ca4a58', '#e0ad41'\n",
    "markers = ['o','^','o','^']\n",
    "for jindx, j in enumerate(list(range(0,len(cs)*6,6))):\n",
    "    # print(jindx,(j,j+6))\n",
    "    for i in range(j,j+6):\n",
    "        print(i,(data_slice['label'][i]+', \\;'+data_slice['newlabel'][i]).replace(\", $\\\\sigma = 0.014$\",\"\"))\n",
    "        _ = ax.scatter(\n",
    "            list(map(int,data_slice['newx']))[i],\n",
    "            np.array(data_slice['y'][i])[0],\n",
    "            linewidth = 1.0,\n",
    "            alpha=1.0,\n",
    "            c = cs[jindx],\n",
    "            marker = markers[jindx],\n",
    "            label = (data_slice['label'][i]+', \\;'+data_slice['newlabel'][i]).replace(\", $\\\\sigma = 0.014$\",\"\"),\n",
    "        )\n",
    "    _ = ax.plot(\n",
    "        list(map(int,data_slice['newx']))[slice(j,j+6)],\n",
    "        np.array(data_slice['y'][slice(j,j+6)])[:,0],\n",
    "        linewidth = 1.0,\n",
    "        c = cs[jindx],\n",
    "        marker = markers[jindx]\n",
    "    )\n",
    "ax.tick_params(labelsize=fontsize)\n",
    "plt.xlabel(r'$N$',fontsize=fontsize)\n",
    "plt.ylabel(r'SRMSE',fontsize=fontsize)\n",
    "handles, labels = plt.gca().get_legend_handles_labels()\n",
    "by_label = dict(zip(labels, handles))\n",
    "_ = plt.legend(by_label.values(), by_label.keys(),fontsize=20)\n",
    "# fig.tight_layout(rect=(0, 0, 0.7, 1.1))\n",
    "plt.show()\n",
    "write_figure(fig,srmseoutputpath,filename_ending=settings['filename_ending'],figure_format='pdf')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0fd7fe7e",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig,ax = plt.subplots(figsize=settings['figure_size'])\n",
    "cs = ['#8ebeda', '#8ebeda', '#a6c858', '#a6c858']#'#8ebeda' '#a6c858', '#ca4a58', '#e0ad41'\n",
    "markers = ['o','^','o','^']\n",
    "for jindx, j in enumerate(list(range(0,len(cs)*6,6))):\n",
    "    # print(jindx,(j,j+6))\n",
    "    for i in range(j,j+6):\n",
    "        print(i,(data_slice['label'][i]+', \\;'+data_slice['newlabel'][i]).replace(\", $\\\\sigma = 0.014$\",\"\"))\n",
    "        _ = ax.scatter(\n",
    "            list(map(int,data_slice['newx']))[i],\n",
    "            100*(np.log(data_slice['marker_size'][i])+2)/8,\n",
    "            linewidth = 1.0,\n",
    "            alpha=1.0,\n",
    "            c = cs[jindx],\n",
    "            marker = markers[jindx],\n",
    "            label = (data_slice['label'][i]+', \\;'+data_slice['newlabel'][i]).replace(\", $\\\\sigma = 0.014$\",\"\"),\n",
    "        )\n",
    "    _ = ax.plot(\n",
    "        list(map(int,data_slice['newx']))[slice(j,j+6)],\n",
    "        100*(np.log(data_slice['marker_size'][slice(j,j+6)])+2)/8,\n",
    "        linewidth = 1.0,\n",
    "        c = cs[jindx],\n",
    "        marker = markers[jindx]\n",
    "    )\n",
    "ax.tick_params(labelsize=fontsize)\n",
    "plt.xlabel(r'$\\color{red}{N}$',fontsize=fontsize)\n",
    "plt.ylabel(r'CP',fontsize=fontsize)\n",
    "handles, labels = plt.gca().get_legend_handles_labels()\n",
    "by_label = dict(zip(labels, handles))\n",
    "_ = plt.legend(by_label.values(), by_label.keys(),fontsize=20)\n",
    "# fig.tight_layout(rect=(0, 0, 0.7, 1.1))\n",
    "plt.show()\n",
    "write_figure(fig,cpoutputpath,filename_ending=settings['filename_ending'],figure_format='pdf')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "880cdfcb",
   "metadata": {},
   "source": [
    "## Import samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e21669cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Specify experiment id\n",
    "experiment_id = \"GBRT_Comparison_UnsetNoise__doubly_and_cell_constrained_25_04_2024_18_41_52\"\n",
    "# Specify experiment group id\n",
    "dataset = 'DC'\n",
    "#'DC'\n",
    "#'cambridge_work_commuter_lsoas_to_msoas'\n",
    "#'sioux_falls'\n",
    "experiment_group_id = 'comparisons'\n",
    "# 'r_squared'\n",
    "# 'exp1'\n",
    "# 'comparisons'\n",
    "experiment_dir = f'../data/outputs/{dataset}/{experiment_group_id}/{experiment_id}/'\n",
    "relative_experiment_dir = os.path.relpath(experiment_dir,os.getcwd())\n",
    "\n",
    "# Create new logging object\n",
    "logger = setup_logger(\n",
    "    __name__,\n",
    "    console_level = 'PROGRESS',\n",
    "    file_level = 'EMPTY'\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "41e02907",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Output processing settings\n",
    "settings = {\n",
    "    \"logging_mode\": \"PROGRESS\",\n",
    "    \"coordinate_slice\": [\n",
    "        # \"da.loss_name.isin([str(['dest_attraction_ts_likelihood_loss']),str(['dest_attraction_ts_likelihood_loss', 'table_likelihood_loss']),str(['table_likelihood_loss'])])\"\n",
    "        # \"da.loss_name == str(['dest_attraction_ts_likelihood_loss'])\",\n",
    "        # \"da.cost_matrix == 'cost_matrix_max_normalised.txt'\"\n",
    "    ],\n",
    "    # \"coordinate_slice\": [\n",
    "    #     \"da.destination_attraction_ts == 'destination_attraction_housing_units_ts_sum_normalised.txt'\",\n",
    "    #     \"da.cost_matrix == 'cost_matrix_sum_normalised.txt'\",\n",
    "    #     \"da.title == '_row_constrained'\",\n",
    "    #     \"da.bmax == 1.0\"\n",
    "    #     # \"da.loss_name == str(['dest_attraction_ts_likelihood_loss'])\",\n",
    "    #     # \"~da.title.isin([str('_unconstrained'), str('_total_constrained')])\"\n",
    "    # ],\n",
    "    \"metadata_keys\":[],\n",
    "    \"burnin_thinning_trimming\": [],\n",
    "    # \"burnin_thinning_trimming\": [{'iter': {\"burnin\":10000, \"thinning\":90, \"trimming\":1000}}],\n",
    "    \"n_workers\": 1,\n",
    "    \"group_by\":[],\n",
    "    \"filename_ending\":\"test\",\n",
    "    \"sample\":[\"intensity\"],\n",
    "    \"force_reload\":False\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "833a9fad",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialise outputs\n",
    "current_sweep_outputs = Outputs(\n",
    "    config = os.path.join(relative_experiment_dir,\"config.json\"),\n",
    "    settings = settings,\n",
    "    inputs = None,\n",
    "    slice = True,\n",
    "    level = 'INFO'\n",
    ")\n",
    "# Silence outputs\n",
    "# current_sweep_outputs.logger.setLevels(console_level='EMPTY')\n",
    "# Load all data\n",
    "current_sweep_outputs.load()\n",
    "\n",
    "print(len(current_sweep_outputs.data),'experiments matched')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "109abcb3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# conf = Config(\n",
    "#     path = os.path.join(relative_experiment_dir,\"config.json\")\n",
    "# )\n",
    "# ins = Inputs(\n",
    "#     config = conf\n",
    "# )\n",
    "# ins.cast_to_xarray()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "567bc38f",
   "metadata": {},
   "source": [
    "# $R^2$ analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "594f7384",
   "metadata": {},
   "outputs": [],
   "source": [
    "sweep_outputs_slices = []\n",
    "for i in tqdm(range(len(current_sweep_outputs.data)),leave=False,desc='Finding best R2 experiments'):\n",
    "    current_sweep_outputs_slice = current_sweep_outputs.get(i)\n",
    "    current_r2 = current_sweep_outputs_slice.data.r2\n",
    "    if np.max(current_r2) > 0.6:\n",
    "        sweep_outputs_slices.append(current_sweep_outputs_slice)\n",
    "print(len(sweep_outputs_slices),'experiments kept')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "adc5111c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# index = 15\n",
    "# sweep_outputs_slice = sweep_outputs_slices[index]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f469501",
   "metadata": {},
   "outputs": [],
   "source": [
    "# data_index = 0\n",
    "# sweep_outputs_slice = current_sweep_outputs.get(data_index)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3eb1df35",
   "metadata": {},
   "outputs": [],
   "source": [
    "r2 = sweep_outputs_slice.data.r2\n",
    "alpha_range = current_sweep_outputs.config['experiments'][0]['grid_ranges']['alpha']\n",
    "r2['alpha_range'] = np.linspace(alpha_range['min'],alpha_range['max'],alpha_range['n'],endpoint=True)\n",
    "r2['alpha_range'] = r2['alpha_range'].values\n",
    "beta_range = current_sweep_outputs.config['experiments'][0]['grid_ranges']['beta']\n",
    "r2['beta_range'] = np.linspace(beta_range['min'],beta_range['max'],beta_range['n'],endpoint=True)\n",
    "r2['beta_range'] = r2['beta_range'].values\n",
    "r2.coords"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac9d884a",
   "metadata": {},
   "outputs": [],
   "source": [
    "argmax_index = np.unravel_index(np.argmax(r2.values.squeeze()), np.shape(r2.values.squeeze()))\n",
    "plt.figure(figsize=(20,20))\n",
    "plt.imshow(r2, cmap='RdYlGn', interpolation='nearest')\n",
    "plt.scatter(argmax_index[1],argmax_index[0],marker='x',color='black',s=500)\n",
    "plt.yticks(ticks=range(len(r2['alpha_range'])),labels=np.round(r2['alpha_range'].values,2))\n",
    "plt.ylabel('alpha')\n",
    "plt.xticks(ticks=range(len(r2['beta_range'])),labels=np.round(r2['beta_range'].values,2))\n",
    "plt.xlabel('beta')\n",
    "for i in range(len(r2['alpha_range'])):\n",
    "    for j in range(len(r2['beta_range'])):\n",
    "        plt.text(j,i,s=np.round(r2.squeeze().values[i,j],2),fontsize=8)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6a140648",
   "metadata": {},
   "source": [
    "# SIM Analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5eb796d",
   "metadata": {},
   "outputs": [],
   "source": [
    "index = 0\n",
    "current_data = current_sweep_outputs.get(index)\n",
    "print('# Sweeps:',len(current_sweep_outputs.data))\n",
    "print(current_data.data.intensity.coords.items())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa8fb841",
   "metadata": {},
   "outputs": [],
   "source": [
    "ins = Inputs(\n",
    "    config = current_data.config\n",
    ")\n",
    "ins.cast_to_xarray()\n",
    "test_cells = read_file('../data/inputs/DC/test_cells.txt').astype('int32')\n",
    "train_cells = read_file('../data/inputs/DC/train_cells.txt').astype('int32')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a1e911f",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_table_error = srmse(\n",
    "    prediction = current_data.data.table.mean('id',dtype='float64'),\n",
    "    ground_truth = ins.data.ground_truth_table\n",
    ")\n",
    "train_table_error = srmse(\n",
    "    prediction = current_data.data.table.mean('id',dtype='float64'),\n",
    "    ground_truth = ins.data.ground_truth_table,\n",
    "    cells = train_cells\n",
    ")\n",
    "test_table_error = srmse(\n",
    "    prediction = current_data.data.table.mean('id',dtype='float64'),\n",
    "    ground_truth = ins.data.ground_truth_table,\n",
    "    cells = test_cells\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc905be7",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\n",
    "    all_table_error.values.squeeze().item(),\n",
    "    train_table_error.values.squeeze().item(),\n",
    "    test_table_error.values.squeeze().item()\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7090ffd",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_intensity_error = srmse(\n",
    "    prediction = current_data.get_sample('intensity').mean('id',dtype='float64'),\n",
    "    ground_truth = ins.data.ground_truth_table\n",
    ")\n",
    "train_intensity_error = srmse(\n",
    "    prediction = current_data.get_sample('intensity').mean('id',dtype='float64'),\n",
    "    ground_truth = ins.data.ground_truth_table,\n",
    "    cells = train_cells\n",
    ")\n",
    "test_intensity_error = srmse(\n",
    "    prediction = current_data.get_sample('intensity').mean('id',dtype='float64'),\n",
    "    ground_truth = ins.data.ground_truth_table,\n",
    "    cells = test_cells\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dcf96b7f",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\n",
    "    all_intensity_error.values.squeeze().item(),\n",
    "    train_intensity_error.values.squeeze().item(),\n",
    "    test_intensity_error.values.squeeze().item()\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca0eecab",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_table_cp = coverage_probability(\n",
    "    prediction = current_data.data.table,\n",
    "    ground_truth = ins.data.ground_truth_table,\n",
    "    region_mass = 0.95\n",
    ")\n",
    "train_table_cp = coverage_probability(\n",
    "    prediction = current_data.data.table,\n",
    "    ground_truth = ins.data.ground_truth_table,\n",
    "    region_mass = 0.95,\n",
    "    cells = train_cells\n",
    ")\n",
    "test_table_cp = coverage_probability(\n",
    "    prediction = current_data.data.table,\n",
    "    ground_truth = ins.data.ground_truth_table,\n",
    "    region_mass = 0.95,\n",
    "    cells = test_cells\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed196756",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_cp = all_table_cp\n",
    "test_cp = train_table_cp\n",
    "test_cp = test_table_cp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f6404451",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\n",
    "    all_table_cp.mean(['origin','destination'],skipna=True).values.item(),\n",
    "    train_table_cp.mean(['origin','destination'],skipna=True).values.item(),\n",
    "    test_table_cp.mean(['origin','destination'],skipna=True).values.item()\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c60450c",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_intensity_cp = coverage_probability(\n",
    "    prediction = current_data.get_sample('intensity'),\n",
    "    ground_truth = ins.data.ground_truth_table,\n",
    "    region_mass = 0.95\n",
    ")\n",
    "train_intensity_cp = coverage_probability(\n",
    "    prediction = current_data.get_sample('intensity'),\n",
    "    ground_truth = ins.data.ground_truth_table,\n",
    "    region_mass = 0.95,\n",
    "    cells = train_cells\n",
    ")\n",
    "test_intensity_cp = coverage_probability(\n",
    "    prediction = current_data.get_sample('intensity'),\n",
    "    ground_truth = ins.data.ground_truth_table,\n",
    "    region_mass = 0.95,\n",
    "    cells = test_cells\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30e55924",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\n",
    "    all_intensity_cp.mean(['origin','destination'],skipna=True).values.item(),\n",
    "    train_intensity_cp.mean(['origin','destination'],skipna=True).values.item(),\n",
    "    test_intensity_cp.mean(['origin','destination'],skipna=True).values.item()\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45057ba4",
   "metadata": {},
   "outputs": [],
   "source": [
    "_ = plt.scatter(\n",
    "    np.exp(current_data.data.log_destination_attraction).mean('id').values.squeeze(),\n",
    "    ins.data.destination_attraction_ts.squeeze()\n",
    ")\n",
    "plt.xlabel(\"Predictions\")\n",
    "plt.ylabel(\"Data\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29ae319d",
   "metadata": {},
   "outputs": [],
   "source": [
    "_ = plt.hist(current_data.data.alpha.squeeze().values,bins=30)\n",
    "plt.xlabel('alpha')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48a46a95",
   "metadata": {},
   "outputs": [],
   "source": [
    "_ = plt.hist(current_data.data.beta.squeeze().values,bins=30)\n",
    "plt.xlabel('beta')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a45db478",
   "metadata": {},
   "outputs": [],
   "source": [
    "_ = plt.hist2d(\n",
    "    current_data.data.beta.squeeze().values,\n",
    "    current_data.data.alpha.squeeze().values,\n",
    "    bins = 30\n",
    ")\n",
    "plt.ylabel('alpha')\n",
    "plt.xlabel('beta')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "gensit",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
