import argparse
import functools
import json
import os
import random
import math
import multiprocessing as mp
import gc
import psutil
import csv
from datetime import datetime

import datasets
import numpy as np
import textattack
import torch
import tqdm
import transformers
from lime.lime_text import LimeTextExplainer, IndexedString
from omegaconf import OmegaConf

from configs import DATASET_CONFIGS

# Import your generic sequence classifier
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'nlp_training'))
from seq_classifier import GenericSequenceClassifier
from exp_data import get_exp_data

# Additional TextAttack imports for custom attack recipes
# Initialize all to None first
RepeatModification = None
StopwordModification = None 
InputColumnModification = None
WordEmbeddingDistance = None
MaxModificationRate = None
PartOfSpeech = None
SBERT = None
WordSwapEmbedding = None
WordSwapMaskedLM = None
GreedyWordSwapWIR = None
UntargetedClassification = None
Attack = None

# Try importing each component individually
try:
    from textattack.constraints.pre_transformation import RepeatModification
    print("✓ RepeatModification imported")
except ImportError as e:
    print(f"✗ RepeatModification failed: {e}")

try:
    from textattack.constraints.pre_transformation import StopwordModification
    print("✓ StopwordModification imported")
except ImportError as e:
    print(f"✗ StopwordModification failed: {e}")

try:
    from textattack.constraints.pre_transformation import InputColumnModification
    print("✓ InputColumnModification imported")
except ImportError as e:
    print(f"✗ InputColumnModification failed: {e}")

try:
    from textattack.constraints.semantics import WordEmbeddingDistance
    print("✓ WordEmbeddingDistance imported")
except ImportError as e:
    print(f"✗ WordEmbeddingDistance failed: {e}")

try:
    from textattack.constraints.overlap import MaxWordsPerturbed
    print("✓ MaxWordsPerturbed imported")
    # For compatibility, alias it to the name we expect
    MaxModificationRate = MaxWordsPerturbed
except ImportError as e:
    print(f"✗ MaxWordsPerturbed failed: {e}")
    MaxModificationRate = None

try:
    from textattack.constraints.grammaticality import PartOfSpeech
    print("✓ PartOfSpeech imported")
except ImportError as e:
    print(f"✗ PartOfSpeech failed: {e}")

try:
    from textattack.constraints.semantics.sentence_encoders import SBERT
    print("✓ SBERT imported")
except ImportError as e:
    print(f"✗ SBERT failed: {e}")

try:
    from textattack.transformations import WordSwapEmbedding
    print("✓ WordSwapEmbedding imported")
except ImportError as e:
    print(f"✗ WordSwapEmbedding failed: {e}")

try:
    from textattack.transformations import WordSwapMaskedLM
    print("✓ WordSwapMaskedLM imported")
except ImportError as e:
    print(f"✗ WordSwapMaskedLM failed: {e}")

try:
    from textattack.search_methods import GreedyWordSwapWIR
    print("✓ GreedyWordSwapWIR imported")
except ImportError as e:
    print(f"✗ GreedyWordSwapWIR failed: {e}")

try:
    from textattack.goal_functions import UntargetedClassification
    print("✓ UntargetedClassification imported")
except ImportError as e:
    print(f"✗ UntargetedClassification failed: {e}")

try:
    from textattack import Attack
    print("✓ Attack imported")
except ImportError as e:
    print(f"✗ Attack failed: {e}")

print("TextAttack import diagnostics complete")


# NUM_SAMPLES_FOR_EVALUATION = 1000
NUM_SAMPLES_FOR_EVALUATION = 10  # Reduced for testing with large models

# Memory optimization settings for large models
LIME_NUM_SAMPLES_LARGE_MODEL = 100  # Reduced from 1000 for large models
BATCH_SIZE_LARGE_MODEL = 8  # Reduced batch size for large models
MAX_LENGTH_DEFAULT = 256  # Reduced from 512

def print_memory_usage(label=""):
    """Print current memory usage."""
    process = psutil.Process(os.getpid())
    memory_mb = process.memory_info().rss / 1024 / 1024
    print(f"🧠 Memory usage {label}: {memory_mb:.1f} MB")
    if torch.cuda.is_available():
        memory_gb = torch.cuda.memory_allocated() / 1024**3
        print(f"🔥 GPU memory: {memory_gb:.1f} GB")

def cleanup_memory():
    """Force garbage collection and clear GPU cache."""
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

def detect_dataset_from_checkpoint(checkpoint_path):
    """
    Automatically detect dataset information from experiment_config.json in the checkpoint directory.
    
    Args:
        checkpoint_path: Path to the model checkpoint directory
        
    Returns:
        tuple: (dataset_name, num_classes) or (None, None) if not found
    """
    config_path = os.path.join(checkpoint_path, "experiment_config.json")
    
    if not os.path.exists(config_path):
        return None, None
    
    try:
        with open(config_path, 'r') as f:
            config = json.load(f)
        
        # Extract dataset information
        dataset_info = config.get('dataset_info', {})
        dataset_name = dataset_info.get('name')
        num_classes = dataset_info.get('num_classes')
        
        # Fallback to args.dataset_name if dataset_info is not available
        if not dataset_name:
            args_info = config.get('args', {})
            dataset_name = args_info.get('dataset_name')
        
        if dataset_name:
            # For custom datasets that require the "custom_" prefix
            # Always use custom_ prefix for HuggingFace datasets to avoid local file lookup
            # if dataset_name not in DATASET_CONFIGS or not DATASET_CONFIGS[dataset_name].get('local_path'):
            #     if not dataset_name.startswith('custom_'):
            #         dataset_name = f"custom_{dataset_name}"
            
            dataset_name = f"custom_{dataset_name}"
            print(f"🔍 Auto-detected dataset from {config_path}:")
            print(f"   📊 Dataset: {dataset_name}")
            if num_classes:
                print(f"   🏷️  Classes: {num_classes}")
            
            return dataset_name, num_classes
        else:
            print(f"⚠️  No dataset name found in {config_path}")
            return None, None
            
    except (json.JSONDecodeError, KeyError) as e:
        print(f"⚠️  Error reading experiment config from {config_path}: {e}")
        return None, None

def create_a2t_results_dir(model_path=None):
    """Create A2T_results directory and return its path."""
    if model_path is not None:
        # Create A2T_results in the same directory as the model checkpoint
        model_dir = os.path.dirname(os.path.abspath(model_path))
        results_dir = os.path.join(model_dir, "A2T_results")
    else:
        # Fallback to current working directory
        results_dir = os.path.join(os.getcwd(), "A2T_results")
    
    os.makedirs(results_dir, exist_ok=True)
    return results_dir

def check_accuracy_exists(model_path):
    """Check if accuracy evaluation has already been completed for this model."""
    a2t_results_dir = create_a2t_results_dir(model_path)
    accuracy_log_path = os.path.join(a2t_results_dir, "accuracy_eval_logs.json")
    return os.path.exists(accuracy_log_path)

def get_model_name_from_path(model_path):
    """Extract a readable model name from the checkpoint path."""
    # Extract meaningful parts from the path
    path_parts = model_path.strip('/').split('/')
    
    # Look for key identifiers in the path
    model_identifiers = []
    for part in path_parts:
        if any(keyword in part.lower() for keyword in ['gpt', 'bert', 'roberta', 'distilbert', 'llama', 'mistral']):
            model_identifiers.append(part)
        elif 'epochs_' in part:
            model_identifiers.append(part)
        elif part in ['lora', 'head_only', 'full']:
            model_identifiers.append(part)
    
    # Fallback to last few directory names if no identifiers found
    if not model_identifiers:
        model_identifiers = path_parts[-3:] if len(path_parts) >= 3 else path_parts
    
    return "_".join(model_identifiers)

def write_summary_csv(results_dir, summary_data):
    """Write or append to the summary CSV file."""
    csv_path = os.path.join(results_dir, "attack_summary.csv")
    
    fieldnames = [
        "timestamp", "dataset", "model_name", "model_checkpoint", "random_seed", 
        "query_budget", "max_words_changed", "num_samples", "attack_name", "original_accuracy", "accuracy_under_attack", 
        "attack_success_rate", "avg_num_queries", "avg_pct_perturbed"
    ]
    
    # Check if file exists to determine if we need to write headers
    file_exists = os.path.exists(csv_path)
    
    with open(csv_path, 'a', newline='') as csvfile:
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
        
        if not file_exists:
            writer.writeheader()
        
        writer.writerow(summary_data)
    
    print(f"📊 Summary data written to: {csv_path}")

def calculate_accuracy_under_attack(results, total_dataset_size, correct_indices_attacked, original_accuracy):
    """
    Calculate the accuracy of the model on the full dataset after attacks.
    
    Args:
        results: Attack results (only on correctly classified samples)
        total_dataset_size: Total number of samples in the original dataset
        correct_indices_attacked: List of indices of correctly classified samples that were attacked
        original_accuracy: Original accuracy as a percentage (0-100)
    
    Returns:
        Accuracy under attack as percentage (0-100)
    """
    # Count successful attacks (where model was fooled)
    successful_attacks = sum(1 for r in results 
                           if not isinstance(r, (textattack.attack_results.FailedAttackResult, 
                                               textattack.attack_results.SkippedAttackResult)))
    
    # Calculate the number of originally correct samples
    originally_correct = int(total_dataset_size * original_accuracy / 100)
    
    # Samples that were originally correct but not attacked (remain correct)
    originally_correct_not_attacked = originally_correct - len(correct_indices_attacked)
    
    # Samples that were originally correct, attacked, but attack failed (remain correct)
    attacked_but_failed = len(results) - successful_attacks
    
    # Total samples that remain correct after attacks
    total_correct_after_attack = originally_correct_not_attacked + attacked_but_failed
    
    # Accuracy under attack = correct samples / total samples
    accuracy_under_attack = total_correct_after_attack / total_dataset_size if total_dataset_size > 0 else 0
    return accuracy_under_attack * 100

# 1. Add device detection function
def get_device():
    """Get the best available device"""
    if torch.cuda.is_available():
        return torch.device("cuda")
    elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
        return torch.device("mps")
    else:
        return torch.device("cpu")

def is_large_model(model_path):
    """
    Detect if this is a large model that needs memory optimization.
    """
    try:
        if os.path.exists(os.path.join(model_path, 'config.json')):
            with open(os.path.join(model_path, 'config.json'), 'r') as f:
                config = json.load(f)
            
            # Check base model name or architecture
            base_model = config.get('base_model_name', config.get('_name_or_path', ''))
            architectures = config.get('architectures', [])
            
            # Consider GPT models, large BERT variants, or models with large hidden sizes as "large"
            large_model_indicators = [
                'gpt', 'llama', 'mistral', 'opt', 'bloom', 'falcon',
                'large', 'xl', '7b', '13b', '30b'
            ]
            
            model_str = (str(base_model) + ' ' + ' '.join(architectures)).lower()
            is_large = any(indicator in model_str for indicator in large_model_indicators)
            
            # Also check hidden size
            hidden_size = config.get('hidden_size', 0)
            if hidden_size > 1024:  # BERT-base is 768, BERT-large is 1024
                is_large = True
                
            return is_large
            
    except Exception:
        pass
    
    return False

def get_tokenizer_for_model(model_path):
    """
    Get the appropriate tokenizer for a model by checking its config.
    Ensures all required special tokens are present for A2T attacks.
    """
    try:
        # First, try to load a saved tokenizer from the model directory
        if os.path.exists(os.path.join(model_path, 'tokenizer_config.json')):
            tokenizer = transformers.AutoTokenizer.from_pretrained(model_path, use_fast=True)
            print(f"Loaded tokenizer from model path: {model_path}")
            return tokenizer
        
        # Fallback: Try to load config from the model path for base model name
        if os.path.exists(os.path.join(model_path, 'config.json')):
            with open(os.path.join(model_path, 'config.json'), 'r') as f:
                config = json.load(f)
            
            # Get base model name if it's a generic sequence classifier
            if 'base_model_name' in config:
                base_model_name = config['base_model_name']
            else:
                # Fallback to model_type or architectures
                base_model_name = config.get('_name_or_path', model_path)
            
            tokenizer = transformers.AutoTokenizer.from_pretrained(base_model_name, use_fast=True)
        else:
            # Fallback to using the model path directly
            tokenizer = transformers.AutoTokenizer.from_pretrained(model_path, use_fast=True)
            
        # Ensure all required special tokens are present for A2T attacks
        tokens_to_add = {}
        
        if tokenizer.pad_token_id is None:
            tokens_to_add['pad_token'] = '[PAD]'
        
        if tokenizer.unk_token is None:
            tokens_to_add['unk_token'] = '[UNK]'
            
        # For A2T MLM variant - add mask token if not present
        if tokenizer.mask_token_id is None:
            tokens_to_add['mask_token'] = '[MASK]'
        
        # Add missing special tokens
        if tokens_to_add:
            tokens_added = tokenizer.add_special_tokens(tokens_to_add)
            print(f"Added {tokens_added} special tokens for A2T compatibility: {list(tokens_to_add.keys())}")
            
        return tokenizer
        
    except Exception as e:
        print(f"Warning: Could not determine tokenizer for {model_path}, using default: {e}")
        tokenizer = transformers.AutoTokenizer.from_pretrained("bert-base-uncased", use_fast=True)
        return tokenizer

def load_model_with_optimizations(model_path, num_classes=None):
    """
    Load model with memory optimizations for large models.
    Device-aware: CUDA supports float16, MPS/CPU use float32 with other optimizations.
    
    Args:
        model_path: Path to the model checkpoint
        num_classes: Number of classes for the classifier head (if None, infers from config)
    """
    is_large = is_large_model(model_path)
    device = get_device()
    
    if is_large:
        print("🔧 Loading large model with memory optimizations...")
        
        # Device-specific optimizations
        if device.type == "cuda":
            # CUDA supports mixed precision
            if num_classes is not None:
                model = GenericSequenceClassifier.from_pretrained(
                    model_path,
                    num_labels=num_classes,
                    torch_dtype=torch.float16,  # Use half precision
                    low_cpu_mem_usage=True,     # Use memory-efficient loading
                    device_map="auto",
                    ignore_mismatched_sizes=True  # Allow classifier head size mismatch
                )
            else:
                model = GenericSequenceClassifier.from_pretrained(
                    model_path,
                    torch_dtype=torch.float16,  # Use half precision
                    low_cpu_mem_usage=True,     # Use memory-efficient loading
                    device_map="auto"
                )
            print("   ✓ Loaded in float16 precision (CUDA)")
        else:
            # MPS and CPU: no mixed precision, but other optimizations
            if num_classes is not None:
                model = GenericSequenceClassifier.from_pretrained(
                    model_path,
                    num_labels=num_classes,
                    low_cpu_mem_usage=True,     # Use memory-efficient loading
                    torch_dtype=torch.float32,   # Keep full precision for MPS/CPU
                    ignore_mismatched_sizes=True  # Allow classifier head size mismatch
                )
            else:
                model = GenericSequenceClassifier.from_pretrained(
                    model_path,
                    low_cpu_mem_usage=True,     # Use memory-efficient loading
                    torch_dtype=torch.float32   # Keep full precision for MPS/CPU
                )
            if device.type == "mps":
                print("   ✓ Loaded with memory-efficient loading (MPS - no mixed precision)")
            else:
                print("   ✓ Loaded with memory-efficient loading (CPU)")
    else:
        # Standard loading for smaller models
        if num_classes is not None:
            model = GenericSequenceClassifier.from_pretrained(
                model_path,
                num_labels=num_classes,
                ignore_mismatched_sizes=True  # Allow classifier head size mismatch
            )
        else:
            model = GenericSequenceClassifier.from_pretrained(model_path)
    
    if num_classes is not None:
        print(f"   📊 Model loaded with {num_classes} classes")
    
    return model, is_large

# Helper functions for collating data
def collate_fn(input_columns, data):
    input_texts = []
    labels = []
    for d in data:
        label = d["label"]
        _input = tuple(d[c] for c in input_columns)
        if len(_input) == 1:
            _input = _input[0]
        input_texts.append(_input)
        labels.append(label)
    return input_texts, torch.tensor(labels)


def get_model_config_info(checkpoint_path):
    """
    Extract dataset and class information from model's experiment configuration.
    
    Args:
        checkpoint_path: Path to the model checkpoint directory
        
    Returns:
        tuple: (num_classes, class_names) or (None, None) if config not found
    """
    # Try checkpoint directory first
    config_path = os.path.join(checkpoint_path, "experiment_config.json")
    
    # If not found, try parent directory (common for final_model structure)
    if not os.path.exists(config_path):
        config_path = os.path.join(os.path.dirname(checkpoint_path), "experiment_config.json")
    
    if os.path.exists(config_path):
        try:
            with open(config_path, 'r') as f:
                config = json.load(f)
            
            dataset_info = config.get('dataset_info', {})
            num_classes = dataset_info.get('num_classes')
            class_names = dataset_info.get('class_names')
            
            if num_classes and class_names:
                return num_classes, class_names
            elif num_classes:
                # Generate generic class names if specific names not available
                return num_classes, [f"class_{i}" for i in range(num_classes)]
                
        except (json.JSONDecodeError, KeyError) as e:
            print(f"⚠️  Warning: Could not parse experiment config: {e}")
    
    return None, None


def get_dataset_info_from_hf(dataset, checkpoint_path=None):
    """
    Extract dataset information from a HuggingFace dataset for A2T compatibility.
    Try to get the correct number of classes from model config first.
    
    Args:
        dataset: HuggingFace Dataset object
        checkpoint_path: Path to model checkpoint (optional)
        
    Returns:
        tuple: (dataset_columns, label_names, num_classes)
    """
    # Determine dataset columns
    column_names = dataset.column_names
    
    # Check if it's a single-input task (text classification)
    if "text" in column_names:
        dataset_columns = (["text"], "label")
    # Check if it's a paired-input task (NLI, paraphrase detection)
    elif "premise" in column_names and "hypothesis" in column_names:
        dataset_columns = (["premise", "hypothesis"], "label")
    elif "sentence1" in column_names and "sentence2" in column_names:
        dataset_columns = (["sentence1", "sentence2"], "label")
    elif "sentence" in column_names:  # For SST2
        dataset_columns = (["sentence"], "label")
    else:
        # Default to first non-label column as text
        text_columns = [col for col in column_names if col != "label"]
        if len(text_columns) == 1:
            dataset_columns = ([text_columns[0]], "label")
        else:
            dataset_columns = (text_columns, "label")
    
    # First, try to get class info from model configuration
    if checkpoint_path:
        model_num_classes, model_class_names = get_model_config_info(checkpoint_path)
        if model_num_classes and model_class_names:
            print(f"📋 Using class info from model config: {model_num_classes} classes")
            return dataset_columns, model_class_names, model_num_classes
    
    # Fallback to dataset inference (original logic)
    print("⚠️  Model config not found or incomplete, inferring from dataset...")
    
    # Extract label information from HuggingFace dataset features
    if "label" in dataset.features:
        label_feature = dataset.features["label"]
        
        # Get label names and number of classes from ClassLabel feature
        if hasattr(label_feature, 'names') and label_feature.names:
            label_names = label_feature.names
            num_classes = len(label_names)
        elif hasattr(label_feature, 'num_classes'):
            num_classes = label_feature.num_classes
            # Create generic names if specific names aren't available
            label_names = [f"class_{i}" for i in range(num_classes)]
        else:
            # Fallback: count unique labels in dataset
            unique_labels = sorted(set(dataset["label"]))
            # Filter out -1 labels for counting
            unique_labels = [label for label in unique_labels if label != -1]
            num_classes = len(unique_labels)
            label_names = [f"class_{i}" for i in range(num_classes)]
    else:
        # No label column found
        raise ValueError("No 'label' column found in dataset")
    
    return dataset_columns, label_names, num_classes


def load_dataset(name, seed=42, num_samples=None, checkpoint_path=None):
    # Check if it's a custom dataset (prefix with "custom_")
    if name.startswith("custom_"):
        # Extract the actual dataset name (remove "custom_" prefix)
        actual_dataset_name = name.replace("custom_", "")
        
        # Load dataset config using the same method as TextFooler
        config_path = os.path.join(os.path.dirname(__file__), '..', 'conf', 'dataset', f'{actual_dataset_name}.yaml')
        
        if not os.path.exists(config_path):
            raise ValueError(f"Dataset config file not found: {config_path}")
        
        dataset_config = OmegaConf.load(config_path)
        
        # Use your custom loading function from exp_data
        texts, labels, num_classes = get_exp_data(
            dataset_config=dataset_config,
            seed=seed,  # Now using the passed seed parameter
            num_samples=num_samples,  # Use all samples for evaluation
        )
        print(f"dataset_config = {dataset_config}")
        print(f"num_classes = {num_classes}")

        # Convert to HuggingFace dataset format for compatibility with A2T
        # Note: get_exp_data returns texts as lists of tokens, so we need to join them back to strings
        text_strings = [" ".join(text) if isinstance(text, list) else text for text in texts]
        
        dataset_dict = {
            "text": text_strings,
            "label": labels
        }
        dataset = datasets.Dataset.from_dict(dataset_dict)
        
        # Filter out any -1 labels (same as original logic)
        dataset = dataset.filter(lambda x: x["label"] != -1)
        
        # Extract dataset information using the new function, passing checkpoint path
        dataset_columns, label_names, num_classes_detected = get_dataset_info_from_hf(dataset, checkpoint_path)
        
        # Use the num_classes from get_exp_data, but fall back to detected if needed
        final_num_classes = num_classes if num_classes is not None else num_classes_detected
        
        return dataset, (dataset_columns, label_names, final_num_classes)  # Return dataset and info tuple
    
    # Original loading logic for predefined datasets
    if name not in DATASET_CONFIGS:
        raise ValueError(f"Unknown dataset {name}")
    dataset_config = DATASET_CONFIGS[name]
    if "local_path" in dataset_config:
        dataset = datasets.load_dataset(
            "csv",
            data_files=os.path.join(dataset_config["local_path"], "test.tsv"),
            delimiter="\t",
        )["train"]
    else:
        if "split" in dataset_config:
            dataset = datasets.load_dataset(
                dataset_config["remote_name"], split=dataset_config["split"]
            )
        else:
            dataset = datasets.load_dataset(dataset_config["remote_name"], split="test")

    dataset = dataset.filter(lambda x: x["label"] != -1)
    return dataset


def calc_attack_stats(results):
    total_attacks = len(results)

    all_num_words = np.zeros(total_attacks)
    perturbed_word_percentages = np.zeros(total_attacks)
    failed_attacks = 0
    skipped_attacks = 0
    successful_attacks = 0

    for i, result in enumerate(results):
        all_num_words[i] = len(result.original_result.attacked_text.words)
        if isinstance(result, textattack.attack_results.FailedAttackResult):
            failed_attacks += 1
            continue
        elif isinstance(result, textattack.attack_results.SkippedAttackResult):
            skipped_attacks += 1
            continue
        else:
            successful_attacks += 1
        num_words_changed = len(
            result.original_result.attacked_text.all_words_diff(
                result.perturbed_result.attacked_text
            )
        )
        if len(result.original_result.attacked_text.words) > 0:
            perturbed_word_percentage = (
                num_words_changed
                * 100.0
                / len(result.original_result.attacked_text.words)
            )
        else:
            perturbed_word_percentage = 0
        perturbed_word_percentages[i] = perturbed_word_percentage

    attack_success_rate = successful_attacks * 100.0 / total_attacks
    attack_success_rate = round(attack_success_rate, 2)

    perturbed_word_percentages = perturbed_word_percentages[
        perturbed_word_percentages > 0
    ]
    average_perc_words_perturbed = round(perturbed_word_percentages.mean(), 2)

    num_queries = np.array(
        [
            r.num_queries
            for r in results
            if not isinstance(r, textattack.attack_results.SkippedAttackResult)
        ]
    )
    avg_num_queries = round(num_queries.mean(), 2)

    return attack_success_rate, avg_num_queries, average_perc_words_perturbed


#####################################################################################

def eval_robustness(args):
    # For custom datasets, we don't need to check DATASET_CONFIGS
    if not args.dataset.startswith("custom_") and args.dataset not in DATASET_CONFIGS:
        raise ValueError(f"Unknown dataset {args.dataset}")
    
    # Only get dataset_config for predefined datasets
    if args.dataset.startswith("custom_"):
        dataset_config = None  # Custom datasets don't use DATASET_CONFIGS
        # Pass the first checkpoint path for configuration info
        first_checkpoint = getattr(args, 'checkpoint_paths', [None])[0] if hasattr(args, 'checkpoint_paths') else None
        dataset_result = load_dataset(args.dataset, seed=args.seed, num_samples=args.num_samples, checkpoint_path=first_checkpoint)
        test_dataset, (dataset_columns, label_names, num_classes) = dataset_result  # Unpack for custom datasets
    else:
        dataset_config = DATASET_CONFIGS[args.dataset]
        test_dataset = load_dataset(args.dataset, seed=args.seed, num_samples=args.num_samples)  # Just dataset for predefined datasets
        dataset_columns = dataset_config["dataset_columns"]
        label_names = dataset_config["label_names"]
        num_classes = None  # Not needed for predefined datasets

    # Load accuracy data to get original accuracy and filter correctly predicted samples
    original_dataset_size = len(test_dataset)  # Store original dataset size for accuracy calculation
    all_correct_indices = set(range(len(test_dataset)))
    original_accuracies = {}
    
    print(f"📊 Total dataset size: {original_dataset_size} samples")
    
    for path in args.checkpoint_paths:
        # Check both potential locations for accuracy logs
        model_accuracy_log_path = os.path.join(path, "accuracy_eval_logs.json")
        a2t_results_dir = create_a2t_results_dir(path)
        a2t_accuracy_log_path = os.path.join(a2t_results_dir, "accuracy_eval_logs.json")
        
        accuracy_log_path = None
        if os.path.exists(a2t_accuracy_log_path):
            accuracy_log_path = a2t_accuracy_log_path
            print(f"📄 Found accuracy logs in A2T results directory: {a2t_accuracy_log_path}")
        elif os.path.exists(model_accuracy_log_path):
            accuracy_log_path = model_accuracy_log_path
            print(f"📄 Found accuracy logs in model directory: {model_accuracy_log_path}")
        
        if accuracy_log_path:
            with open(accuracy_log_path, "r") as f:
                logs = json.load(f)
                dataset_key = args.dataset if not args.dataset.startswith("custom_") else args.dataset
                if f"checkpoint-epoch-{args.epoch}" in logs and dataset_key in logs[f"checkpoint-epoch-{args.epoch}"]:
                    correct_indices = logs[f"checkpoint-epoch-{args.epoch}"][dataset_key]["correct_indices"]
                    original_accuracy = logs[f"checkpoint-epoch-{args.epoch}"][dataset_key]["accuracy"]
                    all_correct_indices = all_correct_indices.intersection(correct_indices)
                    original_accuracies[path] = original_accuracy
                    print(f"📈 Model {path}: {len(correct_indices)} correctly predicted samples (accuracy: {original_accuracy:.1%})")
                else:
                    raise ValueError(
                        f"❌ No accuracy data found for dataset '{dataset_key}' and epoch {args.epoch} in {accuracy_log_path}. "
                        f"Please run accuracy evaluation first:\n"
                        f"python evaluate_device_generic.py --dataset {args.dataset} --checkpoint-paths {path} "
                        f"--epoch {args.epoch} --accuracy --save-log"
                    )
        else:
            raise FileNotFoundError(
                f"❌ Accuracy evaluation logs not found for model: {path}\n"
                f"Robustness evaluation requires accuracy logs to identify correctly classified samples.\n"
                f"Please run accuracy evaluation first:\n"
                f"python evaluate_device_generic.py --dataset {args.dataset} --checkpoint-paths {path} "
                f"--epoch {args.epoch} --accuracy --save-log"
            )

    all_correct_indices = list(all_correct_indices)
    print(f"🎯 Samples correctly predicted by ALL models: {len(all_correct_indices)}")
    random.seed(42) # TODO: expose this seed as a parameter.
    random.shuffle(all_correct_indices)
    # indices_to_test = all_correct_indices[:args.num_samples]
    indices_to_test = all_correct_indices
    print(f"🔍 Selected {len(indices_to_test)} samples for robustness evaluation (only correctly classified samples)")

    test_dataset = test_dataset.select(indices_to_test)
    
    # Handle dataset configuration for custom vs predefined datasets
    if args.dataset.startswith("custom_"):
        # Custom datasets: use extracted information from HuggingFace dataset
        print(f"Custom dataset detected: {num_classes} classes with labels {label_names}")
        print(f"Dataset columns: {dataset_columns}")
    else:
        # Predefined datasets use config from DATASET_CONFIGS
        dataset_columns = dataset_config["dataset_columns"]
        label_names = dataset_config["label_names"]
    
    test_dataset = textattack.datasets.HuggingFaceDataset(
        test_dataset,
        dataset_columns=dataset_columns,
        label_names=label_names,
    )

    print("Evaluating robustness (this might take a long time)...")
    print(f"🔧 Attack Configuration:")
    print(f"   📊 Query Budget: {args.query_budget} (max model queries per attack)")
    if args.max_words_changed:
        print(f"   📝 Word Budget: Testing values {args.max_words_changed} (max words changed per attack)")
    else:
        print(f"   📝 Word Budget: Percentage-based (10% of words)")
    print(f"   🎯 Samples: {args.num_samples} (correctly classified samples to attack)")
    print()

    for path in args.checkpoint_paths:
        # Create A2T results directory for this specific model
        results_dir = create_a2t_results_dir(path)
        print(f"📁 Results will be saved to: {results_dir}")
        
        logs = {}
        logs["indices"] = indices_to_test
        logs[f"checkpoint-epoch-{args.epoch}"] = {}
        model_path = path
        
        # Load model and tokenizer with optimizations
        # Pass num_classes for custom datasets to handle classifier head size mismatch
        model_num_classes = num_classes if args.dataset.startswith("custom_") else None
        model, is_large = load_model_with_optimizations(model_path, num_classes=model_num_classes)
        tokenizer = get_tokenizer_for_model(model_path)
        
        # Resize model embeddings if tokenizer was modified
        if len(tokenizer) != model.config.vocab_size:
            print(f"Resizing model embeddings from {model.config.vocab_size} to {len(tokenizer)}")
            model.resize_token_embeddings(len(tokenizer))
        
        model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(
            model, tokenizer
        )
        
        # Get model name for results
        model_name = get_model_name_from_path(model_path)
        print(f"====== {model_path} =====")
        
        for attack_name in args.attacks:
            # Determine word budget values to test
            if attack_name == "a2t" and args.max_words_changed is not None:
                # Test multiple word budget values for A2T attack
                word_budget_values = args.max_words_changed
                print(f"🔄 Testing A2T attack with word budgets: {word_budget_values}")
            else:
                # For other attacks or when no word budget specified, use default (None means percentage-based)
                word_budget_values = [None]
            
            for word_budget in word_budget_values:
                # Create attack-specific subdirectory in A2T_results with word budget info
                if word_budget is not None:
                    attack_dir = os.path.join(results_dir, f"{args.dataset}_{model_name}_{attack_name}_epoch{args.epoch}_samples{args.num_samples}_words{word_budget}")
                    word_budget_str = str(word_budget)
                else:
                    attack_dir = os.path.join(results_dir, f"{args.dataset}_{model_name}_{attack_name}_epoch{args.epoch}_samples{args.num_samples}_wordspct")
                    word_budget_str = "percentage_based"
                
                os.makedirs(attack_dir, exist_ok=True)
                
                log_file_name = os.path.join(attack_dir, f"{attack_name}-test")
                attack_args = textattack.AttackArgs(
                    num_examples=args.num_samples,
                    parallel=(torch.cuda.device_count() > 1),
                    disable_stdout=True,
                    num_workers_per_device=1,
                    query_budget=args.query_budget,  # Use configurable query budget
                    shuffle=False,
                    log_to_txt=log_file_name + ".txt",
                    log_to_csv=log_file_name + ".csv",
                    silent=True,
                )
                if attack_name == "a2t":
                    # Check if imports are available
                    if any(x is None for x in [RepeatModification, StopwordModification, InputColumnModification, 
                                              PartOfSpeech, MaxModificationRate, SBERT, WordSwapEmbedding, 
                                              WordEmbeddingDistance, UntargetedClassification, GreedyWordSwapWIR, Attack]):
                        raise ImportError("Required TextAttack components not available for custom A2T attack")
                    
                    # Create a custom A2T attack with configurable word budget
                    constraints = [RepeatModification(), StopwordModification()]
                    input_column_modification = InputColumnModification(
                        ["premise", "hypothesis"], {"premise"}
                    )
                    constraints.append(input_column_modification)
                    constraints.append(PartOfSpeech(allow_verb_noun_swap=False))
                    
                    # Use either absolute word count or percentage-based constraint
                    if word_budget is not None:
                        # Use absolute word count constraint
                        constraints.append(MaxModificationRate(max_num_words=word_budget))
                        print(f"A2T attack configured with max {word_budget} words changed")
                    else:
                        # Use default percentage-based constraint (10%)
                        constraints.append(MaxModificationRate(max_percent=0.1))
                        print("A2T attack configured with max 10% words changed")
                    
                    sent_encoder = SBERT(
                        model_name="stsb-distilbert-base", threshold=0.9, metric="cosine"
                    )
                    constraints.append(sent_encoder)
                    
                    transformation = WordSwapEmbedding(max_candidates=20)
                    constraints.append(WordEmbeddingDistance(min_cos_sim=0.8))
                    
                    goal_function = UntargetedClassification(model_wrapper, model_batch_size=32)
                    
                    # Use GreedyWordSwapWIR without the unsupported parameter
                    search_method = GreedyWordSwapWIR(wir_method="gradient")
                    
                    attack = Attack(goal_function, constraints, transformation, search_method)
                elif attack_name == "a2t_mlm":
                    # Check if imports are available
                    if any(x is None for x in [RepeatModification, StopwordModification, InputColumnModification, 
                                              PartOfSpeech, MaxModificationRate, SBERT, WordSwapMaskedLM, 
                                              UntargetedClassification, GreedyWordSwapWIR, Attack]):
                        raise ImportError("Required TextAttack components not available for custom A2T-MLM attack")
                    
                    # Create a custom A2T-MLM attack without the unsupported parameter
                    constraints = [RepeatModification(), StopwordModification()]
                    input_column_modification = InputColumnModification(
                        ["premise", "hypothesis"], {"premise"}
                    )
                    constraints.append(input_column_modification)
                    constraints.append(PartOfSpeech(allow_verb_noun_swap=False))
                    constraints.append(MaxModificationRate(max_percent=0.1))
                    sent_encoder = SBERT(
                        model_name="stsb-distilbert-base", threshold=0.9, metric="cosine"
                    )
                    constraints.append(sent_encoder)
                    
                    transformation = WordSwapMaskedLM(
                        method="bae", max_candidates=20, min_confidence=0.0, batch_size=16
                    )
                    
                    goal_function = UntargetedClassification(model_wrapper, model_batch_size=32)
                    
                    # Use GreedyWordSwapWIR without the unsupported parameter
                    search_method = GreedyWordSwapWIR(wir_method="gradient")
                    
                    attack = Attack(goal_function, constraints, transformation, search_method)
                elif attack_name == "textfooler":
                    attack = textattack.attack_recipes.TextFoolerJin2019.build(
                        model_wrapper
                    )
                elif attack_name == "bae":
                    attack = textattack.attack_recipes.BAEGarg2019.build(model_wrapper)
                elif attack_name == "pwws":
                    attack = textattack.attack_recipes.PWWSRen2019.build(model_wrapper)
                elif attack_name == "pso":
                    attack = textattack.attack_recipes.PSOZang2020.build(model_wrapper)

                attacker = textattack.Attacker(attack, test_dataset, attack_args)
                results = attacker.attack_dataset()

                (
                    attack_success_rate,
                    avg_num_queries,
                    avg_pct_perturbed,
                ) = calc_attack_stats(results)
                
                # Calculate accuracy under attack
                original_acc_percentage = original_accuracies[path] * 100  # Convert to percentage
                accuracy_under_attack = calculate_accuracy_under_attack(results, original_dataset_size, indices_to_test, original_acc_percentage)
                
                # Create a unique key for logging this specific word budget configuration
                if word_budget is not None:
                    attack_key = f"{attack_name}_words{word_budget}"
                    constraint_info = f"max {word_budget} words, {args.query_budget} queries"
                    print(f"{attack_name} ({constraint_info}): {round(attack_success_rate, 1)}% success | {avg_num_queries} avg queries | {avg_pct_perturbed}% avg words changed | {accuracy_under_attack}% accuracy under attack")
                else:
                    attack_key = f"{attack_name}_pct"
                    constraint_info = f"10% words max, {args.query_budget} queries"
                    print(f"{attack_name} ({constraint_info}): {round(attack_success_rate, 1)}% success | {avg_num_queries} avg queries | {avg_pct_perturbed}% avg words changed | {accuracy_under_attack}% accuracy under attack")
                
                logs[f"checkpoint-epoch-{args.epoch}"][attack_key] = {
                    "attack_success_rate": attack_success_rate,
                    "avg_num_queries": avg_num_queries,
                    "avg_pct_perturbed": avg_pct_perturbed,
                    "accuracy_under_attack": accuracy_under_attack,
                    "word_budget": word_budget,
                    "query_budget": args.query_budget,
                    "constraint_type": "word_count" if word_budget is not None else "percentage",
                }
                
                # Write to summary CSV
                summary_data = {
                    "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
                    "dataset": args.dataset,
                    "model_name": model_name,
                    "model_checkpoint": model_path,
                    "random_seed": args.seed,
                    "query_budget": args.query_budget,
                    "max_words_changed": word_budget_str,
                    "num_samples": args.num_samples,
                    "attack_name": attack_name,
                    "original_accuracy": original_accuracies.get(path, 1.0) * 100,  # Convert to percentage
                    "accuracy_under_attack": accuracy_under_attack,
                    "attack_success_rate": attack_success_rate,
                    "avg_num_queries": avg_num_queries,
                    "avg_pct_perturbed": avg_pct_perturbed,
                }
                write_summary_csv(results_dir, summary_data)

        # Save robustness logs in A2T_results directory instead of model directory
        robustness_log_path = os.path.join(results_dir, f"{args.dataset}_{model_name}_epoch{args.epoch}_samples{args.num_samples}_robustness_logs.json")
        with open(robustness_log_path, "w") as f:
            json.dump(logs, f)
        print(f"📄 Robustness logs saved to: {robustness_log_path}")


def eval_accuracy(args):
    print("Evaluating accuracy")
    
    # Check if accuracy evaluation should be skipped for any checkpoint
    checkpoints_to_evaluate = []
    for path in args.checkpoint_paths:
        if check_accuracy_exists(path):
            print(f"⏭️  Skipping accuracy evaluation for {path} (results already exist)")
            continue
        checkpoints_to_evaluate.append(path)
    
    if not checkpoints_to_evaluate:
        print("✅ All checkpoints already have accuracy evaluations. Skipping accuracy evaluation.")
        return
    
    print(f"📊 Evaluating accuracy for {len(checkpoints_to_evaluate)} checkpoint(s)")
    
    device = get_device()
    
    # Handle custom datasets
    if args.dataset.startswith("custom_"):
        # For custom datasets, evaluate on the same dataset
        first_checkpoint = getattr(args, 'checkpoint_paths', [None])[0] if hasattr(args, 'checkpoint_paths') else None
        dataset_result = load_dataset(args.dataset, seed=args.seed, num_samples=args.num_samples, checkpoint_path=first_checkpoint)
        dataset, (dataset_columns, label_names, num_classes) = dataset_result  # Unpack for custom datasets
        eval_datasets = [
            (args.dataset, dataset)
        ]
        # Store dataset columns for custom datasets
        custom_dataset_columns = {args.dataset: dataset_columns}
    else:
        # For predefined datasets, use eval_datasets from config
        if args.dataset not in DATASET_CONFIGS:
            raise ValueError(f"Unknown dataset {args.dataset}")
        dataset_config = DATASET_CONFIGS[args.dataset]
        test_datasets = dataset_config["eval_datasets"]
        
        eval_datasets = [
            (test_datasets[key], load_dataset(test_datasets[key], seed=args.seed)) for key in test_datasets if key == 'test'
        ]
        custom_dataset_columns = None

    for path in checkpoints_to_evaluate:
        logs = {}
        model_save_path = path
        
        # Load generic sequence classifier with optimizations
        # Pass num_classes for custom datasets to handle classifier head size mismatch
        model_num_classes = num_classes if args.dataset.startswith("custom_") else None
        model, is_large = load_model_with_optimizations(model_save_path, num_classes=model_num_classes)
        tokenizer = get_tokenizer_for_model(model_save_path)

        # Resize model embeddings if tokenizer was modified
        if len(tokenizer) != model.config.vocab_size:
            print(f"Resizing model embeddings from {model.config.vocab_size} to {len(tokenizer)}")
            model.resize_token_embeddings(len(tokenizer))

        device_count = 1
        if device.type == "cuda":
            device_count = torch.cuda.device_count()
            if device_count > 1:
                model = torch.nn.DataParallel(model)
        
        model.eval()
        model.to(device)

        if isinstance(model, torch.nn.DataParallel):
            eval_batch_size = 128 * device_count
        else:
            # Use smaller batch size for large models
            if is_large:
                eval_batch_size = BATCH_SIZE_LARGE_MODEL
                print(f"🔧 Using reduced batch size for large model: {eval_batch_size}")
            else:
                eval_batch_size = 128

        logs[f"checkpoint-epoch-{args.epoch}"] = {}
        print(f"====== {model_save_path} =====")
        print(f"before for loop, eval_datasets has {len(eval_datasets)}")

        for dataset_name, dataset in tqdm.tqdm(eval_datasets):
            print(f"Evaluating {dataset_name} dataset")
            
            # Get input columns for custom vs predefined datasets
            if custom_dataset_columns and dataset_name in custom_dataset_columns:
                input_columns = custom_dataset_columns[dataset_name][0]
            else:
                input_columns = DATASET_CONFIGS[dataset_name]["dataset_columns"][0]
                
            collate_func = functools.partial(collate_fn, input_columns)
            dataloader = torch.utils.data.DataLoader(
                dataset, batch_size=eval_batch_size, collate_fn=collate_func
            )

            preds_list = []
            labels_list = []

            with torch.no_grad():
                for batch in tqdm.tqdm(dataloader):
                    input_texts, labels = batch
                    input_ids = tokenizer(
                        input_texts,
                        padding="max_length",
                        return_tensors="pt",
                        truncation=True,
                    )
                    for key in input_ids:
                        if isinstance(input_ids[key], torch.Tensor):
                            input_ids[key] = input_ids[key].to(device) 
                    
                    outputs = model(**input_ids)
                    logits = outputs.logits

                    preds = logits.argmax(dim=-1).detach().cpu()
                    preds_list.append(preds)
                    labels_list.append(labels)

            preds = torch.cat(preds_list)
            labels = torch.cat(labels_list)

            compare = preds == labels
            num_correct = compare.sum().item()
            accuracy = round(num_correct / len(labels), 4)
            correct = torch.nonzero(compare, as_tuple=True)[0].tolist()

            logs[f"checkpoint-epoch-{args.epoch}"][dataset_name] = {
                "accuracy": accuracy,
                "correct_indices": correct,
            }

            print(f"{dataset_name}: {accuracy}")

        if args.save_log:
            # Save accuracy logs in A2T results directory instead of model directory
            a2t_results_dir = create_a2t_results_dir(model_save_path)
            accuracy_log_path = os.path.join(a2t_results_dir, "accuracy_eval_logs.json")
            
            with open(accuracy_log_path, "w") as f:
                json.dump(logs, f)
            
            print(f"📄 Accuracy logs saved to: {accuracy_log_path}")
            
            # Also save a copy in the model directory for backward compatibility (optional)
            # Uncomment the lines below if you want to keep both locations:
            # model_log_path = os.path.join(model_save_path, "accuracy_eval_logs.json")
            # with open(model_log_path, "w") as f:
            #     json.dump(logs, f)
            # print(f"📄 Accuracy logs also saved to: {model_log_path}")


def main(args):
    for path in args.checkpoint_paths:
        if not os.path.exists(path):
            raise FileNotFoundError(f"Checkpoint path {path} not found.")
    
    # Auto-detect dataset if not provided
    if args.dataset is None:
        print("🔍 Dataset not specified, attempting auto-detection from experiment_config.json...")
        
        # Try to detect from the first checkpoint path
        exp_config_path = os.path.dirname(os.path.normpath(args.checkpoint_paths[0]))
        detected_dataset, detected_num_classes = detect_dataset_from_checkpoint(exp_config_path)

        if detected_dataset:
            args.dataset = detected_dataset
            print(f"✅ Auto-detected dataset: {args.dataset}")
            if detected_num_classes:
                print(f"   📊 Classes: {detected_num_classes}")
        else:
            raise ValueError(
                f"❌ Could not auto-detect dataset from {exp_config_path}/experiment_config.json. "
                "Please provide --dataset argument manually."
            )
    else:
        print(f"📊 Using manually specified dataset: {args.dataset}")
        
    # Add device argument handling
    if hasattr(args, 'device'):
        if args.device == "auto":
            device = get_device()
        else:
            device = torch.device(args.device)
        print(f"Using device: {device}")
    
    if args.accuracy:
        eval_accuracy(args)

    if args.robustness:
        eval_robustness(args)

    # if args.interpretability:
    #     evaluate_interpretability(args)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--dataset",
        type=str,
        required=False,
        default=None,
        help="Name of dataset. Use 'custom_<dataset_name>' for custom datasets (e.g., 'custom_imdb'). If not provided, will auto-detect from experiment_config.json in checkpoint directory.",
    )
    parser.add_argument(
        "--checkpoint-paths",
        type=str,
        nargs="*",
        default=None,
        help="Path of model checkpoint",
    )
    parser.add_argument(
        "--epoch", type=int, default=4, help="Epoch of model to evaluate."
    )
    parser.add_argument(
        "--save-log", action="store_true", help="Save evaluation result as log."
    )
    parser.add_argument("--accuracy", action="store_true", help="Evaluate accuracy.")
    parser.add_argument(
        "--robustness", action="store_true", help="Evaluate robustness."
    )
    attack_choices = ["a2t", "at2_mlm", "textfooler", "bae", "pwws", "pso"]
    parser.add_argument(
        "--attacks",
        type=str,
        nargs="*",
        default=None,
        help=f"Attacks to use to measure robustness. Choices are {attack_choices}.",
    )
    # The interpretability part is removed to shorten the file used here. 
    # parser.add_argument(
    #     "--interpretability",
    #     action="store_true",
    #     help="Evaluate interpretability using AOPC metric.",
    # )
    parser.add_argument(
        "--device",
        type=str,
        default="auto",
        choices=["auto", "cuda", "mps", "cpu"],
        help="Device to use for computation. 'auto' selects the best available."
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=42,
        help="Random seed for dataset loading and reproducibility."
    )
    parser.add_argument(
        "--query-budget",
        type=int,
        default=100,
        help="Maximum number of queries (forward passes) the attack can make to the target model. Higher values allow more thorough search but take longer (default: 100)."
    )
    # For backward compatibility, also accept the old parameter name
    # parser.add_argument(
    #     "--attack-budget",
    #     type=int,
    #     dest="query_budget",
    #     help=argparse.SUPPRESS  # Hide from help text since it's deprecated
    # )
    parser.add_argument(
        "--max-words-changed",
        type=int,
        nargs="*",
        default=None,
        help="Maximum absolute number of words that can be changed in A2T attack. Can specify multiple values for comparison (e.g., --max-words-changed 1 2 3 4 5). If not specified, uses percentage-based limit (10%%). This is a content constraint, separate from query budget."
    )
    parser.add_argument(
        "--num-samples",
        type=int,
        default=10,
        help="Number of samples to use for robustness evaluation (default: 10)."
    )

    args = parser.parse_args()
    main(args)
