{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "f12cb6da-23f7-4dbd-93d1-23b40e50f02f",
   "metadata": {},
   "outputs": [],
   "source": [
    "using JuMP\n",
    "using MosekTools\n",
    "using DynamicPolynomials\n",
    "using MultivariatePolynomials\n",
    "using TSSOS\n",
    "using LinearAlgebra, Random, Plots, Distributions, IterTools, Combinatorics, CSV, Statistics, MLDatasets, DataFrames, Revise, Clustering, Distances\n",
    "includet(\"UnivariateModels.jl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e5eb375-4afb-4527-b9cb-36512e0f1f48",
   "metadata": {},
   "outputs": [],
   "source": [
    "using MLDatasets\n",
    "\n",
    "\"\"\"\n",
    "    load_mnist_subset(digits; normalize=true)\n",
    "\n",
    "Load MNIST *training* images for the given `digits` (e.g. `[1,4,7]`).\n",
    "\n",
    "Returns `(X, y)` where:\n",
    "- `X` is an `N × 784` Float32 matrix (N = total # of selected images),\n",
    "  each row is a flattened 28×28 image.\n",
    "- `y` is a Vector{Int} of the corresponding labels.\n",
    "\n",
    "Keyword:\n",
    "- `normalize=true` scales pixel values to [0,1] by dividing by 255.\n",
    "\"\"\"\n",
    "function load_mnist_subset(digits::AbstractVector{<:Integer}; normalize::Bool=true)\n",
    "    # sanity\n",
    "    @assert all(0 .≤ digits .≤ 9) \"digits must be between 0 and 9\"\n",
    "\n",
    "    imgs, labs = MNIST.traindata()          # imgs: 28×28×60000 UInt8, labs: 60000-element Vector{UInt8}\n",
    "\n",
    "    # mask for the chosen digits\n",
    "    mask = falses(length(labs))\n",
    "    @inbounds for d in digits\n",
    "        mask .|= (labs .== d)\n",
    "    end\n",
    "\n",
    "    # select and flatten\n",
    "    sel = imgs[:, :, mask]                  # 28×28×N\n",
    "    N   = size(sel, 3)\n",
    "    X   = reshape(Float32.(sel), 28*28, N)' # N×784\n",
    "    if normalize\n",
    "        X ./= 255f0\n",
    "    end\n",
    "    y = Int.(labs[mask])\n",
    "\n",
    "    return X, y\n",
    "end\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc6715db-a206-46c3-8958-dcd29efe10c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "digs = [0, 1, 2]\n",
    "X, y = load_mnist_subset(digs; normalize=false)\n",
    "@show size(X)\n",
    "@show map(d -> count(==(d), y), digs)\n",
    "@show extrema(X), count(!=(0f0), X) / length(X)\n",
    "@show counts = map(d -> sum(y .== d), digs)  # number of images per digit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c72dbbea-173a-4655-a355-8080ffd5b0a6",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a3758f83-9549-465b-8bca-3f6858d840ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "@polyvar m\n",
    "@polyvar sigma\n",
    "Sm=[(m-0.0001)*(1.0001-m)]\n",
    "Ssig=[(sigma-0.0001)*(1.000-sigma)]\n",
    "S=[vcat(Sm)...,vcat(Ssig)...]\n",
    "println()\n",
    "println(\"Support of the mixing measure\")\n",
    "S_normalized=[S[i]/maximum(abs.(coefficients(S[i]))) for i=1:length(S)]\n",
    "display(S_normalized)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c90fba69-ff2f-4eb8-85dd-9c6d29d46653",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "max_order=4\n",
    "RESW2=[]\n",
    "for dim=1:784\n",
    "    relax=[]\n",
    "    for d = 1:max_order\n",
    "        println(\"  d = $d\")\n",
    "        push!(relax, univariate_SOS_model_Gaussian_W2(d, m, sigma, S_normalized, X[:, dim], true, 0.00001))\n",
    "    end\n",
    "    push!(RESW2,relax)\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ffc23459-bc0d-41cf-995e-1a41d90c8873",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "max_order=4\n",
    "RESW2bige=[]\n",
    "for dim=1:784\n",
    "    relax=[]\n",
    "    for d = 1:max_order\n",
    "        println(\"  d = $d\")\n",
    "        push!(relax, univariate_SOS_model_Gaussian_W2(d, m, sigma, S_normalized, X[:, dim], true, 0.1))\n",
    "    end\n",
    "    push!(RESW2bige,relax)\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b194962-9ec5-4e2b-aa6b-1888a6eda282",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "A=[]\n",
    "for i=1:784\n",
    "    println(\"dimension = \",i)\n",
    "    push!(A, analyse_relaxations(RESW2[i],4, 1));\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "121aaac2-9ffd-4962-a9ca-37247030a3ce",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "Abige=[]\n",
    "for i=1:784\n",
    "    println(\"dimension = \",i)\n",
    "    push!(Abige, analyse_relaxations(RESW2bige[i],4, 1));\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1752dd0c-3c5f-4c5a-a664-ec2f12128821",
   "metadata": {},
   "outputs": [],
   "source": [
    "using LinearAlgebra\n",
    "using Plots\n",
    "import StatsBase: countmap, mode  #\n",
    "\n",
    "\n",
    "function numeric_rank(M; reltol=1e-6)\n",
    "    vals = eigvals(Symmetric(Matrix(M)))\n",
    "    thr = reltol * maximum(abs, vals)\n",
    "    return count(>(thr), abs.(vals))\n",
    "end\n",
    "\n",
    "\n",
    "ranks = [rank_by_energy(A[i][1]; energy_tol=1e-6) for i in eachindex(A)]\n",
    "\n",
    "# summary\n",
    "minr, maxr = minimum(ranks), maximum(ranks)\n",
    "mrank       = StatsBase.mode(ranks)              \n",
    "cm          = countmap(ranks)                    \n",
    "\n",
    "println(\"min rank = $minr, max rank = $maxr, mode = $mrank\")\n",
    "for k in sort(collect(keys(cm)))\n",
    "    println(\"rank $k : \", cm[k])\n",
    "end\n",
    "cm = countmap(ranks)\n",
    "xs = sort!(collect(keys(cm)))\n",
    "ys = [cm[x] for x in xs]\n",
    "\n",
    "# plot the sticks\n",
    "plot(xs, ys;\n",
    "     seriestype = :sticks,\n",
    "     marker = :circle, ms = 5, lw = 3,\n",
    "     xticks = xs,\n",
    "     xlims = (minimum(xs)-0.5, maximum(xs)+0.5),\n",
    "     xlabel = \"Estimated order\", ylabel = \"Number of dimensions\",\n",
    "     #title = \"Rank frequencies across 784 pixels\",\n",
    "     legend = false)\n",
    "\n",
    "dy = max(5, 0.03*maximum(ys)) \n",
    "annotate!([(x, y + dy, text(string(y), 9, :center, :bottom)) for (x,y) in zip(xs, ys)]...)\n",
    "ylims!(0, maximum(ys) + 4dy)   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24cdb4b5-c5f6-4058-994c-5d893fd8d3e3",
   "metadata": {},
   "outputs": [],
   "source": [
    "using LinearAlgebra\n",
    "using Plots\n",
    "import StatsBase: countmap, mode  \n",
    "\n",
    "# numerical rank via eigenvalues with relative tolerance\n",
    "function numeric_rank(M; reltol=1e-6)\n",
    "    vals = eigvals(Symmetric(Matrix(M)))\n",
    "    thr = reltol * maximum(abs, vals)\n",
    "    return count(>(thr), abs.(vals))\n",
    "end\n",
    "\n",
    "\n",
    "ranks = [rank_by_energy(Abige[i][1]; energy_tol=1e-6) for i in eachindex(Abige)]\n",
    "\n",
    "# summary\n",
    "minr, maxr = minimum(ranks), maximum(ranks)\n",
    "mrank       = StatsBase.mode(ranks)              # qualified to avoid conflicts\n",
    "cm          = countmap(ranks)                    # Dict{Int,Int}: rank -> frequency\n",
    "\n",
    "println(\"min rank = $minr, max rank = $maxr, mode = $mrank\")\n",
    "for k in sort(collect(keys(cm)))\n",
    "    println(\"rank $k : \", cm[k])\n",
    "end\n",
    "cm = countmap(ranks)\n",
    "xs = sort!(collect(keys(cm)))\n",
    "ys = [cm[x] for x in xs]\n",
    "\n",
    "# plot the sticks\n",
    "plot(xs, ys;\n",
    "     seriestype = :sticks,\n",
    "     marker = :circle, ms = 5, lw = 3,\n",
    "     xticks = xs,\n",
    "     xlims = (minimum(xs)-0.5, maximum(xs)+0.5),\n",
    "     xlabel = \"Estimated order\", ylabel = \"Number of dimensions\",\n",
    "     #title = \"Rank frequencies across 784 pixels\",\n",
    "     legend = false)\n",
    "\n",
    "# offset labels upward to avoid overlap\n",
    "dy = max(5, 0.03*maximum(ys))  # 3% of max (at least 5 counts)\n",
    "annotate!([(x, y + dy, text(string(y), 9, :center, :bottom)) for (x,y) in zip(xs, ys)]...)\n",
    "ylims!(0, maximum(ys) + 4dy)   \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "8985bab8-0589-46a3-8665-4a22802fb3f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "using LinearAlgebra\n",
    "using Plots\n",
    "using StatsBase\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ea8f6b2-87e1-4636-bd11-29c0358d860d",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "max_order=4\n",
    "RESTV=[]\n",
    "for dim=1:784\n",
    "    relax=[]\n",
    "    for d = 1:max_order\n",
    "        println(\"  d = $d\")\n",
    "        push!(relax, univariate_SOS_model_Gaussian_TV(d, m, sigma, S_normalized, X[:, dim], true, 0.00001))\n",
    "    end\n",
    "    push!(RESTV,relax)\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed751f87-1e14-4d1e-8270-4c91dfd20cad",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "max_order=4\n",
    "RESTVbige=[]\n",
    "for dim=1:784\n",
    "    relax=[]\n",
    "    for d = 1:max_order\n",
    "        println(\"  d = $d\")\n",
    "        push!(relax, univariate_SOS_model_Gaussian_TV(d, m, sigma, S_normalized, X[:, dim], true, 0.1))\n",
    "    end\n",
    "    push!(RESTVbige,relax)\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "079d7ac4-a487-4daa-b000-82afb8e4bae0",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "B=[]\n",
    "for i=1:784\n",
    "    println(\"dimension = \",i)\n",
    "    push!(B, analyse_relaxations(RESTV[i],4, 1));\n",
    "end\n",
    "#extract_CF(TW2[1], TW2[end], size(TW2[2],1), 1, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff2c4232-1b3e-4f41-af62-51fc92f53221",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "Bbige=[]\n",
    "for i=1:784\n",
    "    println(\"dimension = \",i)\n",
    "    push!(Bbige, analyse_relaxations(RESTVbige[i],4, 1));\n",
    "end\n",
    "#extract_CF(TW2[1], TW2[end], size(TW2[2],1), 1, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "92be3d1c-0dc6-4291-95c5-2b037216d2a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "#ranks = [numeric_rank(B[i][1]; reltol=1e-3) for i in eachindex(B)]\n",
    "ranks = [rank_by_energy(B[i][1]; energy_tol=1e-6) for i in eachindex(B)]\n",
    "\n",
    "# summary\n",
    "minr, maxr = minimum(ranks), maximum(ranks)\n",
    "mrank       = StatsBase.mode(ranks)              \n",
    "cm          = countmap(ranks)                    \n",
    "println(\"min rank = $minr, max rank = $maxr, mode = $mrank\")\n",
    "for k in sort(collect(keys(cm)))\n",
    "    println(\"rank $k : \", cm[k])\n",
    "end\n",
    "cm = countmap(ranks)\n",
    "xs = sort!(collect(keys(cm)))\n",
    "ys = [cm[x] for x in xs]\n",
    "\n",
    "# plot the sticks\n",
    "plot(xs, ys;\n",
    "     seriestype = :sticks,\n",
    "     marker = :circle, ms = 5, lw = 3,\n",
    "     xticks = xs,\n",
    "     xlims = (minimum(xs)-0.5, maximum(xs)+0.5),\n",
    "     xlabel = \"Estimated order\", ylabel = \"Number of dimensions\",\n",
    "     #title = \"Rank frequencies across 784 pixels\",\n",
    "     legend = false)\n",
    "\n",
    "# offset labels upward to avoid overlap\n",
    "dy = max(5, 0.03*maximum(ys))  # 3% of max (at least 5 counts)\n",
    "annotate!([(x, y + dy, text(string(y), 9, :center, :bottom)) for (x,y) in zip(xs, ys)]...)\n",
    "ylims!(0, maximum(ys) + 4dy)   # give headroom\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2866369c-e64d-4dee-a7aa-f20ad0fda331",
   "metadata": {},
   "outputs": [],
   "source": [
    "#ranks = [numeric_rank(B[i][1]; reltol=1e-3) for i in eachindex(B)]\n",
    "ranks = [rank_by_energy(Bbige[i][1]; energy_tol=1e-6) for i in eachindex(Bbige)]\n",
    "\n",
    "# summary\n",
    "minr, maxr = minimum(ranks), maximum(ranks)\n",
    "mrank       = StatsBase.mode(ranks)              # qualified to avoid conflicts\n",
    "cm          = countmap(ranks)                    # Dict{Int,Int}: rank -> frequency\n",
    "\n",
    "println(\"min rank = $minr, max rank = $maxr, mode = $mrank\")\n",
    "for k in sort(collect(keys(cm)))\n",
    "    println(\"rank $k : \", cm[k])\n",
    "end\n",
    "cm = countmap(ranks)\n",
    "xs = sort!(collect(keys(cm)))\n",
    "ys = [cm[x] for x in xs]\n",
    "\n",
    "# plot the sticks\n",
    "plot(xs, ys;\n",
    "     seriestype = :sticks,\n",
    "     marker = :circle, ms = 5, lw = 3,\n",
    "     xticks = xs,\n",
    "     xlims = (minimum(xs)-0.5, maximum(xs)+0.5),\n",
    "     xlabel = \"Estimated order\", ylabel = \"Number of dimensions\",\n",
    "     #title = \"Rank frequencies across 784 pixels\",\n",
    "     legend = false)\n",
    "\n",
    "# offset labels upward to avoid overlap\n",
    "dy = max(5, 0.03*maximum(ys))  # 3% of max (at least 5 counts)\n",
    "annotate!([(x, y + dy, text(string(y), 9, :center, :bottom)) for (x,y) in zip(xs, ys)]...)\n",
    "ylims!(0, maximum(ys) + 4dy)   # give headroom\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fedeb102-3fc6-4d79-b9dd-f58837a1d9d8",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Julia 1.10.5",
   "language": "julia",
   "name": "julia-1.10"
  },
  "language_info": {
   "file_extension": ".jl",
   "mimetype": "application/julia",
   "name": "julia",
   "version": "1.10.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
