"""
DB-KSVD training on DINOv2 embeddings.

Usage:
    julia --project=KSVD.jl scripts/ksvd_dino.jl embeddings.h5 output.npy [--no-wandb] [--dict-size=4096] [--nnz=16]
"""

using HDF5
using KSVD
using NPZ
using LinearAlgebra
using Random
using ProgressBars
using StatsBase
using SparseArrays
using ArgParse

function parse_args()
    s = ArgParseSettings()
    @add_arg_table! s begin
        "embeddings"
            help = "Path to HDF5 file with embeddings"
            required = true
        "output"
            help = "Output path for dictionary (.npy)"
            required = true
        "--dict-size", "-m"
            help = "Dictionary size"
            arg_type = Int
            default = 4096
        "--nnz", "-k"
            help = "Non-zeros per column (sparsity)"
            arg_type = Int
            default = 16
        "--batch-size"
            help = "Batch size for KSVD"
            arg_type = Int
            default = 65536
        "--iters-per-batch"
            help = "KSVD iterations per batch"
            arg_type = Int
            default = 1
        "--num-repeats"
            help = "Number of passes over data"
            arg_type = Int
            default = 3
        "--val-size"
            help = "Validation set size"
            arg_type = Int
            default = 16384
        "--no-wandb"
            help = "Disable wandb logging"
            action = :store_true
        "--matryoshka"
            help = "Use Matryoshka training loop (log2min=8)"
            action = :store_true
    end
    return ArgParse.parse_args(s)
end

# Metrics
explainedsignal(Y, D, X; E=(Y - D * X)) = mean(norm.(eachcol(E)) ./ norm.(eachcol(Y)))
explainedvariance(Y, D, X; E=(Y - D * X)) = 1 - sum(var(E; dims=2)) / sum(var(Y; dims=2))

function main()
    args = parse_args()

    # Load embeddings (stored as d x n in Python/HDF5, but HDF5.jl transposes due to row/col major)
    println("Loading embeddings from $(args["embeddings"])...")
    data = h5open(args["embeddings"], "r") do f
        # HDF5 is row-major, Julia is column-major, so we need to transpose
        permutedims(read(f["embeddings"]))
    end
    d, n = size(data)
    println("Loaded $n embeddings of dimension $d")

    # Sample validation set from the end
    val_size = min(args["val-size"], n ÷ 10)
    val_data = data[:, (end-(val_size-1)):end]
    train_end = n - val_size
    println("Using $val_size samples for validation")

    # Config
    dict_size = args["dict-size"]
    nnz_per_col = args["nnz"]
    batch_size = min(args["batch-size"], train_end)  # Don't exceed available data
    iters_per_batch = args["iters-per-batch"]
    num_repeats = args["num-repeats"]
    use_wandb = !args["no-wandb"]

    if batch_size < args["batch-size"]
        println("Adjusted batch_size to $batch_size (data size: $train_end)")
    end

    use_matryoshka = args["matryoshka"]
    ksvd_loop_type = use_matryoshka ? KSVD.MatryoshkaLoop(; log2min=8) : KSVD.NormalLoop()

    config = Dict(
        "dict_size" => dict_size,
        "nnz_per_col" => nnz_per_col,
        "batch_size" => batch_size,
        "iters_per_batch" => iters_per_batch,
        "num_repeats" => num_repeats,
        "embedding_dim" => d,
        "num_samples" => train_end,
        "matryoshka" => use_matryoshka,
    )

    # Initialize wandb if enabled
    lg = nothing
    if use_wandb
        @warn "Wandb support requires adding Wandb.jl to Project.toml. Running without logging."
        use_wandb = false
    end

    sparse_coding_method = KSVD.ParallelMatchingPursuit(; max_nnz=nnz_per_col, refit_coeffs=false)

    D = nothing
    D_init = nothing

    function callback_fn((; iter, Y, D, X, norm_val, nnz_per_col_val))
        variance_expl = explainedvariance(Y, D, X)
        val_X = KSVD.sparse_coding(sparse_coding_method, val_data, D)
        val_norm_val = explainedsignal(val_data, D, val_X)
        val_variance_expl = explainedvariance(val_data, D, val_X)
        val_nnz = nnz(val_X) / size(val_X, 2)

        usagecounts = countmap(X.rowval)
        minusage = isempty(usagecounts) ? 0 : minimum(values(usagecounts))
        numunused = length(setdiff(axes(X, 1), unique(sort(X.rowval))))

        metrics = Dict(
            "num_unused_dicts" => numunused,
            "dict_min_usage" => minusage,
            "train_norm_val" => norm_val,
            "train_variance_expl" => variance_expl,
            "train_nnz_per_col" => nnz_per_col_val,
            "val_norm_val" => val_norm_val,
            "val_variance_expl" => val_variance_expl,
            "val_nnz_per_col" => val_nnz,
        )

        if !isnothing(lg)
            Wandb.log(lg, metrics)
        else
            println("  val_variance_expl=$(round(val_variance_expl, digits=4)), val_nnz=$(round(val_nnz, digits=1))")
        end
        return nothing
    end

    loop_str = use_matryoshka ? "Matryoshka(log2min=8)" : "Normal"
    println("Training DB-KSVD with dict_size=$dict_size, nnz=$nnz_per_col, loop=$loop_str...")
    for rep in 1:num_repeats
        println("Repeat $rep/$num_repeats")
        batch_indices = collect(Iterators.partition(1:train_end, batch_size))

        for (i, batch_idx) in enumerate(batch_indices)
            if length(batch_idx) != batch_size
                continue  # skip incomplete batches
            end

            Y = copy(data[:, batch_idx])
            D_init = isnothing(D) ? nothing : copy(D)

            println("  Batch $i/$(length(batch_indices)), samples $(first(batch_idx))-$(last(batch_idx))")
            res = KSVD.ksvd(Y, dict_size;
                sparse_coding_method,
                ksvd_loop_type,
                verbose=false,
                show_trace=false,
                D_init,
                maxiters=iters_per_batch,
                callback_fn,
                abstol=nothing,
                reltol=nothing,
            )
            D = res.D
        end
    end

    # Save dictionary
    println("Saving dictionary to $(args["output"])...")
    npzwrite(args["output"], D)

    # Also save sparse codes for validation set
    val_X = KSVD.sparse_coding(sparse_coding_method, val_data, D)
    val_X_path = replace(args["output"], ".npy" => "_val_codes.npz")
    # Convert sparse matrix to dense for npz (or save as sparse components)
    npzwrite(val_X_path, Matrix(val_X))
    println("Saved validation codes to $val_X_path")

    if !isnothing(lg)
        close(lg)
    end

    println("Done!")
end

main()
