#!/bin/bash

# Pre-processing: Set working directory and venv environment
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
PROJECT_ROOT="${PROJECT_ROOT:-$(cd "$SCRIPT_DIR/.." && pwd)}"
cd "$PROJECT_ROOT" || { echo "[ERROR] Failed to enter project directory"; exit 1; }
echo "[INFO] Working directory: $(pwd)"

# Activate virtual environment
if [[ ! -d ".venv" ]]; then
    echo "[ERROR] .venv virtual environment does not exist, please create it first"
    exit 1
fi

# shellcheck disable=SC1091
source .venv/bin/activate || { echo "[ERROR] Failed to activate .venv"; exit 1; }
echo "[INFO] Activated virtual environment: .venv"

# Unify temp and cache directories to avoid root partition usage
DEFAULT_CACHE_ROOT="$PROJECT_ROOT/.cache"
export TEMP_ROOT="${TEMP_ROOT:-$DEFAULT_CACHE_ROOT}"
export TMPDIR="${TMPDIR:-$TEMP_ROOT}"
export TMP="${TMP:-$TEMP_ROOT}"
export TEMP="${TEMP:-$TEMP_ROOT}"
export HF_HOME="${HF_HOME:-$TEMP_ROOT/.hf_home}"
export HUGGINGFACE_HUB_CACHE="${HUGGINGFACE_HUB_CACHE:-$TEMP_ROOT/.huggingface}"
export TRANSFORMERS_CACHE="${TRANSFORMERS_CACHE:-$TEMP_ROOT/.transformers}"
export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"
mkdir -p "$TMPDIR" "$HF_HOME" "$HUGGINGFACE_HUB_CACHE" "$TRANSFORMERS_CACHE" || true

# 确保虚拟环境已激活且路径正确
if [[ -z "${VIRTUAL_ENV:-}" ]]; then
    echo "[ERROR] 虚拟环境未激活，请检查 .venv"
    exit 1
fi

EXPECTED_VENV="$PROJECT_ROOT/.venv"
if [[ "$VIRTUAL_ENV" != "$EXPECTED_VENV" ]]; then
    echo "[ERROR] 虚拟环境路径不匹配"
    echo "       当前: $VIRTUAL_ENV"
    echo "       预期: $EXPECTED_VENV"
    echo "       建议:"
    echo "         cd $PROJECT_ROOT && python3.10 -m venv .venv"
    echo "         source .venv/bin/activate && pip install -e \".[all]\""
    exit 1
fi

echo "[INFO] 使用虚拟环境: $VIRTUAL_ENV"

# 默认配置
MODE="Taylor"  # Taylor, Taylor-Scaled, HiCache, original, ToCa, Delta, collect, ClusCa, Hi-ClusCa
MODEL_NAME="flux-dev"  # flux-dev | flux-schnell
INTERVAL="5"
MAX_ORDER="2"
WIDTH=1024
HEIGHT=1024
NUM_STEPS=50
NUM_STEPS_SET=false
LIMIT=200
HICACHE_SCALE_FACTOR="0.6"
FIRST_ENHANCE="3"
PROMPT_FILE="./prompt.txt"
BASE_OUTPUT_DIR="$PROJECT_ROOT/results"
GPU_LIST=""
NUM_GPUS=""
RUN_NAME=""
AUTO_RUN_NAME=false
KEEP_TEMP=false
DRY_RUN=false
MODEL_DIR=""

# ClusCa 默认参数
CLUSCA_FRESH_THRESHOLD=5
CLUSCA_CLUSTER_NUM=16
CLUSCA_CLUSTER_METHOD="kmeans"
CLUSCA_K=1
CLUSCA_PROPAGATION_RATIO=0.005

EXTRA_SAMPLE_ARGS=()

show_help() {
    echo "用法: $0 [选项]"
    echo "选项:"
    echo "  -m, --mode MODE             缓存模式 (Taylor, Taylor-Scaled, HiCache, original, ToCa, Delta, collect, ClusCa, Hi-ClusCa)"
    echo "      --model_name NAME       FLUX 模型 (flux-dev|flux-schnell) [默认: flux-dev]"
    echo "  -i, --interval N            采样间隔 [默认: 5]"
    echo "  -o, --max_order N           泰勒最大阶数 [默认: 2]"
    echo "      --first_enhance N        初始增强步数 (前 N 步强制 full) [默认: 3]"
    echo "  -d, --output_dir DIR        结果基础目录 (多卡运行会在其下创建子目录)"
    echo "  -p, --prompt_file FILE      Prompt 列表 [默认: ./prompt.txt]"
    echo "  -w, --width WIDTH           图像宽度 [默认: 1024]"
    echo "  -h, --height HEIGHT         图像高度 [默认: 1024]"
    echo "  -s, --num_steps STEPS       采样步数 [默认: 50]"
    echo "  -l, --limit LIMIT           Prompt 限制数量 [默认: 10]"
    echo "  --gpus IDS                  指定 GPU 列表 (示例: 0,1,3)"
    echo "  --num_gpus N                未指定 --gpus 时自动从 0 开始取 N 张卡"
    echo "  --run-name NAME             自定义运行名 (用于输出目录)"
    echo "  --hicache_scale FACTOR      HiCache 多项式缩放因子 [默认: 0.7]"
    echo "  --model_dir DIR             指定本地 FLUX 权重目录(包含 flow 与 ae)"
    echo "  --fresh_threshold VALUE     ClusCa: fresh 阈值 [默认: 5]"
    echo "  --cluster_num N             ClusCa: 聚类数量 [默认: 16]"
    echo "  --cluster_method NAME       ClusCa: 聚类方法 [默认: kmeans]"
    echo "  --k N                       ClusCa: 每个聚类选择 fresh token 数 [默认: 1]"
    echo "  --propagation_ratio VALUE   ClusCa: 特征传播比例 [默认: 0.005]"
    echo "  --keep-temp                 保留 RUN/ 下的 Prompt 切分文件"
    echo "  --dry-run                   仅打印即将执行的命令"
    echo "  --help                      显示帮助信息"
    echo "  --                          之后的参数原样透传给 sample.sh"
}

while [[ $# -gt 0 ]]; do
    case $1 in
        -m|--mode)
            MODE="$2"
            shift 2
            ;;
        --model_name)
            MODEL_NAME="$2"
            shift 2
            ;;
        -i|--interval)
            INTERVAL="$2"
            shift 2
            ;;
        -o|--max_order)
            MAX_ORDER="$2"
            shift 2
            ;;
        --first_enhance)
            FIRST_ENHANCE="$2"
            shift 2
            ;;
        -d|--output_dir)
            BASE_OUTPUT_DIR="$2"
            shift 2
            ;;
        -p|--prompt_file)
            PROMPT_FILE="$2"
            shift 2
            ;;
        -w|--width)
            WIDTH="$2"
            shift 2
            ;;
        -h|--height)
            HEIGHT="$2"
            shift 2
            ;;
        -s|--num_steps)
            NUM_STEPS="$2"
            NUM_STEPS_SET=true
            shift 2
            ;;
        -l|--limit)
            LIMIT="$2"
            shift 2
            ;;
        --gpus)
            GPU_LIST="$2"
            shift 2
            ;;
        --num_gpus|--num-gpus)
            NUM_GPUS="$2"
            shift 2
            ;;
        --run-name|--run_name)
            RUN_NAME="$2"
            shift 2
            ;;
        --model_dir)
            MODEL_DIR="$2"
            shift 2
            ;;
        --hicache_scale)
            HICACHE_SCALE_FACTOR="$2"
            shift 2
            ;;
        --fresh_threshold)
            CLUSCA_FRESH_THRESHOLD="$2"
            shift 2
            ;;
        --cluster_num)
            CLUSCA_CLUSTER_NUM="$2"
            shift 2
            ;;
        --cluster_method)
            CLUSCA_CLUSTER_METHOD="$2"
            shift 2
            ;;
        --k)
            CLUSCA_K="$2"
            shift 2
            ;;
        --propagation_ratio)
            CLUSCA_PROPAGATION_RATIO="$2"
            shift 2
            ;;
        --keep-temp)
            KEEP_TEMP=true
            shift
            ;;
        --dry-run)
            DRY_RUN=true
            shift
            ;;
        --help)
            show_help
            exit 0
            ;;
        --)
            shift
            EXTRA_SAMPLE_ARGS+=("$@")
            break
            ;;
        *)
            echo "未知选项: $1"
            show_help
            exit 1
            ;;
    esac
done

if [[ "$MODE" != "Taylor" && "$MODE" != "Taylor-Scaled" && "$MODE" != "HiCache" && "$MODE" != "original" &&
    "$MODE" != "ToCa" && "$MODE" != "Delta" && "$MODE" != "collect" && "$MODE" != "ClusCa" && "$MODE" != "Hi-ClusCa" ]]; then
    echo "错误: 不支持的模式 '$MODE'"
    echo "支持的模式: Taylor, Taylor-Scaled, HiCache, original, ToCa, Delta, collect, ClusCa, Hi-ClusCa"
    exit 1
fi

if [[ ! -f "$PROMPT_FILE" ]]; then
    echo "[ERROR] Prompt 文件不存在: $PROMPT_FILE"
    exit 1
fi

if ! [[ "$LIMIT" =~ ^[0-9]+$ ]]; then
    echo "[ERROR] limit 必须为非负整数"
    exit 1
fi

# 根据模型名称自动匹配模型目录路径
auto_detect_model_dir() {
    local model_name="$1"
    local candidates=()
    
    if [[ "$model_name" == "flux-schnell" ]]; then
        candidates=(
            "$PROJECT_ROOT/models/FLUX.1-schnell"
            "$PROJECT_ROOT/models/flux.schnell"
            "$PROJECT_ROOT/models/flux-schnell"
            "$PROJECT_ROOT/models/schnell"
        )
    else
        candidates=(
            "$PROJECT_ROOT/models/FLUX.1-dev"
            "$PROJECT_ROOT/models/flux.dev"
            "$PROJECT_ROOT/models/flux-dev"
            "$PROJECT_ROOT/models/dev"
        )
    fi
    
    for candidate in "${candidates[@]}"; do
        if [[ -d "$candidate" ]]; then
            echo "$candidate"
            return 0
        fi
    done
    
    return 1
}

# 设置模型目录
if [[ -n "$MODEL_DIR" ]]; then
    echo "[INFO] 指定了 --model_dir: $MODEL_DIR"
    AUTO_MODEL_DIR="$MODEL_DIR"
else
    AUTO_MODEL_DIR="$(auto_detect_model_dir "$MODEL_NAME")"
    if [[ -z "$AUTO_MODEL_DIR" ]]; then
        echo "[ERROR] 未找到匹配的模型目录，请检查 models/ 目录或使用 --model_dir 指定"
        echo "支持的目录格式:"
        if [[ "$MODEL_NAME" == "flux-schnell" ]]; then
            echo "  - models/FLUX.1-schnell"
            echo "  - models/flux.schnell"
            echo "  - models/flux-schnell"
            echo "  - models/schnell"
        else
            echo "  - models/FLUX.1-dev"
            echo "  - models/flux.dev"
            echo "  - models/flux-dev"
            echo "  - models/dev"
        fi
        exit 1
    else
        echo "[INFO] 自动检测到模型目录: $AUTO_MODEL_DIR"
    fi
fi

if [[ -z "$RUN_NAME" ]]; then
    RUN_NAME=""
    AUTO_RUN_NAME=true
fi

MODE_LOWER="${MODE,,}"
if [[ -n "$RUN_NAME" ]]; then
    MERGED_ROOT="$BASE_OUTPUT_DIR/${MODE_LOWER}_$RUN_NAME"
else
    MERGED_ROOT="$BASE_OUTPUT_DIR/${MODE_LOWER}"
fi

# 统一的聚合输出目录：仅保留 mode 作为一级目录，参数合并为子目录名
if [[ "$MODEL_NAME" == "flux-schnell" && "$NUM_STEPS_SET" != true ]]; then
    NUM_STEPS=4
fi
PARAM_TAG="mn_${MODEL_NAME}_i_${INTERVAL}_o_${MAX_ORDER}_s_${NUM_STEPS}_hs_${HICACHE_SCALE_FACTOR}"
MERGED_OUTPUT_DIR="$MERGED_ROOT/${PARAM_TAG}"

TEMP_PROMPT_FILE=$(mktemp "$PROJECT_ROOT/RUN/tmp_multi_gpu_launcher_prompts.XXXXXX.txt") || {
    echo "[ERROR] 创建临时 Prompt 文件失败"
    exit 1
}
REPORT_PATH=$(mktemp "$PROJECT_ROOT/RUN/tmp_multi_gpu_launcher_report.XXXXXX.json") || {
    echo "[ERROR] 创建临时报告文件失败"
    rm -f "$TEMP_PROMPT_FILE"
    exit 1
}

cleanup_tmp_files() {
    rm -f "$TEMP_PROMPT_FILE" "$REPORT_PATH"
}
trap cleanup_tmp_files EXIT

head -n "$LIMIT" "$PROMPT_FILE" > "$TEMP_PROMPT_FILE"

# 若用户未显式指定步数，schnell 默认步数为 4
if [[ "$MODEL_NAME" == "flux-schnell" && "$NUM_STEPS_SET" != true ]]; then
    NUM_STEPS=4
fi

# 汇总传递给 sample.sh 的参数
SAMPLE_ARGS=(
    --interval "$INTERVAL"
    --max_order "$MAX_ORDER"
    --first_enhance "$FIRST_ENHANCE"
    --width "$WIDTH"
    --height "$HEIGHT"
    --num_steps "$NUM_STEPS"
    --limit "$LIMIT"
    --hicache_scale "$HICACHE_SCALE_FACTOR"
    --model_name "$MODEL_NAME"
)


if [[ "$MODE" == "ClusCa" || "$MODE" == "Hi-ClusCa" ]]; then
    SAMPLE_ARGS+=(
        --clusca_fresh_threshold "$CLUSCA_FRESH_THRESHOLD"
        --clusca_cluster_num "$CLUSCA_CLUSTER_NUM"
        --clusca_cluster_method "$CLUSCA_CLUSTER_METHOD"
        --clusca_k "$CLUSCA_K"
        --clusca_propagation_ratio "$CLUSCA_PROPAGATION_RATIO"
    )
fi

if [[ ${#EXTRA_SAMPLE_ARGS[@]} -gt 0 ]]; then
    SAMPLE_ARGS+=("${EXTRA_SAMPLE_ARGS[@]}")
fi

if [[ -n "$AUTO_MODEL_DIR" ]]; then
    SAMPLE_ARGS+=(--model_dir "$AUTO_MODEL_DIR")
fi

# 展示配置概要
echo "================================="
echo "多卡采样配置:"
echo "模式: $MODE"
echo "FLUX 模型: $MODEL_NAME"
if [[ -n "$GPU_LIST" ]]; then
    echo "GPU 列表: $GPU_LIST"
elif [[ -n "$NUM_GPUS" ]]; then
    echo "GPU 数量: $NUM_GPUS (从 0 开始)"
else
    echo "GPU: 自动检测"
fi
echo "基础输出目录: $BASE_OUTPUT_DIR"
if [[ "$AUTO_RUN_NAME" == true && -z "$RUN_NAME" ]]; then
    echo "运行名: (默认，相同参数复用路径)"
elif [[ "$AUTO_RUN_NAME" == true ]]; then
    echo "运行名: $RUN_NAME (自动生成)"
else
    echo "运行名: $RUN_NAME"
fi
echo "Prompt 文件: $PROMPT_FILE (限制 $LIMIT 条)"
echo "临时 Prompt: $TEMP_PROMPT_FILE"
echo "尺寸: ${WIDTH}x${HEIGHT}, 步数: $NUM_STEPS"
if [[ -n "$MODEL_DIR" ]]; then
    echo "模型目录: $MODEL_DIR"
elif [[ -n "$AUTO_MODEL_DIR" ]]; then
    echo "自动检测模型目录: $AUTO_MODEL_DIR"
fi
echo "间隔: $INTERVAL, 阶数: $MAX_ORDER"
echo "HiCache 缩放: $HICACHE_SCALE_FACTOR"
if [[ "$MODE" == "ClusCa" ]]; then
    echo "ClusCa fresh 阈值: $CLUSCA_FRESH_THRESHOLD"
    echo "ClusCa 聚类: $CLUSCA_CLUSTER_NUM ($CLUSCA_CLUSTER_METHOD)"
    echo "ClusCa k: $CLUSCA_K, 传播比例: $CLUSCA_PROPAGATION_RATIO"
elif [[ "$MODE" == "Hi-ClusCa" ]]; then
    echo "Hi-ClusCa fresh 阈值: $CLUSCA_FRESH_THRESHOLD"
    echo "Hi-ClusCa 聚类: $CLUSCA_CLUSTER_NUM ($CLUSCA_CLUSTER_METHOD)"
    echo "Hi-ClusCa k: $CLUSCA_K, 传播比例: $CLUSCA_PROPAGATION_RATIO"
    echo "Hi-ClusCa HiCache 缩放: $HICACHE_SCALE_FACTOR"
fi
echo "保留临时分片: $KEEP_TEMP"
echo "干运行: $DRY_RUN"
if [[ ${#EXTRA_SAMPLE_ARGS[@]} -gt 0 ]]; then
    printf '额外 sample.sh 参数: %s\n' "${EXTRA_SAMPLE_ARGS[*]}"
fi
echo "================================="

PYTHON_CMD=(
    python "RUN/multi_gpu_launcher.py"
    --mode "$MODE"
    --prompt-file "$TEMP_PROMPT_FILE"
    --base-output-dir "$BASE_OUTPUT_DIR"
    --run-name "$RUN_NAME"
    --report-path "$REPORT_PATH"
)

if [[ -n "$GPU_LIST" ]]; then
    PYTHON_CMD+=(--gpus "$GPU_LIST")
fi
if [[ -n "$NUM_GPUS" ]]; then
    PYTHON_CMD+=(--num-gpus "$NUM_GPUS")
fi
if [[ "$KEEP_TEMP" == true ]]; then
    PYTHON_CMD+=(--keep-temp)
fi
if [[ "$DRY_RUN" == true ]]; then
    PYTHON_CMD+=(--dry-run)
fi

PYTHON_CMD+=(--)
if [[ ${#SAMPLE_ARGS[@]} -gt 0 ]]; then
    PYTHON_CMD+=("${SAMPLE_ARGS[@]}")
fi

echo "[INFO] 启动多卡采样..."
printf '[CMD] %s\n' "${PYTHON_CMD[*]}"

"${PYTHON_CMD[@]}"
PYTHON_EXIT_CODE=$?

if [[ $PYTHON_EXIT_CODE -ne 0 ]]; then
    echo "[ERROR] 多卡采样执行失败 (退出码: $PYTHON_EXIT_CODE)"
    exit $PYTHON_EXIT_CODE
fi

echo "多卡采样完成！"
FINAL_OUTPUT_DIR="$MERGED_OUTPUT_DIR"
if [[ -f "$REPORT_PATH" ]]; then
    unset report_success report_path
    while IFS= read -r line; do
        if [[ -n "$line" ]]; then
            eval "$line"
        fi
    done < <(
        python - <<'PY' "$REPORT_PATH"
import json
import shlex
import sys

with open(sys.argv[1], "r", encoding="utf-8") as fh:
    data = json.load(fh)

success = bool(data.get("success"))
path = data.get("final_output_path") or ""

print(f"report_success={str(success).lower()}")
print("report_path=" + shlex.quote(path))
PY
    )

    if [[ "${report_success:-false}" == "true" && -n "${report_path:-}" ]]; then
        FINAL_OUTPUT_DIR="$report_path"
    elif [[ -n "${report_path:-}" ]]; then
        FINAL_OUTPUT_DIR="$report_path"
        echo "[WARN] 报告文件标记 success=false，仍使用解析出的目录: $FINAL_OUTPUT_DIR"
    else
        echo "[WARN] 未能从报告文件解析最终输出目录，使用默认值: $FINAL_OUTPUT_DIR"
    fi
else
    echo "[WARN] 未找到报告文件: $REPORT_PATH"
fi

echo "结果聚合目录: $MERGED_ROOT"
echo "最终图像目录: $FINAL_OUTPUT_DIR"

echo "================================="
# 固定 GT 建议目录为 Taylor baseline interval_1/order_2
GT_SUGGEST="$PROJECT_ROOT/results/taylor/interval_1/order_2"
echo "后续可执行评估命令示例:"
echo "  bash $PROJECT_ROOT/evaluation/run_eval.sh --acc \"multi=$FINAL_OUTPUT_DIR\" --gt \"$GT_SUGGEST\""
echo "================================="

trap - EXIT
cleanup_tmp_files
