#!/bin/bash

# Default intermediate artifacts directory is in ~/.cache/nanochat
export OMP_NUM_THREADS=1
export NANOCHAT_BASE_DIR="/root/autodl-tmp/nanochat-master/cache/nanochat"
# Add hyper-connections to PYTHONPATH for HyperConnections support
export PYTHONPATH="/root/autodl-tmp/hyper-connections-main:$PYTHONPATH"
# Use HuggingFace mirror for China
export HF_HOME="/root/autodl-tmp/huggingface_cache"
export HF_ENDPOINT="https://hf-mirror.com"
# Use GitHub proxy
git config --global url."https://ghfast.top/https://github.com".insteadOf "https://github.com"
mkdir -p $NANOCHAT_BASE_DIR

# -----------------------------------------------------------------------------
# Python venv setup with uv

# install uv (if not already installed)
command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
# # ensure uv is in path
export PATH="$HOME/.local/bin:$PATH"
# # create a .venv local virtual environment (if it doesn't exist)
[ -d ".venv" ] || uv venv
# # install the repo dependencies
uv sync --extra gpu
# # activate venv so that `python` uses the project's venv instead of system python
source .venv/bin/activate

python -m nanochat.report reset

# -----------------------------------------------------------------------------
# Tokenizer
python -m nanochat.dataset -n 8
# Immediately also kick off downloading more shards in the background while tokenizer trains
# See comment below for why 370 is the right number here
python -m nanochat.dataset -n 370 &
DATASET_DOWNLOAD_PID=$!
# train the tokenizer with vocab size 2**16 = 65536 on ~2B characters of data
python -m scripts.tok_train --max-chars=2000000000 --vocab-size=65536
# evaluate the tokenizer (report compression ratio etc.)
python -m scripts.tok_eval

# -----------------------------------------------------------------------------
# Base model (pretraining)
echo "Waiting for dataset download to complete..."
wait $DATASET_DOWNLOAD_PID

# Number of processes/GPUs to use
NPROC_PER_NODE=4
# Model configuration
DEPTH=12
MODEL_TAG="d${DEPTH}_KromHC"
# HyperConnections: number of residual streams (1=disabled, 4=recommended per paper)
NUM_RESIDUAL_STREAMS=4
# HyperConnections type: "mHC" for ManifoldConstrainedHyperConnections, # "mHC-lite" for MHCLite, "KromHC" for Kronecker Low-Rank HyperConnections
HC_TYPE="KromHC"

# pretrain the d20 model --num-iterations=2500 
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=$DEPTH --model-tag=$MODEL_TAG --num-iterations=7000 --target-param-data-ratio=-1 --num-residual-streams=$NUM_RESIDUAL_STREAMS --hc-type=$HC_TYPE --run=$WANDB_RUN --device-batch-size=32
# evaluate the model on a larger chunk of train/val data and draw some samples
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_loss -- --model-tag=$MODEL_TAG --num-residual-streams=$NUM_RESIDUAL_STREAMS --hc-type=$HC_TYPE
# evaluate the model on CORE tasks
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_eval -- --model-tag=$MODEL_TAG --num-residual-streams=$NUM_RESIDUAL_STREAMS --hc-type=$HC_TYPE

# -----------------------------------------------------------------------------
# Generate the full report by putting together all the sections
# report.md is the output and will be copied to current directory for convenience
python -m nanochat.report generate
