{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "su8ya_gZNP-0"
      },
      "source": [
        "# Import\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "c3qVa1_Bou2o"
      },
      "outputs": [],
      "source": [
        "!pip install  yfinance\n",
        "import yfinance as yf\n",
        "\n",
        "import tensorflow as tf\n",
        "import tensorflow_probability as tfp\n",
        "from tensorflow_probability import sts\n",
        "\n",
        "import matplotlib.dates as mdates\n",
        "import math\n",
        "\n",
        "import numpy as np\n",
        "import pandas as pd\n",
        "import matplotlib.pyplot as plt\n",
        "import seaborn as sns\n",
        "from sklearn.decomposition import PCA\n",
        "from sklearn.cluster import KMeans\n",
        "\n",
        "from google.colab import files\n",
        "import csv\n",
        "\n",
        "import os\n",
        "\n",
        "from pandas.plotting import register_matplotlib_converters\n",
        "register_matplotlib_converters()\n",
        "\n",
        "sns.set_context(\"notebook\", font_scale=1.)\n",
        "sns.set_style(\"whitegrid\")\n",
        "%config InlineBackend.figure_format = 'retina'\n",
        "\n",
        "os.path.expanduser('~')\n",
        "sns.reset_defaults()\n",
        "sns.set_context(context='talk',font_scale=0.7)\n",
        "plt.rcParams['image.cmap'] = 'viridis'\n",
        "\n",
        "%matplotlib inline\n",
        "tfd = tfp.distributions"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KL2xE9_UPovO"
      },
      "source": [
        "#Data Download Utils"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EYkAJpSEpBSL"
      },
      "outputs": [],
      "source": [
        "def download_daily_price(ticker_list, start='2015-01-01', end='2020-01-01'):\n",
        "  ticket_str = ' '.join(ticker_list)\n",
        "  data = yf.download(  # or pdr.get_data_yahoo(...\n",
        "    # tickers list or string as well\n",
        "    tickers = ticket_str,\n",
        "\n",
        "    # use \"period\" instead of start/end\n",
        "    # valid periods: 1d,5d,1mo,3mo,6mo,1y,2y,5y,10y,ytd,max\n",
        "    # (optional, default is '1mo')\n",
        "    # period = time_period,\n",
        "    start = start,\n",
        "    end = end,\n",
        "\n",
        "    # fetch data by interval (including intraday if period \u003c 60 days)\n",
        "    # valid intervals: 1m,2m,5m,15m,30m,60m,90m,1h,1d,5d,1wk,1mo,3mo\n",
        "    # (optional, default is '1d')\n",
        "    interval = \"1d\",\n",
        "\n",
        "    # group by ticker (to access via data['SPY'])\n",
        "    # (optional, default is 'column')\n",
        "    group_by = 'ticker',\n",
        "\n",
        "    # adjust all OHLC automatically\n",
        "    # (optional, default is False)\n",
        "    auto_adjust = True,\n",
        "\n",
        "    # download pre/post regular market hours data\n",
        "    # (optional, default is False)\n",
        "    # prepost = True,\n",
        "\n",
        "    # use threads for mass downloading? (True/False/Integer)\n",
        "    # (optional, default is True)\n",
        "    threads = True,\n",
        "\n",
        "    # proxy URL scheme use use when downloading?\n",
        "    # (optional, default is None)\n",
        "    proxy = None\n",
        "  )\n",
        "  return data\n",
        "\n",
        "def get_open_price_dict(ticker_list, start_date='2015-01-01'):\n",
        "  # Returns a dict {ticker: zip([timestamps], [Open_prices])}\n",
        "  data = download_daily_price(ticker_list=ticker_list, start=start_date)\n",
        "  open_price_dict = {\n",
        "      ticker: data[ticker]['Open'].values.tolist() for ticker in ticker_list}\n",
        "  return open_price_dict, data.index.tolist()\n",
        "\n",
        "def get_nasdaq_stock_price(num_stocks=1000):\n",
        "  # upload nasdaq listed stock CSV, which is downloaded from here\n",
        "  # https://www.nasdaq.com/market-activity/stocks/screener\n",
        "  uploaded_nasdaq_csv = files.upload()\n",
        "  for v in uploaded_nasdaq_csv.values():\n",
        "    stock_list = v.decode().split('\\r\\n')\n",
        "  ticker_list = [stock.split(',')[0] for stock in stock_list]\n",
        "  # Get the first num_stocks stocks from the list for study.\n",
        "  open_price_dict, date_range = get_open_price_dict(ticker_list[1:num_stocks], start_date='2019-01-01')\n",
        "  return open_price_dict, date_range"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yc_cKVbRT6p3"
      },
      "outputs": [],
      "source": [
        "def clean_stock_price_list(open_price_dict):\n",
        "  # Remove stocks with nan as price using AAPL as a baseline.\n",
        "  # AAPL also has a few nan prices depending on the date range.\n",
        "  # We remove these dates from the study.\n",
        "  nan_index = []\n",
        "  for i, x in enumerate(open_price_dict['AAPL']):\n",
        "    if math.isnan(float(x)):\n",
        "      nan_index.append(i)\n",
        "  print(nan_index)\n",
        "\n",
        "  cleaned_price_list = []\n",
        "  skip_ticker_list = []\n",
        "  for ticker, price_list in open_price_dict.items():\n",
        "    skip = False\n",
        "    for i, x in enumerate(open_price_dict[ticker]):\n",
        "      if i not in nan_index and math.isnan(float(x)):\n",
        "        # A ticker with more nan than AAPL. skip\n",
        "        skip_ticker_list.append(ticker)\n",
        "        skip = True\n",
        "        break\n",
        "    if not skip:      \n",
        "      cleaned_price_list.append(np.delete(np.array(price_list), nan_index))\n",
        "  print('number of all stocks : ' + str(len(open_price_dict)))\n",
        "  print('number of cleaned stocks: ' + str(len(cleaned_price_list)))\n",
        "  print('number of skipped stocks: ' + str(len(skip_ticker_list)))\n",
        "  return cleaned_price_list, skip_ticker_list, nan_index    "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RPZd2nkGMtWG"
      },
      "source": [
        "# Visualization Utils\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ojSnq4BIMzZv"
      },
      "outputs": [],
      "source": [
        "def plot_coefficient_heatmap(correlation_matrix, ticker_list):\n",
        "  # correlation_matrix is a ndarray with shape (n, 2*n), n being the number of tickers.\n",
        "  print(correlation_matrix.shape)\n",
        "  num_ticker = correlation_matrix.shape[0]\n",
        "  order = correlation_matrix.shape[1]//num_ticker\n",
        "  fig, axes = plt.subplots(nrows=1, ncols=order, figsize=(3*order, 3))\n",
        "\n",
        "  for i in range(order):\n",
        "    df = pd.DataFrame(correlation_matrix[:, i*num_ticker:(i+1)*num_ticker], columns=ticker_list, index=ticker_list)\n",
        "    #df_order1 = pd.DataFrame(correlation_matrix[:, num_ticker:], columns=ticker_list, index=ticker_list)\n",
        "    sns.heatmap(df, ax=axes[i], cbar=True)\n",
        "    #sns.heatmap(df_order1, ax=axes[1], cbar=True)\n",
        "  fig.tight_layout()\n",
        "  plt.show()\n",
        "\n",
        "def plot_hist(list_prices, n_bins=20):\n",
        "  fig, axs = plt.subplots(1, 2, sharey=True, tight_layout=True)\n",
        "  x = open_price_dict['GOOG']\n",
        "  y = open_price_dict['AAPL']\n",
        "  # We can set the number of bins with the `bins` kwarg\n",
        "  axs[0].hist(x, bins=n_bins)\n",
        "  axs[1].hist(y, bins=n_bins)\n",
        "\n",
        "def format_and_print_list(input_list):\n",
        "  print(', '.join('{:.3f}'.format(element) for element in input_list)) \n",
        "\n",
        "def format_and_print_2d_params(input_array, reverse_order):\n",
        "  output_str_list = []\n",
        "  for ticker_list in input_array:\n",
        "    if reverse_order:\n",
        "      for element in reversed(ticker_list):\n",
        "        output_str_list.append(element)\n",
        "    else:\n",
        "      for element in ticker_list:\n",
        "        output_str_list.append(element)\n",
        "\n",
        "  print(', '.join('{:.3f}'.format(element) for element in output_str_list))\n",
        "\n",
        "def plot_ticker_ts(date_range, open_price_dict, ticker, forecast_horizon):\n",
        "  prices = open_price_dict[ticker][:-forecast_horizon]\n",
        "\n",
        "  plot_x_loc = mdates.AutoDateLocator()  # every two months.\n",
        "  plot_x_fmt = mdates.DateFormatter('%Y/%m/%d')\n",
        "\n",
        "\n",
        "  fig = plt.figure(figsize=(12, 6))\n",
        "  ax = fig.add_subplot(1, 1, 1)\n",
        "  ax.plot(date_range, prices, lw=2, label=\"training data\")\n",
        "  ax.xaxis.set_major_locator(plot_x_loc)\n",
        "  ax.xaxis.set_major_formatter(plot_x_fmt)\n",
        "  ax.set_ylabel(\"price\")\n",
        "  ax.set_xlabel(\"day\")\n",
        "  fig.suptitle(ticker + \" price\", fontsize=15)\n",
        "  ax.text(0.99, .02, \"Source: Yahoo finance\", transform=ax.transAxes, horizontalalignment=\"right\", alpha=0.5)\n",
        "  fig.autofmt_xdate()\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YGZEo82x7v38"
      },
      "outputs": [],
      "source": [
        "def plot_forecast(x, y, forecast_horizon,\n",
        "                  forecast_mean, forecast_scale, forecast_samples,\n",
        "                  title, ax, x_locator=None, x_formatter=None):\n",
        "  \"\"\"Plot a forecast distribution against the 'true' time series.\"\"\"\n",
        "  colors = sns.color_palette()\n",
        "  c1, c2 = colors[0], colors[1]\n",
        "  fig = plt.figure(figsize=(12, 6))\n",
        "\n",
        "  ax.plot(x, y, lw=2, color=c1, label='ground truth')\n",
        "  ax.plot(forecast_horizon, forecast_samples.T, lw=1, color=c2, alpha=0.1)\n",
        "  ax.plot(forecast_horizon, forecast_mean, lw=2, ls='--', color=c2, label='forecast')\n",
        "  ax.fill_between(forecast_horizon,\n",
        "                  forecast_mean-2*forecast_scale,\n",
        "                  forecast_mean+2*forecast_scale, color=c2, alpha=0.2)\n",
        "\n",
        "  ymin, ymax = min(np.min(forecast_samples), np.min(y)), max(np.max(forecast_samples), np.max(y))\n",
        "  yrange = ymax-ymin\n",
        "  ax.set_ylim([ymin - yrange*0.1, ymax + yrange*0.1])\n",
        "  ax.set_title(\"{}\".format(title))\n",
        "  ax.legend()\n",
        "\n",
        "  if x_locator is not None:\n",
        "    ax.xaxis.set_major_locator(x_locator)\n",
        "    ax.xaxis.set_major_formatter(x_formatter)\n",
        "    fig.autofmt_xdate()\n",
        "\n",
        "  return fig, ax"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gRmsGXcCYjtc"
      },
      "source": [
        "# Models"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "g47CN_rRle_8"
      },
      "outputs": [],
      "source": [
        "def build_model(observed_time_series, ar_order):\n",
        "  #day_of_week_effect = sts.Seasonal(\n",
        "  #    num_seasons=7,\n",
        "  #    observed_time_series=observed_time_series,\n",
        "  #    name='day_of_week_effect')\n",
        "  # trend = sts.LocalLinearTrend(observed_time_series=observed_time_series)  \n",
        "  autoregressive = sts.Autoregressive(\n",
        "      order=ar_order,\n",
        "      observed_time_series=observed_time_series,\n",
        "      name='autoregressive')\n",
        "  #model = sts.Sum([day_of_week_effect,\n",
        "  #                 autoregressive],\n",
        "  #                 observed_time_series=observed_time_series)\n",
        "  model = sts.Sum([autoregressive],\n",
        "                   observed_time_series=observed_time_series)  \n",
        "  return model\n",
        "\n",
        "def train_model(training_price_ts, stock_model, num_train_steps):\n",
        "  # Build the variational surrogate posteriors `qs`.\n",
        "  variational_posteriors = tfp.sts.build_factored_surrogate_posterior(\n",
        "      model=stock_model)\n",
        "  optimizer = tf.optimizers.Adam(learning_rate=.1)\n",
        "  # Using fit_surrogate_posterior to build and optimize the variational loss function.\n",
        "  @tf.function(experimental_compile=True)\n",
        "  def train():\n",
        "    elbo_loss_curve = tfp.vi.fit_surrogate_posterior(\n",
        "      target_log_prob_fn=stock_model.joint_log_prob(\n",
        "          observed_time_series=training_price_ts),\n",
        "      surrogate_posterior=variational_posteriors,\n",
        "      optimizer=optimizer,\n",
        "      num_steps=num_train_steps)\n",
        "    return elbo_loss_curve\n",
        "  elbo_loss_curve = train()\n",
        "  plt.plot(elbo_loss_curve)\n",
        "  plt.show()\n",
        "  return variational_posteriors  \n",
        "\n",
        "def train_for_stocks(open_price_dict, ticker_list, ar_order):\n",
        "  ts_list = []\n",
        "  for ticker in ticker_list:\n",
        "    ts_list.append(np.array(open_price_dict[ticker]))\n",
        "  training_price_ts = np.stack(ts_list, axis=0)\n",
        "\n",
        "  stock_model = build_model(training_price_ts, ar_order=ar_order)\n",
        "  variational_posteriors = train_model(training_price_ts, stock_model, num_train_steps=250)\n",
        "  return training_price_ts, stock_model, variational_posteriors  \n",
        "\n",
        "def train_for_clusters(cleaned_price_list, n_clusters=8, ar_order=5):\n",
        "  X = np.array(cleaned_price_list)  # X.shape (num_stock, num_days)\n",
        "  print(X.shape) # (588, 252)\n",
        "  clustering = KMeans(n_clusters=n_clusters)\n",
        "  clustering.fit(X)\n",
        "  cluster_mean_list = clustering.cluster_centers_\n",
        "  print(cluster_mean_list.shape) # (8, 252)\n",
        "\n",
        "  stock_model = build_model(cluster_mean_list, ar_order=ar_order)\n",
        "  variational_posteriors = train_model(cluster_mean_list, stock_model, num_train_steps=250)\n",
        "  return cluster_mean_list, stock_model, variational_posteriors    "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CmOJeZYqyX4g"
      },
      "source": [
        "# Import Data and Train Model\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VRwSrBhq2S8H"
      },
      "source": [
        "## Train on stocks"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "i4hunZu4cpLv"
      },
      "outputs": [],
      "source": [
        "short_ticker_list = ['GOOG', 'COKE', 'LOW', 'JPM', 'BBBY']\n",
        "#tech_ticker_list = ['GOOG', 'AAPL', 'MSFT', 'AMZN', 'FB', 'IBM', 'NVDA', 'INTC', 'NFLX', 'TWTR']\n",
        "\n",
        "open_price_dict, date_list = get_open_price_dict(short_ticker_list)\n",
        "\n",
        "training_price_ts, stock_model, variational_posteriors = train_for_stocks(open_price_dict, short_ticker_list, ar_order=5)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NtNQc08P2gfc"
      },
      "source": [
        "## Train on stock clusters"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "I11w5XvsRtLB"
      },
      "outputs": [],
      "source": [
        "open_price_dict, date_list = get_nasdaq_stock_price(num_stocks=1000)\n",
        "cleaned_price_list, skip_ticker_list, nan_index = clean_stock_price_list(open_price_dict)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mClKQO6w_KSA"
      },
      "outputs": [],
      "source": [
        "training_price_ts, stock_model, variational_posteriors = train_for_clusters(cleaned_price_list, n_clusters=8, ar_order=5)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IDNrKIAA5pNI"
      },
      "source": [
        "# Inference and Plot"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wU3v5pvK5UUM"
      },
      "outputs": [],
      "source": [
        "FORECAST_HORIZON = 5\n",
        "\n",
        "# Draw samples from the variational posterior.\n",
        "def get_inferred_params(stock_model, variational_posteriors):\n",
        "  q_samples_ = variational_posteriors.sample(20)\n",
        "  param_dict = {}\n",
        "  print(\"Inferred parameters:\")\n",
        "  for param in stock_model.parameters:\n",
        "    sample_mean = np.mean(q_samples_[param.name], axis=0)\n",
        "    sample_std = np.std(q_samples_[param.name], axis=0)\n",
        "    print(\"{}: {} +- {}\".format(param.name, sample_mean, sample_std))\n",
        "    param_dict[param.name] = (sample_mean, sample_std)\n",
        "  return param_dict, q_samples_ \n",
        "\n",
        "def get_forecast(stock_model, hist_ts, q_samples_, num_sample_traj):\n",
        "  forecast_dist = tfp.sts.forecast(\n",
        "    model=stock_model,\n",
        "    observed_time_series=hist_ts,  # training_price_ts[:, :10],\n",
        "    parameter_samples=q_samples_,\n",
        "    num_steps_forecast=FORECAST_HORIZON)\n",
        "  # print(forecast_dist.mean().numpy().shape) (num_ticker, forecast_horizon, 1)\n",
        "  (forecast_mean, forecast_scale, forecast_samples) = (\n",
        "      forecast_dist.mean().numpy()[:, :, 0],\n",
        "      forecast_dist.stddev().numpy()[:, :, 0],\n",
        "      forecast_dist.sample(num_sample_traj).numpy()[:, :, :, 0])\n",
        "  return (forecast_mean, forecast_scale, forecast_samples)  "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7h8bkCOKy-w2"
      },
      "source": [
        "## Get Inferred Params\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YSQXL07KgU6A"
      },
      "outputs": [],
      "source": [
        "param_dict, q_samples_ = get_inferred_params(stock_model, variational_posteriors)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3qW-zOAFziBd"
      },
      "outputs": [],
      "source": [
        "format_and_print_2d_params(param_dict['autoregressive/_coefficients'][0], reverse_order=True)\n",
        "format_and_print_list(param_dict['autoregressive/_level_scale'][0])\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Jol99nNOzCAu"
      },
      "source": [
        "## Plot Forecast"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0FoHB9GFDMrq"
      },
      "outputs": [],
      "source": [
        "def plot_all_forecast(series_list, stock_model, training_price_ts, date_list):\n",
        "  forecast_mean, forecast_scale, forecast_samples = get_forecast(stock_model, training_price_ts[:, plot_start_date_index:forecast_start_date_index], q_samples_, num_sample_traj)\n",
        "  # series_list is a list of ticker names or clusters id_strs.\n",
        "  num_series = len(series_list)\n",
        "  fig, axes = plt.subplots(nrows=num_series, ncols=1, figsize=(20, 4*num_series))\n",
        "\n",
        "  for i in range(num_series):\n",
        "    fig, ax = plot_forecast(x=date_list[plot_start_date_index:plot_end_date_index], \n",
        "                            y=tf.squeeze(training_price_ts[i:i+1, plot_start_date_index:plot_end_date_index], axis=0),\n",
        "                            forecast_horizon=date_list[forecast_start_date_index:plot_end_date_index],\n",
        "                            forecast_mean=forecast_mean[i],\n",
        "                            forecast_scale=forecast_scale[i],\n",
        "                            forecast_samples=forecast_samples[:, i, :],\n",
        "                            title=\"forecast \" + series_list[i], ax=axes[i])\n",
        "  fig.tight_layout()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "74fHq0VCp7r1"
      },
      "outputs": [],
      "source": [
        "forecast_start_date_index = 50\n",
        "HISTORY_WINDOW = 5\n",
        "plot_start_date_index =  forecast_start_date_index - HISTORY_WINDOW\n",
        "plot_end_date_index = forecast_start_date_index + FORECAST_HORIZON\n",
        "num_sample_traj = 9\n",
        "format_and_print_2d_params(training_price_ts[:, plot_start_date_index:forecast_start_date_index], reverse_order=False)\n",
        "print(date_list[plot_start_date_index:forecast_start_date_index])\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WfHCuLnfAjou"
      },
      "outputs": [],
      "source": [
        "plot_all_forecast(['cluster' + str(i) for i in range(8)], stock_model, training_price_ts, date_list)"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "name": "Copy of stock-preprocess.ipynb",
      "private_outputs": true,
      "provenance": [
        {
          "file_id": "1P4Tme3K0EPcFlUxWttQxqS-kaZYrgPZZ",
          "timestamp": 1622777610316
        },
        {
          "file_id": "1apM2djewu14hBQvG7w8BzayREFZKCcvP",
          "timestamp": 1622777536202
        }
      ],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
