{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8751c42e-1f58-4665-ba5a-f4ff80b596e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "from PIL import Image\n",
    "import torch\n",
    "import sys\n",
    "import utils_exp as ue\n",
    "import importlib\n",
    "sys.path.append(\"methods/Grad/\")\n",
    "sys.path.append(\"methods/emmix/\")\n",
    "\n",
    "import dense_em_mix_all\n",
    "import gradcp\n",
    "import gradmix\n",
    "importlib.reload(gradcp)\n",
    "importlib.reload(gradmix)\n",
    "importlib.reload(dense_em_mix_all)\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a312142-d20c-4c42-968d-a590acde4839",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_2dplot(tensor):\n",
    "    tensor_shape = np.shape(tensor)\n",
    "    gamma = 0.1  # 0 < gamma < 1\n",
    "    \n",
    "    rgb_image = np.zeros((tensor_shape[0], tensor_shape[1], 3), dtype=np.uint8)\n",
    "    \n",
    "    if tensor[:, :, 0].max() > 0:\n",
    "        red = (tensor[:, :, 0] / tensor[:, :, 0].max()) ** gamma\n",
    "        rgb_image[:, :, 0] = (red * 255).astype(np.uint8)\n",
    "    \n",
    "    if tensor[:, :, 1].max() > 0:\n",
    "        blue = (tensor[:, :, 1] / tensor[:, :, 1].max()) ** gamma\n",
    "        rgb_image[:, :, 2] = (blue * 255).astype(np.uint8)\n",
    "    \n",
    "    plt.figure(figsize=(6, 6))\n",
    "    plt.imshow(rgb_image)\n",
    "    plt.title(\"Moons with Label-Flip and Corner Outliers\")\n",
    "    #plt.axis('off')\n",
    "    plt.show()\n",
    "\n",
    "def get_2dplot_fixed_scaler(tensor, scaler_tensor, gamma=0.1):\n",
    "    tensor_shape = np.shape(tensor)\n",
    "    rgb_image = np.zeros((tensor_shape[0], tensor_shape[1], 3), dtype=np.uint8)\n",
    "    \n",
    "    if tensor[:, :, 0].max() > 0:\n",
    "        red = (tensor[:, :, 0] / scaler_tensor[:, :, 0].max()) ** gamma  # ガンマ補正あり\n",
    "        rgb_image[:, :, 0] = (red * 255).astype(np.uint8)\n",
    "    \n",
    "    if tensor[:, :, 1].max() > 0:\n",
    "        blue = (tensor[:, :, 1] / scaler_tensor[:, :, 1].max()) ** gamma\n",
    "        rgb_image[:, :, 2] = (blue * 255).astype(np.uint8)\n",
    "    \n",
    "    plt.figure(figsize=(6, 6))\n",
    "    plt.imshow(rgb_image)\n",
    "    plt.title(\"Moons with Label-Flip and Corner Outliers\")\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "73d79120-1f45-41fc-84f0-89f00b9c5e63",
   "metadata": {},
   "outputs": [],
   "source": [
    "T_no_noise = np.load(\"../syn_data/anoma/emperical_dist_90x90x2_N5000_out0.npy\")\n",
    "T_noise50 = np.load(\"../syn_data/anoma/emperical_dist_90x90x2_N5000_out50.npy\")\n",
    "T_noise100 = np.load(\"../syn_data/anoma/emperical_dist_90x90x2_N5000_out100.npy\")\n",
    "T_noise150 = np.load(\"../syn_data/anoma/emperical_dist_90x90x2_N5000_out150.npy\")\n",
    "\n",
    "T_no_noise = T_no_noise / np.sum(T_no_noise)\n",
    "T_noise150 = T_noise150 / np.sum(T_noise150)\n",
    "T_noise100 = T_noise100 / np.sum(T_noise100)\n",
    "T_noise50  = T_noise50  / np.sum(T_noise50)\n",
    "\n",
    "noises = [0, 50, 100, 150]\n",
    "αs = [0.1, 0.3, 0.5, 0.7, 0.9]\n",
    "Ts = { 0:T_no_noise, 50:T_noise50, 100:T_noise100, 150:T_noise150}\n",
    "\n",
    "Ps = { noise:{} for noise in noises }\n",
    "#Ps[noise][alpha]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f050f6f1-38b9-4f09-89fa-d651bdb62bb2",
   "metadata": {},
   "outputs": [],
   "source": [
    "ranktrain = [22,2]\n",
    "rankcp = 22\n",
    "ranktucker = 0\n",
    "max_iter = 20\n",
    "\n",
    "title_fontsize = 16"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "770eaf0a-7620-4399-9306-a95f70e1c0e9",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "for noise in noises:\n",
    "    for α in αs:\n",
    "        _, _, P, _ = dense_em_mix_all.EMCPTuckerTrain(Ts[noise], [rankcp, ranktucker, ranktrain], alpha=α, model=[1,0,1,1], max_iter=max_iter, tol=0, verbose_interval=1);\n",
    "        Ps[noise][α] = P"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ede9be8-9391-4095-8fa0-7cec1f0f3af4",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams['pdf.fonttype'] = 42 \n",
    "plt.rcParams['ps.fonttype'] = 42\n",
    "import matplotlib.patches as patches\n",
    "\n",
    "def get_2dplot_fixed_scaler_on_ax(tensor, scaler_tensor, gamma=0.1, title=\"\", ax=None,\n",
    "                                 fontsize=25, fontname='DejaVu Sans',\n",
    "                                 highlight_rect=None):  # highlight_rect=(x, y, width, height)\n",
    "    tensor_shape = np.shape(tensor)\n",
    "    rgb_image = np.zeros((tensor_shape[0], tensor_shape[1], 3), dtype=np.uint8)\n",
    "    \n",
    "    if scaler_tensor[:, :, 0].max() > 0:\n",
    "        red = (tensor[:, :, 0] / scaler_tensor[:, :, 0].max()) ** gamma\n",
    "        rgb_image[:, :, 0] = (red * 255).astype(np.uint8)\n",
    "    \n",
    "    if scaler_tensor[:, :, 1].max() > 0:\n",
    "        blue = (tensor[:, :, 1] / scaler_tensor[:, :, 1].max()) ** gamma\n",
    "        rgb_image[:, :, 2] = (blue * 255).astype(np.uint8)\n",
    "    \n",
    "    if ax is None:\n",
    "        fig, ax = plt.subplots(figsize=(6, 6))\n",
    "    ax.imshow(rgb_image)\n",
    "    ax.set_title(title, fontsize=fontsize, fontname=fontname)\n",
    "    ax.axis(\"off\")\n",
    "    \n",
    "    if highlight_rect is not None:\n",
    "        x, y, w, h = highlight_rect\n",
    "        rect = patches.Rectangle(\n",
    "            (x, y), w, h,\n",
    "            linewidth=2, edgecolor='yellow', facecolor='none'\n",
    "        )\n",
    "        ax.add_patch(rect)\n",
    "        \n",
    "    if highlight_rect is not None:\n",
    "        rect = patches.Rectangle(\n",
    "            (0, 70), 19, 19,\n",
    "            linewidth=2, edgecolor='cyan', facecolor='none', linestyle='--'\n",
    "        )\n",
    "        ax.add_patch(rect)\n",
    "\n",
    "    if highlight_rect is not None:\n",
    "        rect = patches.Rectangle(\n",
    "            (70, 0), 19, 19,\n",
    "            linewidth=2, edgecolor='cyan', facecolor='none', linestyle='--'\n",
    "        )\n",
    "        ax.add_patch(rect)\n",
    "\n",
    "\n",
    "    return ax"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fda45530-e7bb-4f88-ac0e-a441dee26911",
   "metadata": {},
   "outputs": [],
   "source": [
    "noise = 150\n",
    "st = Ts[noise]\n",
    "fig, axs = plt.subplots(1, 3, figsize=(12, 6))\n",
    "highlight_rect = (28, 50, 25, 28)\n",
    "get_2dplot_fixed_scaler_on_ax(Ts[noise], st, title=\"Noisy empirical dist. \", ax=axs[0], highlight_rect=highlight_rect)\n",
    "alpha = 0.1\n",
    "get_2dplot_fixed_scaler_on_ax(Ps[noise][alpha], st, title=\"α:0.1\", ax=axs[1], highlight_rect=highlight_rect)\n",
    "alpha = 0.9\n",
    "get_2dplot_fixed_scaler_on_ax(Ps[noise][alpha], st, title=\"α:0.9\", ax=axs[2], highlight_rect=highlight_rect)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"figs/exp_loss/tensors_with_highlight.pdf\", bbox_inches='tight')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
