import argparse
import os

import numpy as np
import dataloader
from contextlib import nullcontext
from train_classifier import Model
import criteria
import random

import tensorflow as tf
import tensorflow_hub as hub
import transformers
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, SequentialSampler, TensorDataset
import nltk, contextlib
from tqdm import tqdm
from torch.amp import autocast

from BERT.tokenization import BertTokenizer
from BERT.modeling import BertForSequenceClassification, BertConfig
from utils.model_utils import ModelSummary
from transformers import AutoTokenizer



from nlp_training.exp_data import get_exp_data

# class USE(object):
#     def __init__(self, cache_path):
#         super(USE, self).__init__()
#         os.environ['TFHUB_CACHE_DIR'] = cache_path
#         module_url = "https://tfhub.dev/google/universal-sentence-encoder-large/3"
#         self.embed = hub.Module(module_url)
#         config = tf.ConfigProto()
#         config.gpu_options.allow_growth = True
#         self.sess = tf.Session(config=config)
#         self.build_graph()
#         self.sess.run([tf.global_variables_initializer(), tf.tables_initializer()])

#     def build_graph(self):
#         self.sts_input1 = tf.placeholder(tf.string, shape=(None))
#         self.sts_input2 = tf.placeholder(tf.string, shape=(None))

#         sts_encode1 = tf.nn.l2_normalize(self.embed(self.sts_input1), axis=1)
#         sts_encode2 = tf.nn.l2_normalize(self.embed(self.sts_input2), axis=1)
#         self.cosine_similarities = tf.reduce_sum(tf.multiply(sts_encode1, sts_encode2), axis=1)
#         clip_cosine_similarities = tf.clip_by_value(self.cosine_similarities, -1.0, 1.0)
#         self.sim_scores = 1.0 - tf.acos(clip_cosine_similarities)

#     def semantic_sim(self, sents1, sents2):
#         scores = self.sess.run(
#             [self.sim_scores],
#             feed_dict={
#                 self.sts_input1: sents1,
#                 self.sts_input2: sents2,
#             })
#         return scores


class USE(object):
    def __init__(self, cache_path):
        os.environ["TFHUB_CACHE_DIR"] = cache_path

        # ➊ Disable eager so we can keep the old placeholder / Session style.
        tf.compat.v1.disable_eager_execution()

        # ➋ Load the TF-2 SavedModel version of USE (v5 works well)
        module_url = "https://tfhub.dev/google/universal-sentence-encoder-large/5"
        with tf.device("/cpu:0"):
            self.embed = hub.KerasLayer(module_url, trainable=False, dtype=tf.string)

        # ➌ Build the same graph as before
        self._build_graph()

        # ➍ Start a Session (TF-2 style: use compat.v1)
        config = tf.compat.v1.ConfigProto()
        config.gpu_options.allow_growth = True
        self.sess = tf.compat.v1.Session(config=config)
        self.sess.run(
            [
                tf.compat.v1.global_variables_initializer(),
                tf.compat.v1.tables_initializer(),
            ]
        )

    def _build_graph(self):
        self.sts_input1 = tf.compat.v1.placeholder(tf.string, shape=(None,))
        self.sts_input2 = tf.compat.v1.placeholder(tf.string, shape=(None,))

        e1 = tf.nn.l2_normalize(self.embed(self.sts_input1), axis=1)
        e2 = tf.nn.l2_normalize(self.embed(self.sts_input2), axis=1)

        cos = tf.reduce_sum(tf.multiply(e1, e2), axis=1)
        cos = tf.clip_by_value(cos, -1.0, 1.0)
        self.sim_scores = 1.0 - tf.acos(cos)

    def semantic_sim(self, s1, s2):
        return self.sess.run(
            self.sim_scores, {self.sts_input1: s1, self.sts_input2: s2}
        )


def pick_most_similar_words_batch(
    src_words, sim_mat, idx2word, ret_count=10, threshold=0.0
):
    """
    embeddings is a matrix with (d, vocab_size)
    """
    sim_order = np.argsort(-sim_mat[src_words, :])[:, 1 : 1 + ret_count]
    sim_words, sim_values = [], []
    for idx, src_word in enumerate(src_words):
        sim_value = sim_mat[src_word][sim_order[idx]]
        mask = sim_value >= threshold
        sim_word, sim_value = sim_order[idx][mask], sim_value[mask]
        sim_word = [idx2word[id] for id in sim_word]
        sim_words.append(sim_word)
        sim_values.append(sim_value)
    return sim_words, sim_values


# class NLI_infer_BERT(nn.Module):
#     def __init__(
#         self,
#         pretrained_dir,
#         nclasses,
#         device,
#         max_seq_length=128,
#         batch_size=32,
#     ):
#         super(NLI_infer_BERT, self).__init__()
#         self.device = device
#         self.model = BertForSequenceClassification.from_pretrained(
#             pretrained_dir, num_labels=nclasses
#         ).to(torch.device(self.device))

#         # construct dataset loader
#         self.dataset = NLIDataset_BERT(
#             pretrained_dir, max_seq_length=max_seq_length, batch_size=batch_size
#         )

#     def text_pred(self, text_data, batch_size=32):
#         # Switch the model to eval mode.
#         self.model.eval()

#         # transform text data into indices and create batches
#         dataloader = self.dataset.transform_text(text_data, batch_size=batch_size)

#         probs_all = []
#         #         for input_ids, input_mask, segment_ids in tqdm(dataloader, desc="Evaluating"):
#         for input_ids, input_mask, segment_ids in dataloader:
#             input_ids = input_ids.to(torch.device(self.device))
#             input_mask = input_mask.to(torch.device(self.device))
#             segment_ids = segment_ids.to(torch.device(self.device))

#             with torch.no_grad():
#                 logits = self.model(input_ids, segment_ids, input_mask)
#                 probs = nn.functional.softmax(logits, dim=-1)
#                 probs_all.append(probs)

#         return torch.cat(probs_all, dim=0)\

class NLI_infer_BERT(nn.Module):
    def __init__(
        self,
        pretrained_dir: str,
        nclasses: int,
        device: str,
        max_seq_length: int = 128,
        batch_size: int = 32,
        use_amp: bool = False,
        dtype: torch.dtype = torch.float32,
    ):
        super(NLI_infer_BERT, self).__init__()
        self.device = device
        self.use_amp = use_amp
        self.dtype = dtype

        # # load model onto device (AMP autocast will handle dtype)

        self.model = transformers.BertForSequenceClassification.from_pretrained(
            pretrained_dir, num_labels=nclasses
        ).to(torch.device(self.device))

        print(f"self.model: {type(self.model)}")
        self.dataset = NLIDataset_BERT(
            pretrained_dir,
            max_seq_length=max_seq_length,
            batch_size=batch_size,
        )

    def text_pred(self, text_data, batch_size: int = None):
        # Switch the model to eval mode.
        self.model.eval()

        # use provided batch_size or default
        bs = batch_size or self.dataset.batch_size
        dataloader = self.dataset.transform_text(text_data, batch_size=bs)
        
        # print(f"using batch_size: {bs}")


        probs_all = []
        for input_ids, input_mask, segment_ids in dataloader:
            # move inputs to device; indices remain integer for embedding
            input_ids = input_ids.to(torch.device(self.device))
            input_mask = input_mask.to(torch.device(self.device))
            segment_ids = segment_ids.to(torch.device(self.device))


            # inference with optional AMP
            with torch.no_grad(), autocast(
                device_type='cuda', dtype=self.dtype, enabled=self.use_amp
            ):
                # logits = self.model(
                #                     input_ids=input_ids, 
                #                     attention_mask=input_mask,
                #                     token_type_ids=segment_ids
                #                 )
                logits = self.model(
                                    input_ids=input_ids, 
                                    attention_mask=input_mask,
                                    token_type_ids=segment_ids
                                ).logits

            # print(f"logits: {logits}")
            probs = nn.functional.softmax(logits, dim=-1)
            probs_all.append(probs)

        # assert False, "breakpoint"
        return torch.cat(probs_all, dim=0)


class InputFeatures(object):
    """A single set of features of data."""

    def __init__(self, input_ids, input_mask, segment_ids):
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.segment_ids = segment_ids


class NLIDataset_BERT(Dataset):
    """
    Dataset class for Natural Language Inference datasets.

    The class can be used to read preprocessed datasets where the premises,
    hypotheses and labels have been transformed to unique integer indices
    (this can be done with the 'preprocess_data' script in the 'scripts'
    folder of this repository).
    """

    def __init__(self, pretrained_dir, max_seq_length=128, batch_size=32):
        """
        Args:
            data: A dictionary containing the preprocessed premises,
                hypotheses and labels of some dataset.
            padding_idx: An integer indicating the index being used for the
                padding token in the preprocessed data. Defaults to 0.
            max_premise_length: An integer indicating the maximum length
                accepted for the sequences in the premises. If set to None,
                the length of the longest premise in 'data' is used.
                Defaults to None.
            max_hypothesis_length: An integer indicating the maximum length
                accepted for the sequences in the hypotheses. If set to None,
                the length of the longest hypothesis in 'data' is used.
                Defaults to None.
        """
        self.tokenizer = BertTokenizer.from_pretrained(
            pretrained_dir, do_lower_case=True
        )
        # self.tokenizer = AutoTokenizer.from_pretrained(pretrained_dir, do_lower_case=True)
        # print(f"tokenizer has vocab size: {len(self.tokenizer.get_vocab())}")
        # assert False, "breakpoint"
        self.max_seq_length = max_seq_length
        self.batch_size = batch_size

    def convert_examples_to_features(self, examples, max_seq_length, tokenizer):
        """Loads a data file into a list of `InputBatch`s."""

        features = []
        for ex_index, text_a in enumerate(examples):
            tokens_a = tokenizer.tokenize(" ".join(text_a))

            # Account for [CLS] and [SEP] with "- 2"
            if len(tokens_a) > max_seq_length - 2:
                tokens_a = tokens_a[: (max_seq_length - 2)]

            tokens = ["[CLS]"] + tokens_a + ["[SEP]"]
            segment_ids = [0] * len(tokens)

            input_ids = tokenizer.convert_tokens_to_ids(tokens)

            # The mask has 1 for real tokens and 0 for padding tokens. Only real
            # tokens are attended to.
            input_mask = [1] * len(input_ids)

            # Zero-pad up to the sequence length.
            padding = [0] * (max_seq_length - len(input_ids))
            input_ids += padding
            input_mask += padding
            segment_ids += padding

            assert len(input_ids) == max_seq_length
            assert len(input_mask) == max_seq_length
            assert len(segment_ids) == max_seq_length

            features.append(
                InputFeatures(
                    input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids
                )
            )
        return features

    def transform_text(self, data, batch_size=32):
        # transform data into seq of embeddings
        eval_features = self.convert_examples_to_features(
            data, self.max_seq_length, self.tokenizer
        )

        all_input_ids = torch.tensor(
            [f.input_ids for f in eval_features], dtype=torch.long
        )
        all_input_mask = torch.tensor(
            [f.input_mask for f in eval_features], dtype=torch.long
        )
        all_segment_ids = torch.tensor(
            [f.segment_ids for f in eval_features], dtype=torch.long
        )
        eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids)

        # Run prediction for full data
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(
            eval_data, sampler=eval_sampler, batch_size=batch_size
        )

        return eval_dataloader


def random_attack(
    text_ls,
    true_label,
    predictor,
    perturb_ratio,
    stop_words_set,
    word2idx,
    idx2word,
    cos_sim,
    device,
    sim_predictor=None,
    import_score_threshold=-1.0,
    sim_score_threshold=0.5,
    sim_score_window=15,
    synonym_num=50,
    batch_size=32,
):
    # first check the prediction of the original text
    orig_probs = predictor([text_ls]).squeeze()
    orig_label = torch.argmax(orig_probs)
    orig_prob = orig_probs.max()



    if true_label != orig_label:
        return "", 0, orig_label, orig_label, 0
    else:
        len_text = len(text_ls)
        if len_text < sim_score_window:
            sim_score_threshold = 0.1  # shut down the similarity thresholding function
        half_sim_score_window = (sim_score_window - 1) // 2
        num_queries = 1

        # get the pos and verb tense info
        pos_ls = criteria.get_pos(text_ls)

        # randomly get perturbed words
        perturb_idxes = random.sample(range(len_text), int(len_text * perturb_ratio))
        words_perturb = [(idx, text_ls[idx]) for idx in perturb_idxes]

        # find synonyms
        words_perturb_idx = [
            word2idx[word] for idx, word in words_perturb if word in word2idx
        ]
        synonym_words, _ = pick_most_similar_words_batch(
            words_perturb_idx, cos_sim, idx2word, synonym_num, 0.5
        )
        synonyms_all = []
        for idx, word in words_perturb:
            if word in word2idx:
                synonyms = synonym_words.pop(0)
                if synonyms:
                    synonyms_all.append((idx, synonyms))

        # start replacing and attacking
        text_prime = text_ls[:]
        text_cache = text_prime[:]
        num_changed = 0
        for idx, synonyms in synonyms_all:
            new_texts = [
                text_prime[:idx] + [synonym] + text_prime[min(idx + 1, len_text) :]
                for synonym in synonyms
            ]
            new_probs = predictor(new_texts, batch_size=batch_size)

            # compute semantic similarity
            if (
                idx >= half_sim_score_window
                and len_text - idx - 1 >= half_sim_score_window
            ):
                text_range_min = idx - half_sim_score_window
                text_range_max = idx + half_sim_score_window + 1
            elif (
                idx < half_sim_score_window
                and len_text - idx - 1 >= half_sim_score_window
            ):
                text_range_min = 0
                text_range_max = sim_score_window
            elif (
                idx >= half_sim_score_window
                and len_text - idx - 1 < half_sim_score_window
            ):
                text_range_min = len_text - sim_score_window
                text_range_max = len_text
            else:
                text_range_min = 0
                text_range_max = len_text
            semantic_sims = sim_predictor.semantic_sim(
                [" ".join(text_cache[text_range_min:text_range_max])] * len(new_texts),
                list(
                    map(lambda x: " ".join(x[text_range_min:text_range_max]), new_texts)
                ),
            )[0]

            num_queries += len(new_texts)
            if len(new_probs.shape) < 2:
                new_probs = new_probs.unsqueeze(0)
            new_probs_mask = (
                (orig_label != torch.argmax(new_probs, dim=-1)).data.cpu().numpy()
            )
            # prevent bad synonyms
            new_probs_mask *= semantic_sims >= sim_score_threshold
            # prevent incompatible pos
            synonyms_pos_ls = [
                (
                    criteria.get_pos(new_text[max(idx - 4, 0) : idx + 5])[min(4, idx)]
                    if len(new_text) > 10
                    else criteria.get_pos(new_text)[idx]
                )
                for new_text in new_texts
            ]
            pos_mask = np.array(criteria.pos_filter(pos_ls[idx], synonyms_pos_ls))
            new_probs_mask *= pos_mask

            if np.sum(new_probs_mask) > 0:
                text_prime[idx] = synonyms[(new_probs_mask * semantic_sims).argmax()]
                num_changed += 1
                break
            else:
                new_label_probs = new_probs[:, orig_label] + torch.from_numpy(
                    (semantic_sims < sim_score_threshold) + (1 - pos_mask).astype(float)
                ).float().to(torch.device(device))
                new_label_prob_min, new_label_prob_argmin = torch.min(
                    new_label_probs, dim=-1
                )
                if new_label_prob_min < orig_prob:
                    text_prime[idx] = synonyms[new_label_prob_argmin]
                    num_changed += 1
            text_cache = text_prime[:]
        return (
            " ".join(text_prime),
            num_changed,
            orig_label,
            torch.argmax(predictor([text_prime])),
            num_queries,
        )



def attack(
    text_ls,
    true_label,
    predictor,
    stop_words_set,
    word2idx,
    idx2word,
    cos_sim,
    device,
    num_changed_budget=None,
    sim_predictor=None,
    import_score_threshold=-1.0,
    sim_score_threshold=0.5,
    sim_score_window=15,
    synonym_num=50,
    batch_size=32,
):
    # first check the prediction of the original text
    orig_probs = predictor([text_ls]).squeeze()
    orig_label = torch.argmax(orig_probs)
    orig_prob = orig_probs.max()

    # print(f"orig_label: {orig_label}, true_label: {true_label}, orig_prob: {orig_prob}")
    # assert False, 'breakpoint'


    if true_label != orig_label:
        return "", 0, orig_label, orig_label, 0
    else:
        len_text = len(text_ls)
        if len_text < sim_score_window:
            sim_score_threshold = 0.1  # shut down the similarity thresholding function
        half_sim_score_window = (sim_score_window - 1) // 2
        num_queries = 1

        # get the pos and verb tense info
        pos_ls = criteria.get_pos(text_ls)

        # oov_str = '<oov>'
        oov_str = '[UNK]'
        # get importance score
        leave_1_texts = [
            text_ls[:ii] + [oov_str] + text_ls[min(ii + 1, len_text) :]
            for ii in range(len_text)
        ]
        leave_1_probs = predictor(leave_1_texts, batch_size=batch_size)
        num_queries += len(leave_1_texts)
        leave_1_probs_argmax = torch.argmax(leave_1_probs, dim=-1)
        # print((f"leave_1_texts: {leave_1_texts}"))
        
        
        # print(f"------ leave_1_probs_argmax: {leave_1_probs_argmax} ------")
        # print(f"------ \n{leave_1_probs}\n ------")

        # compute importance scores for each word in the sentence by replacing it with <oov>
        # and computing the change in probability of the original label and the potetial boost of new labels
        import_scores = (
            (
                orig_prob
                - leave_1_probs[:, orig_label]
                + (leave_1_probs_argmax != orig_label).float()
                * (
                    leave_1_probs.max(dim=-1)[0]
                    - torch.index_select(orig_probs, 0, leave_1_probs_argmax)
                )
            )
            .data.cpu()
            .numpy()
        )

 
        # get words to perturb ranked by importance scorefor word in words_perturb
        words_perturb = []
        for idx, score in sorted(
            enumerate(import_scores), key=lambda x: x[1], reverse=True
        ):
            try:
                if (
                    score > import_score_threshold
                    and text_ls[idx] not in stop_words_set
                ):
                    words_perturb.append((idx, text_ls[idx]))
            except:
                print(
                    idx, len(text_ls), import_scores.shape, text_ls, len(leave_1_texts)
                )


        # find synonyms
        words_perturb_idx = [
            word2idx[word] for idx, word in words_perturb if word in word2idx
        ]
        synonym_words, _ = pick_most_similar_words_batch(
            words_perturb_idx, cos_sim, idx2word, synonym_num, 0.5
        )

        # print(f"text: {text_ls}")
        # print(f"text length: {len(text_ls)}")
        # print(f"importance scores: {import_scores}")
        # print(f"words to perturb: {words_perturb}")
        # print(f"length of words to perturb: {len(words_perturb)}")
        # print(f"words_perturb_idx: {words_perturb_idx}")
        # print(f"length of words_perturb_idx: {len(words_perturb_idx)}")
        # print(f"synonym words: {synonym_words}")
        # print(f"synonym words length: {len(synonym_words)}")
        # print(f"synonym workds idependent length: {[len(x) for x in synonym_words]}")
        # assert False, "breakpoint"


        # a list of all the possible synonyms for each word (selected for perturbation) in the input sentence
        synonyms_all = []
        for idx, word in words_perturb:
            if word in word2idx:
                synonyms = synonym_words.pop(0)
                if synonyms:
                    synonyms_all.append((idx, synonyms))

        # print(f"synonyms_all: {synonyms_all}")
        # print(f"synonyms_all length: {len(synonyms_all)}")
        # print(f"synonyms_all[:3]: {synonyms_all[:3]}")

        # assert False, "breakpoint"

        # start replacing and attacking
        text_prime = text_ls[:]
        text_cache = text_prime[:]
        num_changed = 0

        for idx, synonyms in synonyms_all:
            
            # candidate replaced text for a given word in the input sentence
            new_texts = [
                text_prime[:idx] + [synonym] + text_prime[min(idx + 1, len_text) :]
                for synonym in synonyms
            ]

            # probabilities of the candidate replaced text
            new_probs = predictor(new_texts, batch_size=batch_size)

            # print(f"new_texts is of length: {len(new_texts)}")
            # print(f"new_probs is of shape: {new_probs.shape}")


            # compute semantic similarity

            # selecting local text range for semantic similarity
            if (
                idx >= half_sim_score_window
                and len_text - idx - 1 >= half_sim_score_window
            ):
                text_range_min = idx - half_sim_score_window
                text_range_max = idx + half_sim_score_window + 1
            elif (
                idx < half_sim_score_window
                and len_text - idx - 1 >= half_sim_score_window
            ):
                text_range_min = 0
                text_range_max = sim_score_window
            elif (
                idx >= half_sim_score_window
                and len_text - idx - 1 < half_sim_score_window
            ):
                text_range_min = len_text - sim_score_window
                text_range_max = len_text
            else:
                text_range_min = 0
                text_range_max = len_text

            a = [" ".join(text_cache[text_range_min:text_range_max])] * len(new_texts)
            b = list(
                map(lambda x: " ".join(x[text_range_min:text_range_max]), new_texts)
            )

            # print(f"a is of length: {len(a)}")
            # print(f"b is of length: {len(b)}")
            
            # print(f"a[:3]: {a[:3]}")
            # print(f"b[:3]: {b[:3]}")

            semantic_sims = sim_predictor.semantic_sim(a, b   
)

            # semantic_sims = sim_predictor.semantic_sim(
            #     [" ".join(text_cache[text_range_min:text_range_max])] * len(new_texts),
            #     list(
            #         map(lambda x: " ".join(x[text_range_min:text_range_max]), new_texts)
            #     ),
            # )[0]

            # print(f"semantic_sims is of type: {type(semantic_sims)}")
            # print(f"semantic_sims = {semantic_sims}")
            # assert False, 'breakpoint'

            num_queries += len(new_texts)
            
            if len(new_probs.shape) < 2:
                new_probs = new_probs.unsqueeze(0)
            new_probs_mask = (
                (orig_label != torch.argmax(new_probs, dim=-1)).data.cpu().numpy()
            )

            # print(f"new_probs_mask is of shape: {new_probs_mask.shape}")
            # print(f"new_probs_mask: {new_probs_mask}")
            # assert False, "breakpoint"
            # print(f"sim_threshold: {sim_score_threshold}")
            # print(f"sim_threshold_mask: {semantic_sims >= sim_score_threshold}")
            # prevent bad synonyms
            new_probs_mask *= semantic_sims >= sim_score_threshold
            # print(f"new_probs_mask after semantic sim thresholding: {new_probs_mask}")

            # prevent incompatible pos
            synonyms_pos_ls = [
                (
                    criteria.get_pos(new_text[max(idx - 4, 0) : idx + 5])[min(4, idx)]
                    if len(new_text) > 10
                    else criteria.get_pos(new_text)[idx]
                )
                for new_text in new_texts
            ]
            pos_mask = np.array(criteria.pos_filter(pos_ls[idx], synonyms_pos_ls))

            # print(f"pos_mask: {pos_mask}")
            new_probs_mask *= pos_mask

            # print(f"new_probs_mask after pos filtering: {new_probs_mask}")


            # assert False, "breakpoint"
            
            if np.sum(new_probs_mask) > 0:
                text_prime[idx] = synonyms[(new_probs_mask * semantic_sims).argmax()]
                num_changed += 1
                break
            else:
                new_label_probs = new_probs[:, orig_label] + torch.from_numpy(
                    (semantic_sims < sim_score_threshold) + (1 - pos_mask).astype(float)
                ).float().to(torch.device(device))
                new_label_prob_min, new_label_prob_argmin = torch.min(
                    new_label_probs, dim=-1
                )
                if new_label_prob_min < orig_prob:
                    text_prime[idx] = synonyms[new_label_prob_argmin]
                    num_changed += 1
            text_cache = text_prime[:]
        
            # TODO: add threshold for attacking budget, e.g., number of changes.
            if num_changed_budget and num_changed >= num_changed_budget:
                break



        return (
            " ".join(text_prime),
            num_changed,
            orig_label,
            torch.argmax(predictor([text_prime])),
            num_queries,
        )


def parse_args():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument(
        "--dataset_path", type=str, required=True, help="Which dataset to attack."
    )
    parser.add_argument(
        "--nclasses", type=int, default=2, help="How many classes for classification."
    )
    parser.add_argument(
        "--target_model",
        type=str,
        required=True,
        choices=["wordLSTM", "bert", "wordCNN"],
        help="Target models for text classification: fasttext, charcnn, word level lstm "
        "For NLI: InferSent, ESIM, bert-base-uncased",
    )
    parser.add_argument(
        "--target_model_path",
        type=str,
        required=True,
        help="pre-trained target model path",
    )
    parser.add_argument(
        "--word_embeddings_path",
        type=str,
        default="",
        help="path to the word embeddings for the target model",
    )
    parser.add_argument(
        "--counter_fitting_embeddings_path",
        type=str,
        required=True,
        help="path to the counter-fitting embeddings we used to find synonyms",
    )
    parser.add_argument(
        "--counter_fitting_cos_sim_path",
        type=str,
        default="",
        help="pre-compute the cosine similarity scores based on the counter-fitting embeddings",
    )
    parser.add_argument(
        "--USE_cache_path",
        type=str,
        required=True,
        help="Path to the USE encoder cache.",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="adv_results",
        help="The output directory where the attack results will be written.",
    )

    ## Model hyperparameters
    parser.add_argument(
        "--sim_score_window",
        default=15,
        type=int,
        help="Text length or token number to compute the semantic similarity score",
    )
    parser.add_argument(
        "--import_score_threshold",
        default=-1.0,
        type=float,
        help="Required mininum importance score.",
    )
    parser.add_argument(
        "--sim_score_threshold",
        default=0.7,
        type=float,
        help="Required minimum semantic similarity score.",
    )
    parser.add_argument(
        "--synonym_num", default=50, type=int, help="Number of synonyms to extract"
    )
    parser.add_argument(
        "--batch_size", default=32, type=int, help="Batch size to get prediction"
    )
    parser.add_argument(
        "--data_size", default=None, type=int, help="Data size to create adversaries"
    )
    parser.add_argument(
        "--perturb_ratio",
        default=0.0,
        type=float,
        help="Whether use random perturbation for ablation study",
    )
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help="max sequence length for BERT target model",
    )
    parser.add_argument(
        "--device",
        type=str,
        required=True,
        choices=["cpu", "cuda", "mps"],
        help="Device to use for computation: cpu, cuda, mps",
    )

    args = parser.parse_args()

    return args


def main():
    args = parse_args()

    if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
        print(
            "Output directory ({}) already exists and is not empty.".format(
                args.output_dir
            )
        )
    else:
        os.makedirs(args.output_dir, exist_ok=True)

    # get data to attack
    # texts, labels = dataloader.read_corpus(args.dataset_path)
    texts, labels = get_exp_data()
    data = list(zip(texts, labels))
    if args.data_size is not None and args.data_size < len(data):
        data = data[: args.data_size]  # choose how many samples for adversary

    # print(f"data[0] = {data[0]}")
    # assert False, "breakpoint"
    num_data_samples = len(data)
    print(f"Data import finished! Found {num_data_samples} samples")

    # construct the model
    print("Building Model...")
    if args.target_model == "wordLSTM":
        model = Model(args.word_embeddings_path, nclasses=args.nclasses).to(
            torch.device(args.device)
        )
        checkpoint = torch.load(
            args.target_model_path, map_location=torch.device(args.device)
        )
        model.load_state_dict(checkpoint)
    elif args.target_model == "wordCNN":
        model = Model(
            args.word_embeddings_path, nclasses=args.nclasses, hidden_size=100, cnn=True
        ).to(torch.device(args.device))
        checkpoint = torch.load(
            args.target_model_path, map_location=torch.device(args.device)
        )
        model.load_state_dict(checkpoint)
    elif args.target_model == "bert":
        model = NLI_infer_BERT(
            args.target_model_path,
            nclasses=args.nclasses,
            max_seq_length=args.max_seq_length,
            device=args.device,
        )
    predictor = model.text_pred
    print("Model built!")

    ModelSummary.summarize(
        model,
        model_name=args.target_model,
        logger=None,
        verbose=True,
        print_architecture=False,
    )

    # assert False, 'breakpoint'
    # prepare synonym extractor
    # build dictionary via the embedding file
    idx2word = {}
    word2idx = {}

    print("Building vocab...")
    with open(args.counter_fitting_embeddings_path, "r") as ifile:
        for line in ifile:
            word = line.split()[0]
            if word not in idx2word:
                idx2word[len(idx2word)] = word
                word2idx[word] = len(idx2word) - 1

    print("Building cos sim matrix...")
    if args.counter_fitting_cos_sim_path and os.path.isfile(
        args.counter_fitting_cos_sim_path
    ):
        # load pre-computed cosine similarity matrix if provided
        print(
            "Load pre-computed cosine similarity matrix from {}".format(
                args.counter_fitting_cos_sim_path
            )
        )
        cos_sim = np.load(args.counter_fitting_cos_sim_path)
    else:
        # calculate the cosine similarity matrix
        print("Start computing the cosine similarity matrix!")
        embeddings = []
        with open(args.counter_fitting_embeddings_path, "r") as ifile:
            for line in ifile:
                embedding = [float(num) for num in line.strip().split()[1:]]
                embeddings.append(embedding)
        embeddings = np.array(embeddings)
        product = np.dot(embeddings, embeddings.T)
        norm = np.linalg.norm(embeddings, axis=1, keepdims=True)
        cos_sim = product / np.dot(norm, norm.T)
        np.save(args.counter_fitting_cos_sim_path, cos_sim)
        print(
            "Cosine similarity matrix saved to {}".format(
                args.counter_fitting_cos_sim_path
            )
        )
    print("Cos sim import finished!")

    print(f"building semantic similarity module from {args.USE_cache_path}...")

    # build the semantic similarity module
    use = USE(args.USE_cache_path)
    print("Semantic similarity module built!")

    # start attacking
    orig_failures = 0.0
    adv_failures = 0.0
    changed_rates = []
    nums_queries = []
    orig_texts = []
    adv_texts = []
    true_labels = []
    new_labels = []
    log_file = open(os.path.join(args.output_dir, "results_log"), "a")

    stop_words_set = criteria.get_stopwords()

    # Ensure the POS-tagger and universal mapping are present
    for pkg in ("averaged_perceptron_tagger_eng", "universal_tagset"):
        try:
            nltk.data.find(f"taggers/{pkg}")
        except LookupError:
            nltk.download(pkg, quiet=True)

    print("Start attacking!")
    for idx, (text, true_label) in tqdm(enumerate(data)):
        if idx % 20 == 0:
            print(
                "{} samples out of {} have been finished!".format(idx, num_data_samples)
            )
        
        # text is list of strings
        if args.perturb_ratio > 0.0:
            new_text, num_changed, orig_label, new_label, num_queries = random_attack(
                text,
                true_label,
                predictor,
                args.perturb_ratio,
                stop_words_set,
                word2idx,
                idx2word,
                cos_sim,
                sim_predictor=use,
                device=args.device,
                sim_score_threshold=args.sim_score_threshold,
                import_score_threshold=args.import_score_threshold,
                sim_score_window=args.sim_score_window,
                synonym_num=args.synonym_num,
                batch_size=args.batch_size,
            )
        else: 
            new_text, num_changed, orig_label, new_label, num_queries = attack(
                text,
                true_label,
                predictor,
                stop_words_set,
                word2idx,
                idx2word,
                cos_sim,
                num_changed_budget=2,
                sim_predictor=use,
                device=args.device,
                sim_score_threshold=args.sim_score_threshold,
                import_score_threshold=args.import_score_threshold,
                sim_score_window=args.sim_score_window,
                synonym_num=args.synonym_num,
                batch_size=args.batch_size,
            )

        if true_label != orig_label:
            orig_failures += 1
        # else:
        #     nums_queries.append(num_queries)
        nums_queries.append(num_queries)
        
        if true_label != new_label:
            adv_failures += 1

        changed_rate = 1.0 * num_changed / len(text)

        if true_label == orig_label and true_label != new_label:
            changed_rates.append(changed_rate)
            orig_texts.append(" ".join(text))
            adv_texts.append(new_text)
            true_labels.append(true_label)
            new_labels.append(new_label)


    print(f"num_queries is of length {len(nums_queries)}, num_queries = {nums_queries}")
    message = (
        "For target model {}: original accuracy: {:.3f}%, adv accuracy: {:.3f}%, "
        "avg changed rate: {:.3f}%, num of queries: {:.1f}\n".format(
            args.target_model,
            (1 - orig_failures / num_data_samples) * 100,
            (1 - adv_failures / num_data_samples) * 100,
            np.mean(changed_rates) * 100,
            np.mean(nums_queries),
        )
    )

    print(message)
    log_file.write(message)

    with open(os.path.join(args.output_dir, "adversaries.txt"), "w") as ofile:
        for orig_text, adv_text, true_label, new_label in zip(
            orig_texts, adv_texts, true_labels, new_labels
        ):
            ofile.write(
                "orig sent ({}):\t{}\nadv sent ({}):\t{}\n\n".format(
                    true_label, orig_text, new_label, adv_text
                )
            )


if __name__ == "__main__":
    main()
