# Virtual environments for different attack types
NLP_TRAIN_ENV    = envs/nlp_training
TEXTFOOLER_ENV   = envs/textfooler
A2T_ENV          = envs/a2t

# Python interpreters
NLP_TRAIN_PYTHON = $(shell pwd)/$(NLP_TRAIN_ENV)/bin/python
TEXTFOOLER_PYTHON = $(shell pwd)/$(TEXTFOOLER_ENV)/bin/python
A2T_PYTHON = $(shell pwd)/$(A2T_ENV)/bin/python

# A2T attack parameters (can be overridden)
A2T_MAX_WORDS_CHANGED ?= 1 3 5 7
A2T_NUM_SAMPLES ?= 500
A2T_QUERY_BUDGET ?= 50

# TextFooler attack parameters (can be overridden)
TEXTFOOLER_ATTACK_SAMPLE_SIZE ?= 500
TEXTFOOLER_MAX_ATTACK_CHANGES ?= 10

# Parallel execution parameters (can be overridden)
# Please override GPUS on the command line if needed, e.g., make GPUS=0,1 ft_parallel
GPUS ?= 0,1,2,3

# 
SEARCH_ROOT ?= data_files/nlp_training


PYENV_PYTHON := $(shell pyenv which python3)

# -----------------------------------------------------------------------------
# Create Virtual Environments targets
# -----------------------------------------------------------------------------
# Set pyenv local version (writes .python-version file)
set_pyenv:
	pyenv local 3.11.0

# Create virtualenv using pyenv's python3 for pretraining
setup_nlp_training: set_pyenv
	if [ ! -d "$(NLP_TRAIN_ENV)" ]; then \
		$(PYENV_PYTHON) -m venv $(NLP_TRAIN_ENV); \
	fi
	$(NLP_TRAIN_ENV)/bin/pip install --upgrade -r nlp_training/requirements.txt

# Setup TextFooler environment
setup_textfooler: set_pyenv
	if [ ! -d "$(TEXTFOOLER_ENV)" ]; then \
		$(PYENV_PYTHON) -m venv $(TEXTFOOLER_ENV); \
	fi
	$(TEXTFOOLER_ENV)/bin/pip install --upgrade -r TextFooler/requirements.txt

# Setup TextAttack-A2T environment
setup_a2t: set_pyenv
	if [ ! -d "$(A2T_ENV)" ]; then \
		$(PYENV_PYTHON) -m venv $(A2T_ENV); \
	fi
	$(A2T_ENV)/bin/pip install --upgrade -r TextAttack-A2T/requirements.txt

# Setup all environments
setup_all: setup_nlp_training setup_textfooler setup_a2t


# -----------------------------------------------------------------------------
# Finetuning Targets
# -----------------------------------------------------------------------------

# model tested: bert-base-uncased, roberta-base, gpt2, mistralai/Mistral-7B-v0.1
# dataset tested: ag_news, yelp_polarity, imdb

# ft:
# 	$(NLP_TRAIN_PYTHON) experiment.py -m \
#         model=gpt2 \
#         dataset=ag_news,yelp_polarity,banking77 \
#         train=head_only \
#         seed=42 \
#         device=cuda \

# -----------------------------------------------------------------------------
# New: parallel fine-tuning across all GPUs on this VM
# -----------------------------------------------------------------------------

# Which script to run
PARALLEL_SCRIPT     = experiment_parallel.py

# Sweeps (Python-list syntax) – override on CLI if needed:
PARALLEL_STRATEGIES ?= [head_only,lora]
PARALLEL_DATASETS   ?= [ag_news,banking77,imdb]
# PARALLEL_MODELS     ?= [gemma-2B,llama3_2_1B,tinyllama_v1_1,mistral7B]
# PARALLEL_MODELS     ?= [bert-base-uncased,distilbert-base-uncased,gpt2]
PARALLEL_MODELS     ?= [distilbert-base-uncased,gpt2]
PARALLEL_SEEDS      ?= [42,43,44]


# Launch the parallel GPU-queue experiment
ft_parallel:
	$(NLP_TRAIN_PYTHON) $(PARALLEL_SCRIPT) \
	  train.strategy=$(PARALLEL_STRATEGIES) \
	  dataset.name=$(PARALLEL_DATASETS) \
	  model.name=$(PARALLEL_MODELS) \
	  seed=$(PARALLEL_SEEDS) \
	  output_dir=$(SEARCH_ROOT) \


# -----------------------------------------------------------------------------
# TextFooler attack target
# -----------------------------------------------------------------------------

# textfooler_attack:
# 	VIRTUAL_ENV=$(TEXTFOOLER_ENV) PATH=$(TEXTFOOLER_ENV)/bin:$(PATH) $(TEXTFOOLER_PYTHON) run_textfooler_attacks.py

# Run TextFooler attacks in parallel on multiple checkpoints
textfooler_parallel:
	$(NLP_TRAIN_PYTHON) parallel_executor.py textfooler \
		--search-root $(SEARCH_ROOT) \
		--textfooler-python $(TEXTFOOLER_PYTHON) \
		--attack-sample-size $(TEXTFOOLER_ATTACK_SAMPLE_SIZE) \
		--max-attack-changes $(TEXTFOOLER_MAX_ATTACK_CHANGES) \
		--gpus $(GPUS)


# -----------------------------------------------------------------------------
# TextAttack-A2T targets
# -----------------------------------------------------------------------------
# Run A2T attacks in parallel on multiple checkpoints  
a2t_parallel:
	$(NLP_TRAIN_PYTHON) parallel_executor.py a2t \
		--search-root $(SEARCH_ROOT) \
		--a2t-python $(A2T_PYTHON) \
		--max-words-changed "$(A2T_MAX_WORDS_CHANGED)" \
		--num-samples $(A2T_NUM_SAMPLES) \
		--query-budget $(A2T_QUERY_BUDGET) \
		--gpus $(GPUS)

# Run both attacks in parallel
all_attacks_parallel: textfooler_parallel a2t_parallel

# Combined parallel target: run all attacks in parallel on discovered checkpoints
ft_parallel_and_all_attacks_parallel: ft_parallel all_attacks_parallel
	@echo "All training and attacks completed."
