{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "9b4a8c45-a88f-4c6d-93f0-4e250d205758",
   "metadata": {},
   "outputs": [],
   "source": [
    "using JuMP\n",
    "using MosekTools\n",
    "using DynamicPolynomials\n",
    "using MultivariatePolynomials\n",
    "using TSSOS\n",
    "using LinearAlgebra, Random, Plots, Distributions, IterTools, Combinatorics, CSV, Statistics, MLDatasets, DataFrames, Revise, Clustering, Distances\n",
    "includet(\"Functions_Mixtures.jl\")\n",
    "includet(\"SeparationExperiment.jl\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "47b5cdd3-9a94-40d8-ae14-7c27b26fb740",
   "metadata": {},
   "source": [
    "### local necessary functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 162,
   "id": "e2da9500-ffa7-40f7-bc95-05e1a2f09271",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "run_em (generic function with 1 method)"
      ]
     },
     "execution_count": 162,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "using GaussianMixtures \n",
    "using DataFrames\n",
    "using DataFrames: groupby\n",
    "using StatsBase       \n",
    "using StatsPlots \n",
    "\n",
    "function random_means()\n",
    "  # e.g. sample k distinct data points\n",
    "  idx = sample(1:size(X,1), k; replace=false)\n",
    "  return X[idx, :]\n",
    "end\n",
    "\n",
    "function run_em(μ0; X, nIter=100, tol=1e-6, varfloor=0.0001)\n",
    "    # 1) materialize any Adjoint into a real matrix\n",
    "    μ0 = Matrix(μ0)\n",
    "    k, d = size(μ0)\n",
    "\n",
    "    # 2) compute data‐based diag variances\n",
    "    σ2 = vec(var(X; dims=1))              # length-d Vector\n",
    "    Σ0 = repeat(σ2', k, 1)                # k×d Matrix\n",
    "\n",
    "    # 3) build GMM (skip k‑means by nInit=0)\n",
    "    g = GMM(k, X; kind=:diag, nInit=0)\n",
    "\n",
    "    # 4) overwrite init\n",
    "    g.μ .= μ0\n",
    "    g.Σ .= Σ0\n",
    "    g.w .= fill(1/k, k)\n",
    "\n",
    "    # 5) run EM up to nIter, with varfloor control\n",
    "    logl = em!(g, X; nIter=nIter, varfloor=varfloor)\n",
    "\n",
    "    # 6) find when |Δℓ|<tol\n",
    "    idx = findfirst(i -> abs(logl[i] - logl[i-1]) < tol, 2:length(logl))\n",
    "    iters = idx === nothing ? length(logl) : idx\n",
    "\n",
    "    return (\n",
    "      iterations = iters,\n",
    "      final_ll   = logl[end],\n",
    "      logl       = logl,\n",
    "      model      = g,\n",
    "    )\n",
    "end\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "4a1ae27d-804e-4363-bef9-c4d500e525ec",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "align_labels (generic function with 1 method)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "function find_nearest_indices(X::Matrix{Float64}, centers::Matrix{Float64})\n",
    "    indices = Int[]\n",
    "    for c in eachrow(centers')\n",
    "        dists = [sum((X[i, :] .- c).^2) for i in 1:size(X, 1)]\n",
    "        push!(indices, argmin(dists))\n",
    "    end\n",
    "    return indices\n",
    "end\n",
    "    using StatsBase  # for countmap, combinations\n",
    "\n",
    "function adjusted_rand_index(labels_true::Vector{Int}, labels_pred::Vector{Int})\n",
    "    @assert length(labels_true) == length(labels_pred)\n",
    "\n",
    "    n = length(labels_true)\n",
    "    unique_true = sort(unique(labels_true))\n",
    "    unique_pred = sort(unique(labels_pred))\n",
    "    n_true = length(unique_true)\n",
    "    n_pred = length(unique_pred)\n",
    "\n",
    "    # Confusion matrix: n_ij\n",
    "    contingency = zeros(Int, n_true, n_pred)\n",
    "    for (l_true, l_pred) in zip(labels_true, labels_pred)\n",
    "        i = findfirst(isequal(l_true), unique_true)\n",
    "        j = findfirst(isequal(l_pred), unique_pred)\n",
    "        contingency[i, j] += 1\n",
    "    end\n",
    "\n",
    "    # Row and column sums (a_i, b_j)\n",
    "    a = sum(contingency, dims=2)  # true cluster sizes\n",
    "    b = sum(contingency, dims=1)  # predicted cluster sizes\n",
    "\n",
    "    # Helper: binomial(n, 2)\n",
    "    comb2(x) = x < 2 ? 0 : x * (x - 1) ÷ 2\n",
    "\n",
    "    # ∑_ij C(n_ij, 2)\n",
    "    index = sum(comb2(nij) for nij in contingency)\n",
    "\n",
    "    # ∑_i C(a_i, 2), ∑_j C(b_j, 2)\n",
    "    sum_ai = sum(comb2(ai) for ai in a)\n",
    "    sum_bj = sum(comb2(bj) for bj in b)\n",
    "\n",
    "    expected_index = sum_ai * sum_bj / comb2(n)\n",
    "    max_index = (sum_ai + sum_bj) / 2\n",
    "\n",
    "    # Adjusted Rand Index\n",
    "    return (index - expected_index) / (max_index - expected_index)\n",
    "end\n",
    "function align_labels(true_labels::Vector{Int}, pred_labels::Vector{Int})\n",
    "    classes = sort(unique(true_labels))\n",
    "    perms = collect(permutations(classes))\n",
    "\n",
    "    best_score = -1\n",
    "    best_aligned = similar(pred_labels)\n",
    "\n",
    "    for p in perms\n",
    "        mapping = Dict(classes[i] => p[i] for i in eachindex(classes))\n",
    "        aligned = [mapping[l] for l in pred_labels]\n",
    "        score = sum(aligned .== true_labels)\n",
    "\n",
    "        if score > best_score\n",
    "            best_score = score\n",
    "            best_aligned .= aligned\n",
    "        end\n",
    "    end\n",
    "\n",
    "    return best_aligned\n",
    "end"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6b3f660f-04bf-46fc-af76-d34fc1130240",
   "metadata": {},
   "source": [
    "### K=2, relatively well separated, non-spherical"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e3b5fc98-315b-43b3-96f5-842d7e402dd1",
   "metadata": {},
   "source": [
    "#### Data generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "621356a1-4ae5-41c8-9f2e-08b64491d8e3",
   "metadata": {},
   "outputs": [],
   "source": [
    "c = 5.0\n",
    "ecc= 0.25\n",
    "K=2\n",
    "n=2\n",
    "nb_parameter_choices = 50\n",
    "seed_parameters=1\n",
    "Random.seed!(seed_parameters)\n",
    "gmms_50_025_2 = generate_multiple_gmms_heteroscedastic(nb_parameter_choices, K, n; ecc=ecc, c=c);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "457263de-f62a-493c-837f-dde73cbbb330",
   "metadata": {},
   "outputs": [],
   "source": [
    "N = 1000\n",
    "seed_data = 10\n",
    "Random.seed!(seed_data)\n",
    "\n",
    "all_data = Vector{Matrix{Float64}}(undef, length(gmms_50_025_2))\n",
    "all_labels = Vector{Vector{Int}}(undef, length(gmms_50_025_2))\n",
    "\n",
    "for mix_index in 1:length(gmms_50_025_2)\n",
    "    mix = gmms_50_025_2[mix_index]\n",
    "\n",
    "    # Convert means to Vector of Vectors\n",
    "    means = [mix.means[:, i] for i in 1:size(mix.means, 2)]\n",
    "    covariances = mix.covariances\n",
    "    weights = mix.weights\n",
    "    k = length(means)\n",
    "\n",
    "    # Unique seed per configuration (optional)\n",
    "    seed_i = seed_data + mix_index\n",
    "\n",
    "    # Generate data\n",
    "    samples, labels = generate_gaussian_mixtures(\n",
    "        k, means, covariances, weights;\n",
    "        seed=seed_i, n_samples=N\n",
    "    )\n",
    "\n",
    "    all_data[mix_index] = samples'\n",
    "    all_labels[mix_index] = labels\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "7a3c8cb0-e9d1-40cb-98db-50dccbbd309e",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "n_configs = length(all_data)\n",
    "n_rows = 10\n",
    "n_cols = 5\n",
    "\n",
    "plot_list = []\n",
    "\n",
    "for i in 1:n_configs\n",
    "    samples = all_data[i]'\n",
    "    labels = all_labels[i]\n",
    "    \n",
    "    # Normalize\n",
    "    samples_scaled = hcat(scale_to_minus1_1.(eachcol(samples))...)\n",
    "    \n",
    "    # Optional: compute empirical mean per label\n",
    "    unique_labels = sort(unique(labels))\n",
    "    means_per_label = [mean(samples_scaled[labels .== l, :], dims=1) for l in unique_labels]\n",
    "\n",
    "    # Scatter plot\n",
    "    p = scatter(samples_scaled[:, 1], samples_scaled[:, 2],\n",
    "                group=labels,\n",
    "                markersize=2, alpha=0.6, legend=false,\n",
    "                #xlabel=\"x₁\", ylabel=\"x₂\",\n",
    "                title=\"Mixture $i\", xlims=(-0.2, 1.2), ylims=(-0.2, 1.2))\n",
    "\n",
    "    # Overlay cluster means\n",
    "    for m in means_per_label\n",
    "        scatter!(p, [m[1]], [m[2]], color=:yellow, marker=:circle, ms=6)\n",
    "    end\n",
    "\n",
    "    push!(plot_list, p)\n",
    "end\n",
    "\n",
    "# Display grid layout\n",
    "#plot(plot_list..., layout=(n_rows, n_cols), size=(1400, 2400))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "406de706-bf8b-43d2-b771-022c98f8b5a6",
   "metadata": {},
   "source": [
    "#### $S_{\\theta}$ description"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f2ddc39-ab2d-4de8-bd2e-363a2d0c043a",
   "metadata": {},
   "outputs": [],
   "source": [
    "n=2\n",
    "@polyvar m[1:n]\n",
    "@polyvar sigma[1:n]\n",
    "Sm=[(m[1])*(1-m[1]), (m[2])*(1-m[2])]\n",
    "Ssig=[(sigma[1]-0.05)*(0.5-sigma[1]), (sigma[2]-0.05)*(0.5-sigma[2])]\n",
    "S=[vcat(Sm)...,vcat(Ssig)...]\n",
    "println()\n",
    "println(\"Support of the mixing measure\")\n",
    "S_normalized=[S[i]/maximum(abs.(coefficients(S[i]))) for i=1:length(S)]\n",
    "display(S_normalized)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a8173a3e-fb38-45e7-8c9d-5c00f5b3a8c0",
   "metadata": {},
   "source": [
    "#### W2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "843d9571-50b3-4c7e-8249-a6f317dfcd3d",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "trace_penalization=true\n",
    "vareps=1e-3\n",
    "max_order = 4\n",
    "\n",
    "\n",
    "all_results_NS = Vector{Vector}(undef, length(all_data))\n",
    "\n",
    "for idx in 1:length(all_data)\n",
    "    println(\"\\n>>> Running on GMM config $idx\")\n",
    "\n",
    "    samples = Matrix(transpose(all_data[idx]))\n",
    "    samples_scaled = hcat(scale_to_minus1_1.(eachcol(samples))...)\n",
    "\n",
    "    res = []\n",
    "\n",
    "    for d = 1:max_order\n",
    "        println(\"  d = $d\")\n",
    "        push!(res, multivariate_Gaussian_W2(n, d, m, sigma, S_normalized, samples_scaled, trace_penalization, vareps))\n",
    "    end\n",
    "\n",
    "    all_results_NS[idx] = res\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5300ba0-f111-4be6-840f-7d4738ebcb25",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "d=max_order\n",
    "RES_NS =[]\n",
    "energy_tol=1e-2\n",
    "\n",
    "for i in  1:length(all_data)\n",
    "    println(\">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> Flatness check for mixture $i\")\n",
    "    push!(RES_NS,analyse_relaxations_W2orTV(all_results_NS[i], d,n,energy_tol))\n",
    "    println(\"___________________________________________________________________\")\n",
    "    println()\n",
    "end"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "691779a3-da86-4ac2-b15b-e30f909a10d9",
   "metadata": {},
   "source": [
    "#### TV"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a5ef329-cae5-496e-bbd0-71d619e97855",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "trace_penalization=true\n",
    "vareps=1e-3\n",
    "max_order = 4\n",
    "\n",
    "all_resultsTV_NS = Vector{Vector}(undef, length(all_data))\n",
    "\n",
    "for idx in 1:length(all_data)\n",
    "    println(\"\\n>>> Running on GMM config $idx\")\n",
    "\n",
    "    samples = Matrix(transpose(all_data[idx]))\n",
    "    samples_scaled = hcat(scale_to_minus1_1.(eachcol(samples))...)\n",
    "\n",
    "    resTV = []\n",
    "\n",
    "    for d = 1:max_order\n",
    "        println(\"  d = $d\")\n",
    "        push!(resTV, multivariate_Gaussian_TV(n, d, m, sigma, S_normalized, samples_scaled, trace_penalization, vareps))\n",
    "    end\n",
    "\n",
    "    all_resultsTV_NS[idx] = resTV\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66808272-e075-4f04-bd13-c891f38cb089",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "d=max_order\n",
    "RESTV_NS =[]\n",
    "energy_tol=1e-2\n",
    "for i in  1:length(all_data)\n",
    "    println(\">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> Flatness check for mixture $i\")\n",
    "    push!(RESTV_NS,analyse_relaxations_W2orTV(all_resultsTV_NS[i], d,n,energy_tol))\n",
    "    println(\"___________________________________________________________________\")\n",
    "    println()\n",
    "end"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "17f6ae04-8da8-424e-81f5-f485f4350350",
   "metadata": {},
   "source": [
    "#### Applying extraction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "ce2ae424-2076-4dc6-92b4-20213ffb4631",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "rank_mm = 2\n",
    "d = max_order  # final relaxation order you used earlier\n",
    "\n",
    "plot_list = []\n",
    "\n",
    "CF_points_NS=[]\n",
    "true_means=[]\n",
    "for i in 1:length(all_data)\n",
    "    \n",
    "    # Get samples and labels\n",
    "    samples = all_data[i]' |> Matrix  # now shape is 1000 × 2\n",
    "    labels = all_labels[i]\n",
    "\n",
    "    # Extract RES for this config\n",
    "    M = RES_NS[i][1]  # moment matrix\n",
    "    L = RES_NS[i][end]  # last submatrix (or adjust as needed)\n",
    "\n",
    "    # Curto-Fialkow flat extension extraction\n",
    "    curto_f_points_NS = extract_CF(M, L, binomial(n + d - 1, d - 1), n, rank_mm)\n",
    "    sorted_cf_NS = sort(curto_f_points_NS, by = p -> p[1])\n",
    "    CF_NS = hcat(sorted_cf_NS...)  # 2 × r matrix\n",
    "    push!(CF_points_NS,CF_NS)\n",
    "\n",
    "    # Normalize samples\n",
    "    samples_scaled = hcat(scale_to_minus1_1.(eachcol(samples))...)\n",
    "\n",
    "    # Empirical means from true labels\n",
    "    cluster_means = [vec(mean(samples_scaled[labels .== k, :], dims=1)) for k in sort(unique(labels))]\n",
    "    sorted_means = sort(cluster_means, by = m -> m[1])\n",
    "    trum = hcat(sorted_means...)\n",
    "    push!(true_means, trum)\n",
    "\n",
    "    \n",
    "\n",
    "    # Plotting\n",
    "    p = scatter(samples_scaled[:,1], samples_scaled[:,2],\n",
    "                group=labels, markersize=2, alpha=0.6, legend=false,\n",
    "                title=\"Mixture $i in ℝ²\",\n",
    "                #xlabel=\"X₁\", ylabel=\"X₂\", \n",
    "                xlims=(-0.2, 1.2), ylims=(-0.2, 1.2))\n",
    "    \n",
    "    scatter!(p, trum[1, :], trum[2, :], marker=(:o, 8), color=:yellow, label=\"True Means\")\n",
    "    scatter!(p, CF_NS[1, :], CF_NS[2, :], marker=(:s, 6), color=:white, label=\"CDK Means\")\n",
    "\n",
    "    push!(plot_list, p)\n",
    "end\n",
    "\n",
    "#plot(plot_list..., layout=(10, 5), size=(1400, 2400))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "911634e8-a19e-4cd3-a70a-f4a2e8e62296",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a3a862ab-d0bb-41a7-b100-27840c76b3a6",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "rank_mm = 2\n",
    "d = max_order  # final relaxation order you used earlier\n",
    "\n",
    "plot_listTV_NS = []\n",
    "\n",
    "CF_pointsTV_NS=[]\n",
    "true_means=[]\n",
    "for i in 1:length(all_data)\n",
    "    \n",
    "    # Get samples and labels\n",
    "    samples = all_data[i]' |> Matrix  # now shape is 1000 × 2\n",
    "    labels = all_labels[i]\n",
    "\n",
    "    # Extract RES for this config\n",
    "    M = RESTV_NS[i][1]  # moment matrix\n",
    "    L = RESTV_NS[i][end]  # last submatrix (or adjust as needed)\n",
    "\n",
    "    # Curto-Fialkow flat extension extraction\n",
    "    curto_f_pointsTV_NS = extract_CF(M, L, binomial(n + d - 1, d - 1), n, rank_mm)\n",
    "    sorted_cfTV_NS = sort(curto_f_pointsTV_NS, by = p -> p[1])\n",
    "    CFTV_NS = hcat(sorted_cfTV_NS...)  # 2 × r matrix\n",
    "    push!(CF_pointsTV_NS,CFTV_NS)\n",
    "\n",
    "    # Normalize samples\n",
    "    samples_scaled = hcat(scale_to_minus1_1.(eachcol(samples))...)\n",
    "\n",
    "    # Empirical means from true labels\n",
    "    cluster_means = [vec(mean(samples_scaled[labels .== k, :], dims=1)) for k in sort(unique(labels))]\n",
    "    sorted_means = sort(cluster_means, by = m -> m[1])\n",
    "    trum = hcat(sorted_means...)\n",
    "    push!(true_means, trum)\n",
    "\n",
    "    \n",
    "\n",
    "    # Plotting\n",
    "    p = scatter(samples_scaled[:,1], samples_scaled[:,2],\n",
    "                group=labels, markersize=2, alpha=0.6, legend=false,\n",
    "                title=\"Mixture $i in ℝ²\",\n",
    "                #xlabel=\"X₁\", ylabel=\"X₂\", \n",
    "                xlims=(-0.2, 1.2), ylims=(-0.2, 1.2))\n",
    "    \n",
    "    scatter!(p, trum[1, :], trum[2, :], marker=(:o, 8), color=:yellow, label=\"True Means\")\n",
    "    scatter!(p, CFTV_NS[1, :], CFTV_NS[2, :], marker=(:s, 6), color=:white, label=\"CDK Means\")\n",
    "\n",
    "    push!(plot_listTV_NS, p)\n",
    "end\n",
    "\n",
    "#plot(plot_list..., layout=(10, 5), size=(1400, 2400))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c559e731-b0b5-44ac-824a-e1fdeaeacabd",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a6b90fd-efaa-4859-8361-8a97c081def8",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "a3a7c26b-4093-4ce1-8a67-bd754a5b9d12",
   "metadata": {},
   "source": [
    "#### Impact on $k$-means"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "id": "49585b6f-ead9-4740-8bed-ebe3da3883c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "results_NS = []\n",
    "Random.seed!(123)\n",
    "\n",
    "for i in 1:nb_parameter_choices\n",
    "    # Get mixture data & labels\n",
    "    X = all_data[i]' |> Matrix\n",
    "    X_scaled = hcat(scale_to_minus1_1.(eachcol(X))...)  # Normalization\n",
    "    labels = all_labels[i]\n",
    "    k = length(unique(labels))\n",
    "    \n",
    "    # --- CF Initialization ---\n",
    "    cf_centers_NS = CF_points_NS[i]\n",
    "    cf_centers_NS = size(cf_centers_NS, 1) == 2 ? cf_centers_NS : cf_centers_NS'\n",
    "    cf_indices_NS = find_nearest_indices(X_scaled, cf_centers_NS)\n",
    "    result_cf_NS = kmeans(X_scaled', k; init=cf_indices_NS, maxiter=100, display=:none)\n",
    "    ARI_cf_NS = adjusted_rand_index(labels, result_cf_NS.assignments)\n",
    "    obj_cf_NS = result_cf_NS.totalcost\n",
    "    iter_cf_NS = result_cf_NS.iterations\n",
    "    mis_cf_NS = sum(align_labels(labels, result_cf_NS.assignments) .!= labels)\n",
    "\n",
    "    # --- CFTV Initialization ---\n",
    "    cf_centersTV_NS = CF_pointsTV_NS[i]\n",
    "    cf_centersTV_NS = size(cf_centersTV_NS, 1) == 2 ? cf_centersTV_NS : cf_centersTV_NS'\n",
    "    cf_indicesTV_NS = find_nearest_indices(X_scaled, cf_centersTV_NS)\n",
    "    result_cfTV_NS = kmeans(X_scaled', k; init=cf_indicesTV_NS, maxiter=100, display=:none)\n",
    "    ARI_cfTV_NS = adjusted_rand_index(labels, result_cfTV_NS.assignments)\n",
    "    obj_cfTV_NS = result_cfTV_NS.totalcost\n",
    "    iter_cfTV_NS = result_cfTV_NS.iterations\n",
    "    mis_cfTV_NS = sum(align_labels(labels, result_cfTV_NS.assignments) .!= labels)\n",
    "\n",
    "    # --- Random Initialization (repeat N times) ---\n",
    "    N = 100\n",
    "    objs_rnd = Float64[]\n",
    "    iters_rnd = Int[]\n",
    "    ARIs_rnd = Float64[]\n",
    "    mis_rnd = Int[]\n",
    "    \n",
    "    for rep in 1:N\n",
    "        rand_indices = rand(1:size(X_scaled, 1), k)\n",
    "        result_rnd = kmeans(X_scaled', k; init=rand_indices, maxiter=100, display=:none)\n",
    "        ARI_rnd = adjusted_rand_index(labels, result_rnd.assignments)\n",
    "        push!(objs_rnd, result_rnd.totalcost)\n",
    "        push!(iters_rnd, result_rnd.iterations)\n",
    "        push!(ARIs_rnd, ARI_rnd)\n",
    "        push!(mis_rnd, sum(align_labels(labels, result_rnd.assignments) .!= labels))\n",
    "    end\n",
    "\n",
    "    # Store summary (including full list of iterations)\n",
    "    push!(results_NS, (\n",
    "        i = i,\n",
    "        obj_cf_NS = obj_cf_NS,\n",
    "        iter_cf_NS = iter_cf_NS,\n",
    "        ARI_cf_NS = ARI_cf_NS,\n",
    "        mis_cf_NS = mis_cf_NS,\n",
    "        obj_cfTV_NS = obj_cfTV_NS ,\n",
    "        iter_cfTV_NS  = iter_cfTV_NS ,\n",
    "        ARI_cfTV_NS  = ARI_cfTV_NS ,\n",
    "        mis_cfTV_NS  = mis_cfTV_NS ,\n",
    "        objs_rnd = objs_rnd,\n",
    "        iters_rnd = iters_rnd,\n",
    "        ARIs_rnd = ARIs_rnd,\n",
    "        mis_rnd = mis_rnd,\n",
    "        mean_obj_rnd = mean(objs_rnd),\n",
    "        mean_iter_rnd = mean(iters_rnd),\n",
    "        mean_ARI_rnd = mean(ARIs_rnd),\n",
    "        mean_mis_rnd = mean(mis_rnd),\n",
    "        std_iter_rnd = std(iters_rnd)  \n",
    "    ))\n",
    "end"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b4671d47-9bb9-4862-867f-b487a2c7e56c",
   "metadata": {},
   "source": [
    "#### Plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b285699-a726-402b-b1bf-3890369258b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Extract misclassification data\n",
    "mis_cf_NS = [r.mis_cf_NS for r in results_NS]\n",
    "mis_cfTV_NS = [r.mis_cfTV_NS for r in results_NS]\n",
    "mis_rnd_mean = [mean(r.mis_rnd) for r in results_NS]\n",
    "mis_rnd_std = [std(r.mis_rnd) for r in results_NS]\n",
    "\n",
    "\n",
    "# Sort by CF misclassification\n",
    "sort_idx_NS = sortperm(mis_cf_NS)\n",
    "mis_cf_sorted_NS = mis_cf_NS[sort_idx_NS]\n",
    "mis_cfTV_sorted_NS = mis_cfTV_NS[sort_idx_NS]\n",
    "mis_rnd_mean_sorted = mis_rnd_mean[sort_idx_NS]\n",
    "mis_rnd_std_sorted = mis_rnd_std[sort_idx_NS]\n",
    "\n",
    "mix_ids_sorted = 1:length(results_NS)\n",
    "\n",
    "# Plot CF misclassifications\n",
    "plot(mix_ids_sorted, mis_cf_sorted_NS;\n",
    "    label=\"Curto_Fialkow\",\n",
    "    lw=2, marker=:circle,\n",
    "    xlabel=\"Mixture (sorted by CF)\", ylabel=\"Misclassifications\",\n",
    "    title=\"Misclassifications: CF vs Random Initialization\",\n",
    "    legend=:topright,\n",
    "    size=(900, 600))\n",
    "\n",
    "plot!(mix_ids_sorted, mis_cfTV_sorted_NS;\n",
    "    label=\"Curto-Fialkow-TV\",\n",
    "    lw=2, marker=:diamond,\n",
    "    color=:red\n",
    "   )\n",
    "\n",
    "\n",
    "# Add Random line with ribbon\n",
    "plot!(mix_ids_sorted, mis_rnd_mean_sorted;\n",
    "    ribbon=mis_rnd_std_sorted,\n",
    "    label=\"Random ± 1 std\",\n",
    "    lw=2, marker=:square,\n",
    "   color=:orange,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f2d0ba52-dc47-4bf9-9476-341fcd681cd0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Extract data\n",
    "iter_cf_NS = [r.iter_cf_NS for r in results_NS]\n",
    "iter_cfTV_NS = [r.iter_cfTV_NS for r in results_NS]\n",
    "iter_rnd_mean = [mean(r.iters_rnd) for r in results_NS]\n",
    "iter_rnd_std = [std(r.iters_rnd) for r in results_NS]\n",
    "\n",
    "# Sort by CF iterations\n",
    "sort_idx_NS = sortperm(iter_cf_NS)\n",
    "iter_cf_sorted_NS = iter_cf_NS[sort_idx_NS]\n",
    "iter_rnd_mean_sorted = iter_rnd_mean[sort_idx_NS]\n",
    "iter_rnd_std_sorted = iter_rnd_std[sort_idx_NS]\n",
    "iter_cfTV_sorted_NS=iter_cfTV_NS[sort_idx_NS]\n",
    "\n",
    "mix_ids_sorted = 1:length(results_NS)  # new x-axis = sorted mixture index\n",
    "\n",
    "# Plot CF line\n",
    "plot(mix_ids_sorted, iter_cf_sorted_NS;\n",
    "    label=\"Curto-Fialkow-W2\",\n",
    "    lw=2, marker=:circle,\n",
    "    xlabel=\"Mixture (sorted by CF)\", ylabel=\"Iterations\",\n",
    "    #title=\"K-means Iterations: CF vs Random Initialization\",\n",
    "    legend=:topright,\n",
    "    size=(900, 600))\n",
    "\n",
    "plot!(mix_ids_sorted, iter_cfTV_sorted_NS;\n",
    "    label=\"Curto-Fialkow-TV\",\n",
    "    lw=2, marker=:diamond,\n",
    "    color=:red\n",
    "   )\n",
    "\n",
    "\n",
    "# Add Random line with ribbon\n",
    "plot!(mix_ids_sorted, iter_rnd_mean_sorted;\n",
    "    ribbon=iter_rnd_std_sorted,\n",
    "    label=\"Random ± 1 std\",\n",
    "    lw=2, marker=:square,\n",
    "   color=:orange,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4330a016-8945-4750-b629-40206299ed3d",
   "metadata": {},
   "outputs": [],
   "source": [
    "using Plots, LaTeXStrings, PGFPlotsX, Statistics\n",
    "pgfplotsx()  # TikZ backend\n",
    "\n",
    "# --- data (NS) ---\n",
    "iter_cf_NS        = [r.iter_cf_NS   for r in results_NS]\n",
    "iter_cfTV_NS      = [r.iter_cfTV_NS for r in results_NS]\n",
    "iter_rnd_mean_NS  = [mean(r.iters_rnd) for r in results_NS]\n",
    "iter_rnd_std_NS   = [std(r.iters_rnd)  for r in results_NS]\n",
    "\n",
    "# Sort by CF iterations (NS)\n",
    "sort_idx_NS              = sortperm(iter_cf_NS)\n",
    "iter_cf_sorted_NS        = iter_cf_NS[sort_idx_NS]\n",
    "iter_cfTV_sorted_NS      = iter_cfTV_NS[sort_idx_NS]\n",
    "iter_rnd_mean_sorted_NS  = iter_rnd_mean_NS[sort_idx_NS]\n",
    "iter_rnd_std_sorted_NS   = iter_rnd_std_NS[sort_idx_NS]\n",
    "mix_ids_sorted_NS        = 1:length(results_NS)\n",
    "\n",
    "default(\n",
    "    size=(420,280),\n",
    "    grid=false,\n",
    "    framestyle=:box,\n",
    "    legend=:topright,\n",
    "    linewidth=2,\n",
    "    markerstrokewidth=0.8,\n",
    ")\n",
    "\n",
    "plt = plot(mix_ids_sorted_NS, iter_cf_sorted_NS;\n",
    "    label=L\"\\text{Curto–Fialkow–W2}\",\n",
    "    marker=:circle,\n",
    "    xlabel=L\"\\text{Mixture (sorted by CF)}\",\n",
    "    ylabel=L\"\\text{Iterations}\",\n",
    ")\n",
    "\n",
    "plot!(plt, mix_ids_sorted_NS, iter_cfTV_sorted_NS;\n",
    "    label=L\"\\text{Curto–Fialkow–TV}\",\n",
    "    marker=:diamond,\n",
    ")\n",
    "\n",
    "plot!(plt, mix_ids_sorted_NS, iter_rnd_mean_sorted_NS;\n",
    "    ribbon=iter_rnd_std_sorted_NS,\n",
    "    label=L\"\\text{Random} \\pm \\text{ 1 std}\",\n",
    "    marker=:square,\n",
    ")\n",
    "\n",
    "#savefig(plt, \"iterations_K2kmeans.tikz\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b21656fb-b573-4d7f-894e-dd11b34bdba3",
   "metadata": {},
   "outputs": [],
   "source": [
    "using Plots, LaTeXStrings, PGFPlotsX, Statistics\n",
    "pgfplotsx()  # TikZ backend\n",
    "\n",
    "# --- data (NS) ---\n",
    "mis_cf_NS        = [r.mis_cf_NS    for r in results_NS]\n",
    "mis_cfTV_NS      = [r.mis_cfTV_NS  for r in results_NS]\n",
    "mis_rnd_mean_NS  = [mean(r.mis_rnd) for r in results_NS]\n",
    "mis_rnd_std_NS   = [std(r.mis_rnd)  for r in results_NS]\n",
    "\n",
    "# Sort by CF misclassification (NS)\n",
    "sort_idx_NS               = sortperm(mis_cf_NS)\n",
    "mis_cf_sorted_NS          = mis_cf_NS[sort_idx_NS]\n",
    "mis_cfTV_sorted_NS        = mis_cfTV_NS[sort_idx_NS]\n",
    "mis_rnd_mean_sorted_NS    = mis_rnd_mean_NS[sort_idx_NS]\n",
    "mis_rnd_std_sorted_NS     = mis_rnd_std_NS[sort_idx_NS]\n",
    "mix_ids_sorted_NS         = 1:length(results_NS)\n",
    "\n",
    "# --- styling (same as previous figure) ---\n",
    "default(\n",
    "    size=(420,280),\n",
    "    grid=false,\n",
    "    framestyle=:box,\n",
    "    legend=:topright,\n",
    "    linewidth=2,\n",
    "    markerstrokewidth=0.8,\n",
    "    markersize=3,\n",
    ")\n",
    "\n",
    "plt = plot(mix_ids_sorted_NS, mis_cf_sorted_NS;\n",
    "    label=L\"\\text{Curto–Fialkow–W2}\",\n",
    "    marker=:circle,\n",
    "    xlabel=L\"\\text{Mixture (sorted by CF)}\",\n",
    "    ylabel=L\"\\text{Misclassifications}\",\n",
    ")\n",
    "\n",
    "plot!(plt, mix_ids_sorted_NS, mis_cfTV_sorted_NS;\n",
    "    label=L\"\\text{Curto–Fialkow–TV}\",\n",
    "    marker=:diamond,\n",
    ")\n",
    "\n",
    "plot!(plt, mix_ids_sorted_NS, mis_rnd_mean_sorted_NS;\n",
    "    ribbon=mis_rnd_std_sorted_NS,\n",
    "    label=L\"\\text{Random} \\pm \\text{ 1 std}\",\n",
    "    marker=:square,\n",
    ")\n",
    "\n",
    "#savefig(plt, \"misclassifications_K2kmeans.tikz\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af0efdc7-6c05-4321-b6b9-5e2c916d207d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "f5afd3c6-452f-4e7d-a19b-892e000db3d9",
   "metadata": {},
   "source": [
    "#### EM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bbffa10f-ca20-47b7-ade9-f660dcac90f4",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "\n",
    "Random.seed!(1)\n",
    "\n",
    "nrand = 100    # # of random restarts per dataset\n",
    "nIter = 100    # max EM iter\n",
    "tol   = 1e-6  # your convergence tol\n",
    "\n",
    "# prepare an empty results table\n",
    "results = DataFrame(\n",
    "  dataset    = Int[],       # which dataset index\n",
    "  init       = String[],    # \"CF\" or \"random\"\n",
    "  iterations = Int[],       \n",
    "  final_ll   = Float64[],\n",
    ")\n",
    "\n",
    "for i in eachindex(all_data)\n",
    "  # --- 1) pull + scale this dataset\n",
    "  X      = Matrix(all_data[i]')                  # N×2\n",
    "  Xs     = hcat(scale_to_minus1_1.(eachcol(X))...)\n",
    "  k      = size(CF_points_NS[i], 2) \n",
    "  μcf    = Matrix(CF_points_NS[i]')\n",
    "  μcf_TV = Matrix(CF_pointsTV_NS[i]')\n",
    "    \n",
    "    # 5×2\n",
    "\n",
    "  # --- 2) CF run\n",
    "  cf_res = run_em(μcf; X = Xs, nIter = nIter, tol = tol)\n",
    "  push!(results, (i, \"CF\",     cf_res.iterations, cf_res.final_ll))\n",
    "\n",
    "  # ------- CF_TV run\n",
    "  cf_res_TV = run_em(μcf_TV; X = Xs, nIter = nIter, tol = tol)\n",
    "  push!(results, (i, \"CF_TV\",     cf_res_TV.iterations, cf_res_TV.final_ll))\n",
    "\n",
    "  # --- 3) random runs\n",
    "  for rep in 1:nrand\n",
    "    # sample k distinct rows as init means\n",
    "    inds = sample(1:size(Xs,1), k; replace=false)\n",
    "    μrand = Xs[inds, :]                  # k×2\n",
    "    rr    = run_em(μrand; X = Xs, nIter = nIter, tol = tol)\n",
    "    push!(results, (i, \"random\", rr.iterations, rr.final_ll))\n",
    "  end\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82a3ca4a-9aa5-426c-a622-632ee16adec9",
   "metadata": {},
   "outputs": [],
   "source": [
    "using DataFrames, Statistics, Plots\n",
    "gr()  # show inline first\n",
    "\n",
    "# -- Build an aligned table and sort by CF (blue) --\n",
    "g = combine(groupby(results, [:dataset, :init]),\n",
    "            :final_ll => mean => :mean_ll,\n",
    "            :final_ll => std  => :std_ll)\n",
    "\n",
    "cf   = filter(:init => ==(\"CF\"),      g)[!, [:dataset, :mean_ll]]\n",
    "rename!(cf, :mean_ll => :cf)\n",
    "cfTV = filter(:init => ==(\"CF_TV\"),   g)[!, [:dataset, :mean_ll]]\n",
    "rename!(cfTV, :mean_ll => :cfTV)\n",
    "rnd  = filter(:init => ==(\"random\"),  g)[!, [:dataset, :mean_ll, :std_ll]]\n",
    "rename!(rnd, [:mean_ll, :std_ll] .=> [:rnd_mean, :rnd_std])\n",
    "\n",
    "T = innerjoin(innerjoin(cf, cfTV, on=:dataset), rnd, on=:dataset)\n",
    "sort!(T, :cf)  # ascending by blue\n",
    "\n",
    "xs  = 1:nrow(T)\n",
    "y1  = T.cf\n",
    "y2  = T.cfTV\n",
    "ym  = T.rnd_mean\n",
    "ys  = T.rnd_std\n",
    "\n",
    "# -- Styling & plot: connected lines + markers; ribbon on green only --\n",
    "default(size=(560,360), grid=false, framestyle=:box, legend=:topleft,\n",
    "        linewidth=2, markerstrokewidth=0.8, markersize=3)\n",
    "\n",
    "plt = plot(xs, y1; label=\"Curto–Fialkow–W2\",\n",
    "           marker=:circle, linestyle=:solid)\n",
    "\n",
    "plot!(plt, xs, y2; label=\"Curto–Fialkow–TV\",\n",
    "      marker=:diamond, linestyle=:solid)\n",
    "\n",
    "plot!(plt, xs, ym; ribbon=ys,       # <- ribbon here\n",
    "      label=\"Random ± 1 std\",\n",
    "      marker=:square, linestyle=:solid,\n",
    "      fillalpha=0.25)\n",
    "\n",
    "xlabel!(\"Dataset (sorted by CF fit)\")\n",
    "ylabel!(\"Final log-likelihood\")\n",
    "display(plt)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8db813c4-ba04-44ab-850b-f0e06a3fda24",
   "metadata": {},
   "outputs": [],
   "source": [
    "using PGFPlotsX, LaTeXStrings\n",
    "pgfplotsx()\n",
    "plt_tikz = plot(xs, y1; label=L\"\\text{Curto–Fialkow–W2}\", marker=:circle)\n",
    "plot!(plt_tikz, xs, y2; label=L\"\\text{Curto–Fialkow–TV}\", marker=:diamond)\n",
    "plot!(plt_tikz, xs, ym; ribbon=ys, label=L\"\\text{Random} \\pm \\text{ 1 std}\",\n",
    "      marker=:square, fillalpha=0.25)\n",
    "xlabel!(L\"\\text{Dataset (sorted by CF fit)}\")\n",
    "ylabel!(L\"\\text{Final log-likelihood}\")\n",
    "\n",
    "#savefig(plt_tikz, \"final_ll_K2_NS.tikz\")   # \\usepgfplotslibrary{fillbetween} in LaTeX\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "374ce0b4-64af-470e-8713-9d6aa0f29c39",
   "metadata": {},
   "outputs": [],
   "source": [
    "using DataFrames, Statistics, Plots\n",
    "\n",
    "# --- aggregate once per (dataset, init) ---\n",
    "g = combine(groupby(results, [:dataset, :init]),\n",
    "            :iterations => mean => :mean_it,\n",
    "            :iterations => std  => :std_it)\n",
    "\n",
    "# split and rename\n",
    "cf   = filter(:init => ==(\"CF\"),     g)[!, [:dataset, :mean_it]];  rename!(cf,  :mean_it => :cf)\n",
    "cfTV = filter(:init => ==(\"CF_TV\"),  g)[!, [:dataset, :mean_it]];  rename!(cfTV,:mean_it => :cfTV)\n",
    "rnd  = filter(:init => ==(\"random\"), g)[!, [:dataset, :mean_it, :std_it]];\n",
    "rename!(rnd, [:mean_it, :std_it] .=> [:rnd_mean, :rnd_std])\n",
    "\n",
    "# align rows and sort by CF (blue)\n",
    "T = innerjoin(innerjoin(cf, cfTV, on=:dataset), rnd, on=:dataset)\n",
    "sort!(T, :cf)  # ascending by CF iterations\n",
    "\n",
    "xs  = 1:nrow(T)\n",
    "y1  = T.cf\n",
    "y2  = T.cfTV\n",
    "ym  = T.rnd_mean\n",
    "ys  = T.rnd_std\n",
    "\n",
    "# --- notebook preview (GR backend) ---\n",
    "gr()\n",
    "default(size=(560,360), grid=false, framestyle=:box, legend=:topleft,\n",
    "        linewidth=2, markerstrokewidth=0.8, markersize=3)\n",
    "\n",
    "plt = plot(xs, y1; label=\"Curto–Fialkow–W2\", marker=:circle, linestyle=:solid)\n",
    "plot!(plt, xs, y2; label=\"Curto–Fialkow–TV\",  marker=:diamond, linestyle=:solid)\n",
    "plot!(plt, xs, ym;  ribbon=ys,                 # ribbon only for random\n",
    "      label=\"Random ± 1 std\", marker=:square, linestyle=:solid, fillalpha=0.25)\n",
    "\n",
    "xlabel!(plt, \"Dataset (sorted by CF iterations)\")\n",
    "ylabel!(plt, \"EM iterations to convergence\")\n",
    "display(plt)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b503df50-f04d-405e-80d8-c2e34ff4bfcd",
   "metadata": {},
   "outputs": [],
   "source": [
    "using PGFPlotsX, LaTeXStrings\n",
    "pgfplotsx()\n",
    "plt_tikz = plot(xs, y1; label=L\"\\operatorname{W2}\", marker=:circle)\n",
    "plot!(plt_tikz, xs, y2; label=L\"\\operatorname{TV}\",  marker=:diamond)\n",
    "plot!(plt_tikz, xs, ym; ribbon=ys, label=L\"\\text{Random}\",\n",
    "      marker=:square, fillalpha=0.25)\n",
    "#xlabel!(plt_tikz, L\"\\text{Dataset (sorted by CF iterations)}\")\n",
    "ylabel!(plt_tikz, L\"\\text{Iterations}\")\n",
    "#savefig(plt_tikz, \"iterations_K2EM.tikz\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eee9c4b0-29bc-497f-90a9-4aa61c6299ea",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1fa2bbac-0fea-4a8e-a7ad-58859dc47148",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b483fc3b-6a51-4dbc-a287-64b15e4b899c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4de0ec2-e0fe-4c90-9ee5-2fba25c16392",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aaa8bb6c-2a18-4179-a819-2c99c9732522",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd04e47d-782b-4e4f-bfb1-c9de8a5903cd",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Julia 1.10.5",
   "language": "julia",
   "name": "julia-1.10"
  },
  "language_info": {
   "file_extension": ".jl",
   "mimetype": "application/julia",
   "name": "julia",
   "version": "1.10.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
