 
using Pkg 
Pkg.add(["CairoMakie","LaTeXStrings","PyFormattedStrings","JLD2","MacroTools","MathTeXEngine","Statistics","ProgressLogging"])
  

using LinearAlgebra
using Random
using CairoMakie 
using LaTeXStrings
using PyFormattedStrings
using JLD2
using MacroTools
using MathTeXEngine 
using Statistics
using ProgressLogging 
 
include("./include/utils/MTUtils.jl")

# md"""
# ### Minimal model for outliers below MP bulk

# Teacher student setup, normalized vectors $u\in\mathbb{R}^{N}$, $v\in\mathbb{R}^{K}$, weight matrix $W\in\mathbb{R}^{K\times N}$

# ```math 
# \begin{align}
# \sigma_T(\xi) &= \lambda u^T\xi\\
# \sigma_S(\xi) &= v^T W \xi\\
# \Rightarrow W_0 &= v\lambda u^T 
# \end{align}
# ```
# Probability distribution for weight matrix $W$:
# ```math 
# \begin{align}
#  P(W) &= e^{-b \varepsilon_g(W) - \frac{1}{2 \alpha} \mathrm{Tr}(W^TW)}
# \end{align}
# ```
# Assume MSE loss $\mathcal{L}$ and $\varepsilon_g = \langle \mathcal{L}(\xi) \rangle_\xi$ averaged over $\mathcal{N}(0,1)$ examples $\xi$.

# Under these conditions, we can write 
# ```math 
# \begin{align}
#  P(W) &= e^{-\frac{N}{2} \mathrm{Tr}[(W-W_0)^T \Sigma^{-1} (W-W_0)]}
# \end{align}
# ```
# with row covariance of the form $\Sigma = \alpha \mathbb{1} - \beta v v^T$. Compared to the paper, in the code, we denote: $\beta(\alpha)=\frac{b\alpha^2}{1+\alpha b}$ to shorten code defitions below. With these variables, $W_0=\beta \lambda/\alpha u v^T$.


# Observed weight
# ```math 
# \begin{align}
#  W = W_0 + X
# \end{align}
# ```
# with $\langle X \rangle = 0$ and row covariance of $X$ is $\Sigma$ (column covariance is $\mathbb{1}$).
# """
 
MTUtils.set_default_plot_params!()
 
"""
    get_s_and_overlaps(W, u, v)

Return the minimal singular value `s` of `W` and
its left/right singular vectors' overlaps with `v` and `u`.
"""
function get_s_and_overlaps(W::AbstractMatrix,
                                 u::AbstractVector,
                                 v::AbstractVector)

    SVD = svd(W; full = false)
    s_vals = SVD.S
    imin = argmin(s_vals)

    _v = SVD.U[:, imin]          # left singular vector (K-vector)
    _u = SVD.Vt[imin, :]'        # right singular vector (N-vector)

    o_u = dot(u, _u)
    o_v = dot(v, _v)
    s    = s_vals[imin]

    return s, o_u, o_v
end
 
function draw_weight_pre(;W₀,L,K,N) 
    # i.i.d. 𝒩(0, 1/N) noise
    X = randn(K, N) ./ √N 
    X = L * X

    return X + W₀
end
 
function sout(;α,β,λ,N,K)
	if !(β>λ^2 && (β-λ^2)/α > sqrt(K/N)) && λ  < √α * (1+sqrt(K/N))
		return NaN
	end
	sqrt(α + (λ^2-β) + K/N * α/(1-α/(α + (λ^2-β)))) 
end
 

 
"""
    sample_n(n; u, v, α, β, λ = 3.0)

Draw `n` independent weight matrices and collect
minimal singular values and overlaps.
"""
function sample_n(n::Integer; u, v, α, β, λ = 3.0)
    s_s  = Float64[]      # minimal σ
    o_us = Float64[]      # overlaps with u
    o_vs = Float64[]      # overlaps with v

	Σ = α * I(K) - β * (v * v')                 # K × K 
    W₀ = λ .* (v * u')                           # K × N 
    L = cholesky(Hermitian(Σ)).L    
    for _ in 1:n
		W = draw_weight_pre(;W₀,L,K,N) 
        s, o_u, o_v = get_s_and_overlaps(W, u, v)
        push!(s_s,  s)
        push!(o_us, o_u)
        push!(o_vs, o_v)
    end
    return s_s, o_us, o_vs
end
 
"""Plotting function""" 
function plot_beta_dependence(βs,sout_means,sout_stds,ovl_u_means,ovl_u_stds,ovl_v_means,ovl_v_stds;λ,α,n_samples)

	colors = MTUtils.get_colors()
	
	# set plot theme environment
	fig, axs = with_theme(MTUtils.get_theme(;width=2*8.6, height_width_ratio=0.68/2.2 )) do

		# MP bulk lower bound
		y_max = √α * (1-sqrt(K/N))
		# maximum value of beta
		βmin2(;α,λ,N,K) = (α^2-α^(3/2)*sqrt(α-4*sqrt(K/N)*λ^2))/(2*λ^2)
		# corresponding value of 1-β/α at which outlier fuses with bulk
		x_max = 1-βmin2(;α,λ,N,K)/α

		# setup figure and axes and label them
		fig = Figure()
		ax_s = Axis(fig[1,1];ylabel=L"\langle\nu_{\mathrm{min}}\rangle",yticks=([0.0,0.2,0.4,y_max],[L"0.0",L"0.2",L"0.4",L"\nu_{-}"]))
		ax_uv = Axis(fig[1,2];ylabel=L"\mathrm{overlap}") 
		axs = [ax_s,ax_uv] 

		for ax in axs; ax.xlabel=L"1/(1+\alpha\beta)"; end

		# select x-axis values where outlier exists
		idx = 1 .- βs/α .<= x_max 

		# plot outlier location
		lines!(ax_s, (1 .- βs/α)[idx], sout_means[idx]; linewidth=2) 
		# plot overlaps
		scatter!(ax_uv, 1 .- βs/α, ovl_u_means; label=L"\langle|u\cdot \tilde{u}|\rangle", MTUtils.nice_points(colors[1])...) 
		scatter!(ax_uv, 1 .- βs/α, ovl_v_means; label=L"\langle|v\cdot \tilde{v}|\rangle",  MTUtils.nice_points(colors[2])...)
  
		# errorbars 
		band!(ax_s,1 .- βs/α, sout_means - sout_stds/sqrt(n_samples), sout_means + sout_stds/sqrt(n_samples);alpha=0.5)
		band!(ax_uv,1 .- βs/α, ovl_u_means - ovl_u_stds/sqrt(n_samples), ovl_u_means + ovl_u_stds/sqrt(n_samples);alpha=0.5,color=colors[1], label=nothing)
		band!(ax_uv,1 .- βs/α, ovl_v_means - ovl_v_stds/sqrt(n_samples), ovl_v_means + ovl_v_stds/sqrt(n_samples);alpha=0.5,color=colors[2], label=nothing)

		# add legend
		axislegend(ax_uv; position=:lb, pad=0, margin = (0, 0, 0, 0), patchlabelgap = 0, rowgap=-7 )

		# set limits
		ylims!(ax_s,0.0,√α * (1-sqrt(K/N))+0.1)
		ylims!(ax_uv,-0.1,1.1) 

		# add theory curve for outlier location 
		_sout = [sout(;α,β,λ=λ*β/α,N,K) for β in βs] 
		plt_kwargs = (;color=:red, linestyle=:dash)
		lines!(ax_s, 1 .- βs/α, _sout; plt_kwargs...)  

		# add vertical line where outlier fuses with bulk  
		for ax in axs; lines!(ax,[x_max,x_max],[-0.1,1.1];color=:black,linestyle=:dot,linewidth=0.5); end 

		# add MP bulk gray bands
		y_max = √α * (1-sqrt(K/N))  
		band!(ax_s,[-10,10],[y_max,y_max],[10,10], color=:gray, alpha=0.4)
		text!(ax_s,0.2,y_max+0.02;text=L"\mathrm{MP\, bulk}",rotation =0, fontsize=10) 
		
		for ax in [ax_uv]; band!(ax, [x_max,0.6], [-10,-10], [10,10], color=:gray, alpha=0.4) ;end
		for ax in axs; xlims!(ax,-0.02,0.53) ;end

		# add minimum of outlier for the given lambda
		ymin = λ * sqrt(1-K/N  / (1-λ^2/α)) 
		lines!(ax_s,[-0.1,0.04],[ymin,ymin];color=:black,linestyle=:dot,linewidth=0.5)
		text!(ax_s,0.05,ymin, align=(:left,:center),text=L"\lambda\sqrt{1-\frac{q\,\alpha}{\alpha-\lambda^2}}",fontsize = 10)

		# add random overlap line
		ovp_rnd = sqrt(2/pi) * 1/sqrt(N)
		lines!(ax_uv,[0.3,0.51],[ovp_rnd,ovp_rnd];linestyle=:dash)
	
		# return figure and axes from this scope
		fig, axs
	end
	fig, axs
end
 
begin
	#"""Draw parameters"""  
	N, K = 2048, 512
	u = randn(N) 
	u /= norm(u)
	v = randn(K)
	v /= norm(v) 

	# set plot parameters
	local λ = 0.2 
	local α = 1.0
	local βs = α*sqrt(K/N):0.001:α

	# macro computes inside if result cannot be loaded, otherwise load saved result 
	local (sout_means,sout_stds,ovl_u_means,ovl_u_stds,ovl_v_means,ovl_v_stds,_) = 
	MTUtils.@run_or_load f"pluto_out/s_and_overlaps_alp{α:3.5f}_lambda{λ:3.5f}_tilde_manuscript.jld2" begin
		# initialize outlier locations, overlaps, and errors
		sout_means = zeros(length(βs))
		sout_stds = zeros(length(βs))
		ovl_u_means = zeros(length(βs))
		ovl_u_stds = zeros(length(βs))
		ovl_v_means = zeros(length(βs))
		ovl_v_stds = zeros(length(βs))

		# sample 100 matrices and average quantities
		@progress for (i,β) in enumerate(βs)
			souts,ovl_us,ovl_vs = sample_n(100; u, v, α, β, λ=β*λ/α)

			sout_means[i] = mean(souts)
			sout_stds[i] = std(souts)
			ovl_u_means[i] = mean(abs.(ovl_us))
			ovl_u_stds[i] = std(abs.(ovl_us))
			ovl_v_means[i] = mean(abs.(ovl_vs))
			ovl_v_stds[i] = std(abs.(ovl_vs))
		end
		
		sout_means,sout_stds,ovl_u_means,ovl_u_stds,ovl_v_means,ovl_v_stds,(;βs,α,λ)
	end 

	# plot data
	local fig, axs = plot_beta_dependence(βs,sout_means,sout_stds,ovl_u_means,ovl_u_stds,ovl_v_means,ovl_v_stds;α,λ,n_samples=100)
	# save to file
	save("Neurips_plot7.pdf",fig)
	
	fig
end
 