{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "initial_id",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-03T18:06:31.243358Z",
     "start_time": "2025-09-03T18:06:24.129521Z"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "import configs\n",
    "import ddpm_class\n",
    "import FM_class\n",
    "import evaluate\n",
    "import plot_func\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "32655fc265bed0f0",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-03T18:06:32.852065Z",
     "start_time": "2025-09-03T18:06:31.595962Z"
    }
   },
   "outputs": [],
   "source": [
    "cddpm = ddpm_class.cDDPM(save_int=80000, have_rho=False, cddpm_name=\"cddpm\")\n",
    "cddpm_have_rho = ddpm_class.cDDPM(save_int=80000, have_rho=True, cddpm_name=\"cddpm_rho\")\n",
    "tddpm = ddpm_class.tDDPM(save_int=80000, tddpm_name=\"tddpm\")\n",
    "\n",
    "tFM = FM_class.tFM(save_int=80000, tFM_name=\"tFM\")\n",
    "cFM = FM_class.cFM(save_int=80000, have_rho=False, cFM_name=\"cFM\")\n",
    "cFM_have_rho = FM_class.cFM(save_int=80000, have_rho=True, cFM_name=\"cFM_rho\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45ee7f0faffcb5c7",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-07-07T17:12:31.236103Z",
     "start_time": "2025-07-07T17:09:37.424647Z"
    }
   },
   "outputs": [],
   "source": [
    "cddpm.train(40000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af172d02bf977141",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-07-07T17:15:11.943383Z",
     "start_time": "2025-07-07T17:12:31.507814Z"
    }
   },
   "outputs": [],
   "source": [
    "cddpm_have_rho.train(40000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4840c527730d5c27",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-07-08T17:15:05.562304Z",
     "start_time": "2025-07-08T17:04:54.035780Z"
    }
   },
   "outputs": [],
   "source": [
    "epoches = 80000\n",
    "cFM.train(epoches)\n",
    "cFM_have_rho.train(epoches)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "689bc7a95dadaf35",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-07-08T16:47:31.655345Z",
     "start_time": "2025-07-08T16:33:20.996742Z"
    }
   },
   "outputs": [],
   "source": [
    "cFM.train( 80000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9505ef4f19d4e117",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-03T18:06:32.895652Z",
     "start_time": "2025-09-03T18:06:32.877457Z"
    }
   },
   "outputs": [],
   "source": [
    "tddpm.load_ckpt('./saved_model/tddpm_it_80000.pth',)\n",
    "cddpm.load_ckpt('./saved_model/cddpm_it_80000.pth')\n",
    "cddpm_have_rho.load_ckpt('./saved_model/cddpm_rho_it_80000.pth')\n",
    "tFM.load_ckpt('./saved_model/tFM_it_80000.pth',)\n",
    "cFM.load_ckpt('./saved_model/cFM_it_80000.pth')\n",
    "cFM_have_rho.load_ckpt('./saved_model/cFM_rho_it_80000.pth')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28274df0cc34e59b",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-08-14T18:49:40.842131Z",
     "start_time": "2025-08-14T18:49:40.837692Z"
    }
   },
   "outputs": [],
   "source": [
    "from dataset import generate_data\n",
    "import tqdm\n",
    "def unfold_batch(model, y):\n",
    "    batch_size = y.shape[0]\n",
    "    with torch.no_grad():\n",
    "        unfolded_part = model.sampling(batch_size, y)\n",
    "\n",
    "    return unfolded_part\n",
    "\n",
    "def eval_models(model_list, batch_num, require_rho_list, batch_size = None, rho_list = None):\n",
    "    if rho_list is None:\n",
    "        rho_list = np.linspace(-1,1,201)\n",
    "\n",
    "    if batch_size is None:\n",
    "        batch_size = configs.batch_size\n",
    "\n",
    "    MSE_record = np.zeros([len(model_list), len(rho_list)])\n",
    "    SWD_record = np.zeros([len(model_list), len(rho_list)])\n",
    "\n",
    "    for rho_index in tqdm.tqdm(range(len(rho_list))):\n",
    "        rho = rho_list[rho_index]\n",
    "        unfolded = [[] for _ in range(len(model_list))]\n",
    "        x_record = []\n",
    "        for i in range(batch_num):\n",
    "            x_batch_B2, y_batch_B2, rho = generate_data(batch_size, rho = rho)\n",
    "            x_batch_B2 = torch.tensor(x_batch_B2).to(configs.device).to(torch.float32)\n",
    "            y_batch_B2 = torch.tensor(y_batch_B2).to(configs.device).to(torch.float32)\n",
    "\n",
    "            for model_index in range(len(model_list)):\n",
    "                if require_rho_list[model_index]:\n",
    "                    rho_tensor = torch.full((y_batch_B2.shape[0], 1), rho, dtype=torch.float32, device=configs.device)\n",
    "                    y_batch_B2_cat = torch.cat([y_batch_B2, rho_tensor], dim=1)\n",
    "                else:\n",
    "                    y_batch_B2_cat = y_batch_B2\n",
    "\n",
    "                model = model_list[model_index]\n",
    "                unfolded_part = unfold_batch(model, y_batch_B2_cat)\n",
    "                unfolded[model_index].append(unfolded_part)\n",
    "\n",
    "            x_record.append(x_batch_B2)\n",
    "\n",
    "        x_record = torch.cat(x_record)\n",
    "\n",
    "        for model_index in range(len(model_list)):\n",
    "            # print(unfolded[model_index].__len__())\n",
    "            unfolded_compute = torch.cat(unfolded[model_index])\n",
    "            MSE_record[model_index, rho_index] = np.mean(np.linalg.norm(x_record.cpu().numpy() - unfolded_compute.cpu().numpy(), axis=1) ** 2)\n",
    "            SWD_record[model_index, rho_index] = plot_func.compute_SWD(x_record.cpu().numpy(), unfolded_compute.cpu().numpy(), sample_size= None)\n",
    "    print(MSE_record)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8bd879bb513b4527",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-08-14T18:55:41.490588Z",
     "start_time": "2025-08-14T18:55:41.247625Z"
    }
   },
   "outputs": [],
   "source": [
    "eval_models([tFM], 10, [False], batch_size = None, rho_list = [0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8abc051bb945640c",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-03T18:07:49.812910Z",
     "start_time": "2025-09-03T18:06:33.741876Z"
    }
   },
   "outputs": [],
   "source": [
    "from evaluate_nosave import eval_models, plot_models\n",
    "# eval_models([tddpm, cddpm, cddpm_have_rho], 10, [False, False, True])\n",
    "# plot_models([tddpm, cddpm, cddpm_have_rho, tFM, cFM, cFM_have_rho], 10, [False, False, True, False, False, True], save=True, rho_list=[0.9])\n",
    "# plot_models([ tFM], 10, [False], save=False, rho_list=[0.9])\n",
    "plot_models([tFM, cFM, cFM_have_rho], 10, [ False, False, True], save=True, rho_list=[0.9])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f5a8ef0ff34ebb0",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-03T16:39:04.872797Z",
     "start_time": "2025-09-03T16:39:01.941353Z"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "data = np.load('./output/test_record.npz')\n",
    "MSE_data = data['MSE_record']\n",
    "SWD_data = data['SWD_record']\n",
    "model_name = ['tddpm', 'cddpm', 'cddpm_have_rho', 'tFM', 'cFM', 'cFM_have_rho']\n",
    "x = np.linspace(-1,1,201)\n",
    "\n",
    "SWD_omni = np.load('./output/omnifold_SWD.npy')\n",
    "SWD_omni_combine = np.load('./output/omnifold_combine_SWD.npy')\n",
    "SB_SWD = np.load('./output/SB_SWD.npy')\n",
    "def plot_data(input_data, model_name, x, title, ylabel, save_name = None, SWD_omni = None, Omni_name = ('omnifold_best', 'omnifold_combine', 'SB')):\n",
    "    for i in range(len(model_name)):\n",
    "        plt.plot(x, input_data[i], label=model_name[i])\n",
    "    if SWD_omni is not None:\n",
    "        for j in range(len(SWD_omni)):\n",
    "            plt.plot(x, SWD_omni[j], label=Omni_name[j])\n",
    "\n",
    "    plt.axvspan(-0.75, -0.25, color='grey', alpha=0.5)\n",
    "    plt.axvspan(0.25, 0.75, color='grey', alpha=0.5)\n",
    "    plt.grid()\n",
    "    plt.legend()\n",
    "    plt.xlabel(r'$\\rho$')\n",
    "    plt.ylabel(ylabel)\n",
    "    plt.title(title)\n",
    "    if save_name is not None:\n",
    "        plt.savefig('./fig/' + save_name + '.png', bbox_inches='tight', dpi=300)\n",
    "    plt.show()\n",
    "\n",
    "plot_data(MSE_data, model_name, x, title='MSE_record', ylabel=\"MSE\", save_name='MSE_record')\n",
    "plot_data(SWD_data, model_name, x, title='SWD_record', ylabel=\"SWD\", save_name='SWD_record', SWD_omni = (SWD_omni, SWD_omni_combine, SB_SWD))\n",
    "# plt.plot(SBSWD)\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e56a4470-7734-48fb-8098-7c71d195fff0",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-04T20:47:16.551679Z",
     "start_time": "2025-09-04T20:47:14.251295Z"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "color_list = ['b','g','r','black','orange','brown','k']\n",
    "data = np.load('./output/test_record.npz')\n",
    "MSE_data = data['MSE_record']\n",
    "SWD_data = data['SWD_record']\n",
    "model_name = ['tddpm', 'cddpm', 'cddpm_have_rho', 'EI-FM', 'cFM', 'cFM-$\\\\gamma$']\n",
    "x = np.linspace(-1,1,201)\n",
    "\n",
    "SWD_omni = np.load('./output/omnifold_SWD.npy')\n",
    "SWD_omni_combine = np.load('./output/omnifold_combine_SWD.npy')\n",
    "SB_SWD = np.load('./output/SB_SWD.npy')\n",
    "def plot_data(input_data, model_name, x, title, ylabel, save_name = None, SWD_omni = None, Omni_name = ('Omnifold-best', 'Omnifold-combine', 'SBUnfold'), ylim = None, fontsize = 16):\n",
    "    for k in range(len(model_name)-3):\n",
    "        i=k+3\n",
    "        plt.plot(x, input_data[i], label=model_name[i], color=color_list[k])\n",
    "    if SWD_omni is not None:\n",
    "        for j in range(len(SWD_omni)):\n",
    "            plt.plot(x, SWD_omni[j], label=Omni_name[j], color=color_list[j+3])\n",
    "    # print(input_data[:,191])\n",
    "    plt.axvspan(-0.75, -0.25, color='grey', alpha=0.5)\n",
    "    plt.axvspan(0.25, 0.75, color='grey', alpha=0.5)\n",
    "    plt.grid()\n",
    "    plt.legend(fontsize=12)\n",
    "    plt.xlabel(r'$\\gamma$', fontsize=fontsize)\n",
    "    plt.ylabel(ylabel, fontsize=fontsize)\n",
    "    plt.xticks([-1,-0.5,0,0.5,1],fontsize=fontsize)\n",
    "    plt.yticks(fontsize=fontsize)\n",
    "    # plt.title(title)\n",
    "    plt.xlim([-1,1])\n",
    "    if ylim is not None:\n",
    "        plt.ylim(ylim)\n",
    "    if save_name is not None:\n",
    "        plt.savefig('./fig/' + save_name + '.png', bbox_inches='tight', dpi=300)\n",
    "    plt.show()\n",
    "\n",
    "plot_data(MSE_data, model_name, x, title='MSE_record', ylabel=\"MSE\", save_name='MSE_record')\n",
    "plot_data(SWD_data, model_name, x, title='SWD_record', ylabel=\"SWD\", save_name='SWD_record', SWD_omni = (SWD_omni, SWD_omni_combine, SB_SWD), ylim=[0,0.1])\n",
    "# plt.plot(SBSWD)\n",
    "\n",
    "\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
