{
 "metadata": {
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  },
  "orig_nbformat": 2,
  "kernelspec": {
   "name": "python3710jvsc74a57bd0c5e72e1853e3e82b0c26ee06677180d9d6dc5e90f81bb00734beec12adee3b52",
   "display_name": "Python 3.7.10 64-bit ('ssmf': conda)"
  },
  "metadata": {
   "interpreter": {
    "hash": "c5e72e1853e3e82b0c26ee06677180d9d6dc5e90f81bb00734beec12adee3b52"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2,
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import multiprocessing\n",
    "import os\n",
    "import shutil\n",
    "import sys\n",
    "sys.path.append(\"../src\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import matplotlib as mpl\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "from sklearn import metrics\n",
    "from tqdm import tqdm, trange\n",
    "import utils\n",
    "from ncp import ncp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.set_style(\"whitegrid\")\n",
    "sns.set_palette(\"muted\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_outpath(forecast_step, name, params, step, trial_id=0):\n",
    "    return os.path.join(\n",
    "        RESULTDIR,\n",
    "        name,\n",
    "        f\"forecast_step={forecast_step}\",\n",
    "        *[\"=\".join([k, str(v)]) for k, v in params.items()],\n",
    "        str(trial_id),\n",
    "        str(step)\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_metrics(true, pred):\n",
    "    mask = true > 0\n",
    "    if mask.sum() == 0:\n",
    "        return 0\n",
    "    else:\n",
    "        return np.sqrt(\n",
    "            metrics.mean_squared_error(\n",
    "                true[mask].ravel(), pred[mask].ravel()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_regimes(data, f_step, start_step=0, holdout_step=1e10, trial_id=0,\n",
    "                 k=5, a=0.1, r=100, compression=False):\n",
    "    params = {\n",
    "        \"n_components\": k,\n",
    "        \"n_regimes\": r,\n",
    "        \"alpha\": a,\n",
    "        \"compression\": compression\n",
    "    }\n",
    "\n",
    "    if start_step == 0:\n",
    "        start_step = f_step\n",
    "\n",
    "    for step in trange(start_step, data.shape[-1] - f_step, f_step):\n",
    "        if step > holdout_step:\n",
    "            break\n",
    "\n",
    "        path = get_outpath(f_step, \"assmf\", params, step, trial_id)\n",
    "        \n",
    "        R = np.loadtxt(path + \"/R.txt\")\n",
    "        O = np.loadtxt(path + \"/O.txt\")\n",
    "\n",
    "    print()\n",
    "    print(\"R:\", R.shape)\n",
    "    print(\"O:\", O.shape)\n",
    "    return R, O\n",
    "\n",
    "    fig, ax = plt.subplots(2, figsize=(5, 5))\n",
    "\n",
    "    ax[0].plot(R)\n",
    "    regime_added_points = O == 1\n",
    "    ax[1].scatter(\n",
    "        np.arange(len(O))[regime_added_points],\n",
    "        O[regime_added_points],\n",
    "        marker='x', s=4)\n",
    "\n",
    "    ax[0].set_xlim(0, len(O))\n",
    "    ax[1].set_xlim(0, len(O))\n",
    "    ax[1].tick_params(axis=\"y\")\n",
    "    return fig, ax"
   ]
  },
  {
   "source": [
    "## Plot seasonal factors"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_factors(f_step, step, start_step=0, holdout_step=1e10, trial_id=0,\n",
    "                 k=5, a=0.1, r=100, compression=False):\n",
    "\n",
    "    params = {\n",
    "        \"n_components\": k,\n",
    "        \"n_regimes\": r,\n",
    "        \"alpha\": a,\n",
    "        \"compression\": compression\n",
    "    }\n",
    "\n",
    "    path = get_outpath(f_step, \"assmf\", params, step, trial_id)\n",
    "    U = np.loadtxt(path + \"/U0.txt\")\n",
    "    V = np.loadtxt(path + \"/U1.txt\")\n",
    "    return U, V"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_seasonal_factor(f_step, start_step=0, holdout_step=1e10, trial_id=0,\n",
    "                 k=5, a=0.1, r=100, compression=False):\n",
    "    params = {\n",
    "        \"n_components\": k,\n",
    "        \"n_regimes\": r,\n",
    "        \"alpha\": a,\n",
    "        \"compression\": compression\n",
    "    }\n",
    "\n",
    "    if start_step == 0:\n",
    "        start_step = f_step\n",
    "\n",
    "    path = \"\"\n",
    "    for step in trange(start_step, data.shape[-1] - f_step, f_step):\n",
    "        if step > holdout_step:\n",
    "            break\n",
    "\n",
    "        path = get_outpath(f_step, \"assmf\", params, step, trial_id)\n",
    "\n",
    "    return np.load(path + \"/W.npy\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_seasonal_factor(W, ridx, date_index, week_index, period, cidx, \n",
    "                         freq=\"H\", offset=24*5, fs=14, fn=None, ax=None, smooth=True):\n",
    "    \n",
    "    st = week_index * period + offset\n",
    "    ed = st + period\n",
    "    Wi = W[ridx, st:ed]\n",
    "    print(st, date_index[st], ed, date_index[ed])\n",
    "\n",
    "    if ax is None:\n",
    "        fig, ax = plt.subplots(figsize=(4, 4))\n",
    "\n",
    "    target = Wi[:, cidx]\n",
    "\n",
    "    if smooth == True:\n",
    "        tmp = pd.DataFrame(np.tile(target, 3))\n",
    "        smoothed = tmp.rolling(3).mean()\n",
    "        target = smoothed.values[period+1:period*2+1]\n",
    "\n",
    "    ax.plot(target, color='#072127')\n",
    "    xticks = range(0, len(Wi), 24)\n",
    "    ax.set_xticks(xticks)\n",
    "    ax.set_xticklabels([\"M\", \"T\", \"W\", \"T\", \"F\", \"S\", \"S\"])\n",
    "    ax.tick_params(axis=\"both\", labelsize=fs-2)\n",
    "\n",
    "    if fn is not None:\n",
    "        fig.savefig(fn)\n",
    "\n",
    "    return"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_nyc(data, longitude_key, latitude_key, value_key):\n",
    "    inProj = Proj(init='epsg:4326')\n",
    "    outProj = Proj(init='epsg:3857')\n",
    "    # NYC map data\n",
    "    gdf = geopandas.read_file(geopandas.datasets.get_path(\"nybb\"))\n",
    "    gdf = gdf.to_crs(epsg=3857)\n",
    "    gdf.boundary.plot(ax=ax, alpha=0.5, edgecolor=\"black\")\n",
    "    # around the main area\n",
    "    ax.set_xlim(-8.245*1e6, -8.225*1e6)\n",
    "    ax.set_ylim(4.964*1e6, 4.980*1e6)\n",
    "\n",
    "    x, y = transform(inProj, outProj, data[longitude_key], data[latitude_key])\n",
    "    data[\"x\"] = x\n",
    "    data[\"y\"] = y\n",
    "    # print(station_data.head())\n",
    "    res = pd.merge(target, station_data, left_on=\"start_station_dim\", right_on=\"index\")\n",
    "    \n",
    "    ax.scatter(x=res.x, y=res.y, c=res[\"count\"], marker='o', cmap=\"Reds\")\n",
    "    return ax"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import shapefile\n",
    "from shapely.geometry import Polygon\n",
    "from descartes.patch import PolygonPatch\n",
    "from sklearn.preprocessing import minmax_scale\n",
    "from mpl_toolkits.axes_grid1 import make_axes_locatable\n",
    "\n",
    "def get_lat_lon(sf, shp_dic):\n",
    "    content = []\n",
    "    for sr in sf.shapeRecords():\n",
    "        shape = sr.shape\n",
    "        rec = sr.record\n",
    "        loc_id = rec[shp_dic['LocationID']]\n",
    "        \n",
    "        x = (shape.bbox[0]+shape.bbox[2])/2\n",
    "        y = (shape.bbox[1]+shape.bbox[3])/2\n",
    "        \n",
    "        content.append((loc_id, x, y))\n",
    "    return pd.DataFrame(content, columns=[\"LocationID\", \"longitude\", \"latitude\"])\n",
    "\n",
    "def get_boundaries(sf):\n",
    "    lat, lon = [], []\n",
    "    for shape in list(sf.iterShapes()):\n",
    "        lat.extend([shape.bbox[0], shape.bbox[2]])\n",
    "        lon.extend([shape.bbox[1], shape.bbox[3]])\n",
    "\n",
    "    margin = 0.01 # buffer to add to the range\n",
    "    lat_min = min(lat) - margin\n",
    "    lat_max = max(lat) + margin\n",
    "    lon_min = min(lon) - margin\n",
    "    lon_max = max(lon) + margin\n",
    "\n",
    "\n",
    "    return lat_min, lat_max, lon_min, lon_max\n",
    " \n",
    "def toPlot_dict(df, k, loc_hash, max_=10):\n",
    "    \"\"\"\n",
    "    input\n",
    "    -----\n",
    "    df:O or A matrix (#topic, index)\n",
    "    k:rank\n",
    "    \n",
    "    output\n",
    "    ------\n",
    "    dict:{index:[#topic,intesity]}\n",
    "    \"\"\"\n",
    "    if max_< 1:\n",
    "        dict_={}\n",
    "        for i in range(k):\n",
    "            index = df[df[i]>max_].index\n",
    "            for ind in index:\n",
    "                dict_[loc_hash[ind]] = [i,df[i][ind]]\n",
    "    else:\n",
    "        dict_={}\n",
    "        for i in range(k):\n",
    "            index = getmax_rev(df[i],topnum=max_,getindex=True)\n",
    "            for ind in index:\n",
    "                dict_[loc_hash[ind]] = [i,df[i][ind]]\n",
    "\n",
    "    return dict_\n",
    "\n",
    "# pu_loc_yt=[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,105,106,107,108,109,111,112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127,128,129,130,131,132,133,134,135,136,137,138,139,140,141,142,143,144,145,146,147,148,149,150,151,152,153,154,155,156,157,158,159,160,161,162,163,164,165,166,167,168,169,170,171,172,173,174,175,176,177,178,179,180,181,182,183,184,185,186,187,188,189,190,191,192,193,194,195,196,197,198,199,200,201,202,203,204,205,206,207,208,209,210,211,212,213,214,215,216,217,218,219,220,221,222,223,224,225,226,227,228,229,230,231,232,233,234,235,236,237,238,239,240,241,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,261,262,263,264,265]\n",
    "\n",
    "do_loc_yt=[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,104,105,106,107,108,109,111,112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127,128,129,130,131,132,133,134,135,136,137,138,139,140,141,142,143,144,145,146,147,148,149,150,151,152,153,154,155,156,157,158,159,160,161,162,163,164,165,166,167,168,169,170,171,172,173,174,175,176,177,178,179,180,181,182,183,184,185,186,187,188,189,190,191,192,193,194,195,196,197,198,199,200,201,202,203,204,205,206,207,208,209,210,211,212,213,214,215,216,217,218,219,220,221,222,223,224,225,226,227,228,229,230,231,232,233,234,235,236,237,238,239,240,241,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,261,262,263,264,265]\n",
    "\n",
    "def plot_nyc_zone(component, hash_list, ax=None, cbar=False):\n",
    "\n",
    "    sns.set_style(\"dark\")\n",
    "    # cm = plt.get_cmap(\"Reds\")\n",
    "    # cm = sns.cm.crest\n",
    "    cm = sns.cm.mako_r\n",
    "\n",
    "    if ax is None:\n",
    "        fig, ax = plt.subplots(figsize=(5, 5))\n",
    "    else:\n",
    "        fig = None\n",
    "\n",
    "    src = shapefile.Reader(\"../dat/taxi_zones/taxi_zones.shp\")\n",
    "    fields_name = [field[0] for field in src.fields[1:]]\n",
    "    shp_dic = dict(zip(fields_name, list(range(len(fields_name)))))\n",
    "    attributes = src.records()\n",
    "    shp_attr = [dict(zip(fields_name, attr)) for attr in attributes]\n",
    "    df_loc = pd.DataFrame(shp_attr)\n",
    "    df_loc = df_loc.join(\n",
    "        get_lat_lon(src, shp_dic).set_index(\"LocationID\"),\n",
    "        on=\"LocationID\")\n",
    "\n",
    "    target = component.copy()\n",
    "    # target[target > 0.5] = 0.5\n",
    "    # target = minmax_scale(target)\n",
    "\n",
    "    for shprec in src.iterShapeRecords():\n",
    "        shp = shprec.shape\n",
    "        rec = shprec.record\n",
    "        loc_id = rec[shp_dic['LocationID']]  # station ID\n",
    "        zone = rec[shp_dic['zone']]\n",
    "        try:\n",
    "            loc_id_2 = hash_list.index(loc_id)\n",
    "        except:\n",
    "            continue\n",
    "\n",
    "        # Plot\n",
    "        nparts = len(shp.parts) # total parts\n",
    "        alpha = target[loc_id_2]\n",
    "        rgb = [0.6, 0.6, 0.6]\n",
    "\n",
    "        if nparts == 1:\n",
    "            polygon = Polygon(shp.points)\n",
    "            patch = PolygonPatch(polygon, fc=cm(alpha), ec=rgb, lw=0.1, alpha=1, zorder=0)\n",
    "            ax.add_patch(patch)\n",
    "\n",
    "        else: # loop over parts of each shape, plot separately\n",
    "            for ip in range(nparts): # loop over parts, plot separately\n",
    "                i0 = shp.parts[ip]\n",
    "                i1 = shp.parts[ip+1]-1 if ip < nparts-1 else len(shp.points)\n",
    "                polygon = Polygon(shp.points[i0:i1 + 1])\n",
    "                patch = PolygonPatch(polygon, fc=cm(alpha), ec=rgb, lw=0.1, alpha=1, zorder=0)\n",
    "                ax.add_patch(patch)\n",
    "        \n",
    "        \n",
    "        x = (shp.bbox[0]+shp.bbox[2]) / 2\n",
    "        y = (shp.bbox[1]+shp.bbox[3]) / 2\n",
    "        \n",
    "        # break  # End Plot\n",
    "\n",
    "    limits = get_boundaries(src)\n",
    "    ax.axes.xaxis.set_ticklabels([])\n",
    "    ax.axes.yaxis.set_ticklabels([])\n",
    "    ax.set_xlim(limits[0] + 58000, limits[1])\n",
    "    ax.set_ylim(limits[2] + 30000, limits[3])\n",
    "\n",
    "    if cbar ==True:\n",
    "        cmap = sns.cm.rocket_r\n",
    "        # plt.colorbar(mpl.cm.ScalarMappable(norm=target, cmap=cmap))\n",
    "        # im = ax.imshow(target.reshape(-1, 1), cmap=cmap)\n",
    "        # divider = make_axes_locatable(ax)\n",
    "        # cax = divider.append_axes('left', '5%', pad='3%')\n",
    "        # plt.colorbar(im, cax=cax)\n",
    "        # plt.colorbar(im, ax=ax)\n",
    "\n",
    "    return fig, ax"
   ]
  },
  {
   "source": [
    "## Plot overview of regimes"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_regime_analysis(data, date_range, cidx=0,\n",
    "        step1=1000, step2=2500,  # Lag of regime1 and 2\n",
    "        k=15, a=0.3, f_step=500, period=168):\n",
    "\n",
    "    spc_1 = \"#1b4f77\"  # regime color 1\n",
    "    spc_2 = \"#92384D\"  # regime color 2\n",
    "\n",
    "    sns.set_style(\"white\")\n",
    "    fig = plt.figure(constrained_layout=False, figsize=(10, 6))\n",
    "    gs = fig.add_gridspec(6, 10)\n",
    "\n",
    "    # Axes for Segmentation\n",
    "    f_seg = fig.add_subplot(gs[:, 0])\n",
    "    # Axes for Regime 1\n",
    "    f_r1W = fig.add_subplot(gs[:3, 1:4])\n",
    "    f_r1U = fig.add_subplot(gs[:3, 4:7])\n",
    "    f_r1V = fig.add_subplot(gs[:3, 7:10])\n",
    "    # Axes for Regime 2\n",
    "    f_r2W = fig.add_subplot(gs[3:, 1:4])\n",
    "    f_r2U = fig.add_subplot(gs[3:, 4:7])\n",
    "    f_r2V = fig.add_subplot(gs[3:, 7:10])\n",
    "    \n",
    "    # Segmentation result\n",
    "    R, O = get_regimes(data, f_step, k=k, a=a)\n",
    "    R[R > 0] = 1\n",
    "    R = R[np.newaxis]\n",
    "    # f_seg.plot(O.T)\n",
    "    sns.heatmap(R.T, ax=f_seg, cmap=[spc_1, spc_2], cbar=False)\n",
    "    # sns.heatmap(pd.get_dummies(R), ax=f_seg, cbar=False)\n",
    "    xticks  = [t for t in range(0, len(date_range), 250)]\n",
    "    date_ticks = [date_range[t].date() for t in range(0, len(date_range), 250)]\n",
    "    f_seg.set_yticks(xticks)\n",
    "    f_seg.set_yticklabels(date_ticks)\n",
    "    f_seg.set_xticklabels([])\n",
    "    f_seg.tick_params(axis='both', labelsize=12)\n",
    "\n",
    "    # Seasonal compoents\n",
    "    ridx1, ridx2 = 0, 5\n",
    "    week_idx1 = 4\n",
    "    week_idx2 = 13\n",
    "\n",
    "    W = get_seasonal_factor(f_step, start_step=1000, k=k, a=a)\n",
    "    W = W[:, period:]  # remove initialization steps\n",
    "\n",
    "    plot_seasonal_factor(W, ridx1, date_range, week_idx1, period, cidx, ax=f_r1W, smooth=False)\n",
    "    plot_seasonal_factor(W, ridx2, date_range, week_idx2, period, cidx, ax=f_r2W, smooth=False)\n",
    "\n",
    "    U1, V1 = get_factors(f_step, step1, k=k, a=a)# sns.heatmap(U)\n",
    "    U2, V2 = get_factors(f_step, step2, k=k, a=a)# sns.heatmap(U)\n",
    "\n",
    "    pu_loc_yt = list(range(1, 266))\n",
    "\n",
    "    plot_nyc_zone(U1[:, cidx], pu_loc_yt, ax=f_r1U, cbar=True)\n",
    "    plot_nyc_zone(V1[:, cidx], do_loc_yt, ax=f_r1V)\n",
    "    plot_nyc_zone(U2[:, cidx], pu_loc_yt, ax=f_r2U, cbar=True)\n",
    "    plot_nyc_zone(V2[:, cidx], do_loc_yt, ax=f_r2V)\n",
    "    # color bar\n",
    "    # cmap = sns.cm.rocket_r\n",
    "    # cmap = sns.cm.crest\n",
    "    cmap = sns.cm.mako_r\n",
    "    ax_im1 = fig.add_axes((0.5,0.71,0.015,0.4))\n",
    "    ax_im2 = fig.add_axes((0.5,0.23,0.015,0.4))\n",
    "    im1 = ax_im1.imshow(np.concatenate([U1[:, cidx], V1[:, cidx]]).reshape(-1, 1), aspect=3, cmap=cmap)\n",
    "    im2 = ax_im2.imshow(np.concatenate([U2[:, cidx], V2[:, cidx]]).reshape(-1, 1), aspect=3, cmap=cmap)\n",
    "    plt.colorbar(im1, cax=ax_im1)\n",
    "    plt.colorbar(im2, cax=ax_im2)\n",
    "\n",
    "    f_r1U.set_xlabel(\"Pickup location ({})\".format(date_range[step1].date()), fontsize=12)\n",
    "    f_r2U.set_xlabel(\"Pickup location ({})\".format(date_range[step2].date()), fontsize=12)\n",
    "    f_r1V.set_xlabel(\"Dropoff location ({})\".format(date_range[step1].date()), fontsize=12)\n",
    "    f_r2V.set_xlabel(\"Dropoff location ({})\".format(date_range[step2].date()), fontsize=12)\n",
    "\n",
    "    f_r1U.tick_params(axis=\"both\", labelsize=12)\n",
    "    f_r1V.tick_params(axis=\"both\", labelsize=12)\n",
    "    f_r2U.tick_params(axis=\"both\", labelsize=12)\n",
    "    f_r2V.tick_params(axis=\"both\", labelsize=12)\n",
    "\n",
    "    # colored axes\n",
    "    def set_spline_colors(ax, c):\n",
    "        ax.spines['bottom'].set_color(c)\n",
    "        ax.spines['top'].set_color(c) \n",
    "        ax.spines['right'].set_color(c)\n",
    "        ax.spines['left'].set_color(c)\n",
    "    \n",
    "    set_spline_colors(f_r1W, spc_1)\n",
    "    set_spline_colors(f_r1U, spc_1)\n",
    "    set_spline_colors(f_r1V, spc_1)\n",
    "    set_spline_colors(f_r2W, spc_2)\n",
    "    set_spline_colors(f_r2U, spc_2)\n",
    "    set_spline_colors(f_r2V, spc_2)\n",
    "    # f_r1W.tick_params(axis=\"both\", colors='red')\n",
    "    # f_r1U.tick_params(axis=\"both\", colors='red')\n",
    "    # f_r1V.tick_params(axis=\"both\", colors='red')\n",
    "    # f_r2W.tick_params(axis=\"both\", colors='blue')\n",
    "    # f_r2U.tick_params(axis=\"both\", colors='blue')\n",
    "    # f_r2V.tick_params(axis=\"both\", colors='blue')\n",
    "\n",
    "    fig.tight_layout()\n",
    "\n",
    "    return fig"
   ]
  },
  {
   "source": [
    "## Demo: regime shifts"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_regime_analysis_double(data, date_range, cidx=0,\n",
    "        step1=1000, step2=2500,  # Lag of regime1 and 2\n",
    "        k=15, a=0.3, f_step=500, period=168):\n",
    "\n",
    "    spc_1 = \"#1b4f77\"  # regime color 1\n",
    "    spc_2 = \"#92384D\"  # regime color 2\n",
    "\n",
    "    sns.set_style(\"white\")\n",
    "    fig = plt.figure(constrained_layout=False, figsize=(12, 10))\n",
    "    gs = fig.add_gridspec(13, 16)\n",
    "\n",
    "    # Axes for Segmentation\n",
    "    f_seg = fig.add_subplot(gs[-1, :])  # below\n",
    "\n",
    "    # Axes for Regime 1\n",
    "    f_r1U1 = fig.add_subplot(gs[0:4, 0:4])\n",
    "    f_r1V1 = fig.add_subplot(gs[4:8, 0:4])\n",
    "    f_r1W1 = fig.add_subplot(gs[8:12, 0:4])\n",
    "    f_r1U2 = fig.add_subplot(gs[0:4, 4:8])\n",
    "    f_r1V2 = fig.add_subplot(gs[4:8, 4:8])\n",
    "    f_r1W2 = fig.add_subplot(gs[8:12, 4:8])\n",
    "\n",
    "    # Axes for Regime 2\n",
    "    f_r2U1 = fig.add_subplot(gs[0:4, 8:12])\n",
    "    f_r2V1 = fig.add_subplot(gs[4:8, 8:12])\n",
    "    f_r2W1 = fig.add_subplot(gs[8:12, 8:12])\n",
    "    f_r2U2 = fig.add_subplot(gs[0:4, 12:])\n",
    "    f_r2V2 = fig.add_subplot(gs[4:8, 12:])\n",
    "    f_r2W2 = fig.add_subplot(gs[8:12, 12:])\n",
    "    \n",
    "    # Segmentation result\n",
    "    R, O = get_regimes(data, f_step, k=k, a=a)\n",
    "    R[R > 0] = 1\n",
    "    R = R[np.newaxis]\n",
    "    # f_seg.plot(O.T)\n",
    "    sns.heatmap(R, ax=f_seg, cmap=[spc_1, spc_2], cbar=False)\n",
    "    # sns.heatmap(pd.get_dummies(R), ax=f_seg, cbar=False)\n",
    "    xticks  = [t for t in range(0, len(date_range), 250)]\n",
    "    date_ticks = [date_range[t].date() for t in range(0, len(date_range), 250)]\n",
    "    f_seg.set_xticks(xticks)\n",
    "    f_seg.set_xticklabels(date_ticks, rotation=45)\n",
    "    f_seg.set_yticklabels([])\n",
    "    f_seg.tick_params(axis='both', labelsize=12)\n",
    "    f_seg.set_xlabel(\"Time (per hour)\", fontsize=12)\n",
    "\n",
    "    # Seasonal compoents\n",
    "    ridx1, ridx2 = 0, 5\n",
    "    week_idx1 = 4\n",
    "    week_idx2 = 13\n",
    "\n",
    "    W = get_seasonal_factor(f_step, start_step=1000, k=k, a=a)\n",
    "    W = W[:, period:]  # remove initialization steps\n",
    "\n",
    "    plot_seasonal_factor(W, ridx1, date_range, week_idx1, period, cidx[0], ax=f_r1W1, smooth=True)\n",
    "    plot_seasonal_factor(W, ridx1, date_range, week_idx2, period, cidx[1], ax=f_r1W2, smooth=True)\n",
    "    plot_seasonal_factor(W, ridx2, date_range, week_idx1, period, cidx[0], ax=f_r2W1, smooth=True)\n",
    "    plot_seasonal_factor(W, ridx2, date_range, week_idx2, period, cidx[1], ax=f_r2W2, smooth=True)\n",
    "\n",
    "    max_Wc1 = np.concatenate([W[ridx1, :, cidx[0]], W[ridx2, :, cidx[0]]]).max()\n",
    "    max_Wc2 = np.concatenate([W[ridx1, :, cidx[1]], W[ridx2, :, cidx[1]]]).max()\n",
    "    # max_Wc1 = 60\n",
    "    # max_Wc2 = 15\n",
    "    # f_r1W1.set_ylim(0, max_Wc1)\n",
    "    # f_r2W1.set_ylim(0, max_Wc1)\n",
    "    # f_r1W2.set_ylim(0, max_Wc2)\n",
    "    # f_r2W2.set_ylim(0, max_Wc2)\n",
    "    f_r1W1.set_ylabel(\"Component 1\")\n",
    "    f_r2W1.set_ylabel(\"Component 1\")\n",
    "    f_r1W2.set_ylabel(\"Component 2\")\n",
    "    f_r2W2.set_ylabel(\"Component 2\")\n",
    "\n",
    "    U1, V1 = get_factors(f_step, step1, k=k, a=a)# sns.heatmap(U)\n",
    "    U2, V2 = get_factors(f_step, step2, k=k, a=a)# sns.heatmap(U)\n",
    "\n",
    "    # Plot maps\n",
    "    pu_loc_yt = list(range(1, 266))\n",
    "\n",
    "    plot_nyc_zone(U1[:, cidx[0]], pu_loc_yt, ax=f_r1U1, cbar=True)\n",
    "    plot_nyc_zone(V1[:, cidx[0]], do_loc_yt, ax=f_r1V1)\n",
    "    plot_nyc_zone(U1[:, cidx[1]], pu_loc_yt, ax=f_r1U2, cbar=True)\n",
    "    plot_nyc_zone(V1[:, cidx[1]], do_loc_yt, ax=f_r1V2)\n",
    "\n",
    "    plot_nyc_zone(U2[:, cidx[0]], pu_loc_yt, ax=f_r2U1, cbar=True)\n",
    "    plot_nyc_zone(V2[:, cidx[0]], do_loc_yt, ax=f_r2V1)\n",
    "    plot_nyc_zone(U2[:, cidx[1]], pu_loc_yt, ax=f_r2U2, cbar=True)\n",
    "    plot_nyc_zone(V2[:, cidx[1]], do_loc_yt, ax=f_r2V2)\n",
    "\n",
    "    # color bar \n",
    "    # cmap = sns.cm.rocket_r\n",
    "    # cmap = sns.cm.crest\n",
    "    cmap = sns.cm.mako_r\n",
    "    imw = 0.245\n",
    "    ax_im11 = fig.add_axes((0.055,0.75,0.015,0.4))\n",
    "    ax_im12 = fig.add_axes((0.055+imw*1,0.75,0.015,0.4))\n",
    "    im11 = ax_im11.imshow(np.concatenate([U1[:, cidx[0]], V1[:, cidx[0]]]).reshape(-1, 1), aspect=3, cmap=cmap)\n",
    "    im12 = ax_im12.imshow(np.concatenate([U1[:, cidx[1]], V1[:, cidx[1]]]).reshape(-1, 1), aspect=3, cmap=cmap)\n",
    "    plt.colorbar(im11, cax=ax_im11)\n",
    "    plt.colorbar(im12, cax=ax_im12)\n",
    "\n",
    "    ax_im21 = fig.add_axes((0.055+imw*2,0.75,0.015,0.4))\n",
    "    ax_im22 = fig.add_axes((0.055+imw*3,0.75,0.015,0.4))\n",
    "    im21 = ax_im21.imshow(np.concatenate([U2[:, cidx[0]], V2[:, cidx[0]]]).reshape(-1, 1), aspect=3, cmap=cmap)\n",
    "    im22 = ax_im22.imshow(np.concatenate([U2[:, cidx[1]], V2[:, cidx[1]]]).reshape(-1, 1), aspect=3, cmap=cmap)\n",
    "    plt.colorbar(im21, cax=ax_im21)\n",
    "    plt.colorbar(im22, cax=ax_im22)\n",
    "\n",
    "    # set xlabels\n",
    "    f_r1U1.set_xlabel(\"Pickup location ({})\".format(date_range[step1].date()), fontsize=12)\n",
    "    f_r1V1.set_xlabel(\"Dropoff location ({})\".format(date_range[step1].date()), fontsize=12)\n",
    "    f_r1U2.set_xlabel(\"Pickup location ({})\".format(date_range[step1].date()), fontsize=12)\n",
    "    f_r1V2.set_xlabel(\"Dropoff location ({})\".format(date_range[step1].date()), fontsize=12)\n",
    "\n",
    "    f_r2U1.set_xlabel(\"Pickup location ({})\".format(date_range[step2].date()), fontsize=12)\n",
    "    f_r2V1.set_xlabel(\"Dropoff location ({})\".format(date_range[step2].date()), fontsize=12)\n",
    "    f_r2U2.set_xlabel(\"Pickup location ({})\".format(date_range[step2].date()), fontsize=12)\n",
    "    f_r2V2.set_xlabel(\"Dropoff location ({})\".format(date_range[step2].date()), fontsize=12)\n",
    "\n",
    "    f_r1U1.set_ylabel(\"Component 1\")\n",
    "    f_r1V1.set_ylabel(\"Component 1\")\n",
    "    f_r1U2.set_ylabel(\"Component 2\")\n",
    "    f_r1V2.set_ylabel(\"Component 2\")\n",
    "\n",
    "    f_r2U1.set_ylabel(\"Component 1\")\n",
    "    f_r2V1.set_ylabel(\"Component 1\")\n",
    "    f_r2U2.set_ylabel(\"Component 2\")\n",
    "    f_r2V2.set_ylabel(\"Component 2\")\n",
    "\n",
    "    # f_r1U.tick_params(axis=\"both\", labelsize=12)\n",
    "    # f_r1V.tick_params(axis=\"both\", labelsize=12)\n",
    "    # f_r2U.tick_params(axis=\"both\", labelsize=12)\n",
    "    # f_r2V.tick_params(axis=\"both\", labelsize=12)\n",
    "\n",
    "    # colored axes\n",
    "    def set_spline_colors(ax, c):\n",
    "        ax.spines['bottom'].set_color(c)\n",
    "        ax.spines['top'].set_color(c) \n",
    "        ax.spines['right'].set_color(c)\n",
    "        ax.spines['left'].set_color(c)\n",
    "    \n",
    "    set_spline_colors(f_r1W1, spc_1)\n",
    "    set_spline_colors(f_r1U1, spc_1)\n",
    "    set_spline_colors(f_r1V1, spc_1)\n",
    "    set_spline_colors(f_r1W2, spc_1)\n",
    "    set_spline_colors(f_r1U2, spc_1)\n",
    "    set_spline_colors(f_r1V2, spc_1)\n",
    "\n",
    "    set_spline_colors(f_r2W1, spc_2)\n",
    "    set_spline_colors(f_r2U1, spc_2)\n",
    "    set_spline_colors(f_r2V1, spc_2)\n",
    "    set_spline_colors(f_r2W2, spc_2)\n",
    "    set_spline_colors(f_r2U2, spc_2)\n",
    "    set_spline_colors(f_r2V2, spc_2)\n",
    "\n",
    "    # ensure y-tick labels to be integers\n",
    "    f_r1W1.get_yaxis().set_major_locator(mpl.ticker.MaxNLocator(integer=True))\n",
    "    f_r1W2.get_yaxis().set_major_locator(mpl.ticker.MaxNLocator(integer=True))\n",
    "    f_r2W1.get_yaxis().set_major_locator(mpl.ticker.MaxNLocator(integer=True))\n",
    "    f_r2W2.get_yaxis().set_major_locator(mpl.ticker.MaxNLocator(integer=True))\n",
    "\n",
    "    fig.tight_layout()\n",
    "\n",
    "    return fig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "RESULTDIR = \"../out/nytaxi/\"\n",
    "OUTDIR    = \"../out/assmf/nytaxi/\"\n",
    "\n",
    "cidx = [0, 3]  # \n",
    "\n",
    "data = utils.load_nytaxi(\"../dat/nytaxi/\")\n",
    "date_range = pd.date_range(start=\"2020-01-01\", end=\"2020-07-01\", freq=\"H\")[:-1]\n",
    "fig = plot_regime_analysis_double(data, date_range, k=15, a=0.4, cidx=cidx)\n",
    "fig.savefig(\"../out/regime_overview_double.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ]
}