set -euo pipefail

usage() {
  cat <<'USAGE'
Usage: run_pipeline.sh --model-name NAME --output-dir DIR [options]

Required parameters:
  --model-name NAME       Hugging Face model identifier (e.g. meta-llama/Llama-3.1-8B)
  --output-dir DIR        Directory to store all pipeline artifacts

Optional parameters:
  --rank INT              Target rank for randomized SVD (default: 64)
  --p INT                 Oversamples for randomized SVD (default: 10)
  --q INT                 Power iteration count for randomized SVD (default: 2)
  --shrink-alpha FLOAT    Covariance shrinkage alpha (default: 0.05)
  --calib-dataset NAME    Calibration dataset name (default: wikitext)
  --calib-config NAME     Optional calibration dataset config (default: unset)
  --nsamples INT          Number of calibration samples (default: 64)
  --seqlen INT            Sequence length for calibration (default: 2048)
  --group-size INT        Group size passed to Step 3 (default: 128)
  --device DEVICE         Torch device string for steps 1-3 (default: cuda:0)
  --python PATH           Python interpreter to invoke (default: python)
  --metrics-csv PATH      Optional CSV path for Step 3 metrics output
  --cov-store-device DEV  Device to accumulate covariance stats (default: cpu)
  --cov-stats-path PATH   Custom path for cached covariance stats
  --no-reuse-cov-stats    Disable reuse of cached covariance stats
  --trust-remote-code     Pass --trust_remote_code to python scripts
  --skip-step1            Skip Step 1 (requires existing outputs)
  --skip-step2            Skip Step 2 (requires existing outputs)
  --skip-step3            Skip Step 3 entirely
  --step3-skip-gen        Pass --skip_gen to Step 3 (skip text generation timing)
  --step3-use-cuda-w4a16  Pass --use_cuda_w4a16 to Step 3
  --step3-cache-mode MODE  Choose Step 3 cache mode: both, cache, or no_cache (default: both)
  --gen-max-new-tokens N  Max new tokens for Step 3 generation (default: 128)
  --help, -h              Show this help message and exit

Environment variables:
  PYTHON                  Overrides --python when unset on CLI

USAGE
}

if [[ $# -eq 0 ]]; then
  usage
  exit 1
fi

SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
PYTHON_BIN="${PYTHON:-python}"
MODEL_NAME=""
OUTPUT_DIR=""
RANK=64
P=10
Q=2
SHRINK_ALPHA=0.05
CALIB_DATASET="wikitext"
CALIB_CONFIG=""
NSAMPLES=64
SEQLEN=2048
GROUP_SIZE=128
DEVICE="cuda:0"
METRICS_CSV=""
COV_STORE_DEVICE="cpu"
COV_STATS_PATH=""
REUSE_COV_STATS=1
TRUST_REMOTE_CODE=0
SKIP_STEP1=0
SKIP_STEP2=0
SKIP_STEP3=0
STEP3_SKIP_GEN=0
STEP3_USE_CUDA_W4A16=0
STEP3_CACHE_MODE="both"
GEN_MAX_NEW_TOKENS=128

while [[ $# -gt 0 ]]; do
  case "$1" in
    --model-name)
      MODEL_NAME="$2"
      shift 2
      ;;
    --output-dir)
      OUTPUT_DIR="$2"
      shift 2
      ;;
    --rank)
      RANK="$2"
      shift 2
      ;;
    --p)
      P="$2"
      shift 2
      ;;
    --q)
      Q="$2"
      shift 2
      ;;
    --shrink-alpha)
      SHRINK_ALPHA="$2"
      shift 2
      ;;
    --calib-dataset)
      CALIB_DATASET="$2"
      shift 2
      ;;
    --calib-config)
      CALIB_CONFIG="$2"
      shift 2
      ;;
    --nsamples)
      NSAMPLES="$2"
      shift 2
      ;;
    --seqlen)
      SEQLEN="$2"
      shift 2
      ;;
    --group-size)
      GROUP_SIZE="$2"
      shift 2
      ;;
    --device)
      DEVICE="$2"
      shift 2
      ;;
    --python)
      PYTHON_BIN="$2"
      shift 2
      ;;
    --metrics-csv)
      METRICS_CSV="$2"
      shift 2
      ;;
    --cov-store-device)
      COV_STORE_DEVICE="$2"
      shift 2
      ;;
    --cov-stats-path)
      COV_STATS_PATH="$2"
      shift 2
      ;;
    --no-reuse-cov-stats)
      REUSE_COV_STATS=0
      shift 1
      ;;
    --trust-remote-code)
      TRUST_REMOTE_CODE=1
      shift 1
      ;;
    --skip-step1)
      SKIP_STEP1=1
      shift 1
      ;;
    --skip-step2)
      SKIP_STEP2=1
      shift 1
      ;;
    --skip-step3)
      SKIP_STEP3=1
      shift 1
      ;;
    --step3-skip-gen)
      STEP3_SKIP_GEN=1
      shift 1
      ;;
    --step3-use-cuda-w4a16)
      STEP3_USE_CUDA_W4A16=1
      shift 1
      ;;
    --step3-cache-mode)
      STEP3_CACHE_MODE="$2"
      shift 2
      ;;
    --gen-max-new-tokens)
      GEN_MAX_NEW_TOKENS="$2"
      shift 2
      ;;
    --help|-h)
      usage
      exit 0
      ;;
    *)
      echo "Unknown option: $1" >&2
      usage >&2
      exit 1
      ;;
  esac
done

if [[ -z "$MODEL_NAME" ]]; then
  echo "Error: --model-name is required" >&2
  usage >&2
  exit 1
fi

if [[ -z "$OUTPUT_DIR" ]]; then
  echo "Error: --output-dir is required" >&2
  usage >&2
  exit 1
fi

mkdir -p "$OUTPUT_DIR"
STEP1_DIR="$OUTPUT_DIR/step1"
STEP2_DIR="$OUTPUT_DIR/step2_rank${RANK}_alpha${SHRINK_ALPHA}_ns${NSAMPLES}_seq${SEQLEN}_p${P}_q${Q}"
STEP3_DIR="$OUTPUT_DIR/step3"

mkdir -p "$STEP1_DIR" "$STEP2_DIR" "$STEP3_DIR"

ERR_PATH="$STEP1_DIR/err_quant_asym.pt"
ORIG_WEIGHTS_PATH="$STEP1_DIR/original_weights.pt"
SHARED_PATH="$STEP2_DIR/low_rank_shared.pt"
BMAP_PATH="$STEP2_DIR/b_ref_map.json"
LOG_PATH="$STEP2_DIR/randomized_svd.log"

if [[ -z "$COV_STATS_PATH" ]]; then
  COV_STATS_PATH="$STEP2_DIR/cov_stats_ns${NSAMPLES}_seq${SEQLEN}.pt"
fi

if [[ -n "$METRICS_CSV" ]]; then
  mkdir -p "$(dirname "$METRICS_CSV")"
fi

printf '\n[Pipeline] Model: %s\n' "$MODEL_NAME"
printf '[Pipeline] Output directory: %s\n' "$OUTPUT_DIR"

if [[ $SKIP_STEP1 -eq 1 ]]; then
  if [[ ! -f "$ERR_PATH" || ! -f "$ORIG_WEIGHTS_PATH" ]]; then
    echo "Step1 outputs not found at $STEP1_DIR; cannot skip step1." >&2
    exit 1
  fi
else
  step1_cmd=("$PYTHON_BIN" "$SCRIPT_DIR/step1_quantize_error.py" "--model_name" "$MODEL_NAME" \
             "--out_quant_err" "$ERR_PATH" "--out_original_weights" "$ORIG_WEIGHTS_PATH" \
             "--device" "$DEVICE")
  if [[ $TRUST_REMOTE_CODE -eq 1 ]]; then
    step1_cmd+=("--trust_remote_code")
  fi

  printf '\n[Step 1] Calculating quantization error...\n    '
  printf '%q ' "${step1_cmd[@]}"
  printf '\n'
  "${step1_cmd[@]}"
fi

if [[ $SKIP_STEP2 -eq 1 ]]; then
  if [[ ! -f "$SHARED_PATH" || ! -f "$BMAP_PATH" ]]; then
    echo "Step2 outputs not found at $STEP2_DIR; cannot skip step2." >&2
    exit 1
  fi
else
  step2_cmd=("$PYTHON_BIN" "$SCRIPT_DIR/step2_randomized_svd.py" \
             "--model_name" "$MODEL_NAME" \
             "--err_path" "$ERR_PATH" \
             "--output_path" "$STEP2_DIR" \
             "--log_path" "$LOG_PATH" \
             "--max_rank" "$RANK" \
             "--shrinkage_alpha" "$SHRINK_ALPHA" \
             "--nsamples" "$NSAMPLES" \
             "--seqlen" "$SEQLEN" \
             "--calib_dataset" "$CALIB_DATASET" \
             "--cov_store_device" "$COV_STORE_DEVICE" \
             "--oversamples" "$P" \
             "--power_iters" "$Q" \
             "--cov_stats_path" "$COV_STATS_PATH")

  if [[ -n "$CALIB_CONFIG" ]]; then
    step2_cmd+=("--calib_config" "$CALIB_CONFIG")
  fi
  if [[ $REUSE_COV_STATS -eq 1 ]]; then
    step2_cmd+=("--reuse_cov_stats" "true")
  else
    step2_cmd+=("--reuse_cov_stats" "false")
  fi
  if [[ $TRUST_REMOTE_CODE -eq 1 ]]; then
    step2_cmd+=("--trust_remote_code")
  fi

  printf '\n[Step 2] Running randomized SVD...\n    '
  printf '%q ' "${step2_cmd[@]}"
  printf '\n'
  "${step2_cmd[@]}"
fi

if [[ $SKIP_STEP3 -eq 1 ]]; then
  printf '\n[Step 3] Skipped.\n'
else
  if [[ ! -f "$SHARED_PATH" || ! -f "$BMAP_PATH" ]]; then
    echo "Step3 requires outputs from step2 at $STEP2_DIR" >&2
    exit 1
  fi
  if [[ ! -f "$ORIG_WEIGHTS_PATH" ]]; then
    echo "Step3 requires original weights from step1 at $STEP1_DIR" >&2
    exit 1
  fi

  step3_cmd=("$PYTHON_BIN" "$SCRIPT_DIR/step3_inference.py" \
             "--model_name" "$MODEL_NAME" \
             "--shared_path" "$SHARED_PATH" \
             "--bmap_path" "$BMAP_PATH" \
             "--original_weights_path" "$ORIG_WEIGHTS_PATH" \
             "--device" "$DEVICE" \
             "--group_size" "$GROUP_SIZE" \
             "--cache_mode" "$STEP3_CACHE_MODE" \
             "--gen_max_new_tokens" "$GEN_MAX_NEW_TOKENS")

  if [[ $TRUST_REMOTE_CODE -eq 1 ]]; then
    step3_cmd+=("--trust_remote_code")
  fi
  if [[ $STEP3_SKIP_GEN -eq 1 ]]; then
    step3_cmd+=("--skip_gen")
  fi
  if [[ $STEP3_USE_CUDA_W4A16 -eq 1 ]]; then
    step3_cmd+=("--use_cuda_w4a16")
  fi
  if [[ -n "$METRICS_CSV" ]]; then
    step3_cmd+=("--metrics_csv" "$METRICS_CSV")
  fi

  printf '\n[Step 3] Evaluating decode caching...\n    '
  printf '%q ' "${step3_cmd[@]}"
  printf '\n'
  "${step3_cmd[@]}"
fi

printf '\n[Pipeline] Completed successfully. Artifacts:\n'
printf '  Step1: %s\n' "$STEP1_DIR"
printf '  Step2: %s\n' "$STEP2_DIR"
if [[ $SKIP_STEP3 -eq 0 ]]; then
  printf '  Step3 logs/metrics written by script (check stdout and optional CSV).\n'
else
  printf '  Step3: skipped.\n'
fi
