{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "67efe077-b061-4b11-9563-9ca8def275d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "using JuMP\n",
    "using MosekTools\n",
    "using DynamicPolynomials\n",
    "using MultivariatePolynomials\n",
    "using TSSOS\n",
    "using LinearAlgebra, Random, Plots, Distributions, IterTools, Combinatorics, CSV, Statistics, MLDatasets, DataFrames, Revise, Clustering, Distances\n",
    "using GaussianMixtures\n",
    "includet(\"Functions_Mixtures.jl\")\n",
    "includet(\"SeparationExperiment.jl\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c68bbf92-31ca-4ae3-a6f7-5ec17dea2eec",
   "metadata": {},
   "source": [
    "### Locally used functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "13868002-a7eb-42df-8386-6817caea2000",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "run_em (generic function with 1 method)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "using GaussianMixtures \n",
    "using DataFrames\n",
    "using DataFrames: groupby\n",
    "using StatsBase       \n",
    "using StatsPlots \n",
    "\n",
    "function random_means()\n",
    "  \n",
    "  idx = sample(1:size(X,1), k; replace=false)\n",
    "  return X[idx, :]\n",
    "end\n",
    "\n",
    "function run_em(μ0; X, nIter=100, tol=1e-6, varfloor=1e-3)\n",
    "    \n",
    "    μ0 = Matrix(μ0)\n",
    "    k, d = size(μ0)\n",
    "\n",
    "   \n",
    "    σ2 = vec(var(X; dims=1))              # length-d Vector\n",
    "    Σ0 = repeat(σ2', k, 1)                # k×d Matrix\n",
    "\n",
    "    # 3) build GMM (skip k‑means by nInit=0)\n",
    "    g = GMM(k, X; kind=:diag, nInit=10)\n",
    "\n",
    "    # 4) overwrite init\n",
    "    g.μ .= μ0\n",
    "    g.Σ .= Σ0\n",
    "    g.w .= fill(1/k, k)\n",
    "\n",
    "    # 5) run EM up to nIter, with varfloor control\n",
    "    logl = em!(g, X; nIter=nIter, varfloor=varfloor)\n",
    "\n",
    "    # 6) find when |Δℓ|<tol\n",
    "    idx = findfirst(i -> abs(logl[i] - logl[i-1]) < tol, 2:length(logl))\n",
    "    iters = idx === nothing ? length(logl) : idx\n",
    "\n",
    "    return (\n",
    "      iterations = iters,\n",
    "      final_ll   = logl[end],\n",
    "      logl       = logl,\n",
    "      model      = g,\n",
    "    )\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "4849da71-8f0f-4771-8786-86ae2900dff2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "align_labels (generic function with 1 method)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "function find_nearest_indices(X::Matrix{Float64}, centers::Matrix{Float64})\n",
    "    indices = Int[]\n",
    "    for c in eachrow(centers')\n",
    "        dists = [sum((X[i, :] .- c).^2) for i in 1:size(X, 1)]\n",
    "        push!(indices, argmin(dists))\n",
    "    end\n",
    "    return indices\n",
    "end\n",
    "\n",
    "function adjusted_rand_index(labels_true::Vector{Int}, labels_pred::Vector{Int})\n",
    "    @assert length(labels_true) == length(labels_pred)\n",
    "\n",
    "    n = length(labels_true)\n",
    "    unique_true = sort(unique(labels_true))\n",
    "    unique_pred = sort(unique(labels_pred))\n",
    "    n_true = length(unique_true)\n",
    "    n_pred = length(unique_pred)\n",
    "\n",
    "    # Confusion matrix: n_ij\n",
    "    contingency = zeros(Int, n_true, n_pred)\n",
    "    for (l_true, l_pred) in zip(labels_true, labels_pred)\n",
    "        i = findfirst(isequal(l_true), unique_true)\n",
    "        j = findfirst(isequal(l_pred), unique_pred)\n",
    "        contingency[i, j] += 1\n",
    "    end\n",
    "\n",
    "    a = sum(contingency, dims=2)  # true cluster sizes\n",
    "    b = sum(contingency, dims=1)  # predicted cluster sizes\n",
    "\n",
    "    comb2(x) = x < 2 ? 0 : x * (x - 1) ÷ 2\n",
    "\n",
    "    index = sum(comb2(nij) for nij in contingency)\n",
    "\n",
    "    sum_ai = sum(comb2(ai) for ai in a)\n",
    "    sum_bj = sum(comb2(bj) for bj in b)\n",
    "\n",
    "    expected_index = sum_ai * sum_bj / comb2(n)\n",
    "    max_index = (sum_ai + sum_bj) / 2\n",
    "\n",
    "    return (index - expected_index) / (max_index - expected_index)\n",
    "end\n",
    "function align_labels(true_labels::Vector{Int}, pred_labels::Vector{Int})\n",
    "    classes = sort(unique(true_labels))\n",
    "    perms = collect(permutations(classes))\n",
    "\n",
    "    best_score = -1\n",
    "    best_aligned = similar(pred_labels)\n",
    "\n",
    "    for p in perms\n",
    "        mapping = Dict(classes[i] => p[i] for i in eachindex(classes))\n",
    "        aligned = [mapping[l] for l in pred_labels]\n",
    "        score = sum(aligned .== true_labels)\n",
    "\n",
    "        if score > best_score\n",
    "            best_score = score\n",
    "            best_aligned .= aligned\n",
    "        end\n",
    "    end\n",
    "\n",
    "    return best_aligned\n",
    "end"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9613f36f-cb52-4987-b3ae-509292e8a160",
   "metadata": {},
   "source": [
    "## Main part"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "eab4b502-54ef-4ff4-9caa-2f68e5dca8cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "using MLDatasets\n",
    "\n",
    "mnist = MNIST(split = :train)  \n",
    "# 1) find the indices of 0 and 1\n",
    "inds01 = findall(lbl -> lbl in (0,1), mnist.targets)\n",
    "\n",
    "# 2) subset the images and labels\n",
    "imgs01  = mnist.features[:, :, inds01]   # 28×28×N01 array\n",
    "labels01 = mnist.targets[inds01]          # N01‐vector of 0’s and 1’s\n",
    "\n",
    "inds0 = findall(labels01 .== 0)\n",
    "imgs0 = imgs01[:, :, inds0]    # now 28×28×N0\n",
    "X0 = reshape(imgs0, 28*28, :)'   \n",
    "inds1 = findall(labels01 .== 1)\n",
    "imgs1 = imgs01[:, :, inds1]    # now 28×28×N0\n",
    "X1 = reshape(imgs1, 28*28, :)';"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "306a36db-5605-4ebd-81bd-231ef6958183",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 1) materialize as normal matrices\n",
    "X0_mat = Matrix(X0)   # now 5923×784 Matrix{Float32}\n",
    "X1_mat = Matrix(X1)   # now 6742×784 Matrix{Float32}\n",
    "\n",
    "# 2) stack into one big dataset of 12665×784\n",
    "X_all = vcat(X0_mat, X1_mat)   # size(X_all) == (5923+6742, 784)\n",
    "\n",
    "# 3) make the labels (0 for the first block, 1 for the second)\n",
    "y_all = vcat( fill(0, size(X0_mat,1)),\n",
    "              fill(1, size(X1_mat,1)) );  # length == 12665"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a8b31b91-0af6-444c-8cfd-c45c13df87e3",
   "metadata": {},
   "source": [
    "### PCA + relaxations for {0,1}  and {0,1,2}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "7877c368-09a2-4af9-a7f0-1530d2fa6632",
   "metadata": {},
   "outputs": [],
   "source": [
    "using MLDatasets, LinearAlgebra, Plots\n",
    "\n",
    "# Load and prepare data\n",
    "mnist = MNIST(split=:train)\n",
    "inds01 = findall(lbl -> lbl in (0, 1), mnist.targets)\n",
    "imgs01 = mnist.features[:, :, inds01]\n",
    "labels01 = mnist.targets[inds01]\n",
    "\n",
    "# Extract images for 0 and 1\n",
    "imgs0 = imgs01[:, :, labels01 .== 0]\n",
    "imgs1 = imgs01[:, :, labels01 .== 1]\n",
    "X0 = reshape(imgs0, 28*28, :)'\n",
    "X1 = reshape(imgs1, 28*28, :)'\n",
    "\n",
    "# Combine data and labels\n",
    "X_all = vcat(Matrix(X0), Matrix(X1))\n",
    "y_all = vcat(zeros(size(X0, 1)), ones(size(X1, 1)))\n",
    "\n",
    "# Center the data\n",
    "X_centered = X_all .- mean(X_all, dims=1)\n",
    "\n",
    "# Perform SVD\n",
    "U, S, _ = svd(X_centered)\n",
    "\n",
    "# Project onto first two PCs\n",
    "pc1 = U[:, 1] .* S[1]\n",
    "pc2 = U[:, 2] .* S[2];"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "34989a82-4467-4113-aa43-15b431a56c01",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 1) stack into an N×2 matrix\n",
    "X = hcat(pc1, pc2)    # size(X) == (N, 2)\n",
    "\n",
    "# 2) scale each column into [–1,1]\n",
    "X_scaled = hcat( scale_to_minus1_1.(eachcol(X))... );\n",
    "# size(X_scaled) == (N, 2)\n",
    "samples=X_scaled;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc1e7c96-e9b6-49f0-a66c-82c11d2acb2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot with label\n",
    "scatter(X_scaled[:,1][y_all .== 0], X_scaled[:,2][y_all .== 0], \n",
    "        color=:blue, alpha=0.5, marker=:dot, label=\"0\",\n",
    "        xlabel=\"PC1\", ylabel=\"PC2\", title=\"PCA of MNIST Digits 0 and 1\")\n",
    "scatter!(X_scaled[:,1][y_all .== 1], X_scaled[:,2][y_all .== 1], \n",
    "        color=:red, alpha=0.5, marker=:dot, label=\"1\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d1219cbe-456d-4c02-8fee-c276ef0516f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "using DataFrames, Statistics\n",
    "\n",
    "# 1) Build a DataFrame from your PC coords + labels\n",
    "df2 = DataFrame(\n",
    "  PC1   = X_scaled[:,1],\n",
    "  PC2   = X_scaled[:,2],\n",
    "  label = y_all\n",
    ")\n",
    "\n",
    "# 2) Group by the label and compute mean & var\n",
    "stats = combine(\n",
    "  groupby(df2, :label),\n",
    "  :PC1 => mean => :mean_PC1,\n",
    "  :PC1 => var  => :var_PC1,\n",
    "  :PC2 => mean => :mean_PC2,\n",
    "  :PC2 => var  => :var_PC2,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "befe5027-8850-4b32-895e-6f238b0a249f",
   "metadata": {},
   "outputs": [],
   "source": [
    "n=2\n",
    "@polyvar m[1:n]\n",
    "@polyvar sigma[1:n]\n",
    "#Sm=[(m[1]+1)*(1-m[1]), (m[2]+1)*(1-m[2]),(m[3]+1)*(1-m[3]),(m[4]+1)*(1-m[4])]\n",
    "Sm=[(m[1])*(1-m[1]), (m[2])*(1-m[2])]\n",
    "Ssig=[(sigma[1]-0.00005)*(0.15-sigma[1]), (sigma[2]-0.00005)*(0.15-sigma[2])]\n",
    "S=[vcat(Sm)...,vcat(Ssig)...]\n",
    "S_normalized=[S[i]/maximum(abs.(coefficients(S[i]))) for i=1:length(S)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7934d3dc-4b22-4816-b7f4-e83c49e8a456",
   "metadata": {},
   "outputs": [],
   "source": [
    "tr_reg=true        \n",
    "res=[]\n",
    "max_order=4\n",
    "eps_eq=1e-4\n",
    "eps_tr=1e-2\n",
    "for i=1:max_order\n",
    "    println(\" d = \", i)\n",
    "    push!(res, slack_create_SOS_model(n, i, m, sigma, S_normalized, samples, eps_eq, eps_tr; tr_reg))\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19ad4ce1-ec71-4a56-8252-9234bf297ab5",
   "metadata": {},
   "outputs": [],
   "source": [
    "d=4\n",
    "RES = analyse_relaxations(res, d, n);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aae5d202-f471-47a5-91eb-eb4957bc5398",
   "metadata": {},
   "outputs": [],
   "source": [
    "Curto_F=extract_CF(RES[1], RES[end], binomial(n+d-1,d-1), n, 3)\n",
    "sort(Curto_F, by = p -> p[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dcfd3a36-8e7c-4029-b473-3decb88031ae",
   "metadata": {},
   "outputs": [],
   "source": [
    "tr_reg=true\n",
    "eps_tr=1e-1\n",
    "max_order=4\n",
    "\n",
    "resTV=[]\n",
    "for i=1:max_order\n",
    "    println(\" d = \", i)\n",
    "    push!(resTV, TV_SOS_model(n, i, m, sigma, S_normalized, samples; tr_reg, eps_tr))\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d83c7a2-8e55-4025-98f8-6fc7c3b151e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "d=4\n",
    "RESTV = analyse_relaxations_TV(resTV, d, n);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b242e839-adcd-40fc-afd3-95abc3c7dd5e",
   "metadata": {},
   "outputs": [],
   "source": [
    "Curto_FTV=extract_CF(RESTV[1], RESTV[end], binomial(n+d-1,d-1),n,3)\n",
    "sort(Curto_FTV, by = p -> p[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "334fc446-1f76-40eb-832a-e16ee7ff62c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "using Clustering\n",
    "using Statistics\n",
    "using Random\n",
    "using Plots\n",
    "\n",
    "X = samples'  # shape: n×d, here 150×4\n",
    "C_init = Matrix(hcat(Curto_F...)') \n",
    "C_initTV = Matrix(hcat(Curto_FTV...)')  \n",
    "\n",
    "\n",
    "function find_nearest_indices(X, centers)\n",
    "    indices = Int[]\n",
    "    for c in eachrow(centers)\n",
    "        dists = [sum((X[i, :] .- c).^2) for i in 1:size(X, 1)]\n",
    "        push!(indices, argmin(dists))\n",
    "    end\n",
    "    return indices\n",
    "end\n",
    "\n",
    "init_indices = find_nearest_indices(samples, C_init)\n",
    "\n",
    "init_indicesTV = find_nearest_indices(samples, C_initTV)\n",
    "\n",
    "\n",
    "result_custom = kmeans(X, 2; init=init_indices)\n",
    "cost_custom = result_custom.totalcost\n",
    "iters_custom = result_custom.iterations\n",
    "\n",
    "result_customTV = kmeans(X, 2; init=init_indicesTV)\n",
    "cost_customTV = result_customTV.totalcost\n",
    "iters_customTV = result_customTV.iterations\n",
    "\n",
    "println(\"Curto-Fialkow - W2 distance → Cost: \", cost_custom, \", Iterations: \", iters_custom)\n",
    "println(\"Curto-Fialkow - TV distance → Cost: \", cost_customTV, \", Iterations: \", iters_customTV)\n",
    "\n",
    "\n",
    "N = 100\n",
    "random_costs = Float64[]\n",
    "random_iters = Int[]\n",
    "\n",
    "for _ in 1:N\n",
    "    res = kmeans(X, 2; init=:kmpp)\n",
    "    push!(random_costs, res.totalcost)\n",
    "    push!(random_iters, res.iterations)\n",
    "end\n",
    "\n",
    "\n",
    "println(\"Curto-Fialkow - W2 distance → Cost: \", cost_custom, \", Iterations: \", iters_custom)\n",
    "println(\"Curto-Fialkow - TV distance → Cost: \", cost_customTV, \", Iterations: \", iters_customTV)\n",
    "# ---- Statistics ----\n",
    "mean_cost = mean(random_costs)\n",
    "std_cost = std(random_costs)\n",
    "mean_iters = mean(random_iters)\n",
    "std_iters = std(random_iters)\n",
    "\n",
    "println(\"Random → Mean cost: $(round(mean_cost, digits=2)) ± $(round(std_cost, digits=2))\")\n",
    "println(\"Random → Mean iters: $(round(mean_iters, digits=2)) ± $(round(std_iters, digits=2))\")\n",
    "\n",
    "histogram(random_costs;\n",
    "    bins=20,\n",
    "    xlabel=\"Total cost\",\n",
    "    ylabel=\"Frequency\",\n",
    "    alpha=0.6,\n",
    "    label=\"Random\",\n",
    "    title=\"KMeans Cost Comparison\",\n",
    "    legend=:topright)\n",
    "vline!([cost_custom], label=\"Curto-Fialkow\", lw=2, lc=:red)\n",
    "\n",
    "plot()\n",
    "histogram(random_iters;\n",
    "    bins=16,\n",
    "    xlabel=\"Iterations\",D\n",
    "    ylabel=\"Frequency\",\n",
    "    alpha=0.6,\n",
    "    label=\"Random\",\n",
    "    #title=\"KMeans IteSSration Comparison\",\n",
    "    legend=:topright)\n",
    "vline!([iters_custom], label=\"Curto-Fialkow-W2\", lw=4, lc=:red)\n",
    "vline!([iters_customTV], label=\"Curto-Fialkow-TV\", lw=4, lc=:orange)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93abd708-6647-44b6-9979-27eeed83c40f",
   "metadata": {},
   "outputs": [],
   "source": [
    "using Plots, PGFPlotsX, LaTeXStrings, Colors\n",
    "\n",
    "# fixed colors\n",
    "col_w2  = RGB(0.0,    0.6056, 0.9787)  # blue:  CF–W2\n",
    "col_tv  = RGB(0.8889, 0.4356, 0.2781)  # orange: CF–TV\n",
    "col_rnd = RGB(0.2422, 0.6433, 0.3044)  # green:  Random\n",
    "\n",
    "# --- Preview in notebook (GR) ---\n",
    "gr()\n",
    "default(size=(560,360), framestyle=:box, grid=false, legend=:topright)\n",
    "\n",
    "plt_iter = histogram(random_iters; bins=16,\n",
    "    label=\"Random ± 1 std\",\n",
    "    fillcolor=col_rnd, linecolor=col_rnd, alpha=0.45,   # <- force green\n",
    "    xlabel=\"Iterations\", ylabel=\"Frequency\")\n",
    "\n",
    "vline!(plt_iter, [iters_custom];    label=\"Curto–Fialkow–W2\", lw=2, color=col_w2)\n",
    "vline!(plt_iter, [iters_customTV];  label=\"Curto–Fialkow–TV\", lw=2, color=col_tv)\n",
    "display(plt_iter)\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8797065-a366-47ed-9b1b-4a43beede462",
   "metadata": {},
   "outputs": [],
   "source": [
    "# --- Export same look to TikZ ---\n",
    "pgfplotsx()\n",
    "plt_iter_tex = histogram(random_iters; bins=16,\n",
    "    label=L\"\\text{Random} \\pm \\text{ 1 std}\",\n",
    "    fillcolor=col_rnd, linecolor=col_rnd, alpha=0.45,\n",
    "    xlabel=L\"\\text{Iterations}\", ylabel=L\"\\text{Frequency}\")\n",
    "\n",
    "vline!(plt_iter_tex, [iters_custom];   label=L\"\\text{Curto–Fialkow–W2}\", lw=2, color=col_w2)\n",
    "vline!(plt_iter_tex, [iters_customTV]; label=L\"\\text{Curto–Fialkow–TV}\", lw=2, color=col_tv)\n",
    "\n",
    "#savefig(plt_iter_tex, \"kmeans_iter_hist.tikz\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0740825-aec0-47f0-8db1-2c8f6458a0a5",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6cb2032e-dbc6-4296-85bd-fc7c2e4bdf65",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b69da27-3002-42bd-8173-a5a687c03af0",
   "metadata": {},
   "outputs": [],
   "source": [
    "using Clustering\n",
    "using Statistics\n",
    "using Random\n",
    "using Plots\n",
    "\n",
    "X = samples'  # shape: n×d, here 150×4\n",
    "C_init = Matrix(hcat(Curto_F...)') \n",
    "C_initTV = Matrix(hcat(Curto_FTV...)')  \n",
    "\n",
    "\n",
    "function find_nearest_indices(X, centers)\n",
    "    indices = Int[]\n",
    "    for c in eachrow(centers)\n",
    "        dists = [sum((X[i, :] .- c).^2) for i in 1:size(X, 1)]\n",
    "        push!(indices, argmin(dists))\n",
    "    end\n",
    "    return indices\n",
    "end\n",
    "\n",
    "init_indices = find_nearest_indices(samples, C_init)\n",
    "\n",
    "init_indicesTV = find_nearest_indices(samples, C_initTV)\n",
    "\n",
    "\n",
    "result_custom = kmeans(X, 3; init=init_indices)\n",
    "cost_custom = result_custom.totalcost\n",
    "iters_custom = result_custom.iterations\n",
    "\n",
    "result_customTV = kmeans(X, 3; init=init_indicesTV)\n",
    "cost_customTV = result_customTV.totalcost\n",
    "iters_customTV = result_customTV.iterations\n",
    "\n",
    "println(\"Curto-Fialkow - W2 distance → Cost: \", cost_custom, \", Iterations: \", iters_custom)\n",
    "println(\"Curto-Fialkow - TV distance → Cost: \", cost_customTV, \", Iterations: \", iters_customTV)\n",
    "\n",
    "\n",
    "N = 100\n",
    "random_costs = Float64[]\n",
    "random_iters = Int[]\n",
    "\n",
    "for _ in 1:N\n",
    "    res = kmeans(X, 3; init=:kmpp)\n",
    "    push!(random_costs, res.totalcost)\n",
    "    push!(random_iters, res.iterations)\n",
    "end\n",
    "\n",
    "\n",
    "println(\"Curto-Fialkow - W2 distance → Cost: \", cost_custom, \", Iterations: \", iters_custom)\n",
    "println(\"Curto-Fialkow - TV distance → Cost: \", cost_customTV, \", Iterations: \", iters_customTV)\n",
    "# ---- Statistics ----\n",
    "mean_cost = mean(random_costs)\n",
    "std_cost = std(random_costs)\n",
    "mean_iters = mean(random_iters)\n",
    "std_iters = std(random_iters)\n",
    "\n",
    "println(\"Random → Mean cost: $(round(mean_cost, digits=2)) ± $(round(std_cost, digits=2))\")\n",
    "println(\"Random → Mean iters: $(round(mean_iters, digits=2)) ± $(round(std_iters, digits=2))\")\n",
    "\n",
    "histogram(random_costs;\n",
    "    bins=20,\n",
    "    xlabel=\"Total cost\",\n",
    "    ylabel=\"Frequency\",\n",
    "    alpha=0.6,\n",
    "    label=\"Random\",\n",
    "    title=\"KMeans Cost Comparison\",\n",
    "    legend=:topright)\n",
    "vline!([cost_custom], label=\"Curto-Fialkow\", lw=2, lc=:red)\n",
    "\n",
    "plot()\n",
    "histogram(random_iters;\n",
    "    bins=16,\n",
    "    xlabel=\"Iterations\",\n",
    "    ylabel=\"Frequency\",\n",
    "    alpha=0.6,\n",
    "    label=\"Random\",\n",
    "    #title=\"KMeans IteSSration Comparison\",\n",
    "    legend=:topright)\n",
    "vline!([iters_custom], label=\"Curto-Fialkow-W2\", lw=4, lc=:red)\n",
    "vline!([iters_customTV], label=\"Curto-Fialkow-TV\", lw=4, lc=:orange)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8be005b-575b-4b71-bfe4-e25358777a62",
   "metadata": {},
   "outputs": [],
   "source": [
    "using Plots, PGFPlotsX, LaTeXStrings, Colors\n",
    "\n",
    "# fixed colors\n",
    "col_w2  = RGB(0.0,    0.6056, 0.9787)  # blue:  CF–W2\n",
    "col_tv  = RGB(0.8889, 0.4356, 0.2781)  # orange: CF–TV\n",
    "col_rnd = RGB(0.2422, 0.6433, 0.3044)  # green:  Random\n",
    "\n",
    "# --- Preview in notebook (GR) ---\n",
    "gr()\n",
    "default(size=(560,360), framestyle=:box, grid=false, legend=:topright)\n",
    "\n",
    "plt_iter = histogram(random_iters; bins=16,\n",
    "    label=\"Random ± 1 std\",\n",
    "    fillcolor=col_rnd, linecolor=col_rnd, alpha=0.45,   # <- force green\n",
    "    xlabel=\"Iterations\", ylabel=\"Frequency\")\n",
    "\n",
    "vline!(plt_iter, [iters_custom];    label=\"Curto–Fialkow–W2\", lw=2, color=col_w2)\n",
    "vline!(plt_iter, [iters_customTV];  label=\"Curto–Fialkow–TV\", lw=2, color=col_tv)\n",
    "display(plt_iter)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0a7ee3b-3041-4639-b50c-ada33c7c742b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# --- Export same look to TikZ ---\n",
    "pgfplotsx()\n",
    "plt_iter_tex = histogram(random_iters; bins=16,\n",
    "    label=L\"\\text{Random} \\pm \\text{ 1 std}\",\n",
    "    fillcolor=col_rnd, linecolor=col_rnd, alpha=0.45,\n",
    "    xlabel=L\"\\text{Iterations}\", ylabel=L\"\\text{Frequency}\")\n",
    "\n",
    "vline!(plt_iter_tex, [iters_custom];   label=L\"\\text{Curto–Fialkow–W2}\", lw=2, color=col_w2)\n",
    "vline!(plt_iter_tex, [iters_customTV]; label=L\"\\text{Curto–Fialkow–TV}\", lw=2, color=col_tv)\n",
    "\n",
    "#savefig(plt_iter_tex, \"kmeans_iter_hist_012_2comp.tikz\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c313370e-06df-491e-b88f-6075407d4912",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3274d9b-9071-42fb-8b9d-704b35ab7a18",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Julia 1.10.5",
   "language": "julia",
   "name": "julia-1.10"
  },
  "language_info": {
   "file_extension": ".jl",
   "mimetype": "application/julia",
   "name": "julia",
   "version": "1.10.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
