{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-05-09T17:06:08.220474Z",
          "start_time": "2025-05-09T17:06:04.615356Z"
        }
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "import torch\n",
        "import matplotlib.pyplot as plt\n",
        "import os\n",
        "\n",
        "\n",
        "\n",
        "import FMfuncs\n",
        "import configs\n",
        "import plot_func\n",
        "# import FM2\n",
        "\n",
        "import PGFMv2 as PGFM\n",
        "# from torchdyn.core import NeuralODE\n",
        "# from torchcfm.utils import torch_wrapper\n",
        "# from torchcfm.utils import plot_trajectories\n",
        "import time\n",
        "os.environ[\"KMP_DUPLICATE_LIB_OK\"]=\"TRUE\"\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-05-09T17:06:08.468836Z",
          "start_time": "2025-05-09T17:06:08.465477Z"
        }
      },
      "outputs": [],
      "source": [
        "2"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-05-09T17:06:08.774982Z",
          "start_time": "2025-05-09T17:06:08.537130Z"
        }
      },
      "outputs": [],
      "source": [
        "FMclass = FMfuncs.OTFlowMatching()\n",
        "PGFMclass = PGFM.PGFM()\n",
        "# FM2class = FM2.FM2()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-05-09T17:06:10.366323Z",
          "start_time": "2025-05-09T17:06:08.792125Z"
        }
      },
      "outputs": [],
      "source": [
        "if configs.d_model == 8:\n",
        "    dataset = np.load(r'./data/MDMl2ball_dim8.npy')\n",
        "elif configs.d_model == 20:\n",
        "    dataset = np.load(r'./data/MDMl2ball_dim20.npy')\n",
        "else:\n",
        "    raise ValueError\n",
        "\n",
        "dataset = torch.tensor(dataset,dtype=torch.float32, device=configs.device)\n",
        "x_mat = PGFMclass.get_samples(dataset, 10000)\n",
        "plot_func.plot_scatter_with_info(x_mat.cpu().numpy().transpose(), 0, 'unconstrained2',numItermax = 200000, distance = False)\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-05-09T17:10:12.996579Z",
          "start_time": "2025-05-09T17:06:15.713557Z"
        }
      },
      "outputs": [],
      "source": [
        "trained_model = FMclass.train(dataset, reflect=True)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-04-15T22:05:08.022738Z",
          "start_time": "2025-04-15T21:54:37.566300Z"
        }
      },
      "outputs": [],
      "source": [
        "if configs.d_model == 8:\n",
        "    ckpt_path = './saved_model/FM_l2ball_dim8_iter_1000000.pth'\n",
        "elif configs.d_model == 20:\n",
        "    ckpt_path = './saved_model/FM_l2ball_dim20_iter_1000000.pth'\n",
        "else:\n",
        "    raise ValueError\n",
        "\n",
        "trained_model = PGFMclass.train2_2stage(dataset, ckpt_path)\n",
        "# trained_model = FM2class.train2_2stage(dataset, ckpt_path)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "stage1_t_list = [0.0,0.2,0.4,0.6,0.8]\n",
        "stage1_steps_list = [1, 15, 30, 45, 60]\n",
        "stage2_steps_list = [75, 60, 45, 30, 15]\n",
        "lambda_list = [2, 5, 10, 20, 30]\n",
        "if configs.d_model == 8:\n",
        "    ckpt_path = './saved_model/FM_l2ball_dim8_iter_1000000.pth'\n",
        "elif configs.d_model == 20:\n",
        "    ckpt_path = './saved_model/FM_l2ball_dim20_iter_1000000.pth'\n",
        "\n",
        "for i in range(len(stage1_t_list)):\n",
        "    for j in range(len(lambda_list)):\n",
        "        stage1_t = stage1_t_list[i]\n",
        "        stage2_steps = stage2_steps_list[i]\n",
        "        stage1_steps = stage1_steps_list[i]\n",
        "        lambda_val = lambda_list[j]\n",
        "        PGFMclass = PGFM.PGFM(sig_min=0, stage1_t=stage1_t,\n",
        "                         RL_Steps_S=stage2_steps, d_model=configs.d_model\n",
        "                         , device=configs.device, default_generation_step= stage1_steps,\n",
        "                         constraint_reward=lambda_val)\n",
        "        save_name = 'PGFM_' + str(stage1_t).replace('.','p') +'_1s'+str(stage1_steps) + '_2s'+ str(stage2_steps) + '_lambda' + str(lambda_val)\n",
        "        trained_model = PGFMclass.train2_2stage_batch(dataset, ckpt_path,\n",
        "                                                    save_name, epoches=40000)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "stage1_t_list = [0.0,0.2,0.4,0.6,0.8]\n",
        "stage1_steps_list = [1, 15, 30, 45, 60]\n",
        "stage2_steps_list = [75, 60, 45, 30, 15]\n",
        "lambda_list = [2, 5, 10, 20, 30]\n",
        "if configs.d_model == 8:\n",
        "    ckpt_path = './saved_model/FM_l2ball_dim8_iter_1000000.pth'\n",
        "elif configs.d_model == 20:\n",
        "    ckpt_path = './saved_model/FM_l2ball_dim20_iter_1000000.pth'\n",
        "\n",
        "PGFMnum_out_mat = np.zeros([len(stage1_t_list), len(lambda_list)])\n",
        "PGFMprob_mat = np.zeros([len(stage1_t_list), len(lambda_list)])\n",
        "PGFMSWD_mat = np.zeros([len(stage1_t_list), len(lambda_list)])\n",
        "\n",
        "PGFMnum_out_std_mat = np.zeros([len(stage1_t_list), len(lambda_list)])\n",
        "PGFMprob_std_mat = np.zeros([len(stage1_t_list), len(lambda_list)])\n",
        "PGFMSWD_std_mat = np.zeros([len(stage1_t_list), len(lambda_list)])\n",
        "\n",
        "if configs.d_model == 8:\n",
        "    ckpt1 = torch.load('./saved_model/FM_l2ball_dim8_iter_1000000.pth', map_location=configs.device, weights_only=True)\n",
        "elif configs.d_model == 20:\n",
        "    ckpt1 = torch.load('./saved_model/FM_l2ball_dim20_iter_1000000.pth', map_location=configs.device, weights_only=True)\n",
        "\n",
        "stage1model = PGFMclass.get_untrained_model()\n",
        "stage1model.load_state_dict(ckpt1)\n",
        "\n",
        "\n",
        "\n",
        "for i in range(len(stage1_t_list)):\n",
        "    for j in range(len(lambda_list)):\n",
        "        stage1_t = stage1_t_list[i]\n",
        "        stage2_steps = stage2_steps_list[i]\n",
        "        stage1_steps = stage1_steps_list[i]\n",
        "        lambda_val = lambda_list[j]\n",
        "        PGFMclass = PGFM.PGFM(sig_min=0, stage1_t=stage1_t,\n",
        "                         RL_Steps_S=stage2_steps, d_model=configs.d_model\n",
        "                         , device=configs.device, default_generation_step= stage1_steps,\n",
        "                         constraint_reward=lambda_val)\n",
        "        save_name = 'PGFM_' + str(stage1_t).replace('.','p') +'_1s'+str(stage1_steps) + '_2s'+ str(stage2_steps) + '_lambda' + str(lambda_val)\n",
        "        print(save_name)\n",
        "        ckpt2 = torch.load('./saved_model_2/'+save_name+'.pth', map_location=configs.device, weights_only=True)\n",
        "        stage2model = PGFMclass.policy\n",
        "        stage2model.load_state_dict(ckpt2)\n",
        "\n",
        "        PGFMnum_out_record = np.array([])\n",
        "        PGFMprob_record = np.array([])\n",
        "        PGFMSWD_record = np.array([])\n",
        "\n",
        "        for k in range(100):\n",
        "            reference = PGFMclass.get_samples(dataset, 10000).cpu().numpy()\n",
        "            # resPGFM = PGFMclass.PGFMsample(stage1model, stage2model, 10000)\n",
        "        \n",
        "            resPGFM = PGFMclass.PGFMsample_train2(stage1model, stage2model, 10000)\n",
        "            PGFMnum_out, PGFMprob, PGFMSWD = plot_func.only_info(resPGFM.cpu().numpy().transpose(), reference)\n",
        "            \n",
        "            PGFMnum_out_record = np.append(PGFMnum_out_record, PGFMnum_out)\n",
        "            PGFMprob_record = np.append(PGFMprob_record, PGFMprob)\n",
        "            PGFMSWD_record = np.append(PGFMSWD_record, PGFMSWD)\n",
        "\n",
        "        PGFMnum_out_mat[i,j] = np.mean(PGFMnum_out_record)\n",
        "        PGFMnum_out_std_mat[i,j] = np.std(PGFMnum_out_record)\n",
        "        PGFMprob_mat[i,j] = np.mean(PGFMprob_record)\n",
        "        PGFMprob_std_mat[i,j] = np.std(PGFMprob_record)\n",
        "        PGFMSWD_mat[i,j] = np.mean(PGFMSWD_record)\n",
        "        PGFMSWD_std_mat[i,j] = np.std(PGFMSWD_record)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "stage1_t_list = [0.0,0.2,0.4,0.6,0.8]\n",
        "stage1_steps_list = [1, 15, 30, 45, 60]\n",
        "stage2_steps_list = [75, 60, 45, 30, 15]\n",
        "lambda_list = [2, 5, 10, 20, 30]\n",
        "PGFMprob_mat = np.array([[0.014286, 0.007084, 0.004498, 0.003909, 0.001362],\n",
        "       [0.010152, 0.007765, 0.005415, 0.002874, 0.002453],\n",
        "       [0.011823, 0.006616, 0.004146, 0.002457, 0.002099],\n",
        "       [0.013152, 0.008599, 0.00463 , 0.002513, 0.002408],\n",
        "       [0.013447, 0.008579, 0.005559, 0.002513, 0.001926]])\n",
        "PGFMSWD_mat = np.array([[0.00875843, 0.00980626, 0.01302159, 0.01279321, 0.01661936],\n",
        "       [0.00920035, 0.01024121, 0.01225555, 0.01613421, 0.01548791],\n",
        "       [0.00924858, 0.00983376, 0.01085441, 0.0143669 , 0.01424318],\n",
        "       [0.00835777, 0.00936468, 0.01087226, 0.01205686, 0.01297447],\n",
        "       [0.0090576 , 0.0098378 , 0.01095702, 0.01324996, 0.01324961]])\n",
        "\n",
        "DDFMprob_mat = np.array([0.004158,0.002158,0.001365, 0.000767, 0.000675])\n",
        "DDFMSWD_mat = np.array([0.0075,0.0080,0.0084, 0.0097, 0.0100])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "for i in range(PGFMprob_mat.shape[0]):\n",
        "    plt.plot(PGFMSWD_mat[i], PGFMprob_mat[i], label = \"FM-PG, $t_0=$\"+str(stage1_t_list[i]), marker = '.')\n",
        "\n",
        "plt.plot(DDFMSWD_mat, DDFMprob_mat, label = \"FM-DD\", marker = '.')\n",
        "\n",
        "plt.scatter([0.0087], [0.0908], label = 'FM')\n",
        "# plt.scatter([0.0086], [0.000502], label = 'FM-DD')\n",
        "plt.legend()\n",
        "plt.xlabel(\"SWD\",fontsize=15)\n",
        "plt.ylabel(\"$\\\\mathbb{P}(X_1\\\\notin C)$\",fontsize=15)\n",
        "plt.semilogy()\n",
        "plt.grid()\n",
        "# plt.title(\"$\\\\mathbb{P}(\\\\hat{x}_1\\\\notin C)$ vs. SWD for different $t_0$ by sweeping $\\\\lambda$\")\n",
        "plt.xticks(fontsize=15) \n",
        "plt.yticks(fontsize=15)\n",
        "plt.savefig('./fig/d20compare.png',dpi=300 ,bbox_inches='tight')\n",
        "plt.show()\n",
        "    "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-05-06T19:27:28.010388Z",
          "start_time": "2025-05-06T19:27:26.471727Z"
        }
      },
      "outputs": [],
      "source": [
        "if configs.d_model == 8:\n",
        "    ckpt_path = './saved_model/FM_l2ball_dim8_iter_1000000.pth'\n",
        "elif configs.d_model == 20:\n",
        "    ckpt_path = './saved_model/FM_l2ball_dim20_iter_100000.pth'\n",
        "else:\n",
        "    raise ValueError\n",
        "\n",
        "trained_model = PGFMclass.train2_2stage_w_diff_distance(dataset, ckpt_path)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-05-09T17:11:28.304903Z",
          "start_time": "2025-05-09T17:11:28.296191Z"
        }
      },
      "outputs": [],
      "source": [
        "if configs.d_model == 8:\n",
        "    ckpt1 = torch.load('./saved_model/FM_l2ball_dim8_iter_1000000.pth', map_location=configs.device, weights_only=True)\n",
        "    # ckpt1 = torch.load('./saved_model/RFM_l2ball_dim8_iter_200000.pth', map_location=configs.device, weights_only=True)\n",
        "    # ckpt2 = torch.load('./saved_model/RFM_l2ball_dim8_iter_100000.pth', map_location=configs.device, weights_only=True)\n",
        "    ckpt2 = torch.load('./saved_model/Apr17FM_l2balls2_dim8_iter_diffd_10000.pth', map_location=configs.device, weights_only=True)#FM2_l2balls2_dim8_iter_10000 FM2t0_l2balls2_dim8_iter_6000\n",
        "elif configs.d_model == 20:\n",
        "    ckpt1 = torch.load('./saved_model/FM_l2ball_dim20_iter_1000000.pth', map_location=configs.device, weights_only=True)\n",
        "    # ckpt1 = torch.load('./saved_model/RFM_l2ball_dim20_iter_200000.pth', map_location=configs.device, weights_only=True)\n",
        "    # ckpt2 = torch.load('./saved_model/DDFM_l2balls2_20_dim20_iter_diffd_5000.pth', map_location=configs.device, weights_only=True)#Apr17FM2_l2balls2_dim20_iter_40000\n",
        "    ckpt2 = torch.load('./saved_model/Apr17FM_l2balls2_dim20_iter_diffd_10000.pth', map_location=configs.device, weights_only=True)\n",
        "    # ckpt2 = torch.load('./saved_model/Apr8FM_l2balls2_dim20_iter_40000.pth', map_location=configs.device, weights_only=True)\n",
        "    \n",
        "\n",
        "else:\n",
        "    raise ValueError"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-05-09T17:11:31.461943Z",
          "start_time": "2025-05-09T17:11:29.810123Z"
        }
      },
      "outputs": [],
      "source": [
        "# PGFMclass = FM2class\n",
        "\n",
        "stage1model = PGFMclass.get_untrained_model()\n",
        "stage1model.load_state_dict(ckpt1)\n",
        "\n",
        "stage2model = PGFMclass.policy\n",
        "# stage2model = PGFMclass.get_untrained_model()\n",
        "stage2model.load_state_dict(ckpt2)\n",
        "\n",
        "res = PGFMclass.PGFMsample_train2(stage1model, stage2model, 10000)\n",
        "\n",
        "\n",
        "reference = PGFMclass.get_samples(dataset, 10000).cpu().numpy()\n",
        "plot_func.plot_scatter_with_info(res.cpu().numpy().transpose(), reference, 'l2d8PGFM',numItermax = 200000, distance = True)\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-05-09T17:11:34.232747Z",
          "start_time": "2025-05-09T17:11:32.445787Z"
        }
      },
      "outputs": [],
      "source": [
        "stage1model = FMclass.get_untrained_model()\n",
        "stage1model.load_state_dict(ckpt1)\n",
        "\n",
        "\n",
        "res = FMfuncs.sampler(stage1model, 10000, stoptime=1, reflect=False)\n",
        "reference = FMclass.get_samples(dataset, 10000).cpu().numpy()\n",
        "plot_func.plot_scatter_with_info(res.cpu().numpy().transpose(), reference, 'l2d20FM',numItermax = 200000, distance = True)\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-05-09T17:12:12.232173Z",
          "start_time": "2025-05-09T17:11:44.522646Z"
        }
      },
      "outputs": [],
      "source": [
        "# PGFMclass = FM2class\n",
        "\n",
        "FMnum_out_record = np.array([])\n",
        "FMprob_record = np.array([])\n",
        "FMSWD_record = np.array([])\n",
        "PGFMnum_out_record = np.array([])\n",
        "PGFMprob_record = np.array([])\n",
        "PGFMSWD_record = np.array([])\n",
        "\n",
        "for i in range(100):\n",
        "    if i%10 == 0:\n",
        "        print(i)\n",
        "    reference = PGFMclass.get_samples(dataset, 10000).cpu().numpy()\n",
        "    resFM = FMfuncs.sampler(stage1model, 10000, stoptime=1, reflect=False)\n",
        "    # resPGFM = PGFMclass.PGFMsample(stage1model, stage2model, 10000)\n",
        "\n",
        "    resPGFM = PGFMclass.PGFMsample_train2(stage1model, stage2model, 10000)\n",
        "    FMnum_out, FMprob, FMSWD = plot_func.only_info(resFM.cpu().numpy().transpose(), reference)\n",
        "    PGFMnum_out, PGFMprob, PGFMSWD = plot_func.only_info(resPGFM.cpu().numpy().transpose(), reference)\n",
        "    \n",
        "    FMnum_out_record = np.append(FMnum_out_record, FMnum_out)\n",
        "    FMprob_record = np.append(FMprob_record, FMprob)\n",
        "    FMSWD_record = np.append(FMSWD_record, FMSWD)\n",
        "    PGFMnum_out_record = np.append(PGFMnum_out_record, PGFMnum_out)\n",
        "    PGFMprob_record = np.append(PGFMprob_record, PGFMprob)\n",
        "    PGFMSWD_record = np.append(PGFMSWD_record, PGFMSWD)\n",
        "    \n",
        "\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "print(\" FM SWD: \", f\"{np.mean( FMSWD_record):.4f}\", r\"\\pm\", f\"{np.std( FMSWD_record):.4f}\")\n",
        "print(\" FM num out: \", f\"{np.mean( FMnum_out_record):.4f}\", r\"\\pm\", f\"{np.std( FMnum_out_record):.4f}\")\n",
        "\n",
        "print(\"PGFM SWD: \", f\"{np.mean(PGFMSWD_record):.4f}\", r\"\\pm\", f\"{np.std(PGFMSWD_record):.4f}\")\n",
        "print(\"PGFM num out: \", f\"{np.mean(PGFMnum_out_record):.4f}\", r\"\\pm\", f\"{np.std(PGFMnum_out_record):.4f}\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "np.savez( './result_record/' +'train2_l2d20_record.npz', FMnum_out_record=FMnum_out_record, FMprob_record=FMprob_record,\n",
        "        FMSWD_record=FMSWD_record,\n",
        "        PGFMnum_out_record=PGFMnum_out_record,\n",
        "        PGFMprob_record=PGFMprob_record,\n",
        "        PGFMSWD_record=PGFMSWD_record)"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3 (ipykernel)",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.12.2"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 4
}