#!/bin/bash

# Create log directory
mkdir -p logs/unlearning

# Define colors
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
BLUE='\033[0;34m'
NC='\033[0m' # No color

# Define valid model-dataset combinations
COMBINATIONS=(
    "resnet9 mnist"
    "lenet mnist"
    "lenet svhn"
    "resnet9 svhn"
    "lenet cifar10"
    "resnet9 cifar10"
    "resnet18 cifar10"
    "resnet9 cifar100"
    "resnet18 cifar100"
    "resnet18 tinyimagenet"
)

# Define forget type options
FORGET_OPTIONS=(
    "single"    # Forget a single class
    "multiple"  # Forget multiple classes
    "all"       # Forget all classes
)

# Define sensitivity calculation method options
SENS_OPTIONS=(
    "noise"     # Use synthetic noise
    "sample"    # Use sample data
    "hybrid"    # Use a mix of both
)

TOTAL_TASKS=${#COMBINATIONS[@]}

# Parse command line arguments
MODEL=""
DATASET=""
FORGET_TYPE="single"
CLASSES=""
SENS_METHOD="noise"
TARGET_LAYERS=""
RUN_ALL=false
MULTI_CLASS_MODE=false
LAMBDA_VALUE="10.0"

# Print usage information and available options
function print_usage {
    echo -e "${YELLOW}Usage:${NC} $0 [options]"
    echo "Options:"
    echo "  -m, --model MODEL      Specify model (resnet9, resnet18, lenet, allcnn)"
    echo "  -d, --dataset DATASET  Specify dataset (mnist, svhn, cifar10, cifar100)"
    echo "  -f, --forget TYPE      Specify forget type (single, multiple, all) Default: single"
    echo "  -c, --classes CLASSES  Specify class indices to forget, comma-separated (e.g.: 0,1,2)"
    echo "                         Note: For multiple type, will forget sequentially"
    echo "  -s, --sens METHOD      Specify sensitivity calculation method (noise, sample, hybrid) Default: noise"
    echo "  -t, --target LAYERS    Specify target layers, comma-separated (Default: auto-select)"
    echo "  --multi-class-mode     Force multi-class one-shot forgetting mode (recommended for multi-class)"
    echo "  -a, --all-combinations Run all available model-dataset combinations"
    echo "  -l, --list             List all available combinations"
    echo "  -h, --help             Show this help information"
    echo "  --lambda VALUE         Specify lambda parameter value (Default: 10.0)"
    echo ""
    echo "Examples:"
    echo "  $0 -m resnet9 -d cifar10 -f single -c 0       # Forget class 0 in CIFAR-10"
    echo "  $0 -m resnet18 -d cifar100 -f multiple -c 0,10,20,30  # Sequentially forget multiple classes in CIFAR-100"
    echo "  $0 -m resnet18 -d cifar100 -f multiple -c 0,10,20,30 --multi-class-mode  # Use one-shot multi-class forgetting"
    echo "  $0 -m lenet -d mnist -f all --multi-class-mode # Forget all classes in MNIST (one-shot)"
    echo "  $0 -a -f single -c 0                          # Forget class 0 for all combinations"
}

# List all available model-dataset combinations
function list_combinations {
    echo -e "${YELLOW}Available model-dataset combinations:${NC}"
    for i in "${!COMBINATIONS[@]}"; do
        echo "  $((i+1)). ${COMBINATIONS[$i]}"
    done
}

# Process command line arguments
while [[ $# -gt 0 ]]; do
    case $1 in
        -m|--model)
            MODEL="$2"
            shift 2
            ;;
        -d|--dataset)
            DATASET="$2"
            shift 2
            ;;
        -f|--forget)
            FORGET_TYPE="$2"
            shift 2
            ;;
        -c|--classes)
            CLASSES="$2"
            shift 2
            ;;
        -s|--sens)
            SENS_METHOD="$2"
            shift 2
            ;;
        -t|--target)
            TARGET_LAYERS="$2"
            shift 2
            ;;
        --multi-class-mode)
            MULTI_CLASS_MODE=true
            shift
            ;;
        -a|--all-combinations)
            RUN_ALL=true
            shift
            ;;
        -l|--list)
            list_combinations
            exit 0
            ;;
        -h|--help)
            print_usage
            exit 0
            ;;
        --lambda)
            LAMBDA_VALUE="$2"
            shift 2
            ;;
        *)
            echo "Error: Unknown parameter $1"
            print_usage
            exit 1
            ;;
    esac
done

# Validate forget type
if [[ ! " ${FORGET_OPTIONS[@]} " =~ " ${FORGET_TYPE} " ]]; then
    echo "Error: Invalid forget type '$FORGET_TYPE'"
    echo "Valid options: ${FORGET_OPTIONS[@]}"
    exit 1
fi

# Validate sensitivity calculation method
if [[ ! " ${SENS_OPTIONS[@]} " =~ " ${SENS_METHOD} " ]]; then
    echo "Error: Invalid sensitivity calculation method '$SENS_METHOD'"
    echo "Valid options: ${SENS_OPTIONS[@]}"
    exit 1
fi

# Validate input for "all combinations" mode
if [[ "$RUN_ALL" == "true" ]]; then
    if [[ -n "$MODEL" || -n "$DATASET" ]]; then
        echo -e "${YELLOW}Warning:${NC} Specified model or dataset will be ignored in all-combinations mode"
    fi
else
    # Validate input for single combination
    if [[ -z "$MODEL" || -z "$DATASET" ]]; then
        echo "Error: Must specify model and dataset, or use -a/--all-combinations option to run all combinations"
        print_usage
        exit 1
    fi
fi

# Run a single experiment with specified model and dataset
function run_experiment {
    local model=$1
    local dataset=$2
    
    echo -e "\n${BLUE}====================================================${NC}"
    if [[ "$RUN_ALL" == "true" ]]; then
        echo -e "${BLUE}[${current_combo}/${total_combos}] Running combination: ${model} + ${dataset}${NC}"
    fi
    
    # Validate combination
    VALID=false
    for combo in "${COMBINATIONS[@]}"; do
        if [[ "$combo" == "$model $dataset" ]]; then
            VALID=true
            break
        fi
    done

    if [[ "$VALID" == "false" ]]; then
        echo -e "${YELLOW}Warning:${NC} Uncommon model-dataset combination: $model + $dataset"
        if [[ "$RUN_ALL" == "false" ]]; then
            echo "If you're sure this combination is valid, the experiment will continue, but problems may occur."
            echo "Available combinations:"
            list_combinations
            echo "Continue? (y/n)"
            read -r CONFIRM
            if [[ "$CONFIRM" != "y" ]]; then
                echo "Experiment cancelled"
                return 1
            fi
        else
            echo "Skipping this combination, continuing to next..."
            return 0
        fi
    fi
    
    # Generate experiment description
    EXPERIMENT_DESC="${model}_${dataset}_${FORGET_TYPE}"
    if [[ "$FORGET_TYPE" != "all" && -n "$CLASSES" ]]; then
        EXPERIMENT_DESC="${EXPERIMENT_DESC}_classes${CLASSES//,/-}"
    fi
    EXPERIMENT_DESC="${EXPERIMENT_DESC}_${SENS_METHOD}"
    
    # If using multi-class one-shot mode, add identifier
    if [[ "$MULTI_CLASS_MODE" == "true" ]]; then
        EXPERIMENT_DESC="${EXPERIMENT_DESC}_multimode"
    fi

    # Create experiment log directory
    LOG_DIR="logs/unlearning/${EXPERIMENT_DESC}"
    mkdir -p "$LOG_DIR"
    LOG_FILE="${LOG_DIR}/unlearn_$(date +%Y%m%d_%H%M%S).log"

    # Find pre-trained model
    MODEL_PATH="checkpoints/${model}_${dataset}_best.pth"
    if [[ ! -f "$MODEL_PATH" ]]; then
        echo -e "${YELLOW}Warning:${NC} Pre-trained model not found at default path: $MODEL_PATH"
        if [[ "$RUN_ALL" == "false" ]]; then
            echo "Please enter correct model path, or press Enter to cancel:"
            read -r CUSTOM_PATH
            if [[ -z "$CUSTOM_PATH" ]]; then
                echo "Experiment cancelled"
                return 1
            else
                MODEL_PATH="$CUSTOM_PATH"
            fi
        else
            echo "Skipping this combination, continuing to next..."
            return 0
        fi
    fi

    # Build Python arguments
    PYTHON_ARGS=""
    if [[ "$FORGET_TYPE" == "single" ]]; then
        if [[ -z "$CLASSES" ]]; then
            echo "Error: Single class forgetting requires class index specification"
            return 1
        fi
        PYTHON_ARGS="--class_idx $CLASSES"
    elif [[ "$FORGET_TYPE" == "multiple" ]]; then
        if [[ -z "$CLASSES" ]]; then
            echo "Error: Multi-class forgetting requires class indices, comma-separated"
            return 1
        fi
        # For multiple type, will forget specified classes sequentially
        if [[ "$MULTI_CLASS_MODE" == "true" ]]; then
            echo -e "${YELLOW}Warning: Using multi-class one-shot forgetting mode, sequential forgetting logic will be ignored${NC}"
        fi
        PYTHON_ARGS="--class_idxs $CLASSES"
    elif [[ "$FORGET_TYPE" == "all" ]]; then
        PYTHON_ARGS="--all"
    fi

    # Add other parameters
    PYTHON_ARGS="$PYTHON_ARGS --sens_source $SENS_METHOD"
    if [[ -n "$TARGET_LAYERS" ]]; then
        PYTHON_ARGS="$PYTHON_ARGS --target_layers $TARGET_LAYERS"
    fi
    
    # Add multi-class one-shot mode parameter
    if [[ "$MULTI_CLASS_MODE" == "true" ]]; then
        PYTHON_ARGS="$PYTHON_ARGS --multi_class_mode"
    fi

    # Add lambda parameter
    if [[ -n "$LAMBDA_VALUE" ]]; then
        PYTHON_ARGS="$PYTHON_ARGS --lambda_value $LAMBDA_VALUE"
    fi

    # Output experiment information
    echo -e "${GREEN}Starting forgetting experiment${NC}"
    echo "Model: $model"
    echo "Dataset: $dataset"
    echo "Forget type: $FORGET_TYPE"
    if [[ -n "$CLASSES" ]]; then
        echo "Forget classes: $CLASSES"
    fi
    echo "Sensitivity calculation method: $SENS_METHOD"
    if [[ -n "$TARGET_LAYERS" ]]; then
        echo "Target layers: $TARGET_LAYERS"
    fi
    if [[ "$MULTI_CLASS_MODE" == "true" ]]; then
        echo -e "${YELLOW}Multi-class one-shot forgetting mode: Enabled${NC}"
    fi
    echo "Model path: $MODEL_PATH"
    echo "Log file: $LOG_FILE"
    echo ""

    # Start forgetting experiment
    echo -e "${GREEN}Executing command:${NC}"
    COMMAND="python methods/our/run.py --model $model --dataset $dataset --model_path $MODEL_PATH $PYTHON_ARGS"
    echo "$COMMAND"
    echo ""

    # Execute command and redirect output to log file
    nohup $COMMAND > "$LOG_FILE" 2>&1 &

    # Get process ID
    PID=$!
    echo -e "${GREEN}Experiment started in background (PID: $PID)${NC}"
    echo "Use the following command to view progress:"
    echo "  tail -f $LOG_FILE"
    echo "Use the following command to terminate experiment:"
    echo "  kill $PID"
    echo ""
    echo "Experiment results will be stored in: $LOG_DIR"
    
    # Record PID to config file
    echo "$PID" >> "${LOG_DIR}/pid.txt"
    
    return 0
}

# Main logic
if [[ "$RUN_ALL" == "true" ]]; then
    echo -e "${GREEN}Starting all available model-dataset combinations${NC}"
    echo "Total ${#COMBINATIONS[@]} combinations"
    
    # Create a main folder to store results of all combinations
    MAIN_DIR="logs/unlearning/all_combinations_$(date +%Y%m%d_%H%M%S)"
    mkdir -p "$MAIN_DIR"
    echo "Overview of all experiments will be saved in: $MAIN_DIR"
    
    # Run all combinations
    current_combo=1
    total_combos=${#COMBINATIONS[@]}
    
    for combo in "${COMBINATIONS[@]}"; do
        read -r combo_model combo_dataset <<< "$combo"
        run_experiment "$combo_model" "$combo_dataset"
        
        # Continue to next even if experiment fails
        current_combo=$((current_combo + 1))
        
        # Avoid starting too many processes at once, wait a short while
        sleep 2
    done
    
    echo -e "${GREEN}All experiments started!${NC}"
    echo "You can view logs and results of individual experiments in the following directory:"
    echo "  logs/unlearning/"
    
else
    # Run single experiment
    run_experiment "$MODEL" "$DATASET"
fi
