{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "initial_id",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-07-17T14:21:59.821551Z",
     "start_time": "2025-07-17T14:21:59.819141Z"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "import configs\n",
    "import ddpm_class\n",
    "import FM_class\n",
    "import evaluate\n",
    "import plot_func\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "32655fc265bed0f0",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-07-08T20:54:05.078285Z",
     "start_time": "2025-07-08T20:54:03.274465Z"
    }
   },
   "outputs": [],
   "source": [
    "cddpm = ddpm_class.cDDPM(save_int=80000, have_rho=False, cddpm_name=\"cddpm\")\n",
    "cddpm_have_rho = ddpm_class.cDDPM(save_int=80000, have_rho=True, cddpm_name=\"cddpm_rho\")\n",
    "tddpm = ddpm_class.tDDPM(save_int=80000, tddpm_name=\"tddpm\")\n",
    "\n",
    "tFM = FM_class.tFM(save_int=80000, tFM_name=\"tFM\")\n",
    "cFM = FM_class.cFM(save_int=80000, have_rho=False, cFM_name=\"cFM\")\n",
    "cFM_have_rho = FM_class.cFM(save_int=80000, have_rho=True, cFM_name=\"cFM_rho\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45ee7f0faffcb5c7",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-07-07T17:12:31.236103Z",
     "start_time": "2025-07-07T17:09:37.424647Z"
    }
   },
   "outputs": [],
   "source": [
    "cddpm.train(40000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af172d02bf977141",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-07-07T17:15:11.943383Z",
     "start_time": "2025-07-07T17:12:31.507814Z"
    }
   },
   "outputs": [],
   "source": [
    "cddpm_have_rho.train(40000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4840c527730d5c27",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-07-08T17:15:05.562304Z",
     "start_time": "2025-07-08T17:04:54.035780Z"
    }
   },
   "outputs": [],
   "source": [
    "epoches = 80000\n",
    "cFM.train(epoches)\n",
    "cFM_have_rho.train(epoches)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "689bc7a95dadaf35",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-07-08T16:47:31.655345Z",
     "start_time": "2025-07-08T16:33:20.996742Z"
    }
   },
   "outputs": [],
   "source": [
    "cFM.train( 80000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9505ef4f19d4e117",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-07-08T20:54:08.041994Z",
     "start_time": "2025-07-08T20:54:08.005885Z"
    }
   },
   "outputs": [],
   "source": [
    "tddpm.load_ckpt('./saved_model/tddpm_it_80000.pth',)\n",
    "cddpm.load_ckpt('./saved_model/cddpm_it_80000.pth')\n",
    "cddpm_have_rho.load_ckpt('./saved_model/cddpm_rho_it_80000.pth')\n",
    "tFM.load_ckpt('./saved_model/tFM_it_80000.pth',)\n",
    "cFM.load_ckpt('./saved_model/cFM_it_80000.pth')\n",
    "cFM_have_rho.load_ckpt('./saved_model/cFM_rho_it_80000.pth')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8abc051bb945640c",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-07-08T20:58:03.482773Z",
     "start_time": "2025-07-08T20:56:23.016082Z"
    }
   },
   "outputs": [],
   "source": [
    "from evaluate_nosave import eval_models, plot_models\n",
    "# eval_models([tddpm, cddpm, cddpm_have_rho], 10, [False, False, True])\n",
    "plot_models([tddpm, cddpm, cddpm_have_rho, tFM, cFM, cFM_have_rho], 10, [False, False, True, False, False, True], save=True)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d9704e2-735d-452c-9cfd-ed58d3aee170",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-07-17T14:22:27.196392Z",
     "start_time": "2025-07-17T14:22:27.035335Z"
    }
   },
   "outputs": [],
   "source": [
    "%matplotlib notebook\n",
    "%matplotlib widget"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f5a8ef0ff34ebb0",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-07-17T14:22:28.211622Z",
     "start_time": "2025-07-17T14:22:28.202287Z"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from mpl_toolkits.mplot3d import Axes3D\n",
    "\n",
    "\n",
    "data = np.load('./output/test_record.npz')\n",
    "MSE_data = data['MSE_record']\n",
    "SWD_data = data['SWD_record']\n",
    "# model_name = ['tddpm', 'cddpm', 'cddpm_have_rho', 'tFM', 'cFM', 'cFM_have_rho']\n",
    "model_name = ['EI-FM', 'cFM', 'cFM-$\\\\gamma$']\n",
    "\n",
    "rho = None\n",
    "mu1 = 0\n",
    "mu2 = 2.5\n",
    "\n",
    "def plot_3D(input_data, model_name, title, ylabel, rho, save_name=None):\n",
    "    mu1=np.linspace(-2.5,2.5,41)\n",
    "    mu2=np.linspace(-2.5,2.5,41)\n",
    "    rho_axis = np.linspace(-1,1,21)\n",
    "    x, y = np.meshgrid(mu1, mu2)\n",
    "    plotdata = input_data[:, np.where(np.abs(rho_axis-rho)<1e-5)[0][0], :,:]\n",
    "    fig = plt.figure()\n",
    "    ax = fig.add_subplot(111, projection='3d')\n",
    "    for i in range(len(model_name)):\n",
    "        ax.plot_surface(x, y, plotdata[i], color = ['r','g','b'][i], label = model_name[i])\n",
    "    plt.xlabel(r'$\\mu_1$')\n",
    "    plt.ylabel(r'$\\mu_2$')\n",
    "    ax.set_zlabel(\"SWD\", rotation=90)\n",
    "    ax.set_zlim([0,0.6])\n",
    "    plt.legend()\n",
    "    if save_name is not None:\n",
    "        # plt.tight_layout()\n",
    "        plt.savefig('./fig/' + save_name + '.png', bbox_inches='tight', dpi=300,pad_inches=0.21)\n",
    "    # plt.title(title)\n",
    "    plt.show()\n",
    "\n",
    "def plot_data(input_data, model_name, title, ylabel, z, save_name = None):\n",
    "    rho = z[0]\n",
    "    mu1 = z[1]\n",
    "    mu2 = z[2]\n",
    "    rho_axis = np.linspace(-1,1,21)\n",
    "    mu_axis = np.linspace(-2.5,2.5,41)\n",
    "    \n",
    "    if rho is None and mu1 is not None and mu2 is not None:\n",
    "        x = np.linspace(-1,1,21)\n",
    "        plotdata = input_data[:,:,np.where(mu_axis==mu1)[0][0],np.where(mu_axis==mu2)[0][0]]\n",
    "    elif rho is not None and mu1 is None and mu2 is not None:\n",
    "        x = np.linspace(-2.5,2.5,41)\n",
    "        plotdata = input_data[:,np.where(rho_axis==rho)[0][0],:,np.where(mu_axis==mu2)[0][0]]\n",
    "    elif rho is not None and mu1 is not None and mu2 is None:\n",
    "        x = np.linspace(-2.5,2.5,41)\n",
    "        plotdata = input_data[:,np.where(rho_axis==rho)[0][0],np.where(mu_axis==mu1)[0][0],:]\n",
    "    else:\n",
    "        raise ValueError('only 1 None')\n",
    "\n",
    "    for i in range(len(model_name)):\n",
    "        plt.plot(x, plotdata[i], label=model_name[i])\n",
    "    # plt.axvspan(-0.75, -0.25, color='grey', alpha=0.5)\n",
    "    # plt.axvspan(0.25, 0.75, color='grey', alpha=0.5)\n",
    "    plt.grid()\n",
    "    plt.legend()\n",
    "    if rho is None:\n",
    "        plt.xlabel(r'$\\rho$')\n",
    "    elif mu1 is None:\n",
    "        plt.xlabel(r'$\\mu_1$')\n",
    "    elif mu2 is None:\n",
    "        plt.xlabel(r'$\\mu_2$')\n",
    "\n",
    "\n",
    "    plt.ylabel(ylabel)\n",
    "    plt.title(title)\n",
    "    if save_name is not None:\n",
    "        plt.savefig('./fig/' + save_name + '.png', bbox_inches='tight', dpi=300)\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "z = (None, 1.5, -1.5)\n",
    "# plot_data(MSE_data, model_name, title='MSE_record', ylabel=\"MSE\", z=z, save_name='MSE_record')\n",
    "# plot_data(SWD_data, model_name, title='SWD_record', ylabel=\"SWD\", z=z, save_name='SWD_record')\n",
    "\n",
    "# plot_3D(MSE_data, model_name, title='MSE_record', ylabel=\"MSE\", rho=0, save_name='MSE_record')\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d9a72cad16fab42",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-07-17T14:22:31.058946Z",
     "start_time": "2025-07-17T14:22:30.451927Z"
    }
   },
   "outputs": [],
   "source": [
    "rho_list=[-0.9,-0.5,0.0,0.5,0.9]\n",
    "for rho in rho_list:\n",
    "    rhoname = str(rho).replace('-','n').replace('.','p')\n",
    "    plot_3D(SWD_data, model_name, title='SWD_record_'+str(rho), ylabel=\"SWD\", rho=rho, save_name='SWD_record'+rhoname)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
