## SSR

### Quick Start

Follow these steps to run the full SSR pipeline:

#### 1. Fine-tune the Base Model

First, fine-tune the base model on your target task (e.g. GSM8K).

```python
python ft_llama_gsm8k.py 
```



#### 2. Automated Self-Correction Data Refinement

Create the corrective dataset used for further safety analysis.

```
cd ./safety_eval

python diagnose.py \
    --model_folder /root/model/Llama-3-8B-Instruct \
    --lora_folder ../ckpt/llama-gsm8k-ft \
    --source_file_path ../data/Safety_Diagnostic.json \
    --curated_output_path ../data/llama-gsm8k-curated_for_SSR.jsonl 
```



#### 3. Malicious Pathway Localization

Identify harmful neural pathways by computing gradient attributions.

```
cd  ..

python analyze_activations.py \
    --model_folder /root/model/Llama-3-8B-Instruct \
    --lora_adapter_path ckpt/llama-gsm8k-ft \
    --raw_forget_path data/llama-gsm8k-curated_for_SSR.jsonl  \
	--top_k_neurons 8 \
    --output_file llama-gsm8k-harmful-neurons.pt
```



#### 4. Perform Surgical Safety Repair

Apply the SSR update to repair the model.

```
python ssr.py \
	--model_name  /root/model/Llama-3-8B-Instruct \
    --lora_adapter_path ckpt/llama-gsm8k-ft \
    --raw_forget_path data/llama-gsm8k-curated_for_SSR.jsonl \
    --raw_steering_path data/llama-gsm8k-curated_for_SSR.jsonl\
    --harmful_neurons_path ./llama-gsm8k-harmful-neurons.pt \
    --learning_rate 2e-5 \
    --epochs  5 \
    --early_stopping_ratio  0.35 \
    --output_dir  ckpt/ssr-llama-gsm8k-0.35

```



#### Output

The final repaired model will be saved as `ckpt/ssr-llama-gsm8k-0.35`, ready for evaluation and deployment.

#### Evaluation

##### Harmfulness Score

```
cd ./safety_eval

python llama_pred.py \
     --model_folder /root/model/Llama-3-8B-Instruct \
     --lora_folder ../ckpt/ssr-llama-pubmedqa-gsm8k-0.35 \
     --output_path ../eval_results/safety/ssr-llama-gsm8k-0.35_1000.json
     
python beaver_eval.py \
      --input_path   ../eval_results/safety/ssr-llama-gsm8k-0.35_1000.json
```



##### Fine-tuning Accuracy (GSM8K)

```
cd ../gsm8k_eval

python llama_eval.py
```

