import os
import sys
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from typing import Optional
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from sklearn.datasets import make_moons
from sklearn.cluster import KMeans
import ot  # POT (Python Optimal Transport)
from absl import flags, app

# ----------------- Flags -----------------
FLAGS = flags.FLAGS

#   
flags.DEFINE_list("counts", ["10240", "4920", "2360", "1280", "640", "320", "160", "80"],
                  "8-GMM     ()")
flags.DEFINE_float("radius", 3.0, "8-GMM  ")
flags.DEFINE_float("sigma", 0.15, "  ")
flags.DEFINE_float("train_split", 0.5, "Train data split ratio")
flags.DEFINE_integer("seed", 71, "Random seed") # good: 71 47 / soso: 66* 42
#   
flags.DEFINE_integer("width", 128, "VecField hidden width")
flags.DEFINE_integer("batch_size", 512, "Batch size")
flags.DEFINE_float("lr", 2e-3, "Learning rate")

#   
flags.DEFINE_enum("experiment", "all", ["fm", "ot", "uot", "all"], 
                  "Experiment selection: fm | ot | uot | all")
flags.DEFINE_integer("epochs", 10000, "Training epochs for all methods")

# UOT  
flags.DEFINE_float("uot_reg", 0.05, "UOT entropic regularization ε")
flags.DEFINE_float("tau1", float("inf"), "UOT τ1 (source marginal penalty)")
flags.DEFINE_list("tau2", ["10.0"], "UOT τ2 (target marginal penalty) -  ")
flags.DEFINE_list("weight_power", ["2.0"], "Inverse-marginal weighting power (alpha) -  ")
flags.DEFINE_float("eps_marg", 1e-12, "Marginal stabilization epsilon")
flags.DEFINE_float("cap_w", 50.0, "Weight cap (to prevent explosion); negative disables")
flags.DEFINE_enum("reweight_mode", "loss", ["none", "col", "loss", "both"], 
                  "none:uot-fm, loss:uot-wfm, col:uot-fm w/ weighted coupling, both:uot-wfm w/ weighted coupling")

#   
flags.DEFINE_integer("n_samples", 10000, "Number of generated samples")
flags.DEFINE_integer("steps", 150, "Number of Euler steps for sampling")

#    
flags.DEFINE_string("results_dir", "", "Results directory (if not specified, use default rule)")
flags.DEFINE_bool("save_plots", True, "Save visualization plots")
flags.DEFINE_bool("verbose", True, "Verbose output")

# Device 
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

def set_seed(seed):
    """   """
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

def label_by_nearest_centers(points: np.ndarray, centers: np.ndarray) -> np.ndarray:
    """
        8-GMM  
    points: (N, 2)   
    centers: (8, 2) 8   ()
    : (N,)    (0~7)
    """
    points = points.astype(np.float32)
    centers = centers.astype(np.float32)
    
    #      
    distances = np.sqrt(((points[:, None, :] - centers[None, :, :]) ** 2).sum(axis=2))  # (N, 8)
    labels = np.argmin(distances, axis=1)  # (N,)    
    return labels

def cluster_and_count(X_tar, centers):
    """
    X_tar   : (N,2) 
    centers : (8,2)  
    """
    # ()   
    labels = label_by_nearest_centers(X_tar, centers)
    
    counts = np.bincount(labels, minlength=8)
    total = counts.sum()
    ratios = counts / total

    if FLAGS.verbose:
        for k in range(8):
            print(f"Mode {k:2d}: {counts[k]:5d} samples  ({ratios[k]*100:5.2f}%)")
    return labels, counts, ratios

def make_8gmm_counts_simple(
    counts=(800, 700, 600, 500, 400, 300, 200, 100),
    radius=3.0,
    sigma=0.15,
    seed=42,
    dtype=np.float32,
):
    """
    8-Gaussian mixture  (train/test split )
    """
    rng = np.random.default_rng(seed)

    counts = np.asarray(counts, dtype=int)
    assert counts.size == 8 and (counts > 0).all(), "counts  8    ."

    # 8  :  0, π/4, ..., 7π/4
    thetas = np.arange(8) * (np.pi / 4.0)
    centers = np.stack([radius * np.cos(thetas), radius * np.sin(thetas)], axis=1)  # (8,2)

    cov = (sigma ** 2) * np.eye(2)
    X_list, y_list = [], []

    for k, nk in enumerate(counts):
        Xk = rng.multivariate_normal(mean=centers[k], cov=cov, size=nk)
        yk = np.full(nk, k, dtype=np.int64)
        X_list.append(Xk)
        y_list.append(yk)

    X = np.vstack(X_list)
    y = np.concatenate(y_list)

    # shuffle
    perm = rng.permutation(X.shape[0])
    X = X[perm].astype(dtype, copy=False)
    y = y[perm]

    return X, y, centers.astype(dtype)

def make_8gmm_counts(
    counts=(800, 700, 600, 500, 400, 300, 200, 100),
    radius=3.0,
    sigma=0.15,
    seed=42,
    train_ratio=0.8,
    dtype=np.float32,
):
    """
    8-Gaussian mixture  (   )
    """
    X, y, centers = make_8gmm_counts_simple(counts, radius, sigma, seed, dtype)
    
    # train/test split
    n_train = int(len(X) * train_ratio)
    Y_train, Y_test = X[:n_train], X[n_train:]
    y_tr, y_te = y[:n_train], y[n_train:]

    return X, y, Y_train, Y_test, y_tr, y_te, centers

def sample_base(n):
    """Base distribution: standard normal"""
    return np.random.randn(n, 2).astype(np.float32)

class VecField(nn.Module):
    """Vector field model v_theta(x,t)"""
    def __init__(self, x_dim=2, t_dim=1, width=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(x_dim + t_dim, width),
            nn.SiLU(),
            nn.Linear(width, width),
            nn.SiLU(),
            nn.Linear(width, width),
            nn.SiLU(),
            nn.Linear(width, width),
            nn.SiLU(),
            nn.Linear(width, x_dim)
        )

    def forward(self, x, t):
        # x: (B,2), t: (B,1)
        return self.net(torch.cat([x, t], dim=-1))

@torch.no_grad()
def sample_from_model(vnet, n_samples=2000, steps=100):
    """Sampling by Euler integration"""
    x = torch.randn((n_samples, 2), device=device)  # x(0) ~ N(0,I)
    dt = 1.0 / steps
    for k in range(steps):
        t = torch.full((n_samples, 1), (k + 0.5) / steps, device=device)  # midpoint time
        v = vnet(x, t)  # (n,2)
        x = x + v * dt  # Euler step
    return x.cpu().numpy()

def train_fm(Y_train_t, vnet, optimizer, epochs):
    """Flow Matching """
    if FLAGS.verbose:
        print("Training Flow Matching...")
    
    def train_iter():
        # (1) y ~ target (mini-batch)
        idx = torch.randint(0, Y_train_t.shape[0], (FLAGS.batch_size,), device=device)
        y = Y_train_t[idx]  # (B,2)

        # (2) x0 ~ N(0,I)
        x0 = torch.randn_like(y)  # (B,2)

        # (3) t ~ Uniform(0,1)
        t = torch.rand((FLAGS.batch_size, 1), device=device)

        # (4) Linear path & target velocity
        x_t = (1.0 - t) * x0 + t * y         # (B,2)
        u_star = y - x0                      # (B,2), rectified-flow target velocity

        # (5) Predict & loss
        v_pred = vnet(x_t, t)                # (B,2)
        loss = ((v_pred - u_star) ** 2).mean()
        return loss
    
    for ep in range(1, epochs + 1):
        loss = train_iter()
        optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(vnet.parameters(), 1.0)
        optimizer.step()
        if ep % 100 == 0 and FLAGS.verbose:
            print(f"[{ep:4d}/{epochs}] loss={loss.item():.6f}")

@torch.no_grad()
def ot_pairing_rowwise(x0: torch.Tensor, y: torch.Tensor):
    """
     OT x0 ↔ y   y_match  ( W2)
    """
    B = x0.shape[0]

    # --- :  ''  -> W2 ( EMD)
    C2 = torch.cdist(x0, y, p=2).pow(2).cpu().numpy()   # (B,B)

    #  
    a = ot.unif(B)
    b = ot.unif(B)

    #  ( EMD; entropic regularization )
    P = ot.emd(a, b, C2)   # (B,B), optimal transport plan for squared cost

    #  row   → row-wise categorical 
    row = P / (P.sum(axis=1, keepdims=True) + 1e-12)
    tgt_idx = np.array([np.random.choice(B, p=row[i]) for i in range(B)], dtype=np.int64)

    # Torch   y 
    y_match = y[torch.from_numpy(tgt_idx).to(y.device)]
    return y_match

def train_ot(Y_train_t, vnet, optimizer, epochs):
    """Optimal Transport Flow Matching """
    if FLAGS.verbose:
        print("Training OT Flow Matching...")
    
    def train_iter():
        # (1) y ~ target (mini-batch)
        idx = torch.randint(0, Y_train_t.shape[0], (FLAGS.batch_size,), device=device)
        y = Y_train_t[idx]  # (B,2)

        # (2) x0 ~ N(0,I)
        x0 = torch.randn_like(y)  # (B,2)

        # (3)  OT x0 ↔ y  → y_match
        y_match = ot_pairing_rowwise(x0, y)  # (B,2)

        # (4) t ~ Uniform(0,1)
        t = torch.rand((FLAGS.batch_size, 1), device=device)

        # (5)      (Rectified Flow with OT coupling)
        x_t   = (1.0 - t) * x0 + t * y_match
        u_star = y_match - x0

        # (6)  & 
        v_pred = vnet(x_t, t)
        loss = ((v_pred - u_star) ** 2).mean()
        return loss
    
    for ep in range(1, epochs + 1):
        loss = train_iter()
        optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(vnet.parameters(), 1.0)
        optimizer.step()
        if ep % 100 == 0 and FLAGS.verbose:
            print(f"[{ep:4d}/{epochs}] loss={loss.item():.6f}")

def _finite_tau(t, big=1e6):
    return big if (t is None or not np.isfinite(t)) else float(t)

@torch.no_grad()
def uot_pairing_rowwise(
    x0: torch.Tensor,
    y: torch.Tensor,
    tau2: float,
    alpha: float,
    eps_marg: float = 1e-12,
    cap_w: Optional[float] = 50.0,
    reweight_mode: str = "col",
):
    """
    UOT  +    
    """
    B = x0.shape[0]

    # :  (Flow Matching  )
    C = torch.cdist(x0, y, p=2).detach().cpu().numpy().astype(np.float64)  # (B,B)
    a = ot.unif(B).astype(np.float64)
    b = ot.unif(B).astype(np.float64)

    safe_tau1 = _finite_tau(FLAGS.tau1)
    safe_tau2 = _finite_tau(tau2)

    # UOT 
    P = ot.unbalanced.sinkhorn_unbalanced(
        a, b, C, reg=FLAGS.uot_reg, reg_m=(safe_tau1, safe_tau2)
    )  # (B,B) float64

    #   → - 
    targ_m = P.sum(axis=0)                          # (B,)
    w_col  = np.power(targ_m + eps_marg, -alpha)    # (B,) float64
    if cap_w is not None:
        w_col = np.minimum(w_col, cap_w)
    w_col /= (w_col.mean() + 1e-12)
    
    # MPS : float32 
    w_col = w_col.astype(np.float32)
    
    #    ()
    R = P * w_col[None, :] if reweight_mode in ("col", "both") else P.copy()

    #   (row-wise categorical )
    R = np.clip(R, 0.0, None)
    row_sum = R.sum(axis=1, keepdims=True)
    dead = (row_sum <= 1e-18) | ~np.isfinite(row_sum)
    if np.any(dead):
        R[dead, :] = 1.0 / B
        row_sum[dead] = 1.0
    R /= row_sum

    #   ()
    diff = 1.0 - R.sum(axis=1, keepdims=True)
    R[:, -1] += diff[:, 0]
    R = np.clip(R, 0.0, None)
    R /= (R.sum(axis=1, keepdims=True) + 1e-18)

    #   1   
    tgt_idx = np.empty(B, dtype=np.int64)
    for i in range(B):
        p = R[i]
        s = p.sum()
        if (not np.isfinite(s)) or s <= 0:
            p = np.full(B, 1.0 / B, dtype=np.float64)
        else:
            p = p / p.sum()
            p[-1] = max(0.0, 1.0 - p[:-1].sum())
            if p.sum() <= 0:
                p = np.full(B, 1.0 / B, dtype=np.float64)
        tgt_idx[i] = np.random.choice(B, p=p)

    #   
    y_match = y[torch.from_numpy(tgt_idx).to(y.device)]
    return y_match, tgt_idx, w_col, P, targ_m

def train_uot(Y_train_t, vnet, optimizer, epochs, tau2, alpha):
    """Unbalanced Optimal Transport Flow Matching """
    if FLAGS.verbose:
        print(f"Training UOT Flow Matching (tau2={tau2}, alpha={alpha})...")
    
    def train_iter():
        # 1)  
        idx = torch.randint(0, Y_train_t.shape[0], (FLAGS.batch_size,), device=device)
        y = Y_train_t[idx]  # (B,2)

        # 2)  
        x0 = torch.randn_like(y)  # (B,2)

        # 3) UOT  + -  
        y_match, tgt_idx, w_col, P, targ_m = uot_pairing_rowwise(
            x0, y, tau2, alpha,
            eps_marg=FLAGS.eps_marg, cap_w=FLAGS.cap_w,
            reweight_mode=FLAGS.reweight_mode
        )

        # 4)   &  
        t = torch.rand((FLAGS.batch_size, 1), device=device)
        x_t   = (1.0 - t) * x0 + t * y_match
        u_star = y_match - x0

        # 5)    (-   )
        v_pred = vnet(x_t, t)                               # (B,2)
        mse = (v_pred - u_star).pow(2).mean(dim=1)          # (B,)

        if FLAGS.reweight_mode in ("loss", "both"):
            # float32   
            w_s = torch.tensor(w_col[tgt_idx], device=device, dtype=torch.float32)  # (B,)
            w_s = (w_s / (w_s.mean() + 1e-12)).clamp_(0.0, 1e3)
            loss = (w_s * mse).mean()
        else:
            loss = mse.mean()

        return loss
    
    for ep in range(1, epochs + 1):
        loss = train_iter()
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(vnet.parameters(), 1.0)
        optimizer.step()
        if ep % 100 == 0 and FLAGS.verbose:
            print(f"[UOT] {ep:4d}/{epochs}  loss={loss.item():.6f}")

def calculate_emd_metrics(X_gen, Y_train, centers):
    """
         EMD  
    X_gen: (N, 2)  
    Y_train: (M, 2)  
    centers: (8, 2) 8  
    """
    # 1)    EMD
    #  (     )
    max_samples = 2000
    if len(X_gen) > max_samples:
        idx_gen = np.random.choice(len(X_gen), max_samples, replace=False)
        X_gen_sample = X_gen[idx_gen]
    else:
        X_gen_sample = X_gen
    
    if len(Y_train) > max_samples:
        idx_train = np.random.choice(len(Y_train), max_samples, replace=False)
        Y_train_sample = Y_train[idx_train]
    else:
        Y_train_sample = Y_train
    
    #    ( )
    C_overall = np.sqrt(((X_gen_sample[:, None, :] - Y_train_sample[None, :, :]) ** 2).sum(axis=2))
    
    #   
    a_gen = ot.unif(len(X_gen_sample))
    a_train = ot.unif(len(Y_train_sample))
    
    # EMD 
    emd_overall = ot.emd2(a_gen, a_train, C_overall)
    
    # 2)   EMD 
    gen_labels = label_by_nearest_centers(X_gen, centers)
    train_labels = label_by_nearest_centers(Y_train, centers)
    
    emd_per_mode = []
    
    for mode in range(8):
        #     
        X_gen_mode = X_gen[gen_labels == mode]
        Y_train_mode = Y_train[train_labels == mode]
        
        if len(X_gen_mode) == 0 or len(Y_train_mode) == 0:
            #    EMD 0     
            emd_per_mode.append(float('inf'))  #  0.0
            continue
            
        #  ()
        max_mode_samples = 500
        if len(X_gen_mode) > max_mode_samples:
            idx = np.random.choice(len(X_gen_mode), max_mode_samples, replace=False)
            X_gen_mode = X_gen_mode[idx]
        if len(Y_train_mode) > max_mode_samples:
            idx = np.random.choice(len(Y_train_mode), max_mode_samples, replace=False)
            Y_train_mode = Y_train_mode[idx]
        
        #  
        C_mode = np.sqrt(((X_gen_mode[:, None, :] - Y_train_mode[None, :, :]) ** 2).sum(axis=2))
        
        #  
        a_gen_mode = ot.unif(len(X_gen_mode))
        a_train_mode = ot.unif(len(Y_train_mode))
        
        # EMD 
        try:
            emd_mode = ot.emd2(a_gen_mode, a_train_mode, C_mode)
            emd_per_mode.append(emd_mode)
        except:
            emd_per_mode.append(float('inf'))
    
    # 3)  EMD  (  )
    finite_emds = [emd for emd in emd_per_mode if np.isfinite(emd)]
    emd_mean = np.mean(finite_emds) if finite_emds else float('inf')
    
    return emd_overall, emd_per_mode, emd_mean

def save_results_to_txt(counts, ratios, emd_overall, emd_per_mode, emd_mean, filename, save_dir):
    """
      EMD    
    counts: (8,)    
    ratios: (8,)   
    emd_overall:  EMD
    emd_per_mode: (8,)  EMD 
    emd_mean:  EMD 
    filename: PNG  ( )
    save_dir:  
    """
    if save_dir:
        os.makedirs(save_dir, exist_ok=True)
        txt_filename = filename.replace('.png', '.txt')
        txt_path = os.path.join(save_dir, txt_filename)
        
        with open(txt_path, 'w', encoding='utf-8') as f:
            #  
            f.write("=== LABELING RESULTS ===\n")
            f.write("Mode\tCount\tRatio\n")
            f.write("-" * 25 + "\n")
            for k in range(8):
                f.write(f"{k}\t{counts[k]}\t{ratios[k]:.4f}\n")
            f.write("-" * 25 + "\n")
            f.write(f"Total\t{counts.sum()}\t{ratios.sum():.4f}\n\n")
            
            # EMD 
            f.write("=== EMD METRICS ===\n")
            f.write(f"1) Overall EMD (Generated vs Training): {emd_overall:.6f}\n\n")
            
            f.write("2) Per-mode EMD (Generated vs Training):")
            f.write("\nMode\tEMD\n")
            f.write("-" * 15 + "\n")
            for k in range(8):
                emd_val = emd_per_mode[k]
                if np.isfinite(emd_val):
                    f.write(f"{k}\t{emd_val:.6f}\n")
                else:
                    f.write(f"{k}\tInf/N.A.\n")
            f.write("-" * 15 + "\n")
            
            f.write(f"\n3) Mean of per-mode EMDs: ")
            if np.isfinite(emd_mean):
                f.write(f"{emd_mean:.6f}\n")
            else:
                f.write("Inf/N.A.\n")
        
        if FLAGS.verbose:
            print(f"Results saved: {txt_path}")

def save_visualization(X_base, Y_target, X_gen, Y_train, method_name, centers, tau2=None, alpha=None, save_dir=None):
    """   """
    #    
    if tau2 is not None and alpha is not None:
        prefix = f"{method_name.lower().replace(' ', '_')}_tau2_{tau2}_alpha_{alpha}"
    else:
        prefix = f"{method_name.lower().replace(' ', '_')}"

    # 1) Base  Figure
    fig_base, ax_base = plt.subplots(1, 1, figsize=(4.8, 4.5))
    ax_base.xaxis.set_ticks([])
    ax_base.yaxis.set_ticks([])
    ax_base.scatter(X_base[:,0], X_base[:,1], s=4, alpha=0.5, label="Base N(0,I)")
    ax_base.axis('equal')
    fig_base.tight_layout()

    # 2) Target  Figure + 
    target_labels = label_by_nearest_centers(Y_target, centers)
    target_colors = plt.cm.tab10(target_labels)
    fig_tar, ax_tar = plt.subplots(1, 1, figsize=(4.8, 4.5))
    ax_tar.scatter(Y_target[:,0], Y_target[:,1], s=5, alpha=0.8, c=target_colors, label="Target (train)")
    ax_tar.axis('equal')
    ax_tar.xaxis.set_ticks([])
    ax_tar.yaxis.set_ticks([])
    target_counts = np.bincount(target_labels, minlength=8)
    target_ratios = target_counts / len(Y_target)
    for k, cent in enumerate(centers):
        txt_y_shift = -0.8 if cent[1] > -3 else 1.0
        ax_tar.text(cent[0], cent[1]+txt_y_shift, f"{target_ratios[k]*100:.2f}%", va='top', ha='center',
                    fontsize=16, bbox=dict(fc='white', ec='none', alpha=0.8))
    fig_tar.tight_layout()

    # 3) Generated  Figure +   / 
    labels = label_by_nearest_centers(X_gen, centers)
    colors = plt.cm.tab10(labels)
    fig_gen, ax_gen = plt.subplots(1, 1, figsize=(4.8, 4.5))
    ax_gen.scatter(X_gen[:,0], X_gen[:,1], s=4, alpha=0.6, c=colors, label=f"Generated ({method_name})")
    ax_gen.axis('equal')
    ax_gen.xaxis.set_ticks([])
    ax_gen.yaxis.set_ticks([])

    counts = np.bincount(labels, minlength=8)
    ratios = counts / len(X_gen)
    for k, cent in enumerate(centers):
        txt_y_shift = -0.8 if cent[1] > -3 else 1.0
        ax_gen.text(cent[0], cent[1]+txt_y_shift, f"x{(ratios[k]/target_ratios[k]):.2f}", va='top', ha='center',
                    fontsize=16, bbox=dict(fc='white', ec='none', alpha=0.8))
    fig_gen.tight_layout()

    # 
    if FLAGS.save_plots and save_dir:
        os.makedirs(save_dir, exist_ok=True)
        base_path = os.path.join(save_dir, f"{prefix}_base.png")
        tar_path  = os.path.join(save_dir, f"{prefix}_target.png")
        gen_path  = os.path.join(save_dir, f"{prefix}_generated.png")

        fig_base.savefig(base_path, dpi=150, bbox_inches='tight')
        fig_tar.savefig(tar_path, dpi=150, bbox_inches='tight')
        fig_gen.savefig(gen_path, dpi=150, bbox_inches='tight')

        if FLAGS.verbose:
            print(f"Plot saved: {base_path}")
            print(f"Plot saved: {tar_path}")
            print(f"Plot saved: {gen_path}")

        # EMD      (Generated )
        if FLAGS.verbose:
            print(f"Calculating EMD metrics for {method_name}...")
        emd_overall, emd_per_mode, emd_mean = calculate_emd_metrics(X_gen, Y_train, centers)
        save_results_to_txt(counts, ratios, emd_overall, emd_per_mode, emd_mean, f"{prefix}_generated.png", save_dir)

    #  
    plt.close(fig_base)
    plt.close(fig_tar)
    plt.close(fig_gen)

def run_experiment(experiment_type, Y_train, centers, save_dir):
    """  """
    Y_train_t = torch.from_numpy(Y_train).to(device)
    
    # target     (Y_train  )
    n_samples_plot = len(Y_train)
    
    if experiment_type == "fm":
        # Flow Matching 
        vnet = VecField(width=FLAGS.width).to(device)
        optimizer = optim.Adam(vnet.parameters(), lr=FLAGS.lr)
        train_fm(Y_train_t, vnet, optimizer, FLAGS.epochs)
        
        X_gen = sample_from_model(vnet, n_samples=n_samples_plot, steps=FLAGS.steps)
        X_base = sample_base(n_samples_plot)
        
        if FLAGS.verbose:
            print("\n=== Flow Matching Results ===")
            print(f"Generated {len(X_gen)} samples (matching target size: {len(Y_train)})")
        labels, counts, ratios = cluster_and_count(X_gen, centers)
        
        if FLAGS.save_plots:
            save_visualization(X_base, Y_train, X_gen, Y_train, "Flow Matching", centers, save_dir=save_dir)
            
    elif experiment_type == "ot":
        # Optimal Transport Flow Matching 
        vnet = VecField(width=FLAGS.width).to(device)
        optimizer = optim.Adam(vnet.parameters(), lr=FLAGS.lr)
        train_ot(Y_train_t, vnet, optimizer, FLAGS.epochs)
        
        X_gen = sample_from_model(vnet, n_samples=n_samples_plot, steps=FLAGS.steps)
        X_base = sample_base(n_samples_plot)
        
        if FLAGS.verbose:
            print("\n=== OT Flow Matching Results ===")
            print(f"Generated {len(X_gen)} samples (matching target size: {len(Y_train)})")
        labels, counts, ratios = cluster_and_count(X_gen, centers)
        
        if FLAGS.save_plots:
            save_visualization(X_base, Y_train, X_gen, Y_train, "OT Flow Matching", centers, save_dir=save_dir)
            
    elif experiment_type == "uot":
        # UOT Flow Matching  (tau2 weight_power  )
        tau2_list = [float(x) for x in FLAGS.tau2]
        alpha_list = [float(x) for x in FLAGS.weight_power]
        
        for tau2 in tau2_list:
            for alpha in alpha_list:
                if FLAGS.verbose:
                    print(f"\n=== UOT Flow Matching (τ2={tau2}, α={alpha}) ===")
                
                vnet = VecField(width=FLAGS.width).to(device)
                optimizer = optim.Adam(vnet.parameters(), lr=FLAGS.lr)
                train_uot(Y_train_t, vnet, optimizer, FLAGS.epochs, tau2, alpha)
                
                X_gen = sample_from_model(vnet, n_samples=n_samples_plot, steps=FLAGS.steps)
                X_base = sample_base(n_samples_plot)
                
                if FLAGS.verbose:
                    print(f"\nResults for τ2={tau2}, α={alpha}:")
                    print(f"Generated {len(X_gen)} samples (matching target size: {len(Y_train)})")
                labels, counts, ratios = cluster_and_count(X_gen, centers)
                
                if FLAGS.save_plots:
                    save_visualization(X_base, Y_train, X_gen, Y_train, "UOT Flow Matching", centers,
                                     tau2=tau2, alpha=alpha, save_dir=save_dir)

def main(argv):
    del argv  # Unused
    
    #  
    set_seed(FLAGS.seed)
    
    if FLAGS.verbose:
        print(f"Using device: {device}")
        print(f"Experiment type: {FLAGS.experiment}")
    
    #   (train/test split ,   train )
    counts = [int(x) for x in FLAGS.counts]
    X_tar, y_modes, centers = make_8gmm_counts_simple(
        counts=counts,
        radius=FLAGS.radius,
        sigma=FLAGS.sigma,
        seed=FLAGS.seed,
    )
    
    if FLAGS.verbose:
        print(f"Generated data: {len(X_tar)} samples")
        print(f"Target distribution: {counts}")
    
    #    
    if FLAGS.results_dir:
        save_dir = FLAGS.results_dir
    else:
        save_dir = f"results_8gmm_{FLAGS.experiment}"
    
    #  
    if FLAGS.experiment == "all":
        run_experiment("fm", X_tar, centers, save_dir)
        run_experiment("ot", X_tar, centers, save_dir) 
        run_experiment("uot", X_tar, centers, save_dir)
    else:
        run_experiment(FLAGS.experiment, X_tar, centers, save_dir)
    
    if FLAGS.verbose:
        print(f"\nExperiment completed! Results saved in: {save_dir}")

if __name__ == "__main__":
    app.run(main)