import os
import argparse

import numpy as np
from omegaconf import DictConfig, OmegaConf
from collections import defaultdict
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm
import nltk
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset, SequentialSampler
from torch.amp import autocast
from transformers import (
    AutoModelForSequenceClassification,
    AutoConfig,
)
import logging


import criteria
from attack_classification import (
    USE,
    pick_most_similar_words_batch,
    NLI_infer_BERT,
)
from nlp_training.exp_data import get_exp_data
from train_classifier import Model
from utils.model_utils import ModelSummary
from nlp_training.exp_data import get_exp_data_hf
from nlp_training.training_utils import setup_tokenizer
from pathlib import Path

from nlp_training.seq_classifier import GenericSequenceClassifier

class SeqClassifierInfer(nn.Module):
    """
    Generic wrapper for HuggingFace sequence-classification models.
    Supports any model architecture with a sequence-classification head.
    """
    def __init__(
        self,
        model_name_or_path: str,
        device: str,
        num_classes: int = None,
        max_seq_length: int = 128,
        batch_size: int = 32,
        use_amp: bool = False,
        dtype: torch.dtype = torch.float32,
    ):
        super().__init__()
        self.device = device
        self.use_amp = use_amp
        self.dtype = dtype
        self.batch_size = batch_size
        self.max_seq_length = max_seq_length
        self.num_labels = num_classes

        # load config to get or set num_labels
        # config = AutoConfig.from_pretrained(model_name_or_path)
        # if num_classes is not None:
        #     config.num_labels = num_classes
        # elif not hasattr(config, 'num_labels'):
        #     raise ValueError("num_labels must be provided if the config has none.")

        # setup tokenizer (adds missing special tokens)
        self.tokenizer, tokens_added = setup_tokenizer(model_name_or_path)

        # load model and resize embeddings if tokenizer grew
        # self.model = AutoModelForSequenceClassification.from_pretrained(
        #     model_name_or_path,
        #     # config=config,
        #     num_labels=num_classes,
        # ).to(self.device)

        self.model = GenericSequenceClassifier.from_pretrained(model_name_or_path, num_labels=self.num_labels).to(self.device)

        # print(f"model config = {self.model.config.to_dict()}")
        # assert False, 'breakpoint'

        if tokens_added:
            logging.debug(f"Tokenizer resized with {tokens_added} new tokens.")
            self.model.resize_token_embeddings(len(self.tokenizer))


    def text_pred(self, texts: list[list[str]], batch_size: int = None) -> torch.Tensor:
        """
        Run inference on a batch of texts. Returns softmax probabilities.
        """
        bs = batch_size or self.batch_size
        encodings = self.tokenizer(
            [" ".join(t) for t in texts],
            padding='max_length',
            truncation=True,
            max_length=self.max_seq_length,
            return_tensors='pt',
        )

        dataset = TensorDataset(
            encodings['input_ids'],
            encodings['attention_mask'],
            encodings.get('token_type_ids', torch.zeros_like(encodings['input_ids'])),
        )
        dataloader = DataLoader(
            dataset,
            sampler=SequentialSampler(dataset),
            batch_size=bs
        )
        self.model.eval()
        all_probs = []

        for batch in dataloader:
            input_ids, attention_mask, token_type_ids = [b.to(self.device) for b in batch]
                    # Prepare the arguments to pass to the model
            forward_args = {
                'input_ids': input_ids,
                'attention_mask': attention_mask
            }
            # print(f"model_type = {self.model.config.model_type}")
            # assert False, 'breakpoint'
            # Models that use token_type_ids (BERT-like models)
            if self.model.config.model_type in ['bert', 'roberta', 'generic_sequence_classifier']:
                forward_args['token_type_ids'] = token_type_ids
            # Models that don't use token_type_ids (decoder-only and DistilBERT)
            elif (self.model.config.model_type in ['distilbert', 'mistral', 'gpt2'] or 
                  self.model.config.model_type.startswith(('gemma', 'llama'))):
                pass  # These models don't use/need token_type_ids
            else:
                raise ValueError(f"Unsupported model type: {self.model.config.model_type}")

            with torch.no_grad(), autocast(
                device_type=self.device, dtype=self.dtype, enabled=self.use_amp
            ):
                logits = self.model(**forward_args).logits
                probs = nn.functional.softmax(logits, dim=-1)
                all_probs.append(probs)


        return torch.cat(all_probs, dim=0)  


def attack_with_trajectory(
    text_ls,
    true_label,
    predictor,
    stop_words_set,
    word2idx,
    idx2word,
    cos_sim,
    device,
    oov_str,
    max_budget=20,  # Maximum budget to compute
    sim_predictor=None,
    import_score_threshold=-1.0,
    sim_score_threshold=0.5,
    sim_score_window=15,
    synonym_num=50,
    batch_size=32,
):
    """
    Modified attack function that tracks the trajectory of changes
    Returns results for ALL budget values from 1 to max_budget in one pass
    """

    # Initial setup (same as original)
    orig_probs = predictor([text_ls]).squeeze()
    orig_label = torch.argmax(orig_probs)
    orig_prob = orig_probs.max()

    if true_label != orig_label:
        # Return empty trajectory if original prediction is wrong
        return create_empty_trajectory(max_budget, orig_label, orig_label, 0)

    # Same importance scoring and synonym finding as original
    len_text = len(text_ls)
    if len_text < sim_score_window:
        sim_score_threshold = 0.1  # shut down the similarity thresholding function
    half_sim_score_window = (sim_score_window - 1) // 2
    num_queries = 1
    pos_ls = criteria.get_pos(text_ls)
    # Get importance scores (same as original)

    leave_1_texts = [
        text_ls[:ii] + [oov_str] + text_ls[min(ii + 1, len_text) :]
        for ii in range(len_text)
    ]
    leave_1_probs = predictor(leave_1_texts, batch_size=batch_size)
    num_queries += len(leave_1_texts)
    leave_1_probs_argmax = torch.argmax(leave_1_probs, dim=-1)

    import_scores = (
        (
            orig_prob
            - leave_1_probs[:, orig_label]
            + (leave_1_probs_argmax != orig_label).float()
            * (
                leave_1_probs.max(dim=-1)[0]
                - torch.index_select(orig_probs, 0, leave_1_probs_argmax)
            )
        )
        .data.cpu()
        .numpy()
    )

    # Get words to perturb and their synonyms (same as original)
    words_perturb = []
    for idx, score in sorted(
        enumerate(import_scores), key=lambda x: x[1], reverse=True
    ):
        try:
            if score > import_score_threshold and text_ls[idx] not in stop_words_set:
                words_perturb.append((idx, text_ls[idx]))
        except:
            print(idx, len(text_ls), import_scores.shape, text_ls, len(leave_1_texts))

    # Find synonyms
    words_perturb_idx = [
        word2idx[word] for idx, word in words_perturb if word in word2idx
    ]
    synonym_words, _ = pick_most_similar_words_batch(
        words_perturb_idx, cos_sim, idx2word, synonym_num, 0.5
    )

    synonyms_all = []
    for idx, word in words_perturb:
        if word in word2idx:
            synonyms = synonym_words.pop(0)
            if synonyms:
                synonyms_all.append((idx, synonyms))

    # MAIN OPTIMIZATION: Track trajectory of changes
    trajectory = {}  # budget -> (text, num_changed, final_label, queries_at_this_point)

    # Initialize trajectory
    for budget in range(1, max_budget + 1):
        trajectory[budget] = {
            "text": text_ls[:],
            "num_changed": 0,
            "final_label": orig_label,
            "num_queries": num_queries,
            "success": False,
        }

    # Single pass through the attack process
    text_prime = text_ls[:]
    text_cache = text_prime[:]
    total_num_changed = 0

    for idx, synonyms in synonyms_all:
        if total_num_changed >= max_budget:
            break

        # Same candidate generation as original
        new_texts = [
            text_prime[:idx] + [synonym] + text_prime[min(idx + 1, len_text) :]
            for synonym in synonyms
        ]
        new_probs = predictor(new_texts, batch_size=batch_size)
        num_queries += len(new_texts)

        # Same semantic similarity computation as original
        if idx >= half_sim_score_window and len_text - idx - 1 >= half_sim_score_window:
            text_range_min = idx - half_sim_score_window
            text_range_max = idx + half_sim_score_window + 1
        elif (
            idx < half_sim_score_window and len_text - idx - 1 >= half_sim_score_window
        ):
            text_range_min = 0
            text_range_max = sim_score_window
        elif (
            idx >= half_sim_score_window and len_text - idx - 1 < half_sim_score_window
        ):
            text_range_min = len_text - sim_score_window
            text_range_max = len_text
        else:
            text_range_min = 0
            text_range_max = len_text

        a = [" ".join(text_cache[text_range_min:text_range_max])] * len(new_texts)
        b = list(map(lambda x: " ".join(x[text_range_min:text_range_max]), new_texts))
        semantic_sims = sim_predictor.semantic_sim(a, b)

        if len(new_probs.shape) < 2:
            new_probs = new_probs.unsqueeze(0)
        new_probs_mask = (
            (orig_label != torch.argmax(new_probs, dim=-1)).data.cpu().numpy()
        )
        new_probs_mask *= semantic_sims >= sim_score_threshold

        # POS filtering (same as original)
        synonyms_pos_ls = [
            (
                criteria.get_pos(new_text[max(idx - 4, 0) : idx + 5])[min(4, idx)]
                if len(new_text) > 10
                else criteria.get_pos(new_text)[idx]
            )
            for new_text in new_texts
        ]
        pos_mask = np.array(criteria.pos_filter(pos_ls[idx], synonyms_pos_ls))
        new_probs_mask *= pos_mask

        # Determine if we make a change
        change_made = False
        if np.sum(new_probs_mask) > 0:
            text_prime[idx] = synonyms[(new_probs_mask * semantic_sims).argmax()]
            change_made = True
        else:
            # print(f"new_probs_device: {new_probs.device}, orig_label: {orig_label}")
            # assert False, "breakpoint"
            new_label_probs = new_probs[:, orig_label] + torch.from_numpy(
                (semantic_sims < sim_score_threshold) + (1 - pos_mask).astype(float)
            ).float().to(torch.device(new_probs.device))
            new_label_prob_min, new_label_prob_argmin = torch.min(
                new_label_probs, dim=-1
            )
            if new_label_prob_min < orig_prob:
                text_prime[idx] = synonyms[new_label_prob_argmin]
                change_made = True

        if change_made:
            total_num_changed += 1
            text_cache = text_prime[:]

            # Check if attack succeeded (flipped prediction)
            current_pred = torch.argmax(predictor([text_prime]))
            num_queries += 1
            attack_succeeded = current_pred != orig_label

            # Update trajectory for all budgets >= current changes
            for budget in range(total_num_changed, max_budget + 1):
                trajectory[budget]["text"] = text_prime[:]
                trajectory[budget]["num_changed"] = total_num_changed
                trajectory[budget]["final_label"] = current_pred
                trajectory[budget]["num_queries"] = num_queries
                trajectory[budget]["success"] = attack_succeeded

            # CRITICAL: Stop if attack succeeded (matches original behavior)
            if attack_succeeded:
                break

    return trajectory


def create_empty_trajectory(max_budget, orig_label, final_label, num_queries):
    """Create empty trajectory for failed cases"""
    trajectory = {}
    for budget in range(1, max_budget + 1):
        trajectory[budget] = {
            "text": [],
            "num_changed": 0,
            "final_label": final_label,
            "num_queries": num_queries,
            "success": False,
        }
    return trajectory


def run_optimized_budget_analysis(
    data,
    predictor,
    oov_str,
    max_budget=20,
    stop_words_set=None,
    word2idx=None,
    idx2word=None,
    cos_sim=None,
    sim_predictor=None,
    device="cpu",
    **attack_kwargs,
):
    """
    Run budget analysis with single-pass optimization
    """

    # Initialize results storage
    budget_results = defaultdict(
        lambda: {
            "orig_failures": 0,
            "adv_failures": 0,
            "successful_attacks": 0,
            "changed_rates": [],
            "query_counts": [],
            "total_samples": 0,
        }
    )

    print(f"Running optimized budget analysis for budgets 1-{max_budget}...")
    all_preds = []
    all_labels = []
    
    for idx, (text, true_label) in tqdm(enumerate(data), total=len(data)):
        # Check original prediction
        orig_pred = torch.argmax(predictor([text]))
        orig_correct = true_label == orig_pred

        all_preds.append(int(orig_pred))
        all_labels.append(int(true_label))

        # Single call that computes trajectory for ALL budgets
        trajectory = attack_with_trajectory(
            text,
            true_label,
            predictor,
            stop_words_set,
            word2idx,
            idx2word,
            cos_sim,
            device,
            oov_str=oov_str,
            max_budget=max_budget,
            sim_predictor=sim_predictor,
            **attack_kwargs,
        )

        # Extract results for each budget from the trajectory
        for budget in range(1, max_budget + 1):
            result = trajectory[budget]
            results = budget_results[budget]

            results["total_samples"] += 1
            results["query_counts"].append(result["num_queries"])

            # Count original failures
            if not orig_correct:
                results["orig_failures"] += 1

            # For adversarial accuracy, we need to check if the model's prediction
            # on the adversarial example is correct
            if (
                orig_correct
            ):  # Only consider samples where original prediction was correct
                adv_pred = result["final_label"]
                adv_correct = true_label == adv_pred

                if not adv_correct:
                    results["adv_failures"] += 1

                # Attack success means we changed a correct prediction to incorrect
                if result["success"]:
                    results["successful_attacks"] += 1
                    if len(text) > 0:
                        results["changed_rates"].append(
                            result["num_changed"] / len(text)
                        )
            else:
                # If original prediction was wrong, adversarial example is also considered wrong
                results["adv_failures"] += 1

    print("Sanity check accuracy:", 
        sum(p==l for p,l in zip(all_preds, all_labels)) / len(all_labels))

    # Calculate final metrics
    final_results = {}

    for budget in range(1, max_budget + 1):
        results = budget_results[budget]
        total = results["total_samples"]

        final_results[budget] = {
            "budget": budget,
            "orig_accuracy": (1 - results["orig_failures"] / total) * 100,
            "adv_accuracy": (1 - results["adv_failures"] / total) * 100,
            "attack_success_rate": (results["successful_attacks"] / total) * 100,
            "avg_changed_rate": (
                np.mean(results["changed_rates"]) * 100
                if results["changed_rates"]
                else 0
            ),
            "avg_queries": np.mean(results["query_counts"]),
            "num_successful_attacks": results["successful_attacks"],
            "total_samples": total,
        }

    return final_results


def plot_budget_analysis(results, save_path="budget_analysis.png"):
    """
    Create comprehensive plots showing how metrics change with budget
    """
    budgets = sorted(results.keys())

    # Extract metrics
    orig_acc = [results[b]["orig_accuracy"] for b in budgets]
    adv_acc = [results[b]["adv_accuracy"] for b in budgets]
    attack_success = [results[b]["attack_success_rate"] for b in budgets]
    avg_changed = [results[b]["avg_changed_rate"] for b in budgets]
    avg_queries = [results[b]["avg_queries"] for b in budgets]

    # Create subplots
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    fig.suptitle("Attack Performance vs Budget Constraint", fontsize=16)

    # Plot 1: Attack Success Rate
    axes[0, 0].plot(budgets, attack_success, "b-o", linewidth=2, markersize=8)
    axes[0, 0].set_xlabel("Budget (Max Changes)")
    axes[0, 0].set_ylabel("Attack Success Rate (%)")
    axes[0, 0].set_title("Attack Success Rate vs Budget")
    axes[0, 0].set_ylim(0, 100)
    axes[0, 0].grid(True, alpha=0.3)

    # Plot 2: Accuracy Comparison
    axes[0, 1].plot(budgets, orig_acc, "g-o", label="Original Accuracy", linewidth=2)
    axes[0, 1].plot(budgets, adv_acc, "r-o", label="Adversarial Accuracy", linewidth=2)
    axes[0, 1].set_xlabel("Budget (Max Changes)")
    axes[0, 1].set_ylabel("Accuracy (%)")
    axes[0, 1].set_title("Model Accuracy vs Budget")
    axes[0, 1].legend()
    axes[0, 1].set_ylim(0, 100)
    axes[0, 1].grid(True, alpha=0.3)

    # Plot 3: Average Changed Rate
    axes[1, 0].plot(budgets, avg_changed, "m-o", linewidth=2, markersize=8)
    axes[1, 0].set_xlabel("Budget (Max Changes)")
    axes[1, 0].set_ylabel("Avg Changed Rate (%)")
    axes[1, 0].set_title("Average Word Change Rate vs Budget")
    axes[1, 0].grid(True, alpha=0.3)

    # Plot 4: Query Efficiency
    axes[1, 1].plot(budgets, avg_queries, "c-o", linewidth=2, markersize=8)
    axes[1, 1].set_xlabel("Budget (Max Changes)")
    axes[1, 1].set_ylabel("Average Queries")
    axes[1, 1].set_title("Query Efficiency vs Budget")
    axes[1, 1].grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches="tight")
    # plt.show()

    return fig


def parse_args():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument(
        "--dataset_name", type=str, required=True, help="Which dataset to attack."
    )
    # parser.add_argument(
    #     "--nclasses", type=int, default=2, help="How many classes for classification."
    # )
    parser.add_argument(
        "--attack_sample_size",
        type=int,
        default=None,
        help="How many samples to attack. If None, use all samples.",
    )
    parser.add_argument(
        "--max_attack_changes",
        type=int,
        default=None,
        help="Maximum number of changes allowed in the attack. If None, use full attack until no more synonyms are available.",
    )
    parser.add_argument(
        "--target_model",
        type=str,
        required=True,
        choices=["wordLSTM", "bert", "wordCNN", "seq_classifier"],
        help="Target models for text classification: fasttext, charcnn, word level lstm "
        "For NLI: InferSent, ESIM, bert-base-uncased",
    )
    parser.add_argument(
        "--target_model_path",
        type=str,
        required=True,
        help="pre-trained target model path",
    )
    parser.add_argument(
        "--word_embeddings_path",
        type=str,
        default="",
        help="path to the word embeddings for the target model",
    )
    parser.add_argument(
        "--counter_fitting_embeddings_path",
        type=str,
        required=True,
        help="path to the counter-fitting embeddings we used to find synonyms",
    )
    parser.add_argument(
        "--counter_fitting_cos_sim_path",
        type=str,
        default="",
        help="pre-compute the cosine similarity scores based on the counter-fitting embeddings",
    )
    parser.add_argument(
        "--USE_cache_path",
        type=str,
        required=True,
        help="Path to the USE encoder cache.",
    )
    # parser.add_argument(
    #     "--output_dir",
    #     type=str,
    #     default="adv_results",
    #     help="The output directory where the attack results will be written.",
    # )

    ## Model hyperparameters
    parser.add_argument(
        "--sim_score_window",
        default=15,
        type=int,
        help="Text length or token number to compute the semantic similarity score",
    )
    parser.add_argument(
        "--import_score_threshold",
        default=-1.0,
        type=float,
        help="Required mininum importance score.",
    )
    parser.add_argument(
        "--sim_score_threshold",
        default=0.7,
        type=float,
        help="Required minimum semantic similarity score.",
    )
    parser.add_argument(
        "--synonym_num", default=50, type=int, help="Number of synonyms to extract"
    )
    parser.add_argument(
        "--batch_size", default=32, type=int, help="Batch size to get prediction"
    )
    # parser.add_argument(
    #     "--data_size", default=None, type=int, help="Data size to create adversaries"
    # )
    parser.add_argument(
        "--perturb_ratio",
        default=0.0,
        type=float,
        help="Whether use random perturbation for ablation study",
    )
    parser.add_argument(
        "--max_seq_length",
        default=256,
        type=int,
        help="max sequence length for BERT target model",
    )
    parser.add_argument(
        "--device",
        type=str,
        required=True,
        choices=["cpu", "cuda", "mps"],
        help="Device to use for computation: cpu, cuda, mps",
    )
    parser.add_argument(
        "--seed", type=int, default=42, help="Random seed for reproducibility"
    )
    parser.add_argument(
        "--use_amp",
        action="store_true",
        help="Use automatic mixed precision for training",
    )

    args = parser.parse_args()

    return args


def main():
    args = parse_args()

    output_dir = Path(args.target_model_path).parent
    adv_result_dir = f"adv_attack_results/TextFooler/max_attack_changes_{args.max_attack_changes}_attack_sample_size_{args.attack_sample_size}_seed_{args.seed}"
    output_dir = output_dir / adv_result_dir
    print(f"Output directory: {output_dir}")
    # assert False, 'breakpoint'

    if os.path.exists(output_dir) and os.listdir(output_dir):
        print(
            "Output directory ({}) already exists and is not empty.".format(
                output_dir
            )
        )
    else:
        os.makedirs(output_dir, exist_ok=True)


    # Load and prepare dataset
    logging.info(f"Loading {args.dataset_name} dataset...")
    
    # Load dataset config
    dataset_config = OmegaConf.load(f"conf/dataset/{args.dataset_name}.yaml")

    
    texts, labels, num_classes = get_exp_data(
                                dataset_config=dataset_config,
                                seed=args.seed,
                                num_samples=args.attack_sample_size,
                                )
    
    data = list(zip(texts, labels))

    print(f"data[0]: {data[0]}")

    num_data_samples = len(data)
    print(f"Data import finished! Loading {num_data_samples} samples, containing {num_classes} classes.")


    # construct the model
    print("Building Model...")
    if args.target_model == "wordLSTM":
        model = Model(args.word_embeddings_path, nclasses=num_classes).to(
            torch.device(args.device)
        )
        checkpoint = torch.load(
            args.target_model_path, map_location=torch.device(args.device)
        )
        model.load_state_dict(checkpoint)
    elif args.target_model == "wordCNN":
        model = Model(
            args.word_embeddings_path, nclasses=num_classes, hidden_size=100, cnn=True
        ).to(torch.device(args.device))
        checkpoint = torch.load(
            args.target_model_path, map_location=torch.device(args.device)
        )
        model.load_state_dict(checkpoint)
    elif args.target_model == "bert":
        model = NLI_infer_BERT(
            args.target_model_path,
            nclasses=num_classes,
            max_seq_length=args.max_seq_length,
            device=args.device,
        )
    elif args.target_model == "seq_classifier":
        model = SeqClassifierInfer(
            model_name_or_path=args.target_model_path,
            num_classes=num_classes,
            device=args.device,
            batch_size=args.batch_size,
            use_amp=args.use_amp,
            dtype=torch.float16 if args.use_amp else torch.float32,
            max_seq_length=args.max_seq_length,
        )
    
    predictor = model.text_pred

    oov_str = model.tokenizer.unk_token

    print(f"OOV string: {oov_str}")
    # assert False, "breakpoint"
    print("Model built!")

    ModelSummary.summarize(
        model,
        model_name=args.target_model,
        logger=None,
        verbose=True,
        print_architecture=False,
    )

    # assert False, 'breakpoint'
    # prepare synonym extractor
    # build dictionary via the embedding file
    idx2word = {}
    word2idx = {}

    print("Building vocab...")
    with open(args.counter_fitting_embeddings_path, "r") as ifile:
        for line in ifile:
            word = line.split()[0]
            if word not in idx2word:
                idx2word[len(idx2word)] = word
                word2idx[word] = len(idx2word) - 1

    print("Building cos sim matrix...")
    if args.counter_fitting_cos_sim_path and os.path.isfile(
        args.counter_fitting_cos_sim_path
    ):
        # load pre-computed cosine similarity matrix if provided
        print(
            "Load pre-computed cosine similarity matrix from {}".format(
                args.counter_fitting_cos_sim_path
            )
        )
        cos_sim = np.load(args.counter_fitting_cos_sim_path)
    else:
        # calculate the cosine similarity matrix
        print("Start computing the cosine similarity matrix!")
        embeddings = []
        with open(args.counter_fitting_embeddings_path, "r") as ifile:
            for line in ifile:
                embedding = [float(num) for num in line.strip().split()[1:]]
                embeddings.append(embedding)
        embeddings = np.array(embeddings)
        product = np.dot(embeddings, embeddings.T)
        norm = np.linalg.norm(embeddings, axis=1, keepdims=True)
        cos_sim = product / np.dot(norm, norm.T)
        # Create directory for cosine similarity matrix if it doesn't exist
        if args.counter_fitting_cos_sim_path:
            cos_sim_dir = os.path.dirname(args.counter_fitting_cos_sim_path)
            if cos_sim_dir and not os.path.exists(cos_sim_dir):
                os.makedirs(cos_sim_dir, exist_ok=True)
                print(f"Created directory: {cos_sim_dir}")
                
        np.save(args.counter_fitting_cos_sim_path, cos_sim)
        print(
            "Cosine similarity matrix saved to {}".format(
                args.counter_fitting_cos_sim_path
            )
        )
    print("Cos sim import finished!")

    print(f"building semantic similarity module from {args.USE_cache_path}...")

    # build the semantic similarity module
    use = USE(args.USE_cache_path)
    print("Semantic similarity module built!")

    # start attacking
    # orig_failures = 0.0
    # adv_failures = 0.0
    # changed_rates = []
    # nums_queries = []
    # orig_texts = []
    # adv_texts = []
    # true_labels = []
    # new_labels = []
    # log_file = open(os.path.join(output_dir, "results_log"), "a")

    stop_words_set = criteria.get_stopwords()

    # Ensure the POS-tagger and universal mapping are present
    for pkg in ("averaged_perceptron_tagger_eng", "universal_tagset"):
        try:
            nltk.data.find(f"taggers/{pkg}")
        except LookupError:
            nltk.download(pkg, quiet=True)

    print("Start attacking!")

    # Your existing setup code...

    # Run optimized analysis
    max_budget = args.max_attack_changes
    
    results = run_optimized_budget_analysis(
        data=data,
        predictor=predictor,
        oov_str=oov_str,
        max_budget=max_budget,
        stop_words_set=stop_words_set,
        word2idx=word2idx,
        idx2word=idx2word,
        cos_sim=cos_sim,
        sim_predictor=use,
        device=args.device,
        sim_score_threshold=args.sim_score_threshold,
        import_score_threshold=args.import_score_threshold,
        sim_score_window=args.sim_score_window,
        synonym_num=args.synonym_num,
        batch_size=args.batch_size,
    )

    # Create visualization
    plot_budget_analysis(
        results,
        save_path=os.path.join(
            output_dir, f"optimized_max_budget_{max_budget}_analysis.png"
        ),
    )

    # Save results
    df = pd.DataFrame.from_dict(results, orient="index")
    df['model_name'] = args.target_model
    df['dataset_name'] = args.dataset_name
    df['max_budget'] = args.max_attack_changes
    df['seed'] = args.seed
    df.to_csv(
        os.path.join(output_dir, f"optimized_max_budget_{max_budget}_results.csv")
    )

    print("Optimized budget analysis completed!")
    print(f"Computed results for budgets 1-{max_budget} in a single pass")


# Usage example
if __name__ == "__main__":
    main()
