{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ba2e3d7-8589-4d3b-bacf-597760001563",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from typing import Optional\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.datasets import make_moons\n",
    "from sklearn.cluster import KMeans\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    device = torch.device(\"cuda:1\")\n",
    "elif torch.backends.mps.is_available():\n",
    "    device = torch.device(\"mps\")\n",
    "else:\n",
    "    device = torch.device(\"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d0f118a-eba8-4d4e-a181-154627bbdb87",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ----------------- Reproducibility -----------------\n",
    "seed = 42\n",
    "np.random.seed(seed)\n",
    "torch.manual_seed(seed)\n",
    "\n",
    "def cluster_and_count(X_tar, centers):\n",
    "    \"\"\"\n",
    "    X_tar   : (N,2) \n",
    "    centers : (8,2)  \n",
    "    \"\"\"\n",
    "    kmeans = KMeans(n_clusters=8, init=centers, n_init=1, random_state=0)\n",
    "    labels = kmeans.fit_predict(X_tar)\n",
    "\n",
    "    counts = np.bincount(labels, minlength=8)\n",
    "    total = counts.sum()\n",
    "    ratios = counts / total\n",
    "\n",
    "    for k in range(8):\n",
    "        print(f\"Cluster {k:2d}: {counts[k]:5d} samples  ({ratios[k]*100:5.2f}%)\")\n",
    "    return labels, counts, ratios\n",
    "\n",
    "\n",
    "def make_8gmm_counts(\n",
    "    counts=(800, 700, 600, 500, 400, 300, 200, 100),\n",
    "    radius=3.0,\n",
    "    sigma=0.15,\n",
    "    seed=42,\n",
    "    train_ratio=0.8,\n",
    "    dtype=np.float32,\n",
    "):\n",
    "    \"\"\"\n",
    "    8-Gaussian mixture \n",
    "    -  : (radius)   0, π/4, ..., 7π/4\n",
    "    -    : counts  ( )\n",
    "    - : sigma^2 * I_2 ()\n",
    "    :\n",
    "        X_tar      : (N,2)   (shuffle)\n",
    "        y_modes    : (N,)   (0~7)\n",
    "        Y_train    : (N_tr,2)\n",
    "        Y_test     : (N_te,2)\n",
    "        y_tr, y_te :  split   \n",
    "        centers    : (8,2)   \n",
    "    \"\"\"\n",
    "    rng = np.random.default_rng(seed)\n",
    "\n",
    "    counts = np.asarray(counts, dtype=int)\n",
    "    assert counts.size == 8 and (counts > 0).all(), \"counts  8    .\"\n",
    "\n",
    "    # 8  :  0, π/4, ..., 7π/4\n",
    "    thetas = np.arange(8) * (np.pi / 4.0)\n",
    "    centers = np.stack([radius * np.cos(thetas), radius * np.sin(thetas)], axis=1)  # (8,2)\n",
    "\n",
    "    cov = (sigma ** 2) * np.eye(2)\n",
    "    X_list, y_list = [], []\n",
    "\n",
    "    for k, nk in enumerate(counts):\n",
    "        Xk = rng.multivariate_normal(mean=centers[k], cov=cov, size=nk)\n",
    "        yk = np.full(nk, k, dtype=np.int64)\n",
    "        X_list.append(Xk)\n",
    "        y_list.append(yk)\n",
    "\n",
    "    X = np.vstack(X_list)\n",
    "    y = np.concatenate(y_list)\n",
    "\n",
    "    # shuffle\n",
    "    perm = rng.permutation(X.shape[0])\n",
    "    X = X[perm].astype(dtype, copy=False)\n",
    "    y = y[perm]\n",
    "\n",
    "    # train/test split\n",
    "    n_train = int(len(X) * train_ratio)\n",
    "    Y_train, Y_test = X[:n_train], X[n_train:]\n",
    "    y_tr, y_te = y[:n_train], y[n_train:]\n",
    "\n",
    "    return X, y, Y_train, Y_test, y_tr, y_te, centers.astype(dtype)\n",
    "\n",
    "# ==================   ==================\n",
    "X_tar, y_modes, Y_train, Y_test, y_tr, y_te, centers = make_8gmm_counts(\n",
    "    counts=(800, 700, 600, 500, 400, 300, 200, 100),\n",
    "    radius=3.0,\n",
    "    sigma=0.15,\n",
    "    seed=123,\n",
    "    train_ratio=0.8,\n",
    ")\n",
    "\n",
    "# ----------------- Base distribution: standard normal -----------------\n",
    "def sample_base(n):\n",
    "    return np.random.randn(n, 2).astype(np.float32)\n",
    "\n",
    "# ----------------- Vector field model v_theta(x,t) -----------------\n",
    "class VecField(nn.Module):\n",
    "    def __init__(self, x_dim=2, t_dim=1, width=128):\n",
    "        super().__init__()\n",
    "        self.net = nn.Sequential(\n",
    "            nn.Linear(x_dim + t_dim, width),\n",
    "            nn.SiLU(),\n",
    "            nn.Linear(width, width),\n",
    "            nn.SiLU(),\n",
    "            nn.Linear(width, width),\n",
    "            nn.SiLU(),\n",
    "            nn.Linear(width, width),\n",
    "            nn.SiLU(),\n",
    "            nn.Linear(width, x_dim)\n",
    "        )\n",
    "\n",
    "    def forward(self, x, t):\n",
    "        # x: (B,2), t: (B,1)\n",
    "        return self.net(torch.cat([x, t], dim=-1))\n",
    "\n",
    "vnet = VecField().to(device)\n",
    "\n",
    "# ----------------- Training setup -----------------\n",
    "batch_size = 512\n",
    "epochs = 3000\n",
    "lr = 2e-3\n",
    "opt = optim.Adam(vnet.parameters(), lr=lr)\n",
    "\n",
    "Y_train_t = torch.from_numpy(Y_train).to(device)  # (n_train,2)\n",
    "\n",
    "def train_iter():\n",
    "    # (1) y ~ target (mini-batch)\n",
    "    idx = torch.randint(0, Y_train_t.shape[0], (batch_size,), device=device)\n",
    "    y = Y_train_t[idx]  # (B,2)\n",
    "\n",
    "    # (2) x0 ~ N(0,I)\n",
    "    x0 = torch.randn_like(y)  # (B,2)\n",
    "\n",
    "    # (3) t ~ Uniform(0,1)\n",
    "    t = torch.rand((batch_size, 1), device=device)\n",
    "\n",
    "    # (4) Linear path & target velocity\n",
    "    x_t = (1.0 - t) * x0 + t * y         # (B,2)\n",
    "    u_star = y - x0                      # (B,2), rectified-flow target velocity\n",
    "\n",
    "    # (5) Predict & loss\n",
    "    v_pred = vnet(x_t, t)                # (B,2)\n",
    "    loss = ((v_pred - u_star) ** 2).mean()\n",
    "    return loss\n",
    "\n",
    "# ----------------- Training loop -----------------\n",
    "print(\"Training...\")\n",
    "for ep in range(1, epochs + 1):\n",
    "    loss = train_iter()\n",
    "    opt.zero_grad()\n",
    "    loss.backward()\n",
    "    nn.utils.clip_grad_norm_(vnet.parameters(), 1.0)\n",
    "    opt.step()\n",
    "    if ep % 100 == 0:\n",
    "        print(f\"[{ep:4d}/{epochs}] loss={loss.item():.6f}\")\n",
    "\n",
    "# ----------------- Sampling by Euler integration -----------------\n",
    "@torch.no_grad()\n",
    "def sample_from_model(n_samples=2000, steps=100):\n",
    "    x = torch.randn((n_samples, 2), device=device)  # x(0) ~ N(0,I)\n",
    "    dt = 1.0 / steps\n",
    "    for k in range(steps):\n",
    "        t = torch.full((n_samples, 1), (k + 0.5) / steps, device=device)  # midpoint time\n",
    "        v = vnet(x, t)  # (n,2)\n",
    "        x = x + v * dt  # Euler step\n",
    "    return x.cpu().numpy()\n",
    "\n",
    "X_gen = sample_from_model(n_samples=4000, steps=150)\n",
    "\n",
    "# ----------------- Visualization -----------------\n",
    "fig, axes = plt.subplots(1, 3, figsize=(14, 4.5))\n",
    "\n",
    "# Base (N(0,I))\n",
    "X_base = sample_base(4000)\n",
    "axes[0].scatter(X_base[:,0], X_base[:,1], s=4, alpha=0.5, label=\"Base N(0,I)\")\n",
    "axes[0].set_title(\"Base samples (N(0,I))\")\n",
    "axes[0].axis('equal'); axes[0].legend()\n",
    "\n",
    "# Target (Imbalanced Two-Moons 80/20)\n",
    "axes[1].scatter(Y_test[:,0], Y_test[:,1], s=5, alpha=0.8, color='tab:green', label=\"Target (test)\")\n",
    "axes[1].set_title(\"Target: Imbalanced 8- GMM\")\n",
    "axes[1].axis('equal'); axes[1].legend()\n",
    "\n",
    "# Generated\n",
    "axes[2].scatter(X_gen[:,0], X_gen[:,1], s=4, alpha=0.6, color='tab:blue', label=\"Generated (FM)\")\n",
    "axes[2].set_title(\"Generated by Flow Matching\")\n",
    "axes[2].axis('equal'); axes[2].legend()\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "labels, counts, ratios = cluster_and_count(X_gen, centers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b9f0a10-e486-4a2f-bd9c-c6ba2811faee",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ===  import ===\n",
    "import ot  # POT (Python Optimal Transport)\n",
    "\n",
    "# ===  OT x0 ↔ y   y_match  ( W2) ===\n",
    "@torch.no_grad()\n",
    "def ot_pairing_rowwise(x0: torch.Tensor, y: torch.Tensor):\n",
    "    \"\"\"\n",
    "    x0: (B,2)  base samples\n",
    "    y : (B,2)  target samples (   )\n",
    "    : y_match (B,2)  --  x0 (row)   1  \n",
    "    \"\"\"\n",
    "    B = x0.shape[0]\n",
    "\n",
    "    # --- :  ''  -> W2 ( EMD)\n",
    "    C2 = torch.cdist(x0, y, p=2).pow(2).cpu().numpy()   # (B,B)\n",
    "\n",
    "    #  \n",
    "    a = ot.unif(B)\n",
    "    b = ot.unif(B)\n",
    "\n",
    "    #  ( EMD; entropic regularization )\n",
    "    # : ot.emd  C   C2  W2^2  .\n",
    "    P = ot.emd(a, b, C2)   # (B,B), optimal transport plan for squared cost\n",
    "\n",
    "    # ()  W2^2  :\n",
    "    # w2_squared = float((P * C2).sum())\n",
    "\n",
    "    #  row   → row-wise categorical \n",
    "    row = P / (P.sum(axis=1, keepdims=True) + 1e-12)\n",
    "    tgt_idx = np.array([np.random.choice(B, p=row[i]) for i in range(B)], dtype=np.int64)\n",
    "\n",
    "    # Torch   y \n",
    "    y_match = y[torch.from_numpy(tgt_idx).to(y.device)]\n",
    "    return y_match\n",
    "    \n",
    "# ----------------- Training setup () -----------------\n",
    "batch_size = 512\n",
    "epochs = 2000\n",
    "lr = 2e-3\n",
    "opt = optim.Adam(vnet.parameters(), lr=lr)\n",
    "\n",
    "Y_train_t = torch.from_numpy(Y_train).to(device)  # (n_train,2)\n",
    "\n",
    "# ----------------- train_iter :  OT   -----------------\n",
    "def train_iter():\n",
    "    # (1) y ~ target (mini-batch)\n",
    "    idx = torch.randint(0, Y_train_t.shape[0], (batch_size,), device=device)\n",
    "    y = Y_train_t[idx]  # (B,2)\n",
    "\n",
    "    # (2) x0 ~ N(0,I)\n",
    "    x0 = torch.randn_like(y)  # (B,2)\n",
    "\n",
    "    # (3)  OT x0 ↔ y  → y_match\n",
    "    y_match = ot_pairing_rowwise(x0, y)  # (B,2)\n",
    "\n",
    "    # (4) t ~ Uniform(0,1)\n",
    "    t = torch.rand((batch_size, 1), device=device)\n",
    "\n",
    "    # (5)      (Rectified Flow with OT coupling)\n",
    "    x_t   = (1.0 - t) * x0 + t * y_match\n",
    "    u_star = y_match - x0\n",
    "\n",
    "    # (6)  & \n",
    "    v_pred = vnet(x_t, t)\n",
    "    loss = ((v_pred - u_star) ** 2).mean()\n",
    "    return loss\n",
    "\n",
    "# -----------------   //  -----------------\n",
    "print(\"Training (with minibatch OT coupling)...\")\n",
    "for ep in range(1, epochs + 1):\n",
    "    loss = train_iter()\n",
    "    opt.zero_grad()\n",
    "    loss.backward()\n",
    "    nn.utils.clip_grad_norm_(vnet.parameters(), 1.0)\n",
    "    opt.step()\n",
    "    if ep % 100 == 0:\n",
    "        print(f\"[{ep:4d}/{epochs}] loss={loss.item():.6f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b8c7f73-7dea-49c2-ace6-2840cdecfc38",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ----------------- Sampling by Euler integration -----------------\n",
    "@torch.no_grad()\n",
    "def sample_from_model(n_samples=2000, steps=100):\n",
    "    x = torch.randn((n_samples, 2), device=device)  # x(0) ~ N(0,I)\n",
    "    dt = 1.0 / steps\n",
    "    for k in range(steps):\n",
    "        t = torch.full((n_samples, 1), (k + 0.5) / steps, device=device)  # midpoint time\n",
    "        v = vnet(x, t)  # (n,2)\n",
    "        x = x + v * dt  # Euler step\n",
    "    return x.cpu().numpy()\n",
    "\n",
    "X_gen = sample_from_model(n_samples=4000, steps=150)\n",
    "\n",
    "# ----------------- Visualization -----------------\n",
    "fig, axes = plt.subplots(1, 3, figsize=(14, 4.5))\n",
    "\n",
    "# Base (N(0,I))\n",
    "X_base = sample_base(4000)\n",
    "axes[0].scatter(X_base[:,0], X_base[:,1], s=4, alpha=0.5, label=\"Base N(0,I)\")\n",
    "axes[0].set_title(\"Base samples (N(0,I))\")\n",
    "axes[0].axis('equal'); axes[0].legend()\n",
    "\n",
    "# Target (Imbalanced Two-Moons 90/10)\n",
    "axes[1].scatter(Y_test[:,0], Y_test[:,1], s=5, alpha=0.8, color='tab:green', label=\"Target (test)\")\n",
    "axes[1].set_title(\"Target: Imbalanced 8- GMM\")\n",
    "axes[1].axis('equal'); axes[1].legend()\n",
    "\n",
    "# Generated\n",
    "axes[2].scatter(X_gen[:,0], X_gen[:,1], s=4, alpha=0.6, color='tab:blue', label=\"Generated (OT-CFM)\")\n",
    "axes[2].set_title(\"Generated by OT Flow Matching\")\n",
    "axes[2].axis('equal'); axes[2].legend()\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "labels, counts, ratios = cluster_and_count(X_gen, centers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a69eca0-eeeb-430c-9401-07aeb9f45cf8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# =========================\n",
    "# UOT  +    \n",
    "# =========================\n",
    "import numpy as np\n",
    "import torch\n",
    "import ot  # POT\n",
    "\n",
    "# ---  ---\n",
    "uot_reg   = 0.05           # entropic reg (ε)\n",
    "tau_1     = float(\"inf\")   # source marginal penalty\n",
    "tau_2     = 5.0            # target marginal penalty\n",
    "alpha     = 1            # ^alpha (alpha=1: , 1.5~2:   )\n",
    "eps_marg  = 1e-12          # marginal \n",
    "cap_w     = 50.0          #  ( ); None \n",
    "reweight_mode = \"both\"     # \"none\" | \"col\" | \"loss\" | \"both\"\n",
    "\n",
    "batch_size = 512\n",
    "epochs     = 1000\n",
    "lr         = 2e-3\n",
    "\n",
    "def _finite_tau(t, big=1e6):\n",
    "    return big if (t is None or not np.isfinite(t)) else float(t)\n",
    "\n",
    "safe_tau1 = _finite_tau(tau_1)\n",
    "safe_tau2 = _finite_tau(tau_2)\n",
    "\n",
    "@torch.no_grad()\n",
    "def uot_pairing_rowwise(\n",
    "    x0: torch.Tensor,\n",
    "    y: torch.Tensor,\n",
    "    alpha: float = 1.0,\n",
    "    eps_marg: float = 1e-12,\n",
    "    cap_w: Optional[float] = 50.0,\n",
    "    reweight_mode: str = \"col\",\n",
    "):\n",
    "    \"\"\"\n",
    "    x0: (B,2)  base\n",
    "    y : (B,2)  target\n",
    "    returns:\n",
    "        y_match : (B,2)         — () 1   \n",
    "        tgt_idx : (B,)          —   \n",
    "        w_col   : (B,) np.float64 —  ' ' - \n",
    "        P       : (B,B) np.float64 — UOT  (: )\n",
    "        targ_m  : (B,) np.float64 —  ( )\n",
    "    \"\"\"\n",
    "    B = x0.shape[0]\n",
    "\n",
    "    # :  (Flow Matching  )\n",
    "    C = torch.cdist(x0, y, p=2).detach().cpu().numpy().astype(np.float64)  # (B,B)\n",
    "    a = ot.unif(B).astype(np.float64)\n",
    "    b = ot.unif(B).astype(np.float64)\n",
    "\n",
    "    # UOT \n",
    "    P = ot.unbalanced.sinkhorn_unbalanced(\n",
    "        a, b, C, reg=uot_reg, reg_m=(safe_tau1, safe_tau2)\n",
    "    )  # (B,B) float64\n",
    "\n",
    "    #   → - \n",
    "    targ_m = P.sum(axis=0)                          # (B,)\n",
    "    w_col  = np.power(targ_m + eps_marg, -alpha)    # (B,) float64\n",
    "    if cap_w is not None:\n",
    "        w_col = np.minimum(w_col, cap_w)\n",
    "    w_col /= (w_col.mean() + 1e-12)\n",
    "    \n",
    "    # ★ MPS : float32 \n",
    "    w_col = w_col.astype(np.float32)\n",
    "\n",
    "    \n",
    "    #    ()\n",
    "    R = P * w_col[None, :] if reweight_mode in (\"col\", \"both\") else P.copy()\n",
    "\n",
    "    #   (row-wise categorical )\n",
    "    R = np.clip(R, 0.0, None)\n",
    "    row_sum = R.sum(axis=1, keepdims=True)\n",
    "    dead = (row_sum <= 1e-18) | ~np.isfinite(row_sum)\n",
    "    if np.any(dead):\n",
    "        R[dead, :] = 1.0 / B\n",
    "        row_sum[dead] = 1.0\n",
    "    R /= row_sum\n",
    "\n",
    "    #   ()\n",
    "    diff = 1.0 - R.sum(axis=1, keepdims=True)\n",
    "    R[:, -1] += diff[:, 0]\n",
    "    R = np.clip(R, 0.0, None)\n",
    "    R /= (R.sum(axis=1, keepdims=True) + 1e-18)\n",
    "\n",
    "    #   1   \n",
    "    tgt_idx = np.empty(B, dtype=np.int64)\n",
    "    for i in range(B):\n",
    "        p = R[i]\n",
    "        s = p.sum()\n",
    "        if (not np.isfinite(s)) or s <= 0:\n",
    "            p = np.full(B, 1.0 / B, dtype=np.float64)\n",
    "        else:\n",
    "            p = p / p.sum()\n",
    "            p[-1] = max(0.0, 1.0 - p[:-1].sum())\n",
    "            if p.sum() <= 0:\n",
    "                p = np.full(B, 1.0 / B, dtype=np.float64)\n",
    "        tgt_idx[i] = np.random.choice(B, p=p)\n",
    "\n",
    "    #   \n",
    "    y_match = y[torch.from_numpy(tgt_idx).to(y.device)]\n",
    "    return y_match, tgt_idx, w_col, P, targ_m\n",
    "\n",
    "# -----    -----\n",
    "Y_train_t = torch.from_numpy(Y_train).to(device)  # (n,2)\n",
    "opt = torch.optim.Adam(vnet.parameters(), lr=lr)\n",
    "\n",
    "def train_iter_uot():\n",
    "    # 1)  \n",
    "    idx = torch.randint(0, Y_train_t.shape[0], (batch_size,), device=device)\n",
    "    y = Y_train_t[idx]  # (B,2)\n",
    "\n",
    "    # 2)  \n",
    "    x0 = torch.randn_like(y)  # (B,2)\n",
    "\n",
    "    # 3) UOT  + -  \n",
    "    y_match, tgt_idx, w_col, P, targ_m = uot_pairing_rowwise(\n",
    "        x0, y,\n",
    "        alpha=alpha, eps_marg=eps_marg, cap_w=cap_w,\n",
    "        reweight_mode=reweight_mode\n",
    "    )\n",
    "\n",
    "    # 4)   &  \n",
    "    t = torch.rand((batch_size, 1), device=device)\n",
    "    x_t   = (1.0 - t) * x0 + t * y_match\n",
    "    u_star = y_match - x0\n",
    "\n",
    "    # 5)    (-   )\n",
    "    v_pred = vnet(x_t, t)                               # (B,2)\n",
    "    mse = (v_pred - u_star).pow(2).mean(dim=1)          # (B,)\n",
    "\n",
    "    if reweight_mode in (\"loss\", \"both\"):\n",
    "        # ★ float32    (from_numpy )\n",
    "        w_s = torch.tensor(w_col[tgt_idx], device=device, dtype=torch.float32)  # (B,)\n",
    "        w_s = (w_s / (w_s.mean() + 1e-12)).clamp_(0.0, 1e3)\n",
    "        loss = (w_s * mse).mean()\n",
    "    else:\n",
    "        loss = mse.mean()\n",
    "\n",
    "    return loss\n",
    "\n",
    "print(f\"[UOT] ε={uot_reg}, τ1={tau_1}, τ2={tau_2}, α={alpha}, mode={reweight_mode}\")\n",
    "for ep in range(1, epochs + 1):\n",
    "    loss = train_iter_uot()\n",
    "    opt.zero_grad()\n",
    "    loss.backward()\n",
    "    torch.nn.utils.clip_grad_norm_(vnet.parameters(), 1.0)\n",
    "    opt.step()\n",
    "    if ep % 100 == 0:\n",
    "        print(f\"[UOT] {ep:4d}/{epochs}  loss={loss.item():.6f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76e05595-d5fe-4fe3-ba73-50ab757472c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ----------------- Sampling by Euler integration -----------------\n",
    "@torch.no_grad()\n",
    "def sample_from_model(n_samples=2000, steps=100):\n",
    "    x = torch.randn((n_samples, 2), device=device)  # x(0) ~ N(0,I)\n",
    "    dt = 1.0 / steps\n",
    "    for k in range(steps):\n",
    "        t = torch.full((n_samples, 1), (k + 0.5) / steps, device=device)  # midpoint time\n",
    "        v = vnet(x, t)  # (n,2)\n",
    "        x = x + v * dt  # Euler step\n",
    "    return x.cpu().numpy()\n",
    "\n",
    "X_gen = sample_from_model(n_samples=4000, steps=150)\n",
    "\n",
    "# ----------------- Visualization -----------------\n",
    "fig, axes = plt.subplots(1, 3, figsize=(14, 4.5))\n",
    "\n",
    "# Base (N(0,I))\n",
    "X_base = sample_base(4000)\n",
    "axes[0].scatter(X_base[:,0], X_base[:,1], s=4, alpha=0.5, label=\"Base N(0,I)\")\n",
    "axes[0].set_title(\"Base samples (N(0,I))\")\n",
    "axes[0].axis('equal'); axes[0].legend()\n",
    "\n",
    "# Target (Imbalanced Two-Moons 90/10)\n",
    "axes[1].scatter(Y_test[:,0], Y_test[:,1], s=5, alpha=0.8, color='tab:green', label=\"Target (test)\")\n",
    "axes[1].set_title(\"Target: Imbalanced 8- GMM\")\n",
    "axes[1].axis('equal'); axes[1].legend()\n",
    "\n",
    "# Generated\n",
    "axes[2].scatter(X_gen[:,0], X_gen[:,1], s=4, alpha=0.6, color='tab:blue')\n",
    "axes[2].set_title(\"Generated by Our model with power \" + str(alpha))\n",
    "axes[2].axis('equal'); axes[2].legend()\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "labels, counts, ratios = cluster_and_count(X_gen, centers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "92ddbcba-3823-40e7-b4b5-f79513454f15",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_class = np.array([800, 700, 600, 500, 400, 300, 200, 100])\n",
    "\n",
    "print(num_class/num_class.sum())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a38ca438-8082-471a-a175-06aa501d251d",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "torchcfm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}