#!/bin/bash
# Sweep orchestration script for DB-KSVD vs SAE comparison
# Usage: ./scripts/run_sweep.sh [phase]
# Phases: all, extract, train-sae, train-ksvd, eval-recon, eval-cluster
#
# This script is idempotent - it skips steps whose outputs already exist.
# To re-run a step, delete its output file first.

set -euo pipefail

# Configuration
MODELS=("vits14" "vitb14")
SPARSITIES=(16 32 64)
DICT_SIZE=4096
SAE_STEPS=50000
MATRYOSHKA="${MATRYOSHKA:-false}"
DRY_RUN="${DRY_RUN:-false}"

# Directories
DATA_DIR="data"
MODEL_DIR="models"
RESULTS_DIR="results"
CACHE_DIR="cache"
LOG_DIR="logs/sweep"

# Colors for output
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
BLUE='\033[0;34m'
NC='\033[0m' # No Color

# Logging
mkdir -p "$LOG_DIR"
MASTER_LOG="$LOG_DIR/sweep_$(date +%Y%m%d_%H%M%S).log"

log() {
    local level="$1"
    local msg="$2"
    local timestamp=$(date '+%Y-%m-%d %H:%M:%S')
    echo -e "[$timestamp] [$level] $msg" | tee -a "$MASTER_LOG"
}

log_info()  { log "INFO" "$1"; }
log_ok()    { log "${GREEN}OK${NC}" "$1"; }
log_skip()  { log "${YELLOW}SKIP${NC}" "$1"; }
log_error() { log "${RED}ERROR${NC}" "$1"; }
log_phase() { echo -e "\n${BLUE}========== $1 ==========${NC}" | tee -a "$MASTER_LOG"; }

# Run a command with logging, return success/failure
run_cmd() {
    local desc="$1"
    local logfile="$2"
    shift 2
    local cmd="$@"

    log_info "Starting: $desc"
    log_info "Log: $logfile"
    log_info "Command: $cmd"

    mkdir -p "$(dirname "$logfile")"

    if eval "$cmd" > "$logfile" 2>&1; then
        log_ok "Completed: $desc"
        return 0
    else
        local exit_code=$?
        log_error "Failed: $desc (exit code $exit_code)"
        log_error "Check log: $logfile"
        echo "--- Last 20 lines of $logfile ---" >> "$MASTER_LOG"
        tail -20 "$logfile" >> "$MASTER_LOG" 2>/dev/null || true
        return $exit_code
    fi
}

# Check if output file exists
exists() {
    [[ -f "$1" ]]
}

# ============================================================================
# Phase 1: Extraction
# ============================================================================
phase_extract() {
    log_phase "PHASE 1: EXTRACTION"

    mkdir -p "$DATA_DIR" "$LOG_DIR"

    local failed=0

    for model in "${MODELS[@]}"; do
        local output="$DATA_DIR/imagenet_${model}.h5"
        local logfile="$LOG_DIR/extract_${model}.log"

        if exists "$output"; then
            log_skip "Embeddings exist: $output"
            continue
        fi

        local dino_model="dinov2_${model}"

        if ! run_cmd "Extract $model embeddings" "$logfile" \
            "uv run python -u -m src.cli extract \
                --output=$output \
                --model=$dino_model \
                --dataset=imagenet-1k \
                --cache-dir=/nfs/redacted/huggingface"; then
            ((failed++))
        fi
    done

    return $failed
}

# ============================================================================
# Phase 2a: Train SAE models
# ============================================================================
phase_train_sae() {
    log_phase "PHASE 2a: TRAIN SAE"

    local failed=0

    for model in "${MODELS[@]}"; do
        mkdir -p "$MODEL_DIR/$model"
        local embeddings="$DATA_DIR/imagenet_${model}.h5"

        if ! exists "$embeddings"; then
            log_error "Missing embeddings: $embeddings (run extract phase first)"
            ((failed++))
            continue
        fi

        for k in "${SPARSITIES[@]}"; do
            local output="$MODEL_DIR/$model/sae_k${k}.pt"
            local logfile="$LOG_DIR/train_sae_${model}_k${k}.log"

            if exists "$output"; then
                log_skip "SAE model exists: $output"
                continue
            fi

            if ! run_cmd "Train SAE $model k=$k" "$logfile" \
                "uv run python -u -m src.cli train-sae \
                    --embeddings=$embeddings \
                    --output=$output \
                    --dict-size=$DICT_SIZE \
                    --k=$k \
                    --num-steps=$SAE_STEPS"; then
                ((failed++))
            fi
        done
    done

    return $failed
}

# ============================================================================
# Phase 2b: Train KSVD models
# ============================================================================
phase_train_ksvd() {
    log_phase "PHASE 2b: TRAIN KSVD"

    local failed=0
    local matryoshka_suffix=""
    local matryoshka_arg=""
    [[ "$MATRYOSHKA" == "true" ]] && matryoshka_suffix="_matryoshka" && matryoshka_arg="--matryoshka"

    for model in "${MODELS[@]}"; do
        mkdir -p "$MODEL_DIR/$model"
        local embeddings="$DATA_DIR/imagenet_${model}.h5"

        if ! exists "$embeddings"; then
            log_error "Missing embeddings: $embeddings (run extract phase first)"
            ((failed++))
            continue
        fi

        for k in "${SPARSITIES[@]}"; do
            local output="$MODEL_DIR/$model/ksvd_k${k}${matryoshka_suffix}.npy"
            local logfile="$LOG_DIR/train_ksvd_${model}_k${k}${matryoshka_suffix}.log"

            if exists "$output"; then
                log_skip "KSVD model exists: $output"
                continue
            fi

            if ! run_cmd "Train KSVD $model k=$k${matryoshka_suffix}" "$logfile" \
                "JULIA_NUM_THREADS=auto julia --project=. scripts/ksvd_dino.jl \
                    $embeddings $output \
                    --dict-size=$DICT_SIZE --nnz=$k $matryoshka_arg"; then
                ((failed++))
            fi
        done
    done

    return $failed
}

# ============================================================================
# Phase 3a: Eval reconstruction
# ============================================================================
phase_eval_recon() {
    log_phase "PHASE 3a: EVAL RECONSTRUCTION"

    mkdir -p "$RESULTS_DIR" "$CACHE_DIR"

    local failed=0
    local matryoshka_suffix=""
    [[ "$MATRYOSHKA" == "true" ]] && matryoshka_suffix="_matryoshka"

    for model in "${MODELS[@]}"; do
        local embeddings="$DATA_DIR/imagenet_${model}.h5"
        local results_file="$RESULTS_DIR/recon_${model}${matryoshka_suffix}.txt"

        # Clear results file for fresh run
        > "$results_file"

        for k in "${SPARSITIES[@]}"; do
            local sae_model="$MODEL_DIR/$model/sae_k${k}.pt"
            local sae_dict="$MODEL_DIR/$model/sae_k${k}.D.npy"
            local ksvd_dict="$MODEL_DIR/$model/ksvd_k${k}${matryoshka_suffix}.npy"
            local cache_prefix="$CACHE_DIR/${model}_k${k}${matryoshka_suffix}"
            local logfile="$LOG_DIR/eval_recon_${model}_k${k}${matryoshka_suffix}.log"

            # Check dependencies
            local missing=""
            exists "$embeddings" || missing+=" embeddings"
            exists "$sae_model"  || missing+=" sae_model"
            exists "$sae_dict"   || missing+=" sae_dict"
            exists "$ksvd_dict"  || missing+=" ksvd_dict"

            if [[ -n "$missing" ]]; then
                log_error "Missing for $model k=$k${matryoshka_suffix}:$missing"
                ((failed++))
                continue
            fi

            echo "=== ${model} k=${k}${matryoshka_suffix} ===" >> "$results_file"

            if ! run_cmd "Eval recon $model k=$k${matryoshka_suffix}" "$logfile" \
                "uv run python -u -m src.cli evaluate \
                    --embeddings=$embeddings \
                    --sae-model=$sae_model \
                    --sae-dict=$sae_dict \
                    --ksvd-dict=$ksvd_dict \
                    --k=$k \
                    --evals=recon \
                    --cache-codes=$cache_prefix"; then
                ((failed++))
            else
                # Append results
                cat "$logfile" >> "$results_file"
            fi
        done

        log_info "Results written to: $results_file"
    done

    return $failed
}

# ============================================================================
# Phase 3b: Eval clustering (uses cached codes)
# ============================================================================
phase_eval_cluster() {
    log_phase "PHASE 3b: EVAL CLUSTERING"

    mkdir -p "$RESULTS_DIR"

    local failed=0
    local matryoshka_suffix=""
    [[ "$MATRYOSHKA" == "true" ]] && matryoshka_suffix="_matryoshka"

    for model in "${MODELS[@]}"; do
        local embeddings="$DATA_DIR/imagenet_${model}.h5"
        local results_file="$RESULTS_DIR/cluster_${model}${matryoshka_suffix}.txt"

        # Clear results file for fresh run
        > "$results_file"

        for k in "${SPARSITIES[@]}"; do
            local sae_model="$MODEL_DIR/$model/sae_k${k}.pt"
            local sae_dict="$MODEL_DIR/$model/sae_k${k}.D.npy"
            local ksvd_dict="$MODEL_DIR/$model/ksvd_k${k}${matryoshka_suffix}.npy"
            local cache_prefix="$CACHE_DIR/${model}_k${k}${matryoshka_suffix}"
            local logfile="$LOG_DIR/eval_cluster_${model}_k${k}${matryoshka_suffix}.log"

            # Check if cached codes exist (from recon phase)
            # Note: evaluate.py creates files at {cache_dir}/codes_k{k}_*.npz
            if ! exists "${cache_prefix}/codes_k${k}_sae_train.npz"; then
                log_error "Missing cached codes: ${cache_prefix}/codes_k${k}_sae_train.npz (run eval-recon first)"
                ((failed++))
                continue
            fi

            echo "=== ${model} k=${k}${matryoshka_suffix} ===" >> "$results_file"

            if ! run_cmd "Eval cluster $model k=$k${matryoshka_suffix}" "$logfile" \
                "uv run python -u -m src.cli evaluate \
                    --embeddings=$embeddings \
                    --sae-model=$sae_model \
                    --sae-dict=$sae_dict \
                    --ksvd-dict=$ksvd_dict \
                    --k=$k \
                    --evals=cluster \
                    --cache-codes=$cache_prefix"; then
                ((failed++))
            else
                # Append results
                cat "$logfile" >> "$results_file"
            fi
        done

        log_info "Results written to: $results_file"
    done

    return $failed
}

# ============================================================================
# Phase 3c: Eval linear probe (uses cached codes, GPU-accelerated)
# ============================================================================
phase_eval_probe() {
    log_phase "PHASE 3c: EVAL LINEAR PROBE"

    mkdir -p "$RESULTS_DIR"

    local failed=0
    local matryoshka_suffix=""
    [[ "$MATRYOSHKA" == "true" ]] && matryoshka_suffix="_matryoshka"

    for model in "${MODELS[@]}"; do
        local embeddings="$DATA_DIR/imagenet_${model}.h5"
        local results_file="$RESULTS_DIR/probe_${model}${matryoshka_suffix}.txt"

        # Clear results file for fresh run
        > "$results_file"

        for k in "${SPARSITIES[@]}"; do
            local sae_model="$MODEL_DIR/$model/sae_k${k}.pt"
            local sae_dict="$MODEL_DIR/$model/sae_k${k}.D.npy"
            local ksvd_dict="$MODEL_DIR/$model/ksvd_k${k}${matryoshka_suffix}.npy"
            local cache_prefix="$CACHE_DIR/${model}_k${k}${matryoshka_suffix}"
            local logfile="$LOG_DIR/eval_probe_${model}_k${k}${matryoshka_suffix}.log"

            # Check if cached codes exist (from recon phase)
            # Note: evaluate.py creates files at {cache_dir}/codes_k{k}_*.npz
            if ! exists "${cache_prefix}/codes_k${k}_sae_train.npz"; then
                log_error "Missing cached codes: ${cache_prefix}/codes_k${k}_sae_train.npz (run eval-recon first)"
                ((failed++))
                continue
            fi

            echo "=== ${model} k=${k}${matryoshka_suffix} ===" >> "$results_file"

            if ! run_cmd "Eval probe $model k=$k${matryoshka_suffix}" "$logfile" \
                "uv run python -u -m src.cli evaluate \
                    --embeddings=$embeddings \
                    --sae-model=$sae_model \
                    --sae-dict=$sae_dict \
                    --ksvd-dict=$ksvd_dict \
                    --k=$k \
                    --evals=probe \
                    --cache-codes=$cache_prefix"; then
                ((failed++))
            else
                # Append results
                cat "$logfile" >> "$results_file"
            fi
        done

        log_info "Results written to: $results_file"
    done

    return $failed
}

# ============================================================================
# Phase 3d: Eval sparse probe (variance-based feature selection)
# ============================================================================
phase_eval_sparse_probe() {
    log_phase "PHASE 3d: EVAL SPARSE PROBE"

    mkdir -p "$RESULTS_DIR"

    local failed=0
    local matryoshka_suffix=""
    [[ "$MATRYOSHKA" == "true" ]] && matryoshka_suffix="_matryoshka"

    for model in "${MODELS[@]}"; do
        local results_file="$RESULTS_DIR/sparse_probe_${model}${matryoshka_suffix}.txt"

        # Clear results file for fresh run
        > "$results_file"

        for k in "${SPARSITIES[@]}"; do
            local cache_prefix="$CACHE_DIR/${model}_k${k}${matryoshka_suffix}"
            local logfile="$LOG_DIR/eval_sparse_probe_${model}_k${k}${matryoshka_suffix}.log"

            # Check if cached codes and labels exist
            if ! exists "${cache_prefix}/codes_k${k}_sae_train.npz"; then
                log_error "Missing cached codes: ${cache_prefix}/codes_k${k}_sae_train.npz (run eval-recon first)"
                ((failed++))
                continue
            fi
            if ! exists "${cache_prefix}/labels.npz"; then
                log_error "Missing labels: ${cache_prefix}/labels.npz"
                ((failed++))
                continue
            fi

            echo "=== ${model} k=${k}${matryoshka_suffix} ===" >> "$results_file"

            local json_output="$RESULTS_DIR/sparse_probe_${model}_k${k}${matryoshka_suffix}.json"

            # Build command with optional dry-run settings
            local dry_args=""
            local device="cuda"
            local epochs=10
            if [[ "$DRY_RUN" == "true" ]]; then
                dry_args="--max-samples=10000 --max-classes=20"
                device="cpu"
                epochs=3
            fi

            if ! run_cmd "Eval sparse probe $model k=$k${matryoshka_suffix}" "$logfile" \
                "uv run python -u scripts/sparse_probe.py \
                    --model=${model} \
                    --sparsity-k=${k} \
                    --probe-k 1 2 5 \
                    --device=${device} \
                    --epochs=${epochs} \
                    --cache-dir=$CACHE_DIR \
                    --output=$json_output \
                    $dry_args"; then
                ((failed++))
            else
                # Append results
                cat "$logfile" >> "$results_file"
            fi
        done

        log_info "Results written to: $results_file"
    done

    return $failed
}

# ============================================================================
# Summary
# ============================================================================
print_summary() {
    log_phase "SWEEP SUMMARY"

    echo ""
    echo "=== Embeddings ==="
    for model in "${MODELS[@]}"; do
        local f="$DATA_DIR/imagenet_${model}.h5"
        if exists "$f"; then
            local size=$(du -h "$f" | cut -f1)
            echo "  [OK] $f ($size)"
        else
            echo "  [MISSING] $f"
        fi
    done

    echo ""
    echo "=== Models ==="
    local matryoshka_suffix=""
    [[ "$MATRYOSHKA" == "true" ]] && matryoshka_suffix="_matryoshka"
    for model in "${MODELS[@]}"; do
        for k in "${SPARSITIES[@]}"; do
            local sae="$MODEL_DIR/$model/sae_k${k}.pt"
            local ksvd="$MODEL_DIR/$model/ksvd_k${k}${matryoshka_suffix}.npy"
            local sae_status="MISSING"
            local ksvd_status="MISSING"
            exists "$sae" && sae_status="OK"
            exists "$ksvd" && ksvd_status="OK"
            echo "  $model k=$k: SAE=[$sae_status] KSVD${matryoshka_suffix}=[$ksvd_status]"
        done
    done

    echo ""
    echo "=== Results ==="
    for f in "$RESULTS_DIR"/*.txt; do
        if [[ -f "$f" ]]; then
            local lines=$(wc -l < "$f")
            echo "  $f ($lines lines)"
        fi
    done 2>/dev/null || echo "  (no results yet)"

    echo ""
    echo "Master log: $MASTER_LOG"
}

# ============================================================================
# Main
# ============================================================================
main() {
    local phase="all"

    # Parse arguments
    for arg in "$@"; do
        case "$arg" in
            --matryoshka)
                MATRYOSHKA="true"
                ;;
            --model=*)
                MODELS=("${arg#*=}")
                ;;
            --k=*)
                SPARSITIES=("${arg#*=}")
                ;;
            --dry)
                DRY_RUN="true"
                ;;
            extract|train|train-sae|train-ksvd|eval|eval-recon|eval-cluster|eval-probe|eval-sparse-probe|all|summary)
                phase="$arg"
                ;;
            -h|--help)
                echo "Usage: $0 [phase] [options]"
                echo ""
                echo "Phases: all, extract, train, train-sae, train-ksvd, eval, eval-recon, eval-cluster, eval-probe, eval-sparse-probe, summary"
                echo ""
                echo "Options:"
                echo "  --model=MODEL    Run only for specific model (vits14 or vitb14)"
                echo "  --k=K            Run only for specific sparsity (16, 32, or 64)"
                echo "  --matryoshka     Use matryoshka training for KSVD"
                echo "  --dry            Quick local test (reduced samples, fewer epochs)"
                exit 0
                ;;
            *)
                echo "Unknown argument: $arg" >&2
                exit 1
                ;;
        esac
    done

    log_info "Starting sweep with phase: $phase"
    log_info "Models: ${MODELS[*]}, Sparsities: ${SPARSITIES[*]}, Matryoshka: $MATRYOSHKA, Dry: $DRY_RUN"
    log_info "Master log: $MASTER_LOG"

    local total_failed=0

    case "$phase" in
        extract)
            phase_extract || ((total_failed+=$?))
            ;;
        train-sae)
            phase_train_sae || ((total_failed+=$?))
            ;;
        train-ksvd)
            phase_train_ksvd || ((total_failed+=$?))
            ;;
        train)
            phase_train_sae || ((total_failed+=$?))
            phase_train_ksvd || ((total_failed+=$?))
            ;;
        eval-recon)
            phase_eval_recon || ((total_failed+=$?))
            ;;
        eval-cluster)
            phase_eval_cluster || ((total_failed+=$?))
            ;;
        eval-probe)
            phase_eval_probe || ((total_failed+=$?))
            ;;
        eval-sparse-probe)
            phase_eval_sparse_probe || ((total_failed+=$?))
            ;;
        eval)
            phase_eval_recon || ((total_failed+=$?))
            phase_eval_cluster || ((total_failed+=$?))
            phase_eval_probe || ((total_failed+=$?))
            phase_eval_sparse_probe || ((total_failed+=$?))
            ;;
        all)
            phase_extract || ((total_failed+=$?))
            phase_train_sae || ((total_failed+=$?))
            phase_train_ksvd || ((total_failed+=$?))
            phase_eval_recon || ((total_failed+=$?))
            phase_eval_cluster || ((total_failed+=$?))
            phase_eval_probe || ((total_failed+=$?))
            phase_eval_sparse_probe || ((total_failed+=$?))
            ;;
        summary)
            ;;
        *)
            echo "Usage: $0 [phase] [options] (use --help for details)"
            exit 1
            ;;
    esac

    print_summary

    if [[ $total_failed -gt 0 ]]; then
        log_error "Sweep completed with $total_failed failures"
        exit 1
    else
        log_ok "Sweep completed successfully"
        exit 0
    fi
}

main "$@"
