import numpy as np
import matplotlib.pyplot as plt
from scipy.spatial import Voronoi
from matplotlib.collections import LineCollection

# ----------------------------
#  Generate GMM dataset
# ----------------------------
def generate_gmm_data(N=2000, K=4, seed=42):
    np.random.seed(seed)
    z_samples = np.random.choice(K, size=N, p=[1/K]*K)
    means = [(np.pi*(z+1), 0) for z in range(K)]
    x_samples = np.array([
        np.random.normal(loc=means[z], scale=np.sqrt(1), size=(2,))
        for z in z_samples
    ])
    y_samples = np.array([1 if (z+1) % 2 == 1 else 0 for z in z_samples])
    return x_samples, y_samples, np.array(means)

X, y, means = generate_gmm_data(N=2000, K=4)

# ----------------------------
#  Partition T1: grid midpoints
# ----------------------------
def partition_T1_range(X):
    x_mid = 0.5 * (X[:,0].min() + X[:,0].max())
    y_mid = 0.5 * (X[:,1].min() + X[:,1].max())

    # true centers of the 4 quadrants
    centroids = np.array([
        [(X[:,0].min() + x_mid)/2, (X[:,1].min() + y_mid)/2],  # bottom-left
        [(x_mid + X[:,0].max())/2, (X[:,1].min() + y_mid)/2],  # bottom-right
        [(X[:,0].min() + x_mid)/2, (y_mid + X[:,1].max())/2],  # top-left
        [(x_mid + X[:,0].max())/2, (y_mid + X[:,1].max())/2],  # top-right
    ])
    return x_mid, y_mid, centroids

# ----------------------------
#  Partition T2: random spread centroids
# ----------------------------
def partition_T2_spread(X, K=4, seed=0, min_dist=2.0):
    np.random.seed(seed)
    centroids = []
    attempts = 0
    max_attempts = 10000
    while len(centroids) < K and attempts < max_attempts:
        candidate = np.random.uniform(low=X.min(axis=0), high=X.max(axis=0))
        if all(np.linalg.norm(candidate - c) >= min_dist for c in centroids):
            centroids.append(candidate)
        attempts += 1
    return np.array(centroids)

# ----------------------------
#  Partition T3: GMM means
# ----------------------------
def partition_T3(means):
    return means

# ----------------------------
#  Voronoi helper (dashed lines)
# ----------------------------
def plot_voronoi_with_dashes(vor, ax):
    ptp_bound = X.max(axis=0) - X.min(axis=0)
    segments = []

    for pointidx, simplex in zip(vor.ridge_points, vor.ridge_vertices):
        simplex = np.asarray(simplex)
        if np.all(simplex >= 0):
            segments.append(vor.vertices[simplex])
        else:
            i, j = pointidx
            t = vor.points[j] - vor.points[i]
            t /= np.linalg.norm(t)
            n = np.array([-t[1], t[0]])
            midpoint = vor.points[[i, j]].mean(axis=0)
            far_point = midpoint + n * ptp_bound.max()
            line = [vor.vertices[simplex[simplex >= 0][0]], far_point]
            segments.append(line)

    lc = LineCollection(segments, colors="black", linewidths=1.2,
                        linestyles=(0, (5, 5)))  # dashed pattern
    ax.add_collection(lc)

# ----------------------------
#  Centroid plotting helper
# ----------------------------
def plot_centroids(ax, centroids):
    # white circles with black outline
    ax.scatter(
        centroids[:,0], centroids[:,1],
        s=180, facecolors="black", edgecolors="black",
        marker="o", linewidths=1, zorder=5
    )
    # overlay black "x"
    ax.scatter(
        centroids[:,0], centroids[:,1],
        s=72, c="white", marker="x", linewidths=2, zorder=6
    )

# ----------------------------
#  Plotting
# ----------------------------
x_min, x_max = X[:,0].min()-1, X[:,0].max()+1
y_min, y_max = X[:,1].min()-3, X[:,1].max()+3

fig, axes = plt.subplots(1,3,figsize=(15,5), sharex=True, sharey=True)

# T1
axes[0].scatter(X[:,0], X[:,1], c=y, cmap="bwr", alpha=0.3, s=10)
x_mid, y_mid, centroids_T1 = partition_T1_range(X)
axes[0].axvline(x_mid, color="black", linestyle="--", linewidth=1.5)
axes[0].axhline(y_mid, color="black", linestyle="--", linewidth=1.5)
plot_centroids(axes[0], centroids_T1)

# T2
centroids_T2 = partition_T2_spread(X, K=4, seed=13, min_dist=3.0)
axes[1].scatter(X[:,0], X[:,1], c=y, cmap="bwr", alpha=0.3, s=10)
plot_centroids(axes[1], centroids_T2)
vor = Voronoi(centroids_T2)
plot_voronoi_with_dashes(vor, axes[1])

# T3
centroids_T3 = partition_T3(means)
axes[2].scatter(X[:,0], X[:,1], c=y, cmap="bwr", alpha=0.3, s=10)
plot_centroids(axes[2], centroids_T3)
boundaries = [(centroids_T3[i,0] + centroids_T3[i+1,0]) / 2 for i in range(len(centroids_T3)-1)]
for b in boundaries:
    axes[2].axvline(b, color="black", linestyle="--", linewidth=1.5)

# LaTeX axis labels only
for ax in axes:
    ax.set_xlabel(r"$x_1$", fontsize=24, labelpad=10)
    ax.set_ylabel(r"$x_2$", fontsize=24, labelpad=10)
    ax.set_xlim(x_min, x_max)
    ax.set_ylim(y_min, y_max)
    ax.set_aspect("equal", adjustable="box")
    ax.grid(False)

plt.tight_layout()
plt.savefig("clustering_partitions.pdf", dpi=300)
