#!/bin/bash
# Autointerp evaluation script for SAE vs KSVD comparison
# Usage: ./scripts/run_autointerp.sh [phase] [options]
#
# Phases: all, embeddings, codes, eval, summary
# Options: --model=vits14, --k=32, --dry-run
#
# Architecture:
# 1. Embeddings are extracted in sorted order → embedding[i] = dataset[i]
# 2. Codes are computed per dictionary and cached (hash-based)
# 3. VLM + CLIP evaluation runs on cached codes
#
# This script is idempotent - skips steps whose outputs exist.

set -euo pipefail

# Julia setup for juliacall
export PYTHON_JULIACALL_EXE="${PYTHON_JULIACALL_EXE:-$(which julia)}"
export PYTHON_JULIACALL_PROJECT="${PYTHON_JULIACALL_PROJECT:-$(pwd)}"
export PYTHON_JULIACALL_HANDLE_SIGNALS=yes
export PYTHON_JULIACALL_THREADS=auto

# Source secrets
[[ -f .secrets ]] && source .secrets

# ============================================================================
# Configuration (can be overridden by env vars or command line)
# ============================================================================
DRY_RUN="${DRY_RUN:-false}"

# Default to single model/sparsity for targeted runs
MODEL="${MODEL:-vits14}"
K="${K:-32}"
MATRYOSHKA="${MATRYOSHKA:-false}"

# Directories
HF_CACHE="${HF_CACHE:-/scratch/redacted/huggingface}"
WORK_DIR="${WORK_DIR:-cache/autointerp}"
MODEL_DIR="${MODEL_DIR:-models}"
RESULTS_DIR="${RESULTS_DIR:-results/autointerp}"
LOG_DIR="${LOG_DIR:-logs/autointerp}"

# Eval settings
N_FEATURES="${N_FEATURES:-100}"

# Dry run settings
if [[ "$DRY_RUN" == "true" ]]; then
    DATASET="cifar100"
    N_SAMPLES=1000
    N_FEATURES=5
    HF_CACHE=""  # use default cache for dry run
else
    DATASET="${DATASET:-imagenet-1k}"
    N_SAMPLES="${N_SAMPLES:-}"  # empty = full dataset
fi

# ============================================================================
# Colors and logging
# ============================================================================
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
BLUE='\033[0;34m'
NC='\033[0m'

mkdir -p "$LOG_DIR"
MASTER_LOG="$LOG_DIR/autointerp_$(date +%Y%m%d_%H%M%S).log"

log() {
    local level="$1"
    local msg="$2"
    local timestamp=$(date '+%Y-%m-%d %H:%M:%S')
    echo -e "[$timestamp] [$level] $msg" | tee -a "$MASTER_LOG"
}

log_info()  { log "INFO" "$1"; }
log_ok()    { log "${GREEN}OK${NC}" "$1"; }
log_skip()  { log "${YELLOW}SKIP${NC}" "$1"; }
log_error() { log "${RED}ERROR${NC}" "$1"; }
log_phase() { echo -e "\n${BLUE}========== $1 ==========${NC}" | tee -a "$MASTER_LOG"; }

run_cmd() {
    local desc="$1"
    local logfile="$2"
    shift 2
    local cmd="$@"

    log_info "Starting: $desc"
    log_info "Log: $logfile"

    mkdir -p "$(dirname "$logfile")"

    if eval "$cmd" 2>&1 | tee "$logfile" | tail -20; then
        log_ok "Completed: $desc"
        return 0
    else
        local exit_code=$?
        log_error "Failed: $desc (exit code $exit_code)"
        return $exit_code
    fi
}

exists() { [[ -f "$1" ]]; }

# ============================================================================
# Derived paths
# ============================================================================
get_dino_model() {
    case "$MODEL" in
        vits14) echo "dinov2_vits14" ;;
        vitb14) echo "dinov2_vitb14" ;;
        *) log_error "Unknown model: $MODEL"; exit 1 ;;
    esac
}

SAE_DICT="$MODEL_DIR/$MODEL/sae_k${K}.D.npy"
KSVD_DICT="$MODEL_DIR/$MODEL/ksvd_k${K}.npy"
DINO_MODEL=$(get_dino_model)
EMB_CACHE="$WORK_DIR/embeddings_${DATASET}_${DINO_MODEL}_${N_SAMPLES:-all}.npz"

# ============================================================================
# Phase 1: Extract embeddings (sorted order)
# ============================================================================
phase_embeddings() {
    log_phase "PHASE 1: EXTRACT EMBEDDINGS"

    mkdir -p "$WORK_DIR" "$LOG_DIR"

    if exists "$EMB_CACHE"; then
        log_skip "Embeddings cached: $EMB_CACHE"
        return 0
    fi

    # Need a dictionary to run (just to trigger extraction)
    if ! exists "$SAE_DICT"; then
        log_error "Dictionary not found: $SAE_DICT"
        return 1
    fi

    local samples_desc="${N_SAMPLES:-all}"
    local logfile="$LOG_DIR/embeddings_${MODEL}_${samples_desc}.log"
    local cache_arg=""
    local samples_arg=""
    [[ -n "$HF_CACHE" ]] && cache_arg="--cache-dir=$HF_CACHE"
    [[ -n "$N_SAMPLES" ]] && samples_arg="--n-samples=$N_SAMPLES"

    run_cmd "Extract embeddings $MODEL ($samples_desc samples)" "$logfile" \
        "uv run python -u -m src.autointerp \
            --dictionary=$SAE_DICT \
            --k=$K \
            $samples_arg \
            --dataset=$DATASET \
            --work-dir=$WORK_DIR \
            $cache_arg \
            --embeddings-only"
}

# ============================================================================
# Phase 2: Compute sparse codes (cached per dictionary+k)
# ============================================================================
phase_codes() {
    log_phase "PHASE 2: COMPUTE SPARSE CODES"

    local failed=0

    if ! exists "$EMB_CACHE"; then
        log_error "Embeddings not found: $EMB_CACHE (run embeddings phase first)"
        return 1
    fi

    local cache_arg=""
    local samples_arg=""
    local matryoshka_arg=""
    [[ -n "$HF_CACHE" ]] && cache_arg="--cache-dir=$HF_CACHE"
    [[ -n "$N_SAMPLES" ]] && samples_arg="--n-samples=$N_SAMPLES"
    [[ "$MATRYOSHKA" == "true" ]] && matryoshka_arg="--matryoshka"

    # SAE codes
    if exists "$SAE_DICT"; then
        local matryoshka_suffix=""
        [[ "$MATRYOSHKA" == "true" ]] && matryoshka_suffix="_matryoshka"
        local logfile="$LOG_DIR/codes_sae_${MODEL}_k${K}${matryoshka_suffix}.log"
        run_cmd "Compute SAE codes $MODEL k=$K${matryoshka_suffix}" "$logfile" \
            "uv run python -u -m src.autointerp \
                --dictionary=$SAE_DICT \
                --k=$K \
                $samples_arg \
                --dataset=$DATASET \
                --work-dir=$WORK_DIR \
                $cache_arg \
                $matryoshka_arg \
                --codes-only" || ((failed++))
    else
        log_skip "SAE dictionary not found: $SAE_DICT"
    fi

    # KSVD codes
    if exists "$KSVD_DICT"; then
        local matryoshka_suffix=""
        [[ "$MATRYOSHKA" == "true" ]] && matryoshka_suffix="_matryoshka"
        local logfile="$LOG_DIR/codes_ksvd_${MODEL}_k${K}${matryoshka_suffix}.log"
        run_cmd "Compute KSVD codes $MODEL k=$K${matryoshka_suffix}" "$logfile" \
            "uv run python -u -m src.autointerp \
                --dictionary=$KSVD_DICT \
                --k=$K \
                $samples_arg \
                --dataset=$DATASET \
                --work-dir=$WORK_DIR \
                $cache_arg \
                $matryoshka_arg \
                --codes-only" || ((failed++))
    else
        log_skip "KSVD dictionary not found: $KSVD_DICT"
    fi

    return $failed
}

# ============================================================================
# Phase 3: VLM + CLIP evaluation
# ============================================================================
phase_eval() {
    log_phase "PHASE 3: VLM + CLIP EVALUATION"

    mkdir -p "$RESULTS_DIR"
    local failed=0

    if [[ -z "${OPENAI_API_KEY:-}" ]]; then
        log_error "OPENAI_API_KEY not set"
        return 1
    fi

    local cache_arg=""
    local samples_arg=""
    local matryoshka_arg=""
    local matryoshka_suffix=""
    [[ -n "$HF_CACHE" ]] && cache_arg="--cache-dir=$HF_CACHE"
    [[ -n "$N_SAMPLES" ]] && samples_arg="--n-samples=$N_SAMPLES"
    [[ "$MATRYOSHKA" == "true" ]] && matryoshka_arg="--matryoshka" && matryoshka_suffix="_matryoshka"

    # SAE eval
    if exists "$SAE_DICT"; then
        local output="$RESULTS_DIR/sae_${MODEL}_k${K}${matryoshka_suffix}.json"
        local logfile="$LOG_DIR/eval_sae_${MODEL}_k${K}${matryoshka_suffix}.log"

        if exists "$output"; then
            log_skip "SAE results exist: $output"
        else
            run_cmd "Eval SAE $MODEL k=$K${matryoshka_suffix}" "$logfile" \
                "uv run python -u -m src.autointerp \
                    --dictionary=$SAE_DICT \
                    --k=$K \
                    $samples_arg \
                    --n-features=$N_FEATURES \
                    --dataset=$DATASET \
                    --work-dir=$WORK_DIR \
                    $cache_arg \
                    $matryoshka_arg \
                    --output=$output" || ((failed++))
        fi
    fi

    # KSVD eval
    if exists "$KSVD_DICT"; then
        local output="$RESULTS_DIR/ksvd_${MODEL}_k${K}${matryoshka_suffix}.json"
        local logfile="$LOG_DIR/eval_ksvd_${MODEL}_k${K}${matryoshka_suffix}.log"

        if exists "$output"; then
            log_skip "KSVD results exist: $output"
        else
            run_cmd "Eval KSVD $MODEL k=$K${matryoshka_suffix}" "$logfile" \
                "uv run python -u -m src.autointerp \
                    --dictionary=$KSVD_DICT \
                    --k=$K \
                    $samples_arg \
                    --n-features=$N_FEATURES \
                    --dataset=$DATASET \
                    --work-dir=$WORK_DIR \
                    $cache_arg \
                    $matryoshka_arg \
                    --output=$output" || ((failed++))
        fi
    fi

    return $failed
}

# ============================================================================
# Summary
# ============================================================================
print_summary() {
    log_phase "AUTOINTERP SUMMARY"

    echo ""
    echo "Configuration:"
    echo "  Model: $MODEL"
    echo "  Sparsity: $K"
    echo "  Dataset: $DATASET"
    echo "  N samples: $N_SAMPLES"
    echo "  N features: $N_FEATURES"
    echo "  Matryoshka: $MATRYOSHKA"
    echo "  Dry run: $DRY_RUN"

    echo ""
    echo "=== Embeddings ==="
    if exists "$EMB_CACHE"; then
        local size=$(du -h "$EMB_CACHE" | cut -f1)
        echo "  [OK] $EMB_CACHE ($size)"
    else
        echo "  [MISSING] $EMB_CACHE"
    fi

    echo ""
    echo "=== Dictionaries ==="
    for dict in "$SAE_DICT" "$KSVD_DICT"; do
        if exists "$dict"; then
            echo "  [OK] $dict"
        else
            echo "  [MISSING] $dict"
        fi
    done

    echo ""
    echo "=== Results ==="
    local matryoshka_suffix=""
    [[ "$MATRYOSHKA" == "true" ]] && matryoshka_suffix="_matryoshka"
    for method in sae ksvd; do
        local f="$RESULTS_DIR/${method}_${MODEL}_k${K}${matryoshka_suffix}.json"
        if exists "$f"; then
            local corr=$(python3 -c "import json; d=json.load(open('$f')); print(f'{d[\"mean_correlation\"]:.3f}')" 2>/dev/null || echo "?")
            echo "  [OK] $f (mean_r=$corr)"
        else
            echo "  [MISSING] $f"
        fi
    done

    echo ""
    echo "Master log: $MASTER_LOG"
}

# ============================================================================
# Parse command line
# ============================================================================
PHASE="all"

parse_args() {
    for arg in "$@"; do
        case "$arg" in
            -h|--help)
                cat << EOF
Usage: $0 [phase] [options]

Phases:
  embeddings  Extract DINO embeddings (sorted order)
  codes       Compute sparse codes for dictionaries
  eval        Run VLM + CLIP evaluation
  all         Run all phases (default)
  summary     Show status summary

Options:
  --model=MODEL     vits14 or vitb14 (default: vits14)
  --k=K             Sparsity level (default: 32)
  --dry-run         Use cifar100, 1k samples, 5 features
  --n-samples=N     Number of samples
  --n-features=N    Features to evaluate
  --matryoshka      Use matryoshka sparse coding (log2min=8)

Environment variables:
  OPENAI_API_KEY    Required for eval phase
  HF_CACHE          HuggingFace cache (default: /scratch/redacted/huggingface)
  WORK_DIR          Working directory (default: cache/autointerp)
EOF
                exit 0
                ;;
            --model=*)
                MODEL="${arg#*=}"
                ;;
            --k=*)
                K="${arg#*=}"
                ;;
            --dry-run)
                DRY_RUN="true"
                DATASET="cifar100"
                N_SAMPLES=1000
                N_FEATURES=5
                HF_CACHE=""
                ;;
            --n-samples=*)
                N_SAMPLES="${arg#*=}"
                ;;
            --n-features=*)
                N_FEATURES="${arg#*=}"
                ;;
            --matryoshka)
                MATRYOSHKA="true"
                ;;
            embeddings|codes|eval|all|summary)
                PHASE="$arg"
                ;;
            *)
                echo "Unknown argument: $arg" >&2
                exit 1
                ;;
        esac
    done

    # Recompute derived paths after parsing
    SAE_DICT="$MODEL_DIR/$MODEL/sae_k${K}.D.npy"
    KSVD_DICT="$MODEL_DIR/$MODEL/ksvd_k${K}.npy"
    DINO_MODEL=$(get_dino_model)
    EMB_CACHE="$WORK_DIR/embeddings_${DATASET}_${DINO_MODEL}_${N_SAMPLES:-all}.npz"
}

# ============================================================================
# Main
# ============================================================================
main() {
    parse_args "$@"

    log_info "Starting autointerp with phase: $PHASE"
    log_info "Model: $MODEL, K: $K, Samples: $N_SAMPLES, Matryoshka: $MATRYOSHKA, Dry run: $DRY_RUN"
    log_info "Master log: $MASTER_LOG"

    local total_failed=0

    case "$PHASE" in
        embeddings)
            phase_embeddings || ((total_failed+=$?))
            ;;
        codes)
            phase_codes || ((total_failed+=$?))
            ;;
        eval)
            phase_eval || ((total_failed+=$?))
            ;;
        all)
            phase_embeddings || ((total_failed+=$?))
            phase_codes || ((total_failed+=$?))
            phase_eval || ((total_failed+=$?))
            ;;
        summary)
            ;;
        *)
            log_error "Unknown phase: $PHASE"
            exit 1
            ;;
    esac

    print_summary

    if [[ $total_failed -gt 0 ]]; then
        log_error "Autointerp completed with $total_failed failures"
        exit 1
    else
        log_ok "Autointerp completed successfully"
        exit 0
    fi
}

main "$@"
