{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15861b73-05c6-4ce9-8c8b-a22eaa7e3676",
   "metadata": {},
   "outputs": [],
   "source": [
    "using JLD2\n",
    "using LinearAlgebra\n",
    "using Plots\n",
    "using StatsPlots\n",
    "using Colors\n",
    "using DataFrames\n",
    "using BenchmarkTools\n",
    "include(\"../methods/TensorRing/NTR.jl\")\n",
    "include(\"../methods/TensorRing/TR.jl\")\n",
    "include(\"../methods/TensorTcuker/NTD.jl\")\n",
    "include(\"../methods/TensorTrain/NNTF.jl\")\n",
    "include(\"../methods/IgTensors/decomp.jl\");\n",
    "include(\"../methods/IgTensors/get_msk.jl\");\n",
    "pyplot()\n",
    "iter_max = 20;\n",
    "plt_size = (500,500);\n",
    "\n",
    "path2resultsNTT_rankset2 = \"results/results_NTT_2.jld2\"\n",
    "path2resultsNTR_rankset2 = \"results/results_NTR_2.jld2\"\n",
    "path2resultsNTD_rankset2 = \"results/results_NTD_2.jld2\"\n",
    "path2resultsMBA = \"results/results_MBA.jld2\"\n",
    "\n",
    "# to get mean average \n",
    "ma(vs,n) = [sum(@view vs[i:(i+n-1)])/n for i in 1:(length(vs)-(n-1))]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f8fc690-d397-45db-b505-67547d69c773",
   "metadata": {},
   "outputs": [],
   "source": [
    "# load data\n",
    "T = load(\"../../data/traffic/PeMS/MainlineVDS716331/MainlineVDS716331.jld2\")[\"T\"];"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8081daa1-3e3f-48a4-89d2-4d688218b00f",
   "metadata": {},
   "outputs": [],
   "source": [
    "verbose = false\n",
    "\n",
    "# Run one-, two-, and three- body approximation\n",
    "T1, theta_1, eta_1 = manybody_app(T, 1, verbose=verbose);\n",
    "T2, theta_2, eta_2 = manybody_app(T, 2, verbose=verbose);\n",
    "T3, theta_3, eta_3 = manybody_app(T, 3, verbose=verbose);\n",
    "\n",
    "# Run cyclic two-body approximation\n",
    "intract_cyc = get_intracts_for_cyc_2_body_approximation(4);\n",
    "Tcyc, theta_cyc, eta_cyc = manybody_app(T, intract_cyc, verbose=verbose);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "98eb5289-2355-4c80-bb4d-f48bd3dcd250",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get reconstruction as heatmaps ( Fig12 )\n",
    "xticks = ([1,6,12,18,24], [0,5,11,17,23])\n",
    "yticks = [1,7,14,21,28]\n",
    "cmax  = 100\n",
    "trg_mn  = 1\n",
    "trg_lane = 1\n",
    "h     = heatmap( T[:,:,trg_mn,trg_lane],     clim=(0,cmax), cbar=true,  title=\"Original\")\n",
    "h1    = heatmap( T1[:,:,trg_mn,trg_lane],    clim=(0,cmax), cbar=false,  title=\"One-body \\n approximation\")\n",
    "h2    = heatmap( T2[:,:,trg_mn,trg_lane],    clim=(0,cmax), cbar=false,  title=\"Two-body \\n approximation\")\n",
    "hcyc  = heatmap( Tcyc[:,:,trg_mn,trg_lane],  clim=(0,cmax), cbar=false,  title=\"Cyclic two-body \\n approximation\")\n",
    "h3    = heatmap( T3[:,:,trg_mn,trg_lane],    clim=(0,cmax), cbar=false,   title=\"Three-body \\n approximation\")\n",
    "pl = plot(h1, hcyc, h2, h3, h, layout=(1,5), size=(2*768+10,2*200), xlabel=\"Hour\", ylabel=\"Day\", yflip=true, yticks=yticks, xticks=xticks, frame=:box, colorbar_title=\"Speed (mph)\", \n",
    "    tickfontsize=15, labelfontsize=15, colorbar_titlefontsize=13, colorbar_tickfontsize=13)\n",
    "savefig(pl, \"reconst_imgs/traffic_hmap.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f6fdc83d-ab72-4296-acc6-82aeefab6440",
   "metadata": {},
   "outputs": [],
   "source": [
    "ms = [1,\"cyc\",2,3]\n",
    "function run_MBA(T,ms)\n",
    "    D = ndims(T)\n",
    "    J = size(T)\n",
    "    results = Dict(\"fits\" => Vector{Vector{Float64}}(), \"n_params\" => Vector{Float64}(), \"runtimes\" => Vector{Float64}() )\n",
    "    for m in ms\n",
    "        fits = []\n",
    "        intracts = get_intracts_for_m_body_approximation(m,D)\n",
    "        n_params = get_n_params_from_intracts(intracts, J)\n",
    "        runtime = @elapsed begin\n",
    "            for iter = 1:1\n",
    "                Tm, _, _ = manybody_app(T, intracts);\n",
    "                fit = 1 - norm( Tm - T ) / norm(T)\n",
    "                push!(fits, fit)\n",
    "            end\n",
    "            push!(results[\"fits\"], fits)\n",
    "        end\n",
    "        push!(results[\"runtimes\"], runtime)\n",
    "        push!(results[\"n_params\"], n_params)\n",
    "        @show (m, n_params, runtime, fits[1])\n",
    "    end\n",
    "    return results\n",
    "end\n",
    "\n",
    "results_MBA = run_MBA(T, ms)\n",
    "save(path2resultsMBA, results_MBA)\n",
    "println(\"saved in $path2resultsMBA\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa500a28-b888-47c4-b091-8493097caa3e",
   "metadata": {},
   "outputs": [],
   "source": [
    "Rs1 = [ [1,1,1,1], [3,3,3,3], [4,4,4,4], [5,5,5,5], [6,6,6,6], [7,7,7,7], [8,8,8,8], [9,9,9,9], [11,11,11,11], [15,15,15,15], [18,18,18,18], [20,20,20,20]]\n",
    "Rs2 = [ [1,1,1,1], [1,2,1,1], [1,2,2,1], [2,2,2,1], [4,2,2,1], [6,2,2,1], [6,3,2,1], [8,3,2,1], [10,3,2,1], [10,4,2,1],[15,4,2,1],[20,5,2,1], [20,8,2,1], [30,10,2,1], [30,15,8,1] ]\n",
    "\n",
    "function run_NTT(T, Rs; iter_max = 20)\n",
    "    D = ndims(T)\n",
    "    J = size(T)\n",
    "    results = Dict(\"fit\" => Vector{Vector{Float64}}() ,\"rank\" => Vector{Vector{Int64}}(), \"n_params\" => Vector{Float64}(), \"runtimes\" => Vector{Float64}() )\n",
    "    for R in Rs\n",
    "        n_params = get_n_params_train(R,J)\n",
    "        fits = []\n",
    "        runtime = @elapsed begin\n",
    "            for iter = 1:iter_max\n",
    "                G = NNTF(T,R);\n",
    "                T_TT = reconst_train(G);\n",
    "                fit = 1 - norm( T_TT - T ) / norm(T)\n",
    "                push!(fits, fit)\n",
    "            end\n",
    "            push!(results[\"fit\"], fits)\n",
    "        end\n",
    "        push!(results[\"runtimes\"], runtime)\n",
    "        push!(results[\"rank\"], R)\n",
    "        push!(results[\"n_params\"], n_params)\n",
    "        @show (R, n_params, runtime, maximum(fits))\n",
    "    end\n",
    "    return results\n",
    "end\n",
    "\n",
    "results_NTT = run_NTT(T, Rs1, iter_max = iter_max)\n",
    "save(path2resultsNTT_rankset1, results_NTT)\n",
    "println(\"saved in $path2resultsNTT_rankset1\")\n",
    "\n",
    "results_NTT = run_NTT(T, Rs2, iter_max = iter_max)\n",
    "save(path2resultsNTT_rankset2, results_NTT)\n",
    "println(\"saved in $path2resultsNTT_rankset2\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "155f48c9-0d70-4a41-b4d1-d1dad8d55ed0",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "Rs1 = [ [1,1,1,1], [3,3,3,3], [4,4,4,4], [5,5,5,5], [6,6,6,6], [7,7,7,7], [8,8,8,8], [9,9,9,9], [11,11,11,11]]#, [15,15,15,15], [18,18,18,18], [20,20,20,20]]#,[15,4,2,1],[20,5,2,1], [25,8,2,1], [30,10,2,1], [30,15,8,1] ]\n",
    "Rs2 = [ [1,1,1,1], [6,2,2,1], [6,5,2,1], [8,5,4,2], [10,7,4,2], [12,9,4,2], [16,12,6,2],[18,15,8,2],[20,18,11,3] ]\n",
    "function run_NTD(T, Rs; iter_max = 20)\n",
    "    results = Dict(\"fit\" => Vector{Vector{Float64}}() ,\"rank\" => Vector{Vector{Int64}}(), \"n_params\" => Vector{Float64}(), \"runtimes\" => Vector{Float64}() )\n",
    "    for R in Rs\n",
    "        J = size(T)\n",
    "        n_params = get_n_params_tucker(R,J)\n",
    "        fits = []\n",
    "        runtime = @elapsed begin\n",
    "            for iter = 1:iter_max\n",
    "                _, _, Xr = NTD(T,R, init_method=\"random\", verbose_interval=250, max_iter=1500, tol=2.0e-5);\n",
    "                fit = 1 - norm(T-Xr) / norm(T)\n",
    "                push!(fits, fit)\n",
    "            end\n",
    "            push!(results[\"fit\"], fits)\n",
    "        end\n",
    "        push!(results[\"runtimes\"], runtime)\n",
    "        push!(results[\"rank\"], R)\n",
    "        push!(results[\"n_params\"], n_params)\n",
    "        @show (R, n_params, runtime, maximum(fits))\n",
    "    end\n",
    "    return results\n",
    "end\n",
    "results_NTD = run_NTD(T, Rs1, iter_max = iter_max)\n",
    "save(path2resultsNTD_rankset1, results_NTD)\n",
    "println(\"saved in $path2resultsNTD_rankset1\")\n",
    "\n",
    "results_NTD = run_NTD(T, Rs2, iter_max = iter_max)\n",
    "save(path2resultsNTD_rankset2, results_NTD)\n",
    "println(\"saved in $path2resultsNTD_rankset2\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "366947a6-6f69-4661-91ab-337166cb90f6",
   "metadata": {
    "tags": []
   },
   "source": [
    "### Plots Fits ###"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cec498b7-64bb-434e-ab7e-a358fa621498",
   "metadata": {},
   "outputs": [],
   "source": [
    "results_NTT1 = load(path2resultsNTT_rankset1)\n",
    "results_NTD1 = load(path2resultsNTD_rankset1)\n",
    "\n",
    "results_NTT2 = load(path2resultsNTT_rankset2)\n",
    "results_NTD2 = load(path2resultsNTD_rankset2)\n",
    "results_MBA = load(path2resultsMBA)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "601b692d-f384-4aca-a578-67deef850fa1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot config\n",
    "\n",
    "mc = :transparent\n",
    "ms_MBA = :circle\n",
    "ms_NTT = :rtriangle\n",
    "ms_NTD = :heptagon\n",
    "\n",
    "lt_NTT = :dot\n",
    "lt_NTD = :dashdot\n",
    "lt_MBA  = :solid\n",
    "\n",
    "lc_MBA  = :red\n",
    "lc_NTD  = :yellow\n",
    "lc_NTT = :pink"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "64daad3e-dbd6-4b9e-9a51-5a5ff34ec3a0",
   "metadata": {},
   "source": [
    "## See instability of baselines (Fig 11)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "913d94ce-3d58-48a9-ad60-62e8b08e268e",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Plots.scalefontsizes(0.5)\n",
    "function get_bar_plt(results, method; bar_width = 5)\n",
    "    n_plt = length(results[\"n_params\"])\n",
    "    bar_width = 5\n",
    "\n",
    "    plt_ins = plot(ylim=(0.8,1.0), yticks=[0.8, 0.85, 0.9, 0.95, 1.0], size=plt_size, legend=:topleft, framestyle=:box, xlabel=\"# Parameters\", ylabel=\"Fit\", xscale=:log10)\n",
    "    for i = 1 : n_plt\n",
    "        k = results[\"n_params\"][i]\n",
    "        if i == n_plt\n",
    "            label = method\n",
    "            if method == \"NTT\"\n",
    "                label = \"NTTF\"\n",
    "            end\n",
    "        else\n",
    "            label = false\n",
    "        end\n",
    "        boxplot!(plt_ins, [k], [results[\"fit\"][i]], color=\"gray\", label=label, legend=:topleft, xscale=:log10, bar_width=k/bar_width)\n",
    "    end\n",
    "    return plt_ins\n",
    "end\n",
    "\n",
    "plt_with_cyc = true\n",
    "if plt_with_cyc\n",
    "    plt_ins_NTT1 = get_bar_plt(results_NTT1, \"NTT\")\n",
    "    plot!(plt_ins_NTT1, results_MBA[\"n_params\"], maximum.(results_MBA[\"fits\"]), shape=ms_MBA, mc=mc, ms=10, msc=lc_MBA, label=\"Proposal\", line=(lt_MBA, lc_MBA))\n",
    "    plt_ins_NTD1 = get_bar_plt(results_NTD1, \"NTD\")\n",
    "    plot!(plt_ins_NTD1, results_MBA[\"n_params\"], maximum.(results_MBA[\"fits\"]), shape=ms_MBA, mc=mc, ms=10, msc=lc_MBA, label=\"Proposal\", line=(lt_MBA, lc_MBA))\n",
    "    \n",
    "    plt_ins_NTT2 = get_bar_plt(results_NTT2, \"NTT\")\n",
    "    plot!(plt_ins_NTT2, results_MBA[\"n_params\"], maximum.(results_MBA[\"fits\"]), shape=ms_MBA, mc=mc, ms=10, msc=lc_MBA, label=\"Proposal\", line=(lt_MBA, lc_MBA))\n",
    "    plt_ins_NTD2 = get_bar_plt(results_NTD2, \"NTD\")\n",
    "    plot!(plt_ins_NTD2, results_MBA[\"n_params\"], maximum.(results_MBA[\"fits\"]), shape=ms_MBA, mc=mc, ms=10, msc=lc_MBA, label=\"Proposal\", line=(lt_MBA, lc_MBA))\n",
    "    pl = plot(plt_ins_NTT1, plt_ins_NTD1, plt_ins_NTT2, plt_ins_NTD2,layout=(1,4),\n",
    "        tickfontsize=16, labelfontsize=16, size=(2*780, 2*250), legendfontsize=12)\n",
    "    savefig(pl, \"reconst_imgs/NTTF_NTD.pdf\")\n",
    "else\n",
    "    plt_ins_NTT = get_bar_plt(results_NTT, \"NTT\")\n",
    "    plot!(plt_ins_NTT, results_MBA[\"n_params\"][Not(2)], maximum.(results_MBA[\"fits\"])[Not(2)], shape=ms_MBA, mc=mc, ms=10, msc=lc_MBA, label=\"Proposal\", line=(lt_MBA, lc_MBA))\n",
    "    plt_ins_NTD = get_bar_plt(results_NTD, \"NTD\")\n",
    "    plot!(plt_ins_NTD, results_MBA[\"n_params\"][Not(2)], maximum.(results_MBA[\"fits\"])[Not(2)], shape=ms_MBA, mc=mc, ms=10, msc=lc_MBA, label=\"Proposal\", line=(lt_MBA, lc_MBA))\n",
    "    plot(plt_ins_NTT, plt_ins_NTD)\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ac69449-84f9-4f08-9482-d4f41c5c40ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "function get_time_pred(Tm, Tgt, day, lane, label, a)\n",
    "    xaxis= (\"Time (min.)\", ([0:30:288;],[0:30*5:288*5;]))\n",
    "    yaxis= (\"Spped (mph)\", (20,80))\n",
    "    \n",
    "    tgt = vec(Tgt[day,:,:,lane]')\n",
    "    tm  = vec(Tm[day,:,:,lane]')\n",
    "    tgt = ma(tgt, a)\n",
    "    tm  = ma(tm , a)\n",
    "    p = plot(legend=:bottomleft, yaxis=yaxis, xaxis=xaxis)\n",
    "    p = plot!(tgt, label=\"GT\")\n",
    "    p = plot!(tm,  label=\"$label\")\n",
    "    return p\n",
    "end"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a5b7d3f0-2626-49af-94a7-e2c8cf94feca",
   "metadata": {},
   "source": [
    "### Plot interaction (Fig.13)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d854c929-e049-4497-9a2e-704c7e122143",
   "metadata": {},
   "outputs": [],
   "source": [
    "label=[\"Lane 1\" \"Lane 2\" \"Lane 3\" \"Lane 4\"]\n",
    "Fdl = get_tensor_from_theta(theta_cyc[:,1,1,:]);\n",
    "pcyc = plot(Fdl, label=label, xaxis=xaxis, title=\"(d-l) interaction in \\n cyclic two-body approximation\")\n",
    "\n",
    "Fdl = get_tensor_from_theta(theta_2[:,1,1,:]);\n",
    "p2 = plot(Fdl, label=label, xaxis=xaxis, title=\"(d-l) interaction in \\n two-body approximation\")\n",
    "\n",
    "Fdl = get_tensor_from_theta(theta_3[:,1,1,:]);\n",
    "p3 = plot(Fdl, label=label, xaxis=xaxis, title=\"(d-l) interaction in \\n two-body approximation\")\n",
    "\n",
    "plot(pcyc,p2,p3, layout=(1,3), size=(1200,400), ylim=(0.5,1.05), legend=:bottomright)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0cd9dfde-1738-4272-bdd7-134e06b3180a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Julia (8 threads) 1.8.5",
   "language": "julia",
   "name": "julia-_8-threads_-1.8"
  },
  "language_info": {
   "file_extension": ".jl",
   "mimetype": "application/julia",
   "name": "julia",
   "version": "1.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
