{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-05-09T17:41:23.826766Z",
          "start_time": "2025-05-09T17:41:20.456596Z"
        },
        "pycharm": {
          "name": "#%%\n"
        }
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "import torch\n",
        "import matplotlib.pyplot as plt\n",
        "import os\n",
        "import FMfuncs\n",
        "import configs\n",
        "import plot_func\n",
        "\n",
        "import PGFMv2 as PGFM\n",
        "\n",
        "\n",
        "os.environ[\"KMP_DUPLICATE_LIB_OK\"]=\"TRUE\"\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-05-09T17:41:24.344541Z",
          "start_time": "2025-05-09T17:41:24.159817Z"
        }
      },
      "outputs": [],
      "source": [
        "\n",
        "PGFMclass = PGFM.PGFM()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-05-09T17:41:25.955391Z",
          "start_time": "2025-05-09T17:41:24.419892Z"
        }
      },
      "outputs": [],
      "source": [
        "if configs.dataset == \"two_uniform_box\":\n",
        "    dataset = np.load(r'./data/uniform2.npy') #cropped_Gaussian uniform uniform2\n",
        "elif configs.dataset == \"uniform_box\":\n",
        "    dataset = np.load(r'./data/uniform.npy')\n",
        "elif configs.dataset == \"cropped_Gaussian\":\n",
        "    dataset = np.load(r'./data/cropped_Gaussian.npy')\n",
        "else:\n",
        "    raise ValueError\n",
        "\n",
        "dataset = torch.tensor(dataset,dtype=torch.float32, device=configs.device)\n",
        "x_mat = PGFMclass.get_samples(dataset, 10000)\n",
        "plot_func.plot_scatter_with_info(x_mat.cpu().numpy().transpose(), 0, 'unconstrained2', fix_bound = False,numItermax = 200000, distance = False)\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-05-09T17:41:43.918866Z",
          "start_time": "2025-05-09T17:41:28.883713Z"
        }
      },
      "outputs": [],
      "source": [
        "if configs.dataset == \"two_uniform_box\":\n",
        "    path = './saved_model/FM_2uniform_1000000.pth' #cropped_Gaussian uniform uniform2\n",
        "elif configs.dataset == \"uniform_box\":\n",
        "    path = './saved_model/FM_uniform2_1000000.pth'\n",
        "elif configs.dataset == \"cropped_Gaussian\":\n",
        "    path = './saved_model/FM_test2_1000000.pth'\n",
        "else:\n",
        "    raise ValueError\n",
        "\n",
        "# trained_model = PGFMclass.train2_2stage(dataset, path)\n",
        "trained_model = PGFMclass.train2_2stage_distance(dataset, path)\n",
        "#FM_test2_1000000: CG\n",
        "#FM_2uniform_1000000: 2uniform\n",
        "#FM_uniform2_1000000: uniform"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-05-09T17:41:45.019161Z",
          "start_time": "2025-05-09T17:41:45.010676Z"
        }
      },
      "outputs": [],
      "source": [
        "if configs.dataset == \"two_uniform_box\":\n",
        "    ckpt1 = torch.load('./saved_model/FM_2uniform_1000000.pth', map_location=configs.device, weights_only= True)\n",
        "    # ckpt2 = torch.load('./saved_model/PGFM_2uniform_20000.pth', map_location=configs.device, weights_only= True)\n",
        "    ckpt2 = torch.load('./saved_model/distanceFM_2uniform_5000.pth', map_location=configs.device, weights_only= True)\n",
        "    #PGFM_2uniform_Mar27\n",
        "    #PGFM_2uniform_only_terminal\n",
        "elif configs.dataset == \"uniform_box\":\n",
        "    ckpt1 = torch.load('./saved_model/FM_uniform2_1000000.pth', map_location=configs.device, weights_only= True)\n",
        "    ckpt2 = torch.load('./saved_model/PGFM_uniform_20000.pth', map_location=configs.device, weights_only= True)\n",
        "elif configs.dataset == \"cropped_Gaussian\":\n",
        "    ckpt1 = torch.load('./saved_model/FM_test2_1000000.pth', map_location=configs.device, weights_only= True)\n",
        "    # ckpt2 = torch.load('./saved_model/PGFM_CG_20000.pth', map_location=configs.device, weights_only= True)\n",
        "    ckpt2 = torch.load('./saved_model/distanceFM_CG_5000.pth', map_location=configs.device, weights_only= True)\n",
        "    # PGFM_CG_Mar27\n",
        "else:\n",
        "    raise ValueError\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-05-09T17:41:47.716989Z",
          "start_time": "2025-05-09T17:41:46.328801Z"
        }
      },
      "outputs": [],
      "source": [
        "\n",
        "stage1model = PGFMclass.get_untrained_model()\n",
        "stage1model.load_state_dict(ckpt1)\n",
        "\n",
        "PGFMclass.policy.load_state_dict(ckpt2)\n",
        "stage2model = PGFMclass.policy\n",
        "\n",
        "res = PGFMclass.PGFMsample(stage1model, stage2model, 10000)\n",
        "\n",
        "\n",
        "\n",
        "reference = PGFMclass.get_samples(dataset, 10000).cpu().numpy()\n",
        "plot_func.plot_scatter_with_info(res.cpu().numpy().transpose(), reference, 'Mar27_2uniform_CG', fix_bound = False,numItermax = 200000, distance = True)\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-04-29T14:43:05.871825Z",
          "start_time": "2025-04-29T14:43:05.712125Z"
        }
      },
      "outputs": [],
      "source": [
        "res = PGFMclass.PGFMsample(stage1model, stage2model, 50000).cpu().numpy()\n",
        "# res = PGFMclass.get_samples(dataset, 50000).cpu().numpy()\n",
        "# res = FMfuncs.sampler(stage1model, 50000, stoptime=1).cpu().numpy()\n",
        "xy = res\n",
        "outliner_mat = np.logical_or(np.abs(np.abs(xy[:, 0]) - configs.uniform_center) >= configs.bound,\n",
        "                             np.abs(np.abs(xy[:, 1]) - configs.uniform_center) >= configs.bound)\n",
        "xy_out = xy[outliner_mat]\n",
        "\n",
        "square = plt.Polygon(\n",
        "    [[- configs.bound + configs.uniform_center, - configs.bound + configs.uniform_center],\n",
        "     [configs.bound + configs.uniform_center, - configs.bound + configs.uniform_center],\n",
        "     [configs.bound + configs.uniform_center, configs.bound + configs.uniform_center],\n",
        "     [- configs.bound + configs.uniform_center, configs.bound + configs.uniform_center]],\n",
        "    closed=True, edgecolor='black', fill=False, linewidth=1.2\n",
        ")\n",
        "\n",
        "square1 = plt.Polygon(\n",
        "    [[- configs.bound - configs.uniform_center, - configs.bound - configs.uniform_center],\n",
        "     [configs.bound - configs.uniform_center, - configs.bound - configs.uniform_center],\n",
        "     [configs.bound - configs.uniform_center, configs.bound - configs.uniform_center],\n",
        "     [- configs.bound - configs.uniform_center, configs.bound - configs.uniform_center]],\n",
        "    closed=True, edgecolor='black', fill=False, linewidth=1.2\n",
        ")\n",
        "fig, ax = plt.subplots()\n",
        "H, xedges, yedges = np.histogram2d(xy[:, 0], xy[:, 1], bins=100)\n",
        "\n",
        "H_transformed = H**0.2  # sqrt intensifies visual difference (log-like)\n",
        "H_masked = np.ma.masked_where(H == 0, H_transformed)\n",
        "c = ax.pcolormesh(xedges, yedges, H_masked.T, cmap='Blues')\n",
        "ax.add_patch(square)\n",
        "ax.add_patch(square1)\n",
        "ax.set_aspect('equal')\n",
        "ax.scatter(xy_out[:, 0], xy_out[:, 1], s=15, alpha=1, color='red')\n",
        "# plt.colorbar(scatter, label='Density')\n",
        "\n",
        "# Set axis limits to clearly show the 5x5 area\n",
        "if True == True:\n",
        "    plt.xlim(-configs.bound * 1.2 - configs.uniform_center, configs.bound * 1.2 + configs.uniform_center)\n",
        "    plt.ylim(-configs.bound * 1.2 - configs.uniform_center, configs.bound * 1.2 + configs.uniform_center)\n",
        "\n",
        "# plt.xlabel('$x_1$')\n",
        "# plt.ylabel('$x_2$')\n",
        "# ax.axis('off')\n",
        "ax.tick_params(labelsize=30)\n",
        "ax.set_xticks([-5, 0, 5])\n",
        "ax.set_yticks([-5, 0, 5])\n",
        "plt.savefig('./fig/' + '2uniform_distanceFM' + '.png', bbox_inches='tight', dpi=300)\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-05-09T17:42:24.934651Z",
          "start_time": "2025-05-09T17:42:05.698508Z"
        }
      },
      "outputs": [],
      "source": [
        "\n",
        "\n",
        "FMnum_out_record = np.array([])\n",
        "FMprob_record = np.array([])\n",
        "FMSWD_record = np.array([])\n",
        "PGFMnum_out_record = np.array([])\n",
        "PGFMprob_record = np.array([])\n",
        "PGFMSWD_record = np.array([])\n",
        "\n",
        "for i in range(100):\n",
        "    if i%10 == 0:\n",
        "        print(i)\n",
        "    reference = PGFMclass.get_samples(dataset, 10000).cpu().numpy()\n",
        "    resFM = FMfuncs.sampler(stage1model, 10000, stoptime=1)\n",
        "    resPGFM = PGFMclass.PGFMsample(stage1model, stage2model, 10000)\n",
        "    \n",
        "    FMnum_out, FMprob, FMSWD = plot_func.only_info(resFM.cpu().numpy().transpose(), reference)\n",
        "    PGFMnum_out, PGFMprob, PGFMSWD = plot_func.only_info(resPGFM.cpu().numpy().transpose(), reference)\n",
        "    \n",
        "    FMnum_out_record = np.append(FMnum_out_record, FMnum_out)\n",
        "    FMprob_record = np.append(FMprob_record, FMprob)\n",
        "    FMSWD_record = np.append(FMSWD_record, FMSWD)\n",
        "    PGFMnum_out_record = np.append(PGFMnum_out_record, PGFMnum_out)\n",
        "    PGFMprob_record = np.append(PGFMprob_record, PGFMprob)\n",
        "    PGFMSWD_record = np.append(PGFMSWD_record, PGFMSWD)\n",
        "    \n",
        "#     print(FMnum_out, PGFMnum_out)\n",
        "\n",
        "# np.savez( './result_record/' +'uniform2_record.npz', FMnum_out_record=FMnum_out_record, FMprob_record=FMprob_record,\n",
        "#         FMSWD_record=FMSWD_record,\n",
        "#         PGFMnum_out_record=PGFMnum_out_record,\n",
        "#         PGFMprob_record=PGFMprob_record,\n",
        "#         PGFMSWD_record=PGFMSWD_record)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-05-09T17:42:56.089130Z",
          "start_time": "2025-05-09T17:42:56.086270Z"
        }
      },
      "outputs": [],
      "source": [
        "print(\" FM SWD: \", f\"{np.mean( FMSWD_record):.4f}\", r\"\\pm\", f\"{np.std( FMSWD_record):.4f}\")\n",
        "print(\" FM num out: \", f\"{np.mean( FMnum_out_record):.4f}\", r\"\\pm\", f\"{np.std( FMnum_out_record):.4f}\")\n",
        "\n",
        "print(\"PGFM SWD: \", f\"{np.mean(PGFMSWD_record):.4f}\", r\"\\pm\", f\"{np.std(PGFMSWD_record):.4f}\")\n",
        "print(\"PGFM num out: \", f\"{np.mean(PGFMnum_out_record):.4f}\", r\"\\pm\", f\"{np.std(PGFMnum_out_record):.4f}\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "FMclass = FMfuncs.OTFlowMatching()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "trained_model = FMclass.train(dataset)"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3 (ipykernel)",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.12.2"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 4
}