#!/bin/bash

#Default values
MODEL_ID="runwayml/stable-diffusion-v1-5"
CONCEPT_TO_REMOVE="5-concepts"
OUTPUT_DIR="/vol/biomedic3/agk21/diffusion_unlearning/outputs_finetuning"
LOSS_FUNCTION="concept_neurons_loss"
BATCH_SIZE=5
STEPS_PER_EPOCH=200
NUM_OF_EPOCHS=3
NUM_OF_INFERENCE_STEPS=40
LEARNING_RATE=1e-4
GPU_ID=${CUDA_VISIBLE_DEVICES}
NUM_OF_SAMPLE_IMAGES_TO_GENERATE=2
CONCEPT_GUIDED="concept_guided" # Default to concept guided finetuning
IMPLEMENTATION="correct_implementation" # Default to correct implementation
UNLEARN_CONCEPTS_DATASET_PATH="unlearning_concepts/combination_unlearning/mixed_concepts/multiple_concepts_annotiations/5_concepts_annotations.csv"
PRESERVED_CONCEPTS_DATASET_PATH="preserved_concepts/annotations.csv"
COMPUTE_CONCEPT_NEURONS_BASED_ON="clip_loss" # Default to clip loss
RUN_EVALS_WHILE_FINETUNING=True # Default to running evaluations while finetuning
DERIVATIVE="first_order" # Default to first-order derivative


#Usage function to display help
usage() {
    echo "Usage: $0 [options]"
    echo "Options:"
    echo "  -m, --model_id          Base model ID. Default: $MODEL_ID"
    echo "  -c, --concept_to_remove  Concept to remove from the model. Default: $CONCEPT_TO_REMOVE"
    echo "  -o, --output_dir        Output directory for the model. Default: $OUTPUT_DIR"
    echo "  -l, --loss_function      Loss function to use. Default: $LOSS_FUNCTION"
    echo "  -b, --batch_size         Batch size for training. Default: $BATCH_SIZE"
    echo "  -s, --steps_per_epoch    Steps per epoch. Default: $STEPS_PER_EPOCH"
    echo "  -n, --num_of_epochs      Number of epochs to train. Default: $NUM_OF_EPOCHS"
    echo "  -i, --num_of_inference_steps Number of inference steps. Default: $NUM_OF_INFERENCE_STEPS"
    echo "  -r, --learning_rate        Learning rate for training. Default: $LEARNING_RATE"
    echo "  -g, --gpu_id             GPU ID to use for training. Default: $GPU_ID"
    echo "  -q, --num_of_sample_images_to_generate Number of sample images to generate. Default: $NUM_OF_SAMPLE_IMAGES_TO_GENERATE"
    echo "  -k, --concept_guided     Concept guided finetuning. Default: $CONCEPT_GUIDED"
    echo "  -y, --implementation  Implementation type. Default: $IMPLEMENTATION"
    echo "  -u, --unlearn_concepts_dataset_path  Path to the directory which contains the sample examples of images of concepts which should be unlearned Default: $UNLEARN_CONCEPTS_DATASET_PATH"
    echo "  -p, --preserved_concepts_dataset_path Path to the directory which contains the sample examples of images of concepts which should be preserved Default: $PRESERVED_CONCEPTS_DATASET_PATH"
    echo "  -w, --compute_concept_neurons_based_on  Method to compute concept neurons. Default: $COMPUTE_CONCEPT_NEURONS_BASED_ON"
    echo "  -e, --run_evals_while_finetuning Whether to run evaluations while finetuning the model. Default: $RUN_EVALS_WHILE_FINETUNING"
    echo "  -d, --derivative             Derivative type for gradient loss. Default: $DERIVATIVE"
    echo "  -h, --help              Show this help message"
}   

#Parsing command line arguments
ARGS=$(getopt -o m:c:o:l:b:s:n:i:r:g:q:k:y:u:p:w:e:d:h --long model_id:,concept_to_remove:,output_dir:,loss_function:,batch_size:,steps_per_epoch:,num_of_epochs:,num_of_inference_steps:,learning_rate:,gpu_id:,num_of_sample_images_to_generate:,concept_guided:,implementation:,unlearn_concepts_dataset_path:,preserved_concepts_dataset_path:,compute_concept_neurons_based_on:,run_evals_while_finetuning:,derivative:,help -n "$0" -- "$@")


#Check for invalid arguments
if [ $? -ne 0 ]; then
    usage
    exit 1
fi  

#Processing arguments
eval set -- "$ARGS"
while true; do
    case "$1" in
        -m|--model_name)
            MODEL_ID="$2"
            shift 2
            ;;
        -c|--concept_to_remove)
            CONCEPT_TO_REMOVE="$2"
            shift 2
            ;;
        -o|--output_dir)
            OUTPUT_DIR="$2"
            shift 2
            ;;
        -l|--loss_function)
            LOSS_FUNCTION="$2"
            shift 2
            ;;
        -b|--batch_size)
            BATCH_SIZE="$2"
            shift 2
            ;;
        -s|--steps_per_epoch)
            STEPS_PER_EPOCH="$2"
            shift 2
            ;;
        -n|--num_of_epochs)
            NUM_OF_EPOCHS="$2"
            shift 2
            ;;
        -i|--num_of_inference_steps)
            NUM_OF_INFERENCE_STEPS="$2"
            shift 2
            ;;
        -r|--learning_rate)
            LEARNING_RATE="$2"
            shift 2
            ;;
        -g|--gpu_id)
            GPU_ID="$2"
            shift 2
            ;;
        -q|--num_of_sample_images_to_generate)
            NUM_OF_SAMPLE_IMAGES_TO_GENERATE="$2"
            shift 2
            ;;
        -k|--concept_guided)
            CONCEPT_GUIDED="$2"
            # Validate concept guided option
            if [ "$CONCEPT_GUIDED" != "concept_guided" ]; then
                echo "Okay, Considerig image guided otherwise!"
            fi
            shift 2
            ;;
        -y|--implementation)
            IMPLEMENTATION="$2"
            # Validate implementation option
            if [ "$IMPLEMENTATION" != "correct_implementation" ] && [ "$IMPLEMENTATION" != "wrong_implementation" ]; then
                echo "Error: Implementation must be 'correct_implementation' or 'wrong_implementation'"
                exit 1
            fi
            shift 2
            ;; 
        -u|--unlearn_concepts_dataset_path)
            UNLEARN_CONCEPTS_DATASET_PATH="$2"
            shift 2
            ;;
        -p|--preserve_concepts_dataset_path)
            PRESERVE_CONCEPTS_DATASET_PATH="$2"
            shift 2
            ;;
        -w|--compute_concept_neurons_based_on)
            COMPUTE_CONCEPT_NEURONS_BASED_ON="$2"
            # Validate compute concept neurons option
            if [ "$COMPUTE_CONCEPT_NEURONS_BASED_ON" != "clip_loss" ] && [ "$COMPUTE_CONCEPT_NEURONS_BASED_ON" != "noise_loss" ]; then
                echo "Error: compute_concept_neurons_based_on must be 'clip_loss' or 'noise_loss'"
                exit 1
            fi
            shift 2
            ;;
        -e|--run_evals_while_finetuning)
            RUN_EVALS_WHILE_FINETUNING="$2"
            # Validate run evals while finetuning option
            if [ "$RUN_EVALS_WHILE_FINETUNING" != "True" ] && [ "$RUN_EVALS_WHILE_FINETUNING" != "False" ]; then
                echo "Error: run_evals_while_finetuning must be 'True' or 'False'"
                exit 1
            fi
            shift 2
            ;;  
        -d|--derivative)
            DERIVATIVE="$2"
            # Validate derivative option
            if [ "$DERIVATIVE" != "first_order" ] && [ "$DERIVATIVE" != "second_order" ]; then
                echo "Error: derivative must be 'first_order' or 'second_order'"
                exit 1
            fi
            shift 2
            ;;
        -h|--help)
            usage
            exit 0
            ;;
        --)
            shift
            break
            ;;
        *)
            echo "Internal error!"
            exit 1
            ;;
    esac
done


#Display the parameters being used for finetuning
echo "========================================="
echo "Starting finetuning with the following parameters:"
echo "Model ID: $MODEL_ID"
echo "Concept to remove: $CONCEPT_TO_REMOVE"
echo "Output directory: $OUTPUT_DIR"
echo "Loss function: $LOSS_FUNCTION"
echo "Batch size: $BATCH_SIZE"
echo "Steps per epoch: $STEPS_PER_EPOCH"
echo "Number of epochs: $NUM_OF_EPOCHS"
echo "Number of inference steps: $NUM_OF_INFERENCE_STEPS"
echo "Learning rate: $LEARNING_RATE"
echo "GPU ID: $GPU_ID"
echo "Number of sample images to generate: $NUM_OF_SAMPLE_IMAGES_TO_GENERATE"
echo "Concept guided finetuning: $CONCEPT_GUIDED"
echo "Implementation: $IMPLEMENTATION"
echo "Unlearning concepts dir: $UNLEARN_CONCEPTS_DATASET_PATH"
echo "Preserving concepts dir: $PRESERVED_CONCEPTS_DATASET_PATH"
echo "Compute concept neurons based on: $COMPUTE_CONCEPT_NEURONS_BASED_ON"
echo "Run evaluations while finetuning: $RUN_EVALS_WHILE_FINETUNING"
echo "Derivative type for gradient loss: $DERIVATIVE"
echo "========================================="
echo "Running finetuning script..."  

#making the logs dir 
mkdir -p /vol/biomedic3/agk21/diffusion_unlearning/logs/sdxl

# Run the finetuning script with the specified parameters
python3 -m stage_t2i_model \
    --model_name "$MODEL_ID" \
    --concept_to_remove "$CONCEPT_TO_REMOVE" \
    --output_dir "$OUTPUT_DIR" \
    --loss_function "$LOSS_FUNCTION" \
    --batch_size "$BATCH_SIZE" \
    --steps_per_epoch "$STEPS_PER_EPOCH" \
    --num_of_epochs "$NUM_OF_EPOCHS" \
    --num_of_inference_steps "$NUM_OF_INFERENCE_STEPS" \
    --learning_rate "$LEARNING_RATE" \
    --gpu_id "$GPU_ID" \
    --num_of_sample_images_to_generate "$NUM_OF_SAMPLE_IMAGES_TO_GENERATE" \
    --concept_guided "$CONCEPT_GUIDED" \
    --implementation "$IMPLEMENTATION" \
    --unlearn_concepts_dataset_path "$UNLEARN_CONCEPTS_DATASET_PATH" \
    --preserved_concepts_dataset_path "$PRESERVED_CONCEPTS_DATASET_PATH" \
    --compute_concept_neurons_based_on "$COMPUTE_CONCEPT_NEURONS_BASED_ON" \
    --run_evals_while_finetuning "$RUN_EVALS_WHILE_FINETUNING" \
    --derivative "$DERIVATIVE" \
    > "/vol/biomedic3/agk21/diffusion_unlearning/logs/sdxl/run_concept_${CONCEPT_TO_REMOVE}__loss_fun_${LOSS_FUNCTION}__batch_size_${BATCH_SIZE}__epochs_${NUM_OF_EPOCHS}__steps_per_epoch_${STEPS_PER_EPOCH}__GPU_${GPU_ID}__compute_concept_neurons_based_on_${COMPUTE_CONCEPT_NEURONS_BASED_ON}__derivative_${DERIVATIVE}.log" 2>&1

    echo "Process completed! Please check /vol/biomedic3/agk21/diffusion_unlearning/logs/sdxl/run_concept_${CONCEPT_TO_REMOVE}__loss_fun_${LOSS_FUNCTION}__batch_size_${BATCH_SIZE}__epochs_${NUM_OF_EPOCHS}__steps_per_epoch_${STEPS_PER_EPOCH}__GPU_${GPU_ID}__compute_concept_neurons_based_on_${COMPUTE_CONCEPT_NEURONS_BASED_ON}__derivative_${DERIVATIVE}.log for details."