import os
import sys
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from typing import Optional

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from sklearn.datasets import make_moons
from absl import flags

# ----------------- Flags -----------------
FLAGS = flags.FLAGS
flags.DEFINE_float("major_class_ratio", 0.95, "Major class ratio (0~1)")
flags.DEFINE_enum("preprocess", "none", ["none", "zero_mean", "shift"], "Target preprocessing: none | zero_mean | shift (constant)")
flags.DEFINE_integer("n_total", 10000, "Total number of samples")
flags.DEFINE_float("noise", 0.06, "Noise for make_moons")
flags.DEFINE_float("train_split", 0.8, "Train data split ratio")
flags.DEFINE_integer("seed", 42, "Random seed")

flags.DEFINE_integer("width", 128, "VecField hidden width")
flags.DEFINE_integer("batch_size", 512, "Batch size")
flags.DEFINE_float("lr", 2e-3, "Learning rate")
flags.DEFINE_integer("epochs", 2000, "Training epochs for all stages")
flags.DEFINE_enum("experiment", "all", ["fm", "ot", "uot", "all"], "Experiment selection: fm | ot | uot | all")

flags.DEFINE_float("uot_reg", 0.05, "UOT entropic regularization ε")
flags.DEFINE_float("tau1", float("inf"), "UOT τ1 (source marginal penalty)")
flags.DEFINE_float("tau2", 1.0, "UOT τ2 (target marginal penalty)")
flags.DEFINE_float("alpha", 4.0, "Inverse-marginal weighting power")
flags.DEFINE_float("eps_marg", 1e-12, "Marginal stabilization epsilon")
flags.DEFINE_float("cap_w", 50.0, "Weight cap (to prevent explosion); negative disables")
flags.DEFINE_enum("reweight_mode", "loss", ["none", "col", "loss", "both"], "none:uot-fm, loss:uot-wfm, col:uot-fm w/ weighted coupling, both:uot-wfm w/ weighted coupling")

flags.DEFINE_integer("n_samples", 4000, "Number of generated samples")
flags.DEFINE_integer("steps", 150, "Number of Euler steps for sampling")
flags.DEFINE_integer("n_base_samples", 4000, "Number of base samples for visualization")

flags.DEFINE_integer("rep_per_moon", 64, "Number of representative points per moon")
# The tri_label_thr value (threshold) determines the 3-class label formula:
# Example: if |p0 - p1| < tri_label_thr, then "middle", otherwise argmax for 0/1
flags.DEFINE_float("tri_label_thr", 0.85, "Threshold (rest when |p0 - p1| < threshold, argmax otherwise)")

flags.DEFINE_string("results_dir", "", "Results directory (if not specified, use default rule)")
flags.DEFINE_list("shift_const", ["4.0", "-4.0"], "Constant shift vector for preprocess=shift, as [dx, dy]")

# Parse flags early (module-level execution)
if not FLAGS.is_parsed():
    FLAGS(sys.argv)


# ----------------- Device -----------------
device = torch.device("cpu")


# ----------------- Results directory -----------------
RESULTS_DIR = FLAGS.results_dir or os.path.join(
    os.path.dirname(__file__),
    f"results/2moon_ratio{FLAGS.major_class_ratio}_prep{FLAGS.preprocess}",
)
os.makedirs(RESULTS_DIR, exist_ok=True)


# ----------------- Reproducibility -----------------
seed = FLAGS.seed
np.random.seed(seed)
torch.manual_seed(seed)


# ----------------- Experiment selection & unified epochs -----------------
exp = FLAGS.experiment
do_fm_train = exp in ("fm", "ot", "uot", "all")
do_fm_plot  = exp in ("fm", "all")
do_ot_train = exp in ("ot", "uot", "all")
do_ot_plot  = exp in ("ot", "all")
do_uot_train = exp in ("uot", "all")
do_uot_plot  = exp in ("uot", "all")

# ----------------- Build imbalanced two-moons -----------------
def zero_mean_moons(X: np.ndarray, y: np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    """    ,     ."""
    X_mean = X.mean(axis=0)
    X_norm = X - X_mean
    return X_norm, y, X_mean

def shift_moons(X: np.ndarray, y: np.ndarray, shift_vec: np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    """  shift_vec  .    (  )."""
    X_shifted = X + shift_vec[None, :]
    return X_shifted.astype(np.float32), y, -shift_vec.astype(np.float32)

N_total = FLAGS.n_total
X, y = make_moons(n_samples=N_total, noise=FLAGS.noise, random_state=seed)
if FLAGS.preprocess == "zero_mean":
    X, y, ZERO_MEAN_SHIFT = zero_mean_moons(X, y)
elif FLAGS.preprocess == "shift":
    try:
        shift_vals = np.array(list(map(float, FLAGS.shift_const)), dtype=np.float32)
        if shift_vals.shape[0] != 2:
            raise ValueError
    except Exception:
        shift_vals = np.array([0.3, -0.1], dtype=np.float32)
    X, y, ZERO_MEAN_SHIFT = shift_moons(X, y, shift_vals)
else:
    ZERO_MEAN_SHIFT = np.zeros(2, dtype=np.float32)

#  make_moons class 0/1 50/50; /   
mask0 = (y == 0)
mask1 = (y == 1)
idx0 = np.where(mask0)[0]
idx1 = np.where(mask1)[0]
np.random.shuffle(idx0)
np.random.shuffle(idx1)

n0 = int(FLAGS.major_class_ratio * N_total)
n1 = N_total - n0
pick0 = idx0[:n0]
pick1 = idx1[:n1]
pick = np.concatenate([pick0, pick1])
np.random.shuffle(pick)

X_tar = X[pick].astype(np.float32)  # (N,2)
N = X_tar.shape[0]

# / ( 80/20)
perm = np.random.permutation(N)
n_train = int(0.8 * N)
idx_tr, idx_te = perm[:n_train], perm[n_train:]
Y_train = X_tar[idx_tr]     # target samples (train)
Y_test  = X_tar[idx_te]     # target samples (test)


# ----------------- Base distribution: standard normal -----------------
def sample_base(n: int) -> np.ndarray:
    return np.random.randn(n, 2).astype(np.float32)


# ----------------- Labeling utilities (representatives and tri-label) -----------------
# --- 3-way labeling (0: upper, 1: lower, 2: rest) ---
TRI_LABEL_THR = FLAGS.tri_label_thr


# ----------------- True make_moons representatives -----------------
def sample_true_moon_points(num_per_moon: int = 16, shift: Optional[np.ndarray] = None) -> np.ndarray:
    """
    make_moons     16    .
    - : (cos t, sin t), t ∈ [0, π]
    - : (1 - cos t, 1 - sin t - 0.5), t ∈ [0, π]
    : (2, num_per_moon, 2)
    """
    t = np.linspace(0.0, np.pi, num_per_moon, endpoint=True, dtype=np.float32)
    upper = np.stack([np.cos(t), np.sin(t)], axis=1).astype(np.float32)
    lower = np.stack([1.0 - np.cos(t), 1.0 - np.sin(t) - 0.5], axis=1).astype(np.float32)
    reps = np.stack([upper, lower], axis=0)
    if shift is not None:
        reps = reps - shift.astype(np.float32)[None, None, :]
    return reps


def label_by_nearest_reps(points: np.ndarray, reps: np.ndarray) -> np.ndarray:
    """  (reps: (2,K,2))  moon  ."""
    pts = points.astype(np.float32)
    K = reps.shape[1]
    d_upper = ((pts[:, None, :] - reps[0][None, :, :]) ** 2).sum(axis=2).min(axis=1)
    d_lower = ((pts[:, None, :] - reps[1][None, :, :]) ** 2).sum(axis=2).min(axis=1)
    return (d_lower < d_upper).astype(np.int64)


def tri_label_by_nearest_reps(points: np.ndarray, reps: np.ndarray, thr: float = TRI_LABEL_THR) -> np.ndarray:
    """   3  (0/1/rest)."""
    pts = points.astype(np.float32)
    d_upper = ((pts[:, None, :] - reps[0][None, :, :]) ** 2).sum(axis=2).min(axis=1)
    d_lower = ((pts[:, None, :] - reps[1][None, :, :]) ** 2).sum(axis=2).min(axis=1)
    denom = d_upper + d_lower + 1e-12
    rel = np.abs(d_upper - d_lower) / denom
    mid = (rel <= thr)
    basic = (d_lower < d_upper).astype(np.int64)
    labels3 = basic.copy()
    labels3[mid] = 2
    return labels3


#   ( )
TRUE_REP_POINTS = sample_true_moon_points(FLAGS.rep_per_moon, shift=ZERO_MEAN_SHIFT)


# ----------------- Vector field model v_theta(x,t) -----------------
class VecField(nn.Module):
    def __init__(self, x_dim: int = 2, t_dim: int = 1, width: int = FLAGS.width):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(x_dim + t_dim, width),
            nn.SiLU(),
            nn.Linear(width, width),
            nn.SiLU(),
            nn.Linear(width, width),
            nn.SiLU(),
            nn.Linear(width, width),
            nn.SiLU(),
            nn.Linear(width, x_dim),
        )

    def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        # x: (B,2), t: (B,1)
        return self.net(torch.cat([x, t], dim=-1))


vnet = VecField().to(device)


# ----------------- Training setup (fm-flow baseline) -----------------
batch_size = FLAGS.batch_size
epochs = FLAGS.epochs
lr = FLAGS.lr
opt = optim.Adam(vnet.parameters(), lr=FLAGS.lr)

Y_train_t = torch.from_numpy(Y_train).to(device)  # (n_train,2)


def train_iter_rf() -> torch.Tensor:
    # (1) y ~ target (mini-batch)
    idx = torch.randint(0, Y_train_t.shape[0], (batch_size,), device=device)
    y = Y_train_t[idx]  # (B,2)

    # (2) x0 ~ N(0,I)
    x0 = torch.randn_like(y)  # (B,2)

    # (3) t ~ Uniform(0,1)
    t = torch.rand((batch_size, 1), device=device)

    # (4) Linear path & target velocity
    x_t = (1.0 - t) * x0 + t * y         # (B,2)
    u_star = y - x0                      # (B,2), fm-flow target velocity

    # (5) Predict & loss
    v_pred = vnet(x_t, t)                # (B,2)
    loss = ((v_pred - u_star) ** 2).mean()
    return loss


# ----------------- Training loop -----------------
if do_fm_train:
    print("Training (FM baseline)...")
    for ep in range(1, epochs + 1):
        loss = train_iter_rf()
        opt.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(vnet.parameters(), 1.0)
        opt.step()
        if ep % 100 == 0:
            print(f"[FM {ep:4d}/{epochs}] loss={loss.item():.6f}")


# ----------------- Sampling by Euler integration -----------------
@torch.no_grad()
def sample_from_model(n_samples: int = 2000, steps: int = 100) -> np.ndarray:
    x = torch.randn((n_samples, 2), device=device)  # x(0) ~ N(0,I)
    dt = 1.0 / steps
    for k in range(steps):
        t = torch.full((n_samples, 1), (k + 0.5) / steps, device=device)  # midpoint time
        v = vnet(x, t)  # (n,2)
        x = x + v * dt  # Euler step
    return x.cpu().numpy()


# ----------------- Visualization (stage 1) -----------------
if do_fm_plot:
    X_gen = sample_from_model(n_samples=FLAGS.n_samples, steps=FLAGS.steps)
    fig, axes = plt.subplots(1, 3, figsize=(14, 4.5))

    # Base (N(0,I))
    N_base_samples = FLAGS.n_base_samples
    X_base = sample_base(N_base_samples)
    axes[0].scatter(X_base[:, 0], X_base[:, 1], s=4, alpha=0.5, label="Base N(0,I)")
    axes[0].set_title("Base samples (N(0,I))")
    axes[0].axis('equal'); axes[0].legend()

    # Target (Imbalanced Two-Moons )
    axes[1].scatter(Y_test[:, 0], Y_test[:, 1], s=5, alpha=0.8, color='tab:green', label="Target (test)")
    axes[1].set_title(f"Target: Imbalanced Two-Moons ({FLAGS.major_class_ratio:.0%}/{1-FLAGS.major_class_ratio:.0%})")
    axes[1].axis('equal'); axes[1].legend()

    # Generated (colored by nearest true reps)
    _labels = label_by_nearest_reps(X_gen, TRUE_REP_POINTS)
    _m0 = _labels == 0
    _m1 = ~_m0
    axes[2].scatter(X_gen[_m0, 0], X_gen[_m0, 1], s=4, alpha=0.6, color='tab:blue', label="Generated (FM) - 0")
    axes[2].scatter(X_gen[_m1, 0], X_gen[_m1, 1], s=4, alpha=0.6, color='tab:orange', label="Generated (FM) - 1")
    _c0 = int(_m0.sum()); _c1 = int(_m1.sum())
    _r0 = _c0 / N_base_samples
    _r1 = _c1 / N_base_samples
    axes[2].text(0.02, 0.98, f"upper: {_c0} ({_r0:.2f}), lower: {_c1} ({_r1:.2f})", transform=axes[2].transAxes, va='top', ha='left', fontsize=10, bbox=dict(fc='white', ec='none', alpha=0.6))
    axes[2].set_title("Generated by Flow Matching")
    axes[2].axis('equal'); axes[2].legend()

    plt.tight_layout()
    fig.savefig(os.path.join(RESULTS_DIR, f"2moon_stage1_fm.png"), dpi=150)
    plt.close(fig)

    # ----------------- Visualization (stage 1, tri-label) -----------------
    fig, axes = plt.subplots(1, 3, figsize=(14, 4.5))

    # Base
    axes[0].scatter(X_base[:, 0], X_base[:, 1], s=4, alpha=0.5, label="Base N(0,I)")
    axes[0].set_title("Base samples (N(0,I))")
    axes[0].axis('equal'); axes[0].legend()

    # Target
    axes[1].scatter(Y_test[:, 0], Y_test[:, 1], s=5, alpha=0.8, color='tab:green', label="Target (test)")
    axes[1].set_title(f"Target: Imbalanced Two-Moons ({FLAGS.major_class_ratio:.0%}/{1-FLAGS.major_class_ratio:.0%})")
    axes[1].axis('equal'); axes[1].legend()

    # Generated (tri-label by true reps)
    _labels3 = tri_label_by_nearest_reps(X_gen, TRUE_REP_POINTS)
    _m0 = _labels3 == 0
    _m1 = _labels3 == 1
    _mm = _labels3 == 2
    axes[2].scatter(X_gen[_m0, 0], X_gen[_m0, 1], s=4, alpha=0.6, color='tab:blue', label="upper")
    axes[2].scatter(X_gen[_m1, 0], X_gen[_m1, 1], s=4, alpha=0.6, color='tab:orange', label="lower")
    axes[2].scatter(X_gen[_mm, 0], X_gen[_mm, 1], s=4, alpha=0.7, color='tab:purple', label="rest")
    _c0 = int(_m0.sum()); _c1 = int(_m1.sum()); _cm = int(_mm.sum())
    _r0 = _c0 / N_base_samples; _r1 = _c1 / N_base_samples; _rm = _cm / N_base_samples
    axes[2].text(0.02, 0.98, f"thr={TRI_LABEL_THR:.3f}\nU:{_c0}({_r0:.2f}) L:{_c1}({_r1:.2f}) R:{_cm}({_rm:.2f})", transform=axes[2].transAxes, va='top', ha='left', fontsize=10, bbox=dict(fc='white', ec='none', alpha=0.6))
    axes[2].set_title("Generated (tri-label) FM")
    axes[2].axis('equal'); axes[2].legend()

    plt.tight_layout()
    fig.savefig(os.path.join(RESULTS_DIR, f"2moon_stage1_fm_trilabel.png"), dpi=150)
    plt.close(fig)

    # ----------------- Visualization (stage 1, plain: no labeling) -----------------
    fig, axes = plt.subplots(1, 3, figsize=(14, 4.5))

    # Base
    axes[0].scatter(X_base[:, 0], X_base[:, 1], s=4, alpha=0.5, label="Base N(0,I)")
    axes[0].set_title("Base samples (N(0,I))")
    axes[0].axis('equal'); axes[0].legend()

    # Target
    axes[1].scatter(Y_test[:, 0], Y_test[:, 1], s=5, alpha=0.8, color='tab:green', label="Target (test)")
    axes[1].set_title(f"Target: Imbalanced Two-Moons ({FLAGS.major_class_ratio:.0%}/{1-FLAGS.major_class_ratio:.0%})")
    axes[1].axis('equal'); axes[1].legend()

    # Generated (plain color, no labeling)
    axes[2].scatter(X_gen[:, 0], X_gen[:, 1], s=4, alpha=0.6, color='tab:blue', label="Generated (plain)")
    axes[2].set_title("Generated by Flow Matching (plain)")
    axes[2].axis('equal'); axes[2].legend()

    plt.tight_layout()
    fig.savefig(os.path.join(RESULTS_DIR, f"2moon_stage1_fm_plain.png"), dpi=150)
    plt.close(fig)

# ----------------- Quick imbalance check -----------------
if do_fm_plot:
    print("Target size:", Y_test.shape[0])
    print("Generated size:", X_gen.shape[0])


# ===========================================================
# Stage 2: Mini-batch OT coupling (exact EMD on squared cost)
# ===========================================================
import ot  # POT (Python Optimal Transport)


@torch.no_grad()
def ot_pairing_rowwise(x0: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    """
    x0: (B,2)  base samples
    y : (B,2)  target samples (   )
    : y_match (B,2)  --  x0 (row)   1  
    """
    B = x0.shape[0]

    # :  ''  -> W2 ( EMD)
    C2 = torch.cdist(x0, y, p=2).pow(2).cpu().numpy()   # (B,B)

    #  
    a = ot.unif(B)
    b = ot.unif(B)

    #  EMD (entropic regularization )
    P = ot.emd(a, b, C2)   # (B,B)

    #  row   → row-wise categorical 
    row = P / (P.sum(axis=1, keepdims=True) + 1e-12)
    tgt_idx = np.array([np.random.choice(B, p=row[i]) for i in range(B)], dtype=np.int64)

    # Torch   y 
    y_match = y[torch.from_numpy(tgt_idx).to(y.device)]
    return y_match


# ----------------- Training setup (same hyperparams) -----------------
opt = optim.Adam(vnet.parameters(), lr=FLAGS.lr)

Y_train_t = torch.from_numpy(Y_train).to(device)


def train_iter_ot() -> torch.Tensor:
    # (1) y ~ target (mini-batch)
    idx = torch.randint(0, Y_train_t.shape[0], (batch_size,), device=device)
    y = Y_train_t[idx]  # (B,2)

    # (2) x0 ~ N(0,I)
    x0 = torch.randn_like(y)  # (B,2)

    # (3)  OT x0 ↔ y  → y_match
    y_match = ot_pairing_rowwise(x0, y)  # (B,2)

    # (4) t ~ Uniform(0,1)
    t = torch.rand((batch_size, 1), device=device)

    # (5)     
    x_t = (1.0 - t) * x0 + t * y_match
    u_star = y_match - x0

    # (6)  & 
    v_pred = vnet(x_t, t)
    loss = ((v_pred - u_star) ** 2).mean()
    return loss


if do_ot_train:
    print("Training (with minibatch OT coupling)...")
    for ep in range(1, epochs + 1):
        loss = train_iter_ot()
        opt.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(vnet.parameters(), 1.0)
        opt.step()
        if ep % 100 == 0:
            print(f"[OT {ep:4d}/{epochs}] loss={loss.item():.6f}")


@torch.no_grad()
def sample_from_model(n_samples: int = 2000, steps: int = 100) -> np.ndarray:  # redefined as in notebook
    x = torch.randn((n_samples, 2), device=device)
    dt = 1.0 / steps
    for k in range(steps):
        t = torch.full((n_samples, 1), (k + 0.5) / steps, device=device)
        v = vnet(x, t)
        x = x + v * dt
    return x.cpu().numpy()


# ----------------- Visualization (stage 2) -----------------
if do_ot_plot:
    X_gen = sample_from_model(n_samples=FLAGS.n_samples, steps=FLAGS.steps)
    fig, axes = plt.subplots(1, 3, figsize=(14, 4.5))

    N_base_samples = FLAGS.n_base_samples
    X_base = sample_base(N_base_samples)
    axes[0].scatter(X_base[:, 0], X_base[:, 1], s=4, alpha=0.5, label="Base N(0,I)")
    axes[0].set_title("Base samples (N(0,I))")
    axes[0].axis('equal'); axes[0].legend()

    axes[1].scatter(Y_test[:, 0], Y_test[:, 1], s=5, alpha=0.8, color='tab:green', label="Target (test)")
    axes[1].set_title(f"Target: Imbalanced Two-Moons ({FLAGS.major_class_ratio:.0%}/{1-FLAGS.major_class_ratio:.0%})")
    axes[1].axis('equal'); axes[1].legend()

    # Generated (colored by nearest true reps)
    _labels = label_by_nearest_reps(X_gen, TRUE_REP_POINTS)
    _m0 = _labels == 0
    _m1 = ~_m0
    axes[2].scatter(X_gen[_m0, 0], X_gen[_m0, 1], s=4, alpha=0.6, color='tab:blue', label="Generated (OT-CFM) - 0")
    axes[2].scatter(X_gen[_m1, 0], X_gen[_m1, 1], s=4, alpha=0.6, color='tab:orange', label="Generated (OT-CFM) - 1")
    _c0 = int(_m0.sum()); _c1 = int(_m1.sum())
    _r0 = _c0 / N_base_samples
    _r1 = _c1 / N_base_samples
    axes[2].text(0.02, 0.98, f"upper: {_c0} ({_r0:.2f}), lower: {_c1} ({_r1:.2f})", transform=axes[2].transAxes, va='top', ha='left', fontsize=10, bbox=dict(fc='white', ec='none', alpha=0.6))
    axes[2].set_title("Generated by OT Flow Matching")
    axes[2].axis('equal'); axes[2].legend()

    plt.tight_layout()
    fig.savefig(os.path.join(RESULTS_DIR, f"2moon_stage2_ot_cfm.png"), dpi=150)
    plt.close(fig)

    # ----------------- Visualization (stage 2, tri-label) -----------------
    fig, axes = plt.subplots(1, 3, figsize=(14, 4.5))

    axes[0].scatter(X_base[:, 0], X_base[:, 1], s=4, alpha=0.5, label="Base N(0,I)")
    axes[0].set_title("Base samples (N(0,I))")
    axes[0].axis('equal'); axes[0].legend()

    axes[1].scatter(Y_test[:, 0], Y_test[:, 1], s=5, alpha=0.8, color='tab:green', label="Target (test)")
    axes[1].set_title(f"Target: Imbalanced Two-Moons ({FLAGS.major_class_ratio:.0%}/{1-FLAGS.major_class_ratio:.0%})")
    axes[1].axis('equal'); axes[1].legend()

    _labels3 = tri_label_by_nearest_reps(X_gen, TRUE_REP_POINTS)
    _m0 = _labels3 == 0
    _m1 = _labels3 == 1
    _mm = _labels3 == 2
    axes[2].scatter(X_gen[_m0, 0], X_gen[_m0, 1], s=4, alpha=0.6, color='tab:blue', label="upper")
    axes[2].scatter(X_gen[_m1, 0], X_gen[_m1, 1], s=4, alpha=0.6, color='tab:orange', label="lower")
    axes[2].scatter(X_gen[_mm, 0], X_gen[_mm, 1], s=4, alpha=0.7, color='tab:purple', label="rest")
    _c0 = int(_m0.sum()); _c1 = int(_m1.sum()); _cm = int(_mm.sum())
    _r0 = _c0 / N_base_samples; _r1 = _c1 / N_base_samples; _rm = _cm / N_base_samples
    axes[2].text(0.02, 0.98, f"thr={TRI_LABEL_THR:.3f}\nU:{_c0}({_r0:.2f}) L:{_c1}({_r1:.2f}) R:{_cm}({_rm:.2f})", transform=axes[2].transAxes, va='top', ha='left', fontsize=10, bbox=dict(fc='white', ec='none', alpha=0.6))
    axes[2].set_title("Generated (tri-label) OT-CFM")
    axes[2].axis('equal'); axes[2].legend()

    plt.tight_layout()
    fig.savefig(os.path.join(RESULTS_DIR, f"2moon_stage2_ot_cfm_trilabel.png"), dpi=150)
    plt.close(fig)

    # ----------------- Visualization (stage 2, plain: no labeling) -----------------
    fig, axes = plt.subplots(1, 3, figsize=(14, 4.5))
    axes[0].scatter(X_base[:, 0], X_base[:, 1], s=4, alpha=0.5, label="Base N(0,I)")
    axes[0].set_title("Base samples (N(0,I))")
    axes[0].axis('equal'); axes[0].legend()
    axes[1].scatter(Y_test[:, 0], Y_test[:, 1], s=5, alpha=0.8, color='tab:green', label="Target (test)")
    axes[1].set_title(f"Target: Imbalanced Two-Moons ({FLAGS.major_class_ratio:.0%}/{1-FLAGS.major_class_ratio:.0%})")
    axes[1].axis('equal'); axes[1].legend()
    axes[2].scatter(X_gen[:, 0], X_gen[:, 1], s=4, alpha=0.6, color='tab:blue', label="Generated (plain)")
    axes[2].set_title("Generated by OT Flow Matching (plain)")
    axes[2].axis('equal'); axes[2].legend()
    plt.tight_layout()
    fig.savefig(os.path.join(RESULTS_DIR, f"2moon_stage2_ot_cfm_plain.png"), dpi=150)
    plt.close(fig)


# =============================================================
# Stage 3: UOT coupling + target-marginal-based reweighted loss
# =============================================================


# ---  ---
uot_reg = FLAGS.uot_reg
tau_1 = FLAGS.tau1
tau_2 = FLAGS.tau2
alpha = FLAGS.alpha
eps_marg = FLAGS.eps_marg
cap_w = None if (FLAGS.cap_w is not None and FLAGS.cap_w < 0) else FLAGS.cap_w
reweight_mode = FLAGS.reweight_mode



def _finite_tau(t, big=1e6):
    return big if (t is None or not np.isfinite(t)) else float(t)


safe_tau1 = _finite_tau(tau_1)
safe_tau2 = _finite_tau(tau_2)


@torch.no_grad()
def uot_pairing_rowwise(
    x0: torch.Tensor,
    y: torch.Tensor,
    alpha: float = 1.0,
    eps_marg: float = 1e-12,
    cap_w: Optional[float] = 50.0,
    reweight_mode: str = "col",
):
    """
    x0: (B,2)  base
    y : (B,2)  target
    returns:
        y_match : (B,2)            — () 1   
        tgt_idx : (B,)             —   
        w_col   : (B,) np.float32  —  ' ' - 
        P       : (B,B) np.float64 — UOT  (: )
        targ_m  : (B,) np.float64  —  ( )
    """
    B = x0.shape[0]

    # :  
    C = torch.cdist(x0, y, p=2).detach().cpu().numpy().astype(np.float64)  # (B,B)
    a = ot.unif(B).astype(np.float64)
    b = ot.unif(B).astype(np.float64)

    # UOT 
    P = ot.unbalanced.sinkhorn_unbalanced(
        a, b, C, reg=uot_reg, reg_m=(safe_tau1, safe_tau2)
    )  # (B,B) float64

    #   → - 
    targ_m = P.sum(axis=0)                          # (B,)
    w_col = np.power(targ_m + eps_marg, -alpha)     # (B,) float64
    if cap_w is not None:
        w_col = np.minimum(w_col, cap_w)
    w_col /= (w_col.mean() + 1e-12)

    # MPS : float32 
    w_col = w_col.astype(np.float32)

    #    ()
    R = P * w_col[None, :] if reweight_mode in ("col", "both") else P.copy()

    #   (row-wise categorical )
    R = np.clip(R, 0.0, None)
    row_sum = R.sum(axis=1, keepdims=True)
    dead = (row_sum <= 1e-18) | ~np.isfinite(row_sum)
    if np.any(dead):
        R[dead, :] = 1.0 / B
        row_sum[dead] = 1.0
    R /= row_sum

    #   ()
    diff = 1.0 - R.sum(axis=1, keepdims=True)
    R[:, -1] += diff[:, 0]
    R = np.clip(R, 0.0, None)
    R /= (R.sum(axis=1, keepdims=True) + 1e-18)

    #   1   
    tgt_idx = np.empty(B, dtype=np.int64)
    for i in range(B):
        p = R[i]
        s = p.sum()
        if (not np.isfinite(s)) or s <= 0:
            p = np.full(B, 1.0 / B, dtype=np.float64)
        else:
            p = p / p.sum()
            p[-1] = max(0.0, 1.0 - p[:-1].sum())
            if p.sum() <= 0:
                p = np.full(B, 1.0 / B, dtype=np.float64)
        tgt_idx[i] = np.random.choice(B, p=p)

    #   
    y_match = y[torch.from_numpy(tgt_idx).to(y.device)]
    return y_match, tgt_idx, w_col, P, targ_m


Y_train_t = torch.from_numpy(Y_train).to(device)  # (n,2)
opt = torch.optim.Adam(vnet.parameters(), lr=lr)


def train_iter_uot() -> torch.Tensor:
    # 1)  
    idx = torch.randint(0, Y_train_t.shape[0], (batch_size,), device=device)
    y = Y_train_t[idx]  # (B,2)

    # 2)  
    x0 = torch.randn_like(y)  # (B,2)

    # 3) UOT  + -  
    y_match, tgt_idx, w_col, P, targ_m = uot_pairing_rowwise(
        x0, y,
        alpha=alpha, eps_marg=eps_marg, cap_w=cap_w,
        reweight_mode=reweight_mode,
    )

    # 4)   &  
    t = torch.rand((batch_size, 1), device=device)
    x_t = (1.0 - t) * x0 + t * y_match
    u_star = y_match - x0

    # 5)    (-   )
    v_pred = vnet(x_t, t)                               # (B,2)
    mse = (v_pred - u_star).pow(2).mean(dim=1)          # (B,)

    if reweight_mode in ("loss", "both"):
        w_s = torch.tensor(w_col[tgt_idx], device=device, dtype=torch.float32)  # (B,)
        w_s = (w_s / (w_s.mean() + 1e-12)).clamp_(0.0, 1e3)
        loss = (w_s * mse).mean()
    else:
        loss = mse.mean()

    return loss


if do_uot_train:
    print(f"[UOT] ε={uot_reg}, τ1={tau_1}, τ2={tau_2}, α={alpha}, mode={reweight_mode}")
    for ep in range(1, epochs + 1):
        loss = train_iter_uot()
        opt.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(vnet.parameters(), 1.0)
        opt.step()
        if ep % 100 == 0:
            print(f"[UOT] {ep:4d}/{epochs}  loss={loss.item():.6f}")


@torch.no_grad()
def sample_from_model(n_samples: int = 2000, steps: int = 100) -> np.ndarray:  # redefined as in notebook
    x = torch.randn((n_samples, 2), device=device)  # x(0) ~ N(0,I)
    dt = 1.0 / steps
    for k in range(steps):
        t = torch.full((n_samples, 1), (k + 0.5) / steps, device=device)  # midpoint time
        v = vnet(x, t)  # (n,2)
        x = x + v * dt  # Euler step
    return x.cpu().numpy()


# ----------------- Visualization (stage 3) -----------------
if do_uot_plot:
    X_gen = sample_from_model(n_samples=FLAGS.n_samples, steps=FLAGS.steps)
    fig, axes = plt.subplots(1, 3, figsize=(14, 4.5))

    N_base_samples = FLAGS.n_base_samples
    X_base = sample_base(N_base_samples)
    axes[0].scatter(X_base[:, 0], X_base[:, 1], s=4, alpha=0.5, label="Base N(0,I)")
    axes[0].set_title("Base samples (N(0,I))")
    axes[0].axis('equal'); axes[0].legend()

    axes[1].scatter(Y_test[:, 0], Y_test[:, 1], s=5, alpha=0.8, color='tab:green', label="Target (test)")
    axes[1].set_title(f"Target: Imbalanced Two-Moons ({FLAGS.major_class_ratio:.0%}/{1-FLAGS.major_class_ratio:.0%})")
    axes[1].axis('equal'); axes[1].legend()

    # Generated (colored by nearest true reps)
    _labels = label_by_nearest_reps(X_gen, TRUE_REP_POINTS)
    _m0 = _labels == 0
    _m1 = ~_m0
    axes[2].scatter(X_gen[_m0, 0], X_gen[_m0, 1], s=4, alpha=0.6, color='tab:blue', label="Generated (UOT) - 0")
    axes[2].scatter(X_gen[_m1, 0], X_gen[_m1, 1], s=4, alpha=0.6, color='tab:orange', label="Generated (UOT) - 1")
    _c0 = int(_m0.sum()); _c1 = int(_m1.sum())
    _r0 = _c0 / N_base_samples
    _r1 = _c1 / N_base_samples
    axes[2].text(0.02, 0.98, f"upper: {_c0} ({_r0:.2f}), lower: {_c1} ({_r1:.2f})", transform=axes[2].transAxes, va='top', ha='left', fontsize=10, bbox=dict(fc='white', ec='none', alpha=0.6))
    axes[2].set_title("Generated by Our model with power " + str(alpha))
    axes[2].axis('equal'); axes[2].legend()

    plt.tight_layout()
    fig.savefig(os.path.join(RESULTS_DIR, f"2moon_stage3_uot_{FLAGS.reweight_mode}_taub{FLAGS.tau2}_pow{FLAGS.alpha}.png"), dpi=150)
    plt.close(fig)

    # ----------------- Visualization (stage 3, tri-label) -----------------
    fig, axes = plt.subplots(1, 3, figsize=(14, 4.5))

    axes[0].scatter(X_base[:, 0], X_base[:, 1], s=4, alpha=0.5, label="Base N(0,I)")
    axes[0].set_title("Base samples (N(0,I))")
    axes[0].axis('equal'); axes[0].legend()

    axes[1].scatter(Y_test[:, 0], Y_test[:, 1], s=5, alpha=0.8, color='tab:green', label="Target (test)")
    axes[1].set_title(f"Target: Imbalanced Two-Moons ({FLAGS.major_class_ratio:.0%}/{1-FLAGS.major_class_ratio:.0%})")
    axes[1].axis('equal'); axes[1].legend()

    _labels3 = tri_label_by_nearest_reps(X_gen, TRUE_REP_POINTS)
    _m0 = _labels3 == 0
    _m1 = _labels3 == 1
    _mm = _labels3 == 2
    axes[2].scatter(X_gen[_m0, 0], X_gen[_m0, 1], s=4, alpha=0.6, color='tab:blue', label="upper")
    axes[2].scatter(X_gen[_m1, 0], X_gen[_m1, 1], s=4, alpha=0.6, color='tab:orange', label="lower")
    axes[2].scatter(X_gen[_mm, 0], X_gen[_mm, 1], s=4, alpha=0.7, color='tab:purple', label="rest")
    _c0 = int(_m0.sum()); _c1 = int(_m1.sum()); _cm = int(_mm.sum())
    _r0 = _c0 / N_base_samples; _r1 = _c1 / N_base_samples; _rm = _cm / N_base_samples
    axes[2].text(0.02, 0.98, f"thr={TRI_LABEL_THR:.3f}\nU:{_c0}({_r0:.2f}) L:{_c1}({_r1:.2f}) R:{_cm}({_rm:.2f})", transform=axes[2].transAxes, va='top', ha='left', fontsize=10, bbox=dict(fc='white', ec='none', alpha=0.6))
    axes[2].set_title("Generated (tri-label) UOT")
    axes[2].axis('equal'); axes[2].legend()

    plt.tight_layout()
    fig.savefig(os.path.join(RESULTS_DIR, f"2moon_stage3_uot_{FLAGS.reweight_mode}_taub{FLAGS.tau2}_pow{FLAGS.alpha}_trilabel.png"), dpi=150)
    plt.close(fig)

    # ----------------- Visualization (stage 3, plain: no labeling) -----------------
    fig, axes = plt.subplots(1, 3, figsize=(14, 4.5))
    axes[0].scatter(X_base[:, 0], X_base[:, 1], s=4, alpha=0.5, label="Base N(0,I)")
    axes[0].set_title("Base samples (N(0,I))")
    axes[0].axis('equal'); axes[0].legend()
    axes[1].scatter(Y_test[:, 0], Y_test[:, 1], s=5, alpha=0.8, color='tab:green', label="Target (test)")
    axes[1].set_title(f"Target: Imbalanced Two-Moons ({FLAGS.major_class_ratio:.0%}/{1-FLAGS.major_class_ratio:.0%})")
    axes[1].axis('equal'); axes[1].legend()
    axes[2].scatter(X_gen[:, 0], X_gen[:, 1], s=4, alpha=0.6, color='tab:blue', label="Generated (plain)")
    axes[2].set_title("Generated by UOT (plain)")
    axes[2].axis('equal'); axes[2].legend()
    plt.tight_layout()
    fig.savefig(os.path.join(RESULTS_DIR, f"2moon_stage3_uot_{FLAGS.reweight_mode}_taub{FLAGS.tau2}_pow{FLAGS.alpha}_plain.png"), dpi=150)
    plt.close(fig)


if __name__ == "__main__":
    pass


