{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "acba136d-c9e7-4dfc-9b4d-5060234c72be",
   "metadata": {},
   "outputs": [],
   "source": [
    "from PIL import Image\n",
    "import torch\n",
    "import sys\n",
    "import utils_exp as ue\n",
    "import importlib\n",
    "sys.path.append(\"methods/Grad/\")\n",
    "sys.path.append(\"methods/emmix/\")\n",
    "\n",
    "import dense_em_mix_all\n",
    "import gradcp\n",
    "import gradmix\n",
    "importlib.reload(gradcp)\n",
    "importlib.reload(gradmix)\n",
    "importlib.reload(dense_em_mix_all)\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d195d2a1-9d79-4704-88f7-8dc96c769363",
   "metadata": {},
   "outputs": [],
   "source": [
    "# configs\n",
    "\n",
    "αs = [0.3, 0.5, 0.7, 0.9]\n",
    "lrs = {\"SGD\":[1000, 100, 10, 1, 0.1],\n",
    "       \"Adam\":[1, 0.1, 0.01, 0.001],\n",
    "       \"RMSprop\":[0.1, 0.01, 0.001, 0.0001],\n",
    "       \"Adagrad\":[0.1, 0.01, 0.001, 0.0001],\n",
    "      }\n",
    "\n",
    "methods = [\"Adagrad\", \"RMSprop\", \"Adam\", \"SGD\"]\n",
    "rep_times = 5\n",
    "photo_name = \"house\"\n",
    "\n",
    "#rankcp = 50\n",
    "#ranktucker = [10,10,3]\n",
    "#ranktrain  = [10,3]\n",
    "\n",
    "rankcp = 10\n",
    "ranktucker = [0,0,0]\n",
    "ranktrain  = [5,5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca42be91-c7f9-4c6a-ba6a-f8754c4ac0cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_photo(photo_name):\n",
    "    image = Image.open(f\"../real_data/images/{photo_name}.tiff\")\n",
    "    width, height = image.size\n",
    "    image = image.resize((width // 3, height // 3))\n",
    "    \n",
    "    array = np.array(image).astype(np.float32)\n",
    "    array /= 255.0\n",
    "    array += 1.0e-8\n",
    "    total_sum = array.sum()\n",
    "    array /= total_sum\n",
    "    return array, total_sum"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "08cea503-752a-444e-9cf8-32aa91b14434",
   "metadata": {},
   "outputs": [],
   "source": [
    "photo_name = \"house\"\n",
    "T, total_sum = get_photo(photo_name)\n",
    "plt.imshow(T * total_sum)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "683a4fd8-fd1f-43ad-903c-849bac6ad7c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def save_recover_img_from_res(res, save_path):\n",
    "    P, histos = res\n",
    "    recoverd_img = P.detach().numpy()# * total_sum\n",
    "    recoverd_img /= recoverd_img.max()\n",
    "    plt.axis(\"off\")\n",
    "    plt.imshow(recoverd_img)\n",
    "    plt.savefig(save_path, bbox_inches='tight', pad_inches=0)\n",
    "    print(\"saved in\", save_path)\n",
    "\n",
    "def save_recover_img_from_res_eem(res_eem, save_path):\n",
    "    _, hist, P, details = res_eem\n",
    "    recoverd_img = P * total_sum\n",
    "    recoverd_img /= recoverd_img.max()\n",
    "    plt.axis(\"off\")\n",
    "    plt.imshow(recoverd_img)\n",
    "    plt.savefig(save_path, bbox_inches='tight', pad_inches=0)\n",
    "    print(\"saved in\", save_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "739bc11d-a35b-4346-b095-728de91adddc",
   "metadata": {},
   "outputs": [],
   "source": [
    "res = gradmix.mix_alpha_grad_ctt(T, rankcp, ranktucker, ranktrain, 0.6, lr=0.0001, max_iter=500, verbose_interval=10, optim_method=\"Adagrad\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50d8589c-2eb9-4d7a-ac9f-c87f627e77fa",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "for method in methods:\n",
    "    for idα, α in enumerate(αs):\n",
    "        for idlr, lr in enumerate(lrs[method]):\n",
    "            for rep in range(rep_times):\n",
    "                res = gradmix.mix_alpha_grad_ctt(T, rankcp, ranktucker, ranktrain, α, lr=lr, max_iter=700, verbose_interval=10, optim_method=method)\n",
    "                save_path = f\"results/exp_loss/{photo_name}/{method}/a{idα}_lr{idlr}_{rep}.pkl\"\n",
    "                ue.pickle_dump(res[1], save_path)\n",
    "                print(\"saved in\", save_path)\n",
    "                save_path_img = save_path[0:-4] + \"_img.pdf\"\n",
    "                save_recover_img_from_res(res, save_path_img)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c45e44a8-d520-440b-9198-9da9135b402f",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "for idα, α in enumerate(αs):\n",
    "    for rep in range(rep_times):\n",
    "        res_eem = dense_em_mix_all.EMCPTuckerTrain(T, [rankcp, ranktucker, ranktrain], alpha=α, model=[1,0,1,0], max_iter=700, tol=0, verbose_interval=1);\n",
    "        save_path = f\"../results/exp_loss/{photo_name}_full/eem/a{idα}_{rep}.pkl\"\n",
    "        ue.pickle_dump(res_eem[1], save_path)\n",
    "        print(\"saved in\", save_path)\n",
    "        save_path_img = save_path[0:-4] + \"_img.pdf\"\n",
    "        save_recover_img_from_res_eem(res_eem, save_path_img)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b4a22f0d-f440-4858-af90-752725f539ee",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "rankcp = 50\n",
    "ranktucker = [10,10,3]\n",
    "ranktrain  = [10,3]\n",
    "α=1\n",
    "res_eem = dense_em_mix_all.EMCPTuckerTrain(T, [rankcp, ranktucker, ranktrain], alpha=α, model=[1,1,1,1], max_iter=2, tol=0, verbose_interval=1);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db017e4d-0188-45fa-b929-66ffab20b9fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "_, hist, P, details = res_eem\n",
    "recoverd_img = P * total_sum\n",
    "recoverd_img /= recoverd_img.max()\n",
    "plt.axis(\"off\")\n",
    "plt.imshow(recoverd_img)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
