{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "initial_id",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-04-25T21:09:01.415973Z",
          "start_time": "2025-04-25T21:09:01.412976Z"
        },
        "jupyter": {
          "outputs_hidden": true
        }
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "import torchvision\n",
        "import torchvision.transforms as transforms\n",
        "import cv2\n",
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n",
        "import load\n",
        "from compute_thickness import thickness_batch\n",
        "from FM_funcs import OTFlowMatching\n",
        "import configs\n",
        "from PGFMv2_funcs import PGFM\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "68d347db5e476040",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-04-25T21:09:01.498991Z",
          "start_time": "2025-04-25T21:09:01.482373Z"
        }
      },
      "outputs": [],
      "source": [
        "2"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "9b178d3d573b08cf",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-04-25T21:09:01.387081Z",
          "start_time": "2025-04-25T21:09:01.135487Z"
        }
      },
      "outputs": [],
      "source": [
        "mnist_dataset = load.myMMIST()\n",
        "FM_class = OTFlowMatching()\n",
        "PGFM_class = PGFM()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "dfa8dbd3f8f7af2d",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-04-25T19:51:01.713531Z",
          "start_time": "2025-04-25T19:50:57.224343Z"
        }
      },
      "outputs": [],
      "source": [
        "train_set = mnist_dataset.get_data(plot_hist=True).to(configs.device)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "38e76aa42e970d28",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-04-25T21:09:10.100870Z",
          "start_time": "2025-04-25T21:09:06.864013Z"
        }
      },
      "outputs": [],
      "source": [
        "train_set = mnist_dataset.get_data_brightness(plot_hist=True).to(configs.device)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "28df92cc46fad63c",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-03-20T15:15:48.464811Z",
          "start_time": "2025-03-20T14:10:43.372958Z"
        }
      },
      "outputs": [],
      "source": [
        "FM_class.train(train_set)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "2005e85355c48693",
      "metadata": {
        "ExecuteTime": {
          "start_time": "2025-04-25T21:09:18.300223Z"
        },
        "jupyter": {
          "is_executing": true
        }
      },
      "outputs": [],
      "source": [
        "# ckpt1 = torch.load('./saved_model/FM_MNIST_iter_200000.pth', map_location=configs.device, weights_only=True)\n",
        "PGFM_class.train2_2stage(train_set, './saved_model/FM_MNIST_bright_iter_200000.pth', constraint_type='brightness')\n",
        "#FM_MNIST_bright_iter_200000 : for brightness\n",
        "# FM_MNIST_iter_200000_23 : for thickness"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3 (ipykernel)",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.12.2"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}