{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 151,
   "id": "41b1c224-dc21-4068-b5b5-a74383742399",
   "metadata": {},
   "outputs": [],
   "source": [
    "using JuMP\n",
    "using MosekTools\n",
    "using DynamicPolynomials\n",
    "using MultivariatePolynomials\n",
    "using TSSOS, GaussianMixtures\n",
    "using LinearAlgebra, Random, Plots, Distributions, IterTools, Combinatorics, CSV, Statistics, MLDatasets, DataFrames, Revise, Clustering, Distances, Colors\n",
    "includet(\"Functions_Mixtures.jl\")\n",
    "includet(\"SeparationExperiment.jl\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6fcaa5aa-054e-4b6d-a3e4-f9e69ff7fef7",
   "metadata": {},
   "source": [
    "### Locally necessary functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 273,
   "id": "fc884486-6345-4dc3-b08a-35f7d711d8bd",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "run_em (generic function with 1 method)"
      ]
     },
     "execution_count": 273,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "function find_nearest_indices(X::Matrix{Float64}, centers::Matrix{Float64})\n",
    "    indices = Int[]\n",
    "    for c in eachrow(centers')\n",
    "        dists = [sum((X[i, :] .- c).^2) for i in 1:size(X, 1)]\n",
    "        push!(indices, argmin(dists))\n",
    "    end\n",
    "    return indices\n",
    "end\n",
    "    using StatsBase  # for countmap, combinations\n",
    "\n",
    "function adjusted_rand_index(labels_true::Vector{Int}, labels_pred::Vector{Int})\n",
    "    @assert length(labels_true) == length(labels_pred)\n",
    "\n",
    "    n = length(labels_true)\n",
    "    unique_true = sort(unique(labels_true))\n",
    "    unique_pred = sort(unique(labels_pred))\n",
    "    n_true = length(unique_true)\n",
    "    n_pred = length(unique_pred)\n",
    "\n",
    "    # Confusion matrix: n_ij\n",
    "    contingency = zeros(Int, n_true, n_pred)\n",
    "    for (l_true, l_pred) in zip(labels_true, labels_pred)\n",
    "        i = findfirst(isequal(l_true), unique_true)\n",
    "        j = findfirst(isequal(l_pred), unique_pred)\n",
    "        contingency[i, j] += 1\n",
    "    end\n",
    "\n",
    "    # Row and column sums (a_i, b_j)\n",
    "    a = sum(contingency, dims=2)  # true cluster sizes\n",
    "    b = sum(contingency, dims=1)  # predicted cluster sizes\n",
    "\n",
    "    # Helper: binomial(n, 2)\n",
    "    comb2(x) = x < 2 ? 0 : x * (x - 1) ÷ 2\n",
    "\n",
    "    # ∑_ij C(n_ij, 2)\n",
    "    index = sum(comb2(nij) for nij in contingency)\n",
    "\n",
    "    # ∑_i C(a_i, 2), ∑_j C(b_j, 2)\n",
    "    sum_ai = sum(comb2(ai) for ai in a)\n",
    "    sum_bj = sum(comb2(bj) for bj in b)\n",
    "\n",
    "    expected_index = sum_ai * sum_bj / comb2(n)\n",
    "    max_index = (sum_ai + sum_bj) / 2\n",
    "\n",
    "    # Adjusted Rand Index\n",
    "    return (index - expected_index) / (max_index - expected_index)\n",
    "end\n",
    "function align_labels(true_labels::Vector{Int}, pred_labels::Vector{Int})\n",
    "    classes = sort(unique(true_labels))\n",
    "    perms = collect(permutations(classes))\n",
    "\n",
    "    best_score = -1\n",
    "    best_aligned = similar(pred_labels)\n",
    "\n",
    "    for p in perms\n",
    "        mapping = Dict(classes[i] => p[i] for i in eachindex(classes))\n",
    "        aligned = [mapping[l] for l in pred_labels]\n",
    "        score = sum(aligned .== true_labels)\n",
    "\n",
    "        if score > best_score\n",
    "            best_score = score\n",
    "            best_aligned .= aligned\n",
    "        end\n",
    "    end\n",
    "\n",
    "    return best_aligned\n",
    "end\n",
    "\n",
    "using GaussianMixtures \n",
    "using DataFrames\n",
    "using DataFrames: groupby\n",
    "using StatsBase       \n",
    "using StatsPlots \n",
    "\n",
    "function random_means()\n",
    "  # e.g. sample k distinct data points\n",
    "  idx = sample(1:size(X,1), k; replace=false)\n",
    "  return X[idx, :]\n",
    "end\n",
    "\n",
    "function run_em(μ0; X, nIter=100, tol=1e-6, varfloor=0.0001)\n",
    "    # 1) materialize any Adjoint into a real matrix\n",
    "    μ0 = Matrix(μ0)\n",
    "    k, d = size(μ0)\n",
    "\n",
    "    # 2) compute data‐based diag variances\n",
    "    σ2 = vec(var(X; dims=1))              # length-d Vector\n",
    "    Σ0 = repeat(σ2', k, 1)                # k×d Matrix\n",
    "\n",
    "    # 3) build GMM (skip k‑means by nInit=0)\n",
    "    g = GMM(k, X; kind=:diag, nInit=0)\n",
    "\n",
    "    # 4) overwrite init\n",
    "    g.μ .= μ0\n",
    "    g.Σ .= Σ0\n",
    "    g.w .= fill(1/k, k)\n",
    "\n",
    "    # 5) run EM up to nIter, with varfloor control\n",
    "    logl = em!(g, X; nIter=nIter, varfloor=varfloor)\n",
    "\n",
    "    # 6) find when |Δℓ|<tol\n",
    "    idx = findfirst(i -> abs(logl[i] - logl[i-1]) < tol, 2:length(logl))\n",
    "    iters = idx === nothing ? length(logl) : idx\n",
    "\n",
    "    return (\n",
    "      iterations = iters,\n",
    "      final_ll   = logl[end],\n",
    "      logl       = logl,\n",
    "      model      = g,\n",
    "    )\n",
    "end"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "de1a67b5-1099-469b-bcd9-1f284c316821",
   "metadata": {},
   "source": [
    "## K=5, relatively well separated, non spherical"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0e8c82fb-4335-42c3-ac36-e82a11da5155",
   "metadata": {},
   "source": [
    "#### Data generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 165,
   "id": "9decdcd1-3b57-4d7b-ab31-91e5fe3cc604",
   "metadata": {},
   "outputs": [],
   "source": [
    "c = 5.0\n",
    "ecc= 0.25\n",
    "K=5\n",
    "n=2\n",
    "nb_parameter_choices = 50\n",
    "seed_parameters=1\n",
    "Random.seed!(seed_parameters)\n",
    "gmms_50_025_5 = generate_multiple_gmms_heteroscedastic(nb_parameter_choices, K, n; ecc=ecc, c=c);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 167,
   "id": "5e80f24c-78b5-4cca-9e52-ad2ce14a65d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "N = 1000\n",
    "seed_data = 10\n",
    "Random.seed!(seed_data)\n",
    "\n",
    "all_data = Vector{Matrix{Float64}}(undef, length(gmms_50_025_5))\n",
    "all_labels = Vector{Vector{Int}}(undef, length(gmms_50_025_5))\n",
    "\n",
    "for mix_index in 1:length(gmms_50_025_5)\n",
    "    mix = gmms_50_025_5[mix_index]\n",
    "\n",
    "    # Convert means to Vector of Vectors\n",
    "    means = [mix.means[:, i] for i in 1:size(mix.means, 2)]\n",
    "    covariances = mix.covariances\n",
    "    weights = mix.weights\n",
    "    k = length(means)\n",
    "\n",
    "    # Unique seed per configuration (optional)\n",
    "    seed_i = seed_data + mix_index\n",
    "\n",
    "    # Generate data\n",
    "    samples, labels = generate_gaussian_mixtures(\n",
    "        k, means, covariances, weights;\n",
    "        seed=seed_i, n_samples=N\n",
    "    )\n",
    "\n",
    "    all_data[mix_index] = samples'\n",
    "    all_labels[mix_index] = labels\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 169,
   "id": "dd3723ed-fde9-40a7-a81c-f6625784a289",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "n_configs = length(all_data)\n",
    "n_rows = 10\n",
    "n_cols = 5\n",
    "\n",
    "plot_list = []\n",
    "\n",
    "for i in 1:n_configs\n",
    "    samples = all_data[i]'\n",
    "    labels = all_labels[i]\n",
    "    \n",
    "    # Normalize\n",
    "    samples_scaled = hcat(scale_to_minus1_1.(eachcol(samples))...)\n",
    "    \n",
    "    # Optional: compute empirical mean per label\n",
    "    unique_labels = sort(unique(labels))\n",
    "    means_per_label = [mean(samples_scaled[labels .== l, :], dims=1) for l in unique_labels]\n",
    "\n",
    "    # Scatter plot\n",
    "    p = scatter(samples_scaled[:, 1], samples_scaled[:, 2],\n",
    "                group=labels,\n",
    "                markersize=2, alpha=0.6, legend=false,\n",
    "                #xlabel=\"x₁\", ylabel=\"x₂\",\n",
    "                title=\"Mixture $i\", xlims=(-0.2, 1.2), ylims=(-0.2, 1.2))\n",
    "\n",
    "    # Overlay cluster means\n",
    "    for m in means_per_label\n",
    "        scatter!(p, [m[1]], [m[2]], color=:yellow, marker=:circle, ms=6)\n",
    "    end\n",
    "\n",
    "    push!(plot_list, p)\n",
    "end\n",
    "\n",
    "# Display grid layout\n",
    "#plot(plot_list..., layout=(n_rows, n_cols), size=(1400, 2400))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "02e1737d-e6f8-4a12-ada2-cf7ec58e475b",
   "metadata": {},
   "source": [
    "#### $S_{\\theta}$ description"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "920693a1-9f2b-4a1f-93d5-92bb79c8c04b",
   "metadata": {},
   "outputs": [],
   "source": [
    "n=2\n",
    "@polyvar m[1:n]\n",
    "@polyvar sigma[1:n]\n",
    "Sm=[(m[1])*(1-m[1]), (m[2])*(1-m[2])]\n",
    "Ssig=[(sigma[1]-0.05)*(0.5-sigma[1]), (sigma[2]-0.05)*(0.5-sigma[2])]\n",
    "S=[vcat(Sm)...,vcat(Ssig)...]\n",
    "println()\n",
    "println(\"Support of the mixing measure\")\n",
    "S_normalized=[S[i]/maximum(abs.(coefficients(S[i]))) for i=1:length(S)]\n",
    "display(S_normalized)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6d6efd24-aaf7-462e-afa8-bab8592991ce",
   "metadata": {},
   "source": [
    "#### W2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7e03cee-17c8-4776-ad4d-0c5e794a0725",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "trace_penalization=true\n",
    "vareps=1e-3\n",
    "max_order = 4\n",
    "\n",
    "\n",
    "all_results_NS = Vector{Vector}(undef, length(all_data))\n",
    "\n",
    "for idx in 1:length(all_data)\n",
    "    println(\"\\n>>> Running on GMM config $idx\")\n",
    "\n",
    "    samples = Matrix(transpose(all_data[idx]))\n",
    "    samples_scaled = hcat(scale_to_minus1_1.(eachcol(samples))...)\n",
    "\n",
    "    res = []\n",
    "\n",
    "    for d = 1:max_order\n",
    "        println(\"  d = $d\")\n",
    "        push!(res, multivariate_Gaussian_W2(n, d, m, sigma, S_normalized, samples_scaled, trace_penalization, vareps))\n",
    "    end\n",
    "\n",
    "    all_results_NS[idx] = res\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6665a42d-e8f6-4161-8f36-04e75488df32",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "d=max_order\n",
    "RES_NS =[]\n",
    "energy_tol=1e-4\n",
    "\n",
    "for i in  1:length(all_data)\n",
    "    println(\">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> Flatness check for mixture $i\")\n",
    "    push!(RES_NS,analyse_relaxations_W2orTV(all_results_NS[i], d,n,energy_tol))\n",
    "    println(\"___________________________________________________________________\")\n",
    "    println()\n",
    "end"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e15c4ee8-506c-49d3-8a4b-b10bcaa72ab8",
   "metadata": {},
   "source": [
    "#### TV"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c24fe408-1336-48f3-adfa-d8157e9f2418",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "trace_penalization=true\n",
    "vareps=1e-3\n",
    "max_order = 4\n",
    "\n",
    "all_resultsTV_NS = Vector{Vector}(undef, length(all_data))\n",
    "\n",
    "for idx in 1:length(all_data)\n",
    "    println(\"\\n>>> Running on GMM config $idx\")\n",
    "\n",
    "    samples = Matrix(transpose(all_data[idx]))\n",
    "    samples_scaled = hcat(scale_to_minus1_1.(eachcol(samples))...)\n",
    "\n",
    "    resTV = []\n",
    "\n",
    "    for d = 1:max_order\n",
    "        println(\"  d = $d\")\n",
    "        push!(resTV, multivariate_Gaussian_TV(n, d, m, sigma, S_normalized, samples_scaled, trace_penalization, vareps))\n",
    "    end\n",
    "\n",
    "    all_resultsTV_NS[idx] = resTV\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46a89e04-f1f6-487c-be0c-b9d3ae500f5c",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "d=max_order\n",
    "RESTV_NS =[]\n",
    "energy_tol=1e-3\n",
    "for i in  1:length(all_data)\n",
    "    println(\">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> Flatness check for mixture $i\")\n",
    "    push!(RESTV_NS,analyse_relaxations_W2orTV(all_resultsTV_NS[i], d,n,energy_tol))\n",
    "    println(\"___________________________________________________________________\")\n",
    "    println()\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5706715a-2dc5-43ab-9fc7-96207bed4a9f",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "08f14ddc-1caf-4d3c-a32b-4dc56390ab44",
   "metadata": {},
   "source": [
    "#### Extraction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76097ddc-2a8c-47ca-bb85-2aeb3574fea6",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "rank_mm = 5\n",
    "d = max_order  # final relaxation order you used earlier\n",
    "\n",
    "plot_list = []\n",
    "\n",
    "CF_points_NS=[]\n",
    "true_means=[]\n",
    "for i in 1:length(all_data)\n",
    "    \n",
    "    # Get samples and labels\n",
    "    samples = all_data[i]' |> Matrix  # now shape is 1000 × 2\n",
    "    labels = all_labels[i]\n",
    "\n",
    "    # Extract RES for this config\n",
    "    M = RES_NS[i][1]  # moment matrix\n",
    "    L = RES_NS[i][end]  # last submatrix (or adjust as needed)\n",
    "\n",
    "    # Curto-Fialkow flat extension extraction\n",
    "    curto_f_points_NS = extract_CF(M, L, binomial(n + d - 1, d - 1), n, rank_mm)\n",
    "    sorted_cf_NS = sort(curto_f_points_NS, by = p -> p[1])\n",
    "    CF_NS = hcat(sorted_cf_NS...)  # 2 × r matrix\n",
    "    push!(CF_points_NS,CF_NS)\n",
    "\n",
    "    # Normalize samples\n",
    "    samples_scaled = hcat(scale_to_minus1_1.(eachcol(samples))...)\n",
    "\n",
    "    # Empirical means from true labels\n",
    "    cluster_means = [vec(mean(samples_scaled[labels .== k, :], dims=1)) for k in sort(unique(labels))]\n",
    "    sorted_means = sort(cluster_means, by = m -> m[1])\n",
    "    trum = hcat(sorted_means...)\n",
    "    push!(true_means, trum)\n",
    "\n",
    "    \n",
    "\n",
    "    # Plotting\n",
    "    p = scatter(samples_scaled[:,1], samples_scaled[:,2],\n",
    "                group=labels, markersize=2, alpha=0.6, legend=false,\n",
    "                title=\"Mixture $i in ℝ²\",\n",
    "                #xlabel=\"X₁\", ylabel=\"X₂\", \n",
    "                xlims=(-0.2, 1.2), ylims=(-0.2, 1.2))\n",
    "    \n",
    "    scatter!(p, trum[1, :], trum[2, :], marker=(:o, 8), color=:yellow, label=\"True Means\")\n",
    "    scatter!(p, CF_NS[1, :], CF_NS[2, :], marker=(:s, 6), color=:white, label=\"CDK Means\")\n",
    "\n",
    "    push!(plot_list, p)\n",
    "end\n",
    "\n",
    "#plot(plot_list..., layout=(10, 5), size=(1400, 2400))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 241,
   "id": "9d45f883-7265-4303-bdcd-f3d1f109661e",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "rank_mm = 5\n",
    "d = max_order  # final relaxation order you used earlier\n",
    "\n",
    "plot_listTV_NS = []\n",
    "\n",
    "CF_pointsTV_NS=[]\n",
    "true_means=[]\n",
    "for i in 1:length(all_data)\n",
    "    \n",
    "    # Get samples and labels\n",
    "    samples = all_data[i]' |> Matrix  # now shape is 1000 × 2\n",
    "    labels = all_labels[i]\n",
    "\n",
    "    # Extract RES for this config\n",
    "    M = RESTV_NS[i][1]  # moment matrix\n",
    "    L = RESTV_NS[i][end]  # last submatrix (or adjust as needed)\n",
    "\n",
    "    # Curto-Fialkow flat extension extraction\n",
    "    curto_f_pointsTV_NS = extract_CF(M, L, binomial(n + d - 1, d - 1), n, rank_mm)\n",
    "    sorted_cfTV_NS = sort(curto_f_pointsTV_NS, by = p -> p[1])\n",
    "    CFTV_NS = hcat(sorted_cfTV_NS...)  # 2 × r matrix\n",
    "    push!(CF_pointsTV_NS,CFTV_NS)\n",
    "\n",
    "    # Normalize samples\n",
    "    samples_scaled = hcat(scale_to_minus1_1.(eachcol(samples))...)\n",
    "\n",
    "    # Empirical means from true labels\n",
    "    cluster_means = [vec(mean(samples_scaled[labels .== k, :], dims=1)) for k in sort(unique(labels))]\n",
    "    sorted_means = sort(cluster_means, by = m -> m[1])\n",
    "    trum = hcat(sorted_means...)\n",
    "    push!(true_means, trum)\n",
    "\n",
    "    \n",
    "\n",
    "    # Plotting\n",
    "    p = scatter(samples_scaled[:,1], samples_scaled[:,2],\n",
    "                group=labels, markersize=2, alpha=0.6, legend=false,\n",
    "                title=\"Mixture $i in ℝ²\",\n",
    "                #xlabel=\"X₁\", ylabel=\"X₂\", \n",
    "                xlims=(-0.2, 1.2), ylims=(-0.2, 1.2))\n",
    "    \n",
    "    scatter!(p, trum[1, :], trum[2, :], marker=(:o, 8), color=:yellow, label=\"True Means\")\n",
    "    scatter!(p, CFTV_NS[1, :], CFTV_NS[2, :], marker=(:s, 6), color=:white, label=\"CDK Means\")\n",
    "\n",
    "    push!(plot_listTV_NS, p)\n",
    "end\n",
    "\n",
    "#plot(plot_list..., layout=(10, 5), size=(1400, 2400))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55e4d3b4-5744-49a3-b8df-c39344a34d47",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "dd1abcac-6b73-42f5-bf15-85be88c9e90e",
   "metadata": {},
   "source": [
    "#### Impact on $k$-means"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 245,
   "id": "ff9c4f50-34b2-479b-9e5f-ce07081c1a57",
   "metadata": {},
   "outputs": [],
   "source": [
    "results_NS = []\n",
    "Random.seed!(123)\n",
    "\n",
    "for i in 1:nb_parameter_choices\n",
    "    # Get mixture data & labels\n",
    "    X = all_data[i]' |> Matrix\n",
    "    X_scaled = hcat(scale_to_minus1_1.(eachcol(X))...)  # Normalization\n",
    "    labels = all_labels[i]\n",
    "    k = length(unique(labels))\n",
    "    \n",
    "    # --- CF Initialization ---\n",
    "    cf_centers_NS = CF_points_NS[i]\n",
    "    cf_centers_NS = size(cf_centers_NS, 1) == 2 ? cf_centers_NS : cf_centers_NS'\n",
    "    cf_indices_NS = find_nearest_indices(X_scaled, cf_centers_NS)\n",
    "    result_cf_NS = kmeans(X_scaled', k; init=cf_indices_NS, maxiter=100, display=:none)\n",
    "    ARI_cf_NS = adjusted_rand_index(labels, result_cf_NS.assignments)\n",
    "    obj_cf_NS = result_cf_NS.totalcost\n",
    "    iter_cf_NS = result_cf_NS.iterations\n",
    "    mis_cf_NS = sum(align_labels(labels, result_cf_NS.assignments) .!= labels)\n",
    "\n",
    "    # --- CFTV Initialization ---\n",
    "    cf_centersTV_NS = CF_pointsTV_NS[i]\n",
    "    cf_centersTV_NS = size(cf_centersTV_NS, 1) == 2 ? cf_centersTV_NS : cf_centersTV_NS'\n",
    "    cf_indicesTV_NS = find_nearest_indices(X_scaled, cf_centersTV_NS)\n",
    "    result_cfTV_NS = kmeans(X_scaled', k; init=cf_indicesTV_NS, maxiter=100, display=:none)\n",
    "    ARI_cfTV_NS = adjusted_rand_index(labels, result_cfTV_NS.assignments)\n",
    "    obj_cfTV_NS = result_cfTV_NS.totalcost\n",
    "    iter_cfTV_NS = result_cfTV_NS.iterations\n",
    "    mis_cfTV_NS = sum(align_labels(labels, result_cfTV_NS.assignments) .!= labels)\n",
    "\n",
    "    # --- Random Initialization (repeat N times) ---\n",
    "    N = 100\n",
    "    objs_rnd = Float64[]\n",
    "    iters_rnd = Int[]\n",
    "    ARIs_rnd = Float64[]\n",
    "    mis_rnd = Int[]\n",
    "    \n",
    "    for rep in 1:N\n",
    "        rand_indices = rand(1:size(X_scaled, 1), k)\n",
    "        result_rnd = kmeans(X_scaled', k; init=rand_indices, maxiter=100, display=:none)\n",
    "        ARI_rnd = adjusted_rand_index(labels, result_rnd.assignments)\n",
    "        push!(objs_rnd, result_rnd.totalcost)\n",
    "        push!(iters_rnd, result_rnd.iterations)\n",
    "        push!(ARIs_rnd, ARI_rnd)\n",
    "        push!(mis_rnd, sum(align_labels(labels, result_rnd.assignments) .!= labels))\n",
    "    end\n",
    "\n",
    "    # Store summary (including full list of iterations!)\n",
    "    push!(results_NS, (\n",
    "        i = i,\n",
    "        obj_cf_NS = obj_cf_NS,\n",
    "        iter_cf_NS = iter_cf_NS,\n",
    "        ARI_cf_NS = ARI_cf_NS,\n",
    "        mis_cf_NS = mis_cf_NS,\n",
    "        obj_cfTV_NS = obj_cfTV_NS ,\n",
    "        iter_cfTV_NS  = iter_cfTV_NS ,\n",
    "        ARI_cfTV_NS  = ARI_cfTV_NS ,\n",
    "        mis_cfTV_NS  = mis_cfTV_NS ,\n",
    "        objs_rnd = objs_rnd,\n",
    "        iters_rnd = iters_rnd,\n",
    "        ARIs_rnd = ARIs_rnd,\n",
    "        mis_rnd = mis_rnd,\n",
    "        mean_obj_rnd = mean(objs_rnd),\n",
    "        mean_iter_rnd = mean(iters_rnd),\n",
    "        mean_ARI_rnd = mean(ARIs_rnd),\n",
    "        mean_mis_rnd = mean(mis_rnd),\n",
    "        std_iter_rnd = std(iters_rnd)  # precompute if needed for plotting\n",
    "    ))\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed870cf1-833b-4372-a05c-62d6af8fb03d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Extract data\n",
    "iter_cf_NS = [r.iter_cf_NS for r in results_NS]\n",
    "iter_cfTV_NS = [r.iter_cfTV_NS for r in results_NS]\n",
    "iter_rnd_mean = [mean(r.iters_rnd) for r in results_NS]\n",
    "iter_rnd_std = [std(r.iters_rnd) for r in results_NS]\n",
    "\n",
    "# Sort by CF iterations\n",
    "sort_idx_NS = sortperm(iter_cf_NS)\n",
    "iter_cf_sorted_NS = iter_cf_NS[sort_idx_NS]\n",
    "iter_rnd_mean_sorted = iter_rnd_mean[sort_idx_NS]\n",
    "iter_rnd_std_sorted = iter_rnd_std[sort_idx_NS]\n",
    "iter_cfTV_sorted_NS=iter_cfTV_NS[sort_idx_NS]\n",
    "\n",
    "mix_ids_sorted = 1:length(results_NS)  # new x-axis = sorted mixture index\n",
    "\n",
    "# Plot CF line\n",
    "plot(mix_ids_sorted, iter_cf_sorted_NS;\n",
    "    label=\"Curto-Fialkow-W2\",\n",
    "    lw=2, marker=:circle,\n",
    "    xlabel=\"Mixture (sorted by CF)\", ylabel=\"Iterations\",\n",
    "    #title=\"K-means Iterations: CF vs Random Initialization\",\n",
    "    legend=:topright,\n",
    "    size=(900, 600))\n",
    "\n",
    "plot!(mix_ids_sorted, iter_cfTV_sorted_NS;\n",
    "    label=\"Curto-Fialkow-TV\",\n",
    "    lw=2, marker=:diamond,\n",
    "    color=:red\n",
    "   )\n",
    "\n",
    "\n",
    "# Add Random line with ribbon\n",
    "plot!(mix_ids_sorted, iter_rnd_mean_sorted;\n",
    "    ribbon=iter_rnd_std_sorted,\n",
    "    label=\"Random ± 1 std\",\n",
    "    lw=2, marker=:square,\n",
    "   color=:orange,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ecaa81b3-6d0d-4b04-8a6f-c75a07e134dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Extract misclassification data\n",
    "mis_cf_NS = [r.mis_cf_NS for r in results_NS]\n",
    "mis_cfTV_NS = [r.mis_cfTV_NS for r in results_NS]\n",
    "mis_rnd_mean = [mean(r.mis_rnd) for r in results_NS]\n",
    "mis_rnd_std = [std(r.mis_rnd) for r in results_NS]\n",
    "\n",
    "\n",
    "# Sort by CF misclassification\n",
    "sort_idx_NS = sortperm(mis_cf_NS)\n",
    "mis_cf_sorted_NS = mis_cf_NS[sort_idx_NS]\n",
    "mis_cfTV_sorted_NS = mis_cfTV_NS[sort_idx_NS]\n",
    "mis_rnd_mean_sorted = mis_rnd_mean[sort_idx_NS]\n",
    "mis_rnd_std_sorted = mis_rnd_std[sort_idx_NS]\n",
    "\n",
    "mix_ids_sorted = 1:length(results_NS)\n",
    "\n",
    "# Plot CF misclassifications\n",
    "plot(mix_ids_sorted, mis_cf_sorted_NS;\n",
    "    label=\"Curto_Fialkow\",\n",
    "    lw=2, marker=:circle,\n",
    "    xlabel=\"Mixture (sorted by CF)\", ylabel=\"Misclassifications\",\n",
    "    title=\"Misclassifications: CF vs Random Initialization\",\n",
    "    legend=:topright,\n",
    "    size=(900, 600))\n",
    "\n",
    "plot!(mix_ids_sorted, mis_cfTV_sorted_NS;\n",
    "    label=\"Curto-Fialkow-TV\",\n",
    "    lw=2, marker=:diamond,\n",
    "    color=:red\n",
    "   )\n",
    "\n",
    "\n",
    "# Add Random line with ribbon\n",
    "plot!(mix_ids_sorted, mis_rnd_mean_sorted;\n",
    "    ribbon=mis_rnd_std_sorted,\n",
    "    label=\"Random ± 1 std\",\n",
    "    lw=2, marker=:square,\n",
    "   color=:orange,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "85650b78-f6d8-4536-9e29-53c3aa9ae75b",
   "metadata": {},
   "outputs": [],
   "source": [
    "using Plots, LaTeXStrings, PGFPlotsX, Statistics\n",
    "pgfplotsx()  # TikZ backend\n",
    "\n",
    "# --- data (NS) ---\n",
    "iter_cf_NS        = [r.iter_cf_NS   for r in results_NS]\n",
    "iter_cfTV_NS      = [r.iter_cfTV_NS for r in results_NS]\n",
    "iter_rnd_mean_NS  = [mean(r.iters_rnd) for r in results_NS]\n",
    "iter_rnd_std_NS   = [std(r.iters_rnd)  for r in results_NS]\n",
    "\n",
    "# Sort by CF iterations (NS)\n",
    "sort_idx_NS              = sortperm(iter_cf_NS)\n",
    "iter_cf_sorted_NS        = iter_cf_NS[sort_idx_NS]\n",
    "iter_cfTV_sorted_NS      = iter_cfTV_NS[sort_idx_NS]\n",
    "iter_rnd_mean_sorted_NS  = iter_rnd_mean_NS[sort_idx_NS]\n",
    "iter_rnd_std_sorted_NS   = iter_rnd_std_NS[sort_idx_NS]\n",
    "mix_ids_sorted_NS        = 1:length(results_NS)\n",
    "\n",
    "# --- styling to match papers ---\n",
    "default(\n",
    "    size=(420,280),\n",
    "    grid=false,\n",
    "    framestyle=:box,\n",
    "    legend=:topright,\n",
    "    linewidth=2,\n",
    "    markerstrokewidth=0.8,\n",
    ")\n",
    "\n",
    "plt = plot(mix_ids_sorted_NS, iter_cf_sorted_NS;\n",
    "    label=L\"\\text{Curto–Fialkow–W2}\",\n",
    "    marker=:circle,\n",
    "    xlabel=L\"\\text{Mixture (sorted by CF)}\",\n",
    "    ylabel=L\"\\text{Iterations}\",\n",
    ")\n",
    "\n",
    "plot!(plt, mix_ids_sorted_NS, iter_cfTV_sorted_NS;\n",
    "    label=L\"\\text{Curto–Fialkow–TV}\",\n",
    "    marker=:diamond,\n",
    ")\n",
    "\n",
    "plot!(plt, mix_ids_sorted_NS, iter_rnd_mean_sorted_NS;\n",
    "    ribbon=iter_rnd_std_sorted_NS,\n",
    "    label=L\"\\text{Random} \\pm \\text{ 1 std}\",\n",
    "    marker=:square,\n",
    ")\n",
    "\n",
    "savefig(plt, \"iterations_K5kmeans.tikz\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "939eccb8-83ca-47c1-82a3-50b30d7458b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "using Plots, LaTeXStrings, PGFPlotsX, Statistics\n",
    "pgfplotsx()  # TikZ backend\n",
    "\n",
    "# --- data (NS) ---\n",
    "mis_cf_NS        = [r.mis_cf_NS    for r in results_NS]\n",
    "mis_cfTV_NS      = [r.mis_cfTV_NS  for r in results_NS]\n",
    "mis_rnd_mean_NS  = [mean(r.mis_rnd) for r in results_NS]\n",
    "mis_rnd_std_NS   = [std(r.mis_rnd)  for r in results_NS]\n",
    "\n",
    "# Sort by CF misclassification (NS)\n",
    "sort_idx_NS               = sortperm(mis_cf_NS)\n",
    "mis_cf_sorted_NS          = mis_cf_NS[sort_idx_NS]\n",
    "mis_cfTV_sorted_NS        = mis_cfTV_NS[sort_idx_NS]\n",
    "mis_rnd_mean_sorted_NS    = mis_rnd_mean_NS[sort_idx_NS]\n",
    "mis_rnd_std_sorted_NS     = mis_rnd_std_NS[sort_idx_NS]\n",
    "mix_ids_sorted_NS         = 1:length(results_NS)\n",
    "\n",
    "# --- styling (same as previous figure) ---\n",
    "default(\n",
    "    size=(420,280),\n",
    "    grid=false,\n",
    "    framestyle=:box,\n",
    "    legend=:topright,\n",
    "    linewidth=2,\n",
    "    markerstrokewidth=0.8,\n",
    "    markersize=3,\n",
    ")\n",
    "\n",
    "plt = plot(mix_ids_sorted_NS, mis_cf_sorted_NS;\n",
    "    label=L\"\\text{Curto–Fialkow–W2}\",\n",
    "    marker=:circle,\n",
    "    xlabel=L\"\\text{Mixture (sorted by CF)}\",\n",
    "    ylabel=L\"\\text{Misclassifications}\",\n",
    ")\n",
    "\n",
    "plot!(plt, mix_ids_sorted_NS, mis_cfTV_sorted_NS;\n",
    "    label=L\"\\text{Curto–Fialkow–TV}\",\n",
    "    marker=:diamond,\n",
    ")\n",
    "\n",
    "plot!(plt, mix_ids_sorted_NS, mis_rnd_mean_sorted_NS;\n",
    "    ribbon=mis_rnd_std_sorted_NS,\n",
    "    label=L\"\\text{Random} \\pm \\text{ 1 std}\",\n",
    "    marker=:square,\n",
    ")\n",
    "\n",
    "savefig(plt, \"misclassifications_K5kmeans.tikz\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62852e16-a524-46ff-b3a5-8e847a3c1df6",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "831faa74-0a5a-4f5d-b707-703c45ac9f37",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3020db20-6fcd-4291-b171-31171517e8b2",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3f34c1c-0034-4f70-abe1-15a8ff280eda",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1298f52f-96c5-4b95-96af-c51038c6d968",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "405745a3-b69f-4678-8c71-0b6a49a37fb4",
   "metadata": {},
   "source": [
    "#### EM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a3b4922c-8598-4142-804d-a2c1c453e3c5",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    " # for results and plotting\n",
    "\n",
    "Random.seed!(1)\n",
    "\n",
    "nrand = 100    # # of random restarts per dataset\n",
    "nIter = 100    # max EM iter\n",
    "tol   = 1e-6  # your convergence tol\n",
    "\n",
    "# prepare an empty results table\n",
    "results = DataFrame(\n",
    "  dataset    = Int[],       # which dataset index\n",
    "  init       = String[],    # \"CF\" or \"random\"\n",
    "  iterations = Int[],       \n",
    "  final_ll   = Float64[],\n",
    ")\n",
    "\n",
    "for i in eachindex(all_data)\n",
    "  # --- 1) pull + scale this dataset\n",
    "  X      = Matrix(all_data[i]')                  # N×2\n",
    "  Xs     = hcat(scale_to_minus1_1.(eachcol(X))...)\n",
    "  k      = size(CF_points_NS[i], 2) \n",
    "  μcf    = Matrix(CF_points_NS[i]')\n",
    "  μcf_TV = Matrix(CF_pointsTV_NS[i]')\n",
    "    \n",
    "    # 5×2\n",
    "\n",
    "  # --- 2) CF run\n",
    "  cf_res = run_em(μcf; X = Xs, nIter = nIter, tol = tol)\n",
    "  push!(results, (i, \"CF\",     cf_res.iterations, cf_res.final_ll))\n",
    "\n",
    "  # ------- CF_TV run\n",
    "  cf_res_TV = run_em(μcf_TV; X = Xs, nIter = nIter, tol = tol)\n",
    "  push!(results, (i, \"CF_TV\",     cf_res_TV.iterations, cf_res_TV.final_ll))\n",
    "\n",
    "  # --- 3) random runs\n",
    "  for rep in 1:nrand\n",
    "    # sample k distinct rows as init means\n",
    "    inds = sample(1:size(Xs,1), k; replace=false)\n",
    "    μrand = Xs[inds, :]                  # k×2\n",
    "    rr    = run_em(μrand; X = Xs, nIter = nIter, tol = tol)\n",
    "    push!(results, (i, \"random\", rr.iterations, rr.final_ll))\n",
    "  end\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "id": "7508c275-6dcb-4fbc-8fa3-923a3d0a9ab8",
   "metadata": {},
   "outputs": [
    {
     "data": {},
     "execution_count": 86,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "N = maximum(results.dataset)\n",
    "\n",
    "# pre‐allocate arrays\n",
    "cf_vals   = Float64[]\n",
    "cf_vals_TV   = Float64[]\n",
    "\n",
    "rand_means = Float64[]\n",
    "rand_stds  = Float64[]\n",
    "\n",
    "for i in 1:N\n",
    "    sub = results[results.dataset .== i, :]\n",
    "    push!(cf_vals,   sub.final_ll[sub.init .== \"CF\"][1])\n",
    "    push!(cf_vals_TV,   sub.final_ll[sub.init .== \"CF_TV\"][1])\n",
    "    rands = sub.final_ll[sub.init .== \"random\"]\n",
    "    push!(rand_means, mean(rands))\n",
    "    push!(rand_stds,  std(rands))\n",
    "end\n",
    "\n",
    "# x‐axis = dataset index\n",
    "xs = 1:N\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "# 1) Compute the sort order of cf_vals\n",
    "order = sortperm(cf_vals)    # gives the indices that sort cf_vals ascending\n",
    "\n",
    "# 2) Reorder everything\n",
    "xs_sorted        = xs[order]\n",
    "cf_sorted        = cf_vals[order]\n",
    "cf_sorted_TV        = cf_vals_TV[order]\n",
    "\n",
    "rand_means_sorted = rand_means[order]\n",
    "rand_stds_sorted  = rand_stds[order]\n",
    "\n",
    "# 3) Plot the sorted curves\n",
    "using Plots\n",
    "\n",
    "plot(\n",
    "  xs, cf_sorted;\n",
    "  label   = \"Curto-Fialkow-W2\",\n",
    "  lw      = 2,\n",
    "  marker  = :circle,\n",
    "  xlabel  = \"Dataset (sorted by CF fit)\",\n",
    "  ylabel  = \"Final log‑likelihood\",\n",
    "  title   = \"CF vs random init across datasets\",\n",
    "  legend  = :bottomright,\n",
    "  size=(900,600)\n",
    ")\n",
    "\n",
    "plot!(\n",
    "  xs, cf_sorted_TV;\n",
    "  fillalpha = 0.2,\n",
    "  label     = \"Curto-Fialkow-TV\",\n",
    "  lw        = 2,\n",
    "  marker    = :diamond,\n",
    "  color     = :red,\n",
    ")\n",
    "\n",
    "\n",
    "plot!(\n",
    "  xs, rand_means_sorted;\n",
    "  ribbon    = rand_stds_sorted,\n",
    "  fillalpha = 0.2,\n",
    "  label     = \"Random ± 1 std\",\n",
    "  lw        = 2,\n",
    "  marker    = :diamond,\n",
    "  color     = :orange,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47956416-32eb-4641-b623-bbd6b742d460",
   "metadata": {},
   "outputs": [],
   "source": [
    "using DataFrames, Plots, LaTeXStrings, PGFPlotsX, Statistics\n",
    "pgfplotsx()  # TikZ backend\n",
    "\n",
    "# Unique datasets (robust even if not 1:N)\n",
    "datasets = sort(unique(results.dataset))\n",
    "\n",
    "cf_vals      = Float64[]\n",
    "cf_vals_TV   = Float64[]\n",
    "rand_means   = Float64[]\n",
    "rand_stds    = Float64[]\n",
    "\n",
    "for d in datasets\n",
    "    sub = results[results.dataset .== d, :]\n",
    "    push!(cf_vals,    first(sub.final_ll[sub.init .== \"CF\"]))\n",
    "    push!(cf_vals_TV, first(sub.final_ll[sub.init .== \"CF_TV\"]))\n",
    "    r = sub.final_ll[sub.init .== \"random\"]\n",
    "    push!(rand_means, mean(r))\n",
    "    push!(rand_stds,  std(r))\n",
    "end\n",
    "\n",
    "# Sort by CF fit\n",
    "order                 = sortperm(cf_vals)\n",
    "xs_sorted             = collect(1:length(datasets))[order]\n",
    "cf_sorted             = cf_vals[order]\n",
    "cf_sorted_TV          = cf_vals_TV[order]\n",
    "rand_means_sorted     = rand_means[order]\n",
    "rand_stds_sorted      = rand_stds[order]\n",
    "\n",
    "# Styling\n",
    "default(size=(420,280), grid=false, framestyle=:box, legend=:bottomright,\n",
    "        linewidth=2, markerstrokewidth=0.8, markersize=3)\n",
    "\n",
    "plt = plot(xs_sorted, cf_sorted;\n",
    "    label=L\"\\text{Curto–Fialkow–W2}\",\n",
    "    marker=:circle,\n",
    "    xlabel=L\"\\text{Dataset (sorted by CF fit)}\",\n",
    "    ylabel=L\"\\text{Final log-likelihood}\",\n",
    "    title=L\"\\text{CF vs.\\ random init across datasets}\",\n",
    ")\n",
    "\n",
    "plot!(plt, xs_sorted, cf_sorted_TV;\n",
    "    label=L\"\\text{Curto–Fialkow–TV}\",\n",
    "    marker=:diamond,\n",
    ")\n",
    "\n",
    "plot!(plt, xs_sorted, rand_means_sorted;\n",
    "    ribbon=rand_stds_sorted,\n",
    "    label=L\"\\text{Random} \\pm \\text{ 1 std}\",\n",
    "    marker=:square,\n",
    "    color=:orange,\n",
    ")\n",
    "\n",
    "# Show or save\n",
    "display(plt)                     # for REPL/VSCode/Jupyter\n",
    "savefig(plt, \"final_ll_NS.tikz\") # for LaTeX inclusion\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ac4f1f6-6978-4931-adf9-c22455e18ddc",
   "metadata": {},
   "outputs": [],
   "source": [
    "using DataFrames, Statistics, Plots\n",
    "gr()  # show inline first\n",
    "\n",
    "# -- Build an aligned table and sort by CF (blue) --\n",
    "g = combine(groupby(results, [:dataset, :init]),\n",
    "            :final_ll => mean => :mean_ll,\n",
    "            :final_ll => std  => :std_ll)\n",
    "\n",
    "cf   = filter(:init => ==(\"CF\"),      g)[!, [:dataset, :mean_ll]]\n",
    "rename!(cf, :mean_ll => :cf)\n",
    "cfTV = filter(:init => ==(\"CF_TV\"),   g)[!, [:dataset, :mean_ll]]\n",
    "rename!(cfTV, :mean_ll => :cfTV)\n",
    "rnd  = filter(:init => ==(\"random\"),  g)[!, [:dataset, :mean_ll, :std_ll]]\n",
    "rename!(rnd, [:mean_ll, :std_ll] .=> [:rnd_mean, :rnd_std])\n",
    "\n",
    "T = innerjoin(innerjoin(cf, cfTV, on=:dataset), rnd, on=:dataset)\n",
    "sort!(T, :cf)  # ascending by blue\n",
    "\n",
    "xs  = 1:nrow(T)\n",
    "y1  = T.cf\n",
    "y2  = T.cfTV\n",
    "ym  = T.rnd_mean\n",
    "ys  = T.rnd_std\n",
    "\n",
    "# -- Styling & plot: connected lines + markers; ribbon on green only --\n",
    "default(size=(560,360), grid=false, framestyle=:box, legend=:bottomright,\n",
    "        linewidth=2, markerstrokewidth=0.8, markersize=3)\n",
    "\n",
    "plt = plot(xs, y1; label=\"Curto–Fialkow–W2\",\n",
    "           marker=:circle, linestyle=:solid)\n",
    "\n",
    "plot!(plt, xs, y2; label=\"Curto–Fialkow–TV\",\n",
    "      marker=:diamond, linestyle=:solid)\n",
    "\n",
    "plot!(plt, xs, ym; ribbon=ys,       # <- ribbon here\n",
    "      label=\"Random ± 1 std\",\n",
    "      marker=:square, linestyle=:solid,\n",
    "      fillalpha=0.25)\n",
    "\n",
    "xlabel!(\"Dataset (sorted by CF fit)\")\n",
    "ylabel!(\"Final log-likelihood\")\n",
    "display(plt)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cdcfa315-5595-4816-8aea-e15531137688",
   "metadata": {},
   "outputs": [],
   "source": [
    "using PGFPlotsX, LaTeXStrings\n",
    "pgfplotsx()\n",
    "plt_tikz = plot(xs, y1; label=L\"\\text{Curto–Fialkow–W2}\", marker=:circle)\n",
    "plot!(plt_tikz, xs, y2; label=L\"\\text{Curto–Fialkow–TV}\", marker=:diamond)\n",
    "plot!(plt_tikz, xs, ym; ribbon=ys, label=L\"\\text{Random} \\pm \\text{ 1 std}\",\n",
    "      marker=:square, fillalpha=0.25)\n",
    "xlabel!(L\"\\text{Dataset (sorted by CF fit)}\")\n",
    "ylabel!(L\"\\text{Final log-likelihood}\")\n",
    "savefig(plt_tikz, \"final_ll_NS.tikz\")   # remember \\usepgfplotslibrary{fillbetween} in LaTeX\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e451164f-3d39-4c67-9625-bc16a52b70c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "using DataFrames, Statistics, Plots\n",
    "\n",
    "# --- aggregate once per (dataset, init) ---\n",
    "g = combine(groupby(results, [:dataset, :init]),\n",
    "            :iterations => mean => :mean_it,\n",
    "            :iterations => std  => :std_it)\n",
    "\n",
    "# split and rename\n",
    "cf   = filter(:init => ==(\"CF\"),     g)[!, [:dataset, :mean_it]];  rename!(cf,  :mean_it => :cf)\n",
    "cfTV = filter(:init => ==(\"CF_TV\"),  g)[!, [:dataset, :mean_it]];  rename!(cfTV,:mean_it => :cfTV)\n",
    "rnd  = filter(:init => ==(\"random\"), g)[!, [:dataset, :mean_it, :std_it]];\n",
    "rename!(rnd, [:mean_it, :std_it] .=> [:rnd_mean, :rnd_std])\n",
    "\n",
    "# align rows and sort by CF (blue)\n",
    "T = innerjoin(innerjoin(cf, cfTV, on=:dataset), rnd, on=:dataset)\n",
    "sort!(T, :cf)  # ascending by CF iterations\n",
    "\n",
    "xs  = 1:nrow(T)\n",
    "y1  = T.cf\n",
    "y2  = T.cfTV\n",
    "ym  = T.rnd_mean\n",
    "ys  = T.rnd_std\n",
    "\n",
    "# --- notebook preview (GR backend) ---\n",
    "gr()\n",
    "default(size=(560,360), grid=false, framestyle=:box, legend=:bottomright,\n",
    "        linewidth=2, markerstrokewidth=0.8, markersize=3)\n",
    "\n",
    "plt = plot(xs, y1; label=\"Curto–Fialkow–W2\", marker=:circle, linestyle=:solid)\n",
    "plot!(plt, xs, y2; label=\"Curto–Fialkow–TV\",  marker=:diamond, linestyle=:solid)\n",
    "plot!(plt, xs, ym;  ribbon=ys,                 # ribbon only for random\n",
    "      label=\"Random ± 1 std\", marker=:square, linestyle=:solid, fillalpha=0.25)\n",
    "\n",
    "xlabel!(plt, \"Dataset (sorted by CF iterations)\")\n",
    "ylabel!(plt, \"EM iterations to convergence\")\n",
    "display(plt)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0cd381b2-72ec-491f-9595-40cc4908c3a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "using PGFPlotsX, LaTeXStrings\n",
    "pgfplotsx()\n",
    "plt_tikz = plot(xs, y1; label=L\"\\operatorname{W2}\", marker=:circle)\n",
    "plot!(plt_tikz, xs, y2; label=L\"\\operatorname{TV}\",  marker=:diamond)\n",
    "plot!(plt_tikz, xs, ym; ribbon=ys, label=L\"\\text{Random}\",\n",
    "      marker=:square, fillalpha=0.25)\n",
    "#xlabel!(plt_tikz, L\"\\text{Iterations}\")\n",
    "ylabel!(plt_tikz, L\"\\text{Iterations}\")\n",
    "savefig(plt_tikz, \"iterations_K5EM.tikz\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "08ba7894-d3c6-4a92-9052-232f4376062c",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Julia 1.10.5",
   "language": "julia",
   "name": "julia-1.10"
  },
  "language_info": {
   "file_extension": ".jl",
   "mimetype": "application/julia",
   "name": "julia",
   "version": "1.10.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
