{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "312f383b-d778-48b0-a7e3-4dd00e29fcc3",
   "metadata": {},
   "outputs": [],
   "source": [
    "using JLD2\n",
    "using DataStructures\n",
    "using Images\n",
    "using StatsBase\n",
    "using Plots\n",
    "using Colors\n",
    "pyplot()\n",
    "using NearestNeighbors\n",
    "include(\"../methods/IgTensors/decomp.jl\");\n",
    "include(\"../methods/IgTensors/get_msk.jl\");\n",
    "include(\"../methods/IgTensors/get_params.jl\");\n",
    "io = IOContext(stdout, :limit => false);\n",
    "\n",
    "# Creat Datasets\n",
    "img_deck = [4,5,10,23,17,28,38,42,78, 7]\n",
    "img_size = [40,40]\n",
    "n_imgs = length(img_deck)\n",
    "T = zeros(img_size[1], img_size[2], 3, n_imgs)\n",
    "for (i, img_id) in enumerate( 1:n_imgs )\n",
    "    img_id = img_deck[img_id]\n",
    "    img_s = load(\"../../data/cimg/COIL100/tensor_COIL100_obj$img_id.jld2\")[\"tensor_COIL100_obj$img_id\"];\n",
    "    T[:,:,:,i] = img_s[:,:,:,1]\n",
    "end\n",
    "\n",
    "layout = (2,5);\n",
    "img_size   = (450, 250);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c9cb739-f62f-485a-b2f6-8979b1904d9e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# see input imgs\n",
    "img_plts = []\n",
    "for img_id = 1:n_imgs\n",
    "    tensor_for_img = permutedims( T[:,:,:,img_id], [3,1,2] )\n",
    "    imgc = colorview(RGB, tensor_for_img)\n",
    "    imgc = map(clamp01nan,imgc)\n",
    "    push!(img_plts, plot(imgc, frame=:none, title=\"$img_id\"))\n",
    "end\n",
    "title = \"Four-body\"\n",
    "p = plot(img_plts..., layout=layout, size=img_size, plot_title=title)\n",
    "save(\"reconst_imgs/$title.pdf\", p)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "287afd53-1c5c-4fb0-b21a-7e3ec385118d",
   "metadata": {},
   "source": [
    "# Decomposition #"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1df99ead-f582-445a-aff9-88a388077c4f",
   "metadata": {},
   "source": [
    "### Run many-body approximation and reconstruct inputs ###"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39cd2666-1ca7-4dfc-8eee-8f3ed34a1c0e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Three-body approximation ( Fig3(a) )\n",
    "intracts = [[1,1,1,1],[1,1,1,1,1,1],[1,1,1,1],[0]]\n",
    "Tm_a, theta_m_a, eta_m_a = manybody_app(T,intracts,verbose=true);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9406ffc0-89a4-4557-bd4e-72b9d3fb9828",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Two-body approximation ( Fig3(b) )\n",
    "intracts = [[1,1,1,1],[1,1,1,1,1,1],[0,0,0,0],[0]]\n",
    "Tm_b, theta_m_b, eta_m_b = manybody_app(T,intracts,verbose=true);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67b714c8-3002-4f8b-ab12-0db480b9605e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Fig 3(c)\n",
    "intracts = [[1,1,1,1],[1,0,1,0,1,1],[0,1,0,0],[0]]\n",
    "Tm_mono_c, theta_m_mono_c, eta_m_mono_c = manybody_app(T,intracts,verbose=true);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa0d54bb-189f-457f-9ee8-224591951847",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Fig3 (d)\n",
    "intracts = [[1,1,1,1],[1,0,1,0,1,0],[0,1,0,0],[0]]\n",
    "Tm_mono, theta_m_mono, eta_m_mono = manybody_app(T,intracts,verbose=true);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eaa56ad2-b399-4b8f-a237-c041358f0f1d",
   "metadata": {},
   "outputs": [],
   "source": [
    "function show_img_reconst(imgs, n_imgs) \n",
    "    img_plts = []\n",
    "    for img_id = 1:n_imgs\n",
    "        tensor_for_img = permutedims( imgs[:,:,:,img_id], [3,1,2] )\n",
    "        imgc = colorview(RGB, tensor_for_img)\n",
    "        imgc = map(clamp01nan,imgc)\n",
    "        push!(img_plts, plot(imgc, frame=:none, title=\"$img_id\"))\n",
    "    end\n",
    "    return img_plts\n",
    "end\n",
    "title = \"\"\n",
    "img_size   = (450, 250)\n",
    "\n",
    "title = \"Three-body\"\n",
    "img_plts = show_img_reconst(Tm_a, n_imgs)\n",
    "p = plot(img_plts..., layout=layout, size=img_size, plot_title=title)\n",
    "save(\"reconst_imgs/$title.pdf\", p)\n",
    "\n",
    "title = \"Two-body\"\n",
    "img_plts = show_img_reconst(Tm_b, n_imgs)\n",
    "p = plot(img_plts..., layout=layout, size=img_size, plot_title=title)\n",
    "save(\"reconst_imgs/$title.pdf\", p)\n",
    "\n",
    "\n",
    "title = \"Monotone_img\"\n",
    "img_plts = show_img_reconst(Tm_mono_c, n_imgs)\n",
    "p = plot(img_plts..., layout=layout, size=img_size, plot_title=title)\n",
    "save(\"reconst_imgs/$title.pdf\", p)\n",
    "\n",
    "title = \"Monotones\"\n",
    "img_plts = show_img_reconst(Tm_mono, n_imgs)\n",
    "p = plot(img_plts..., layout=layout, size=img_size, plot_title=title)\n",
    "save(\"reconst_imgs/$title.pdf\", p)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "089ce072-4d83-4634-86ac-2da971b4562f",
   "metadata": {},
   "source": [
    "## Visualize three-body interaction (w,h,i)  ##"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b99ceedf-5307-4aad-b672-d5b4870f8061",
   "metadata": {},
   "outputs": [],
   "source": [
    "## Get three-body interaction\n",
    "Fwhi = get_tensor_from_theta(theta_m_mono_c[:,:,1,:]);\n",
    "\n",
    "hmaps = []\n",
    "for img_id = 1:n_imgs\n",
    "    hmap = heatmap(Fwhi[:,:,img_id], color = :greys, yflip=true, cbar=false, ticks=:none, frame=:none, title=\"$img_id\", size=(200,200))\n",
    "    push!(hmaps, hmap)\n",
    "end\n",
    "p = plot(hmaps..., layout=layout,size=img_size)\n",
    "save(\"obtained_intra/gray_scale.pdf\", p)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "edda8fee-ad4f-411e-abd2-77753f953f93",
   "metadata": {},
   "source": [
    "## Visualize two-body interaction (c,i) ##"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f9ad882b-8929-43e7-8448-5fd13c472ef2",
   "metadata": {},
   "outputs": [],
   "source": [
    "function get_fci(theta_m, title)\n",
    "    Fci = get_tensor_from_theta(theta_m)\n",
    "    Fci_normalized = mapslices(x -> x / norm(x), Fci, dims=1)\n",
    "    yaxis = []\n",
    "    clim =  (0,1)\n",
    "    n_img = size(theta_m)[2]\n",
    "    xaxis = (\"Image\", 1:1:n_img)\n",
    "    yaxis = (\"Color\", ([1,2,3],[\"R\",\"G\",\"B\"]))\n",
    "    p = heatmap(Fci_normalized, yflip=false, cbar=true, size=(800,250), xaxis=xaxis, yaxis=yaxis, frame=:box, c=:bilbao, clim=clim, tickfontsize=16, xguidefontsize=16, yguidefontsize=16, colorbar_fontsize=16)\n",
    "    m,n = size(Fci)\n",
    "    vline!(p, 0.5:(n+0.5), c=:black, label=:false)\n",
    "    hline!(p, 0.5:(m+0.5), c=:black, label=:false)\n",
    "\n",
    "    fontsize = 16\n",
    "    ann = [(j,i, text(round(Fci_normalized[i,j], digits=2), fontsize, :white, :center))\n",
    "                for i in 1:m for j in 1:n]\n",
    "    annotate!(ann, linecolor=:white)\n",
    "    save(\"obtained_intra/$title.pdf\", p)\n",
    "    return ann\n",
    "end\n",
    "p = get_fci(theta_m_mono_c[1,1,:,1:10], \"mono_img\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c282eef-4ca7-4903-a594-922dbe6bad7f",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f25a801b-aecf-4a6f-9479-f16ab699325b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Julia (8 threads) 1.8.5",
   "language": "julia",
   "name": "julia-_8-threads_-1.8"
  },
  "language_info": {
   "file_extension": ".jl",
   "mimetype": "application/julia",
   "name": "julia",
   "version": "1.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
