import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"
import torch
import random
import argparse
import torch.nn as nn
from diffusers import StableDiffusionPipeline
import numpy as np
from sklearn.neighbors import NearestNeighbors
from sklearn.feature_extraction.text import CountVectorizer
# import clip
from PIL import Image
# import OpenAttack
import pandas as pd
import tqdm
import scipy.stats as stest
from anthro_lib import ANTHRO # download from naacl2019-like-humans-visual-attacks
anthro = ANTHRO()
anthro.load('DMattacker/anthro/ANTHRO_Data_V1.0') # download from naacl2019-like-humans-visual-attacks
import ssl
from sklearn.metrics import pairwise_distances
from perturbations_store import PerturbationsStorage
from gensim.models import KeyedVectors as W2Vec                 
w2vmodel = W2Vec.load_word2vec_format("RobTest-Demo/naacl2019-like-humans-visual-attacks/VIPER/vce.normalized")



ssl._create_default_https_context = ssl._create_unverified_context

torch.cuda.empty_cache()
device = "cuda" if torch.cuda.is_available() else "cpu"


pipe = StableDiffusionPipeline.from_pretrained(
    "stabilityai/stable-diffusion-2-1",  
    revision="fp16",  # 如果不想用半精度，删掉这行和下面一行
    torch_dtype=torch.float16
)  

num_inference_steps = 50
num_batch = 5 #15
batch_size = 3
pipe = pipe.to(device)

test_method_pool = {
    "t": stest.ttest_ind,
    "tr":stest.ttest_rel,
    "ks": stest.ks_2samp,
    "ch2":stest.chi2_contingency,
    "e":stest.epps_singleton_2samp,
    "brunner":stest.brunnermunzel,
    "mood":stest.mood,
    "f":stest.f_oneway,

}


from torchmetrics.multimodal import CLIPScore
import torchvision.transforms as transforms
transform_ori = transforms.Compose([
        # transforms.Resize(224),
        # transforms.CenterCrop(224),
        transforms.ToTensor(),
        # transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
    ])

import clip
clipmodel, preprocess = clip.load("ViT-B/32")
clipmodel.cuda().eval()
metric = CLIPScore().to(device)
def calculate_text_image_distance(text, image):
    # print(transform(image).shape)
    img = transform_ori(image)*255
    score = metric(img.to(device),text)
    return score.detach().cpu().numpy().item()

def calculate_text_text_distance(text1, text2):
    text = clip.tokenize([text1,text2]).to(device)
    with torch.no_grad():
        text_features = clipmodel.encode_text(text)
    score = text_features[0] @ text_features[1].T / (text_features[0].norm() * text_features[1].norm())
    return max(score.detach().cpu().numpy().item(),0)
        

def cal_loss(prompt,ori_prompt):
    print("prompt:",prompt)
    print("ori_prompt:",ori_prompt)
    loss = []
    for i in range(num_batch):
        generator = torch.Generator(device).manual_seed(1023+i)
        images = pipe([prompt]*batch_size, num_inference_steps=num_inference_steps,generator=generator) 
        for j in range(batch_size):
            loss.append(calculate_text_image_distance(ori_prompt,images.images[j]))
                # images.images[j].save(str(i)+"_"+str(j)+".png")
    # return np.array(loss)
    return loss
from PIL import Image, ImageOps


def gen_pic1(prompt):
    print("text:",prompt)
    pic = []
    for i in range(num_batch):
        generator = torch.Generator(device).manual_seed(1023+i)
        images = pipe([prompt]*batch_size, num_inference_steps=num_inference_steps,generator=generator) 
        for j in range(batch_size):
            pic.append(images.images[j])
    return pic


def gen_pic2(prompt):
    pic = []
    for i in range(num_batch):
        generator = torch.Generator(device).manual_seed(102+i)
        images = pipe([prompt]*batch_size, num_inference_steps=num_inference_steps,generator=generator) 
        for j in range(batch_size):
            pic.append(images.images[j])
    return pic

import torch
import timm
# def mmd_distance(x1, x2, h):
#     """
#     Computes the MMD distance between two sets of images using a pre-trained image encoder h.

#     Args:
#         x1 (torch.Tensor): A tensor of shape (batch_size, num_channels, height, width) representing the first set of images.
#         x2 (torch.Tensor): A tensor of shape (batch_size, num_channels, height, width) representing the second set of images.
#         h (torch.nn.Module): A pre-trained image encoder that maps images to a feature space.

#     Returns:
#         torch.Tensor: A scalar tensor representing the MMD distance between the two sets of images.
#     """
    # Compute feature maps for the two sets of images
    # print(x1.size())
    # x1 = x1.to(device)
    # x2 = x2.to(device)
    # with torch.no_grad():
        # f1 = h(x1).reshape(x1.shape[0], -1)
        # f2 = h(x2).reshape(x2.shape[0], -1)
    #     f1 = h.encode_image(x1).float()
    #     f2 = h.encode_image(x2).float()

    # # Compute the kernel matrix using the cosine kernel
    # Kxx = torch.mm(f1, f1.t()) / torch.mm(torch.norm(f1, dim=1, keepdim=True), torch.norm(f1, dim=1, keepdim=True).t())
    # Kxy = torch.mm(f1, f2.t()) / torch.mm(torch.norm(f1, dim=1, keepdim=True), torch.norm(f2, dim=1, keepdim=True).t())
    # Kyy = torch.mm(f2, f2.t()) / torch.mm(torch.norm(f2, dim=1, keepdim=True), torch.norm(f2, dim=1, keepdim=True).t())

    # Compute the MMD distance
    # mmd = Kxx.mean() + Kyy.mean() - 2 * Kxy.mean()
    # return mmd.cpu().item()

# Load a pre-trained ViT model
# vit = timm.create_model('vit_base_patch16_224', pretrained=True).to(device)
def mmd_distance(source_images, target_images, model, kernel='rbf'):
    X = model.encode_image(source_images).detach().cpu().numpy()
    Y = model.encode_image(target_images).detach().cpu().numpy()
    
    if kernel == 'rbf':
        # 使用径向基核函数计算MMD
        sigma = np.median(pairwise_distances(X, Y, metric='euclidean'))
        gamma = 1 / (2 * sigma ** 2)
        XX = np.exp(-gamma * pairwise_distances(X, X, metric='sqeuclidean'))
        YY = np.exp(-gamma * pairwise_distances(Y, Y, metric='sqeuclidean'))
        XY = np.exp(-gamma * pairwise_distances(X, Y, metric='sqeuclidean'))
    else:
        raise ValueError("Invalid kernel type. Only 'rbf' is supported.")
    
    mmd = np.mean(XX) + np.mean(YY) - 2 * np.mean(XY)
    return mmd


# Compute the MMD distance between two sets of images
# x1 = torch.randn(10, 3, 224, 224) 
# x2 = torch.randn(10, 3, 224, 224)
# mmd = 
# print(f"MMD distance: {mmd:.4f}")
# import torchvision.transforms as transforms
# transform = transforms.Compose([
#     # transforms.Resize(224),
#     # transforms.CenterCrop(224),
#     transforms.ToTensor(),
#     transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
# ])

def cal_mmd_loss(x1,x2):
    x1_copy = torch.tensor(np.stack([preprocess(img) for img in x1])).cuda()
    x2_copy = torch.tensor(np.stack([preprocess(img) for img in x2])).cuda()
    return mmd_distance(x1_copy, x2_copy, clipmodel)
from transformers import AutoTokenizer


def gen_image(ori_prompt,adv_prompt,cnt,attack_method,test_method,alpha,data_name,limit):
    root_file = "./myattacker_mmd_limit"+str(limit)+"/"
    if not os.path.exists(root_file):
        os.mkdir(root_file)
    main_file_path = root_file+str(attack_method)+"_"+str(test_method)+"_"+str(alpha)+"_"+str(data_name)
    sub_file_path = main_file_path+"/"+str(cnt)
    ori_file_path = sub_file_path+"/ori/"
    adv_file_path = sub_file_path+"/adv/"
    if not os.path.exists(main_file_path):
        os.mkdir(main_file_path)
    
    if not os.path.exists(sub_file_path):
        os.mkdir(sub_file_path)
        os.mkdir(ori_file_path)
        os.mkdir(adv_file_path)

    def gen(prompt,path):
        for i in range(num_batch):
            generator = torch.Generator(device).manual_seed(1023+i)
            images = pipe([prompt]*batch_size, num_inference_steps=num_inference_steps, generator=generator)
            for j in range(batch_size):
                # new_size = (images.images[j].size[0] // 2, images.images[j].size[1] // 2)
                # resized_image = ImageOps.fit(images.images[j], new_size, Image.LANCZOS)
                resized_image = images.images[j].resize((images.images[j].size[0] // 2,images.images[j].size[1] // 2), Image.LANCZOS) # Image.ANTIALIAS为高质量的缩放滤波器
                resized_image.save(path+str(i)+"_"+str(j)+".png")
    gen(ori_prompt,ori_file_path)
    gen(adv_prompt,adv_file_path)

# loss = cal_loss(prompt)
# ave_loss = np.average(loss)

#把[0:1]以k作为中点用指数放缩到新的[0:1], 例如k是0.95时输入0.95得到0.5
def scale(x, k):
    x = x**0.1
    if x<=k:
        return 0.5*x/k
    else:
        return 0.5+0.5*(x-k)/(1-k)

# class CustomTensorFlowModelWrapper(ModelWrapper):
#     def __init__(self,ori_loss,alpha=0.95,test_method="ttest",ori_prompt=""):
#         self.model = "value"
#         self.ori_loss = ori_loss
#         self.alpha = alpha
#         self.test_method = test_method_pool[test_method]
#         self.ori_prompt = ori_prompt
        # self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

    # def __call__(self, input_):
    #     ret = []
    #     for sent in input_:
    #         loss = cal_loss(sent,self.ori_prompt)
    #         print('loss:',loss)
    #         print('ori_loss:',self.ori_loss)
    #         p_val =self.test_method(loss,self.ori_loss).pvalue # t_stat, p_val = ttest_ind  [0][0]
    #         print(p_val**0.1)
    #         prob = scale(p_val,self.alpha)
    #         print("prob:",prob)
    #         ret.append([prob, 1-prob])
    #     return torch.tensor(ret)


keyboard_neighbors = {
            "q": "was", "w": "qeasd", "e": "wrsdf", "r": "etdfg", "t": "ryfgh", "y": "tughj", "u": "yihjk",
            "i": "uojkl", "o": "ipkl", "p": "ol",
            "a": "qwszx", "s": "qweadzx", "d": "wersfxc", "f": "ertdgcv", "g": "rtyfhvb", "h": "tyugjbn",
            "j": "yuihknm", "k": "uiojlm", "l": "opk", ";":"op[l',./", '\'':"p[];./",
            "z": "asx", "x": "sdzc", "c": "dfxv", "v": "fgcb", "b": "ghvn", "n": "hjbm", "m": "jkn",
            ',': "mkl.", ".": ",l;/", 
        }
keyboard_neighbors_keys = keyboard_neighbors.keys()

def delete_transform(word):
    if len(word)==1:
        return None
    i = random.randint(0, len(word)-1)
    return word[:i] + word[i+1:]


def insert_transform(word):
    i = random.randint(0, len(word)-1)
    char = word[i].lower()
    if char in keyboard_neighbors_keys:   
        keyboard_neighbor = random.choice(keyboard_neighbors[char])
        return word[:i] + keyboard_neighbor + word[i:]
    else:   
        return None
    
def case_transform(word,max_num = 1):
    case_num =  random.randint(1, max_num)
    idx_list = [idx for idx in range(len(word))]
    random.shuffle(idx_list)
    idx_list = idx_list[:case_num]
    for i in idx_list:
        char = word[i]
        if char.islower():
            char = char.upper()
        else:
            char = char.lower()
        word = word[:i] + char + word[i+1:]
    return word

def space_transform(word):
    i = random.randint(0, len(word)-1)
    if word[i] != ' ':
        return word[:i] + ' ' + word[i:]
    else:
        return None

def replace_transform(word):
    i = random.randint(0, len(word)-1)
    # 判断字符大小写
    lower = False
    if word[i].islower():
        lower=True
    char = word[i].lower()
    if char in keyboard_neighbors_keys:
        keyboard_neighbor = random.choice(keyboard_neighbors[char])
        if lower == False:
            keyboard_neighbor = keyboard_neighbor.upper()
        return word[:i] + keyboard_neighbor + word[i+1:]
    else:
        return None


def swap_transform(word):
    i = random.randint(0, len(word)-1)
    if len(word)==1:
        return None
    try:
        if word[i-1] != ' ' and word[i] != ' ':
            return word[:i-1] + word[i] + word[i-1] + word[i+1:]
        else:
            return None
    except:
        return None


def repeat_transform(word):
    i = random.randint(0, len(word)-1)
    char = word[i]
    return word[:i] + char + word[i:]

TRANSFORMATION = {
            'delete': delete_transform, 
            'insert': insert_transform, 
            'replace': replace_transform,
            'swap': swap_transform, 
            'repeat': repeat_transform,
            'case': case_transform,
            'space':space_transform,
        }

similar_chars = {
    'i':['і', '1'], 'l':['ⅼ', '1'], 'z':['ᴢ', '2'],   "s":['5', 'ѕ'], "g":['ɡ','9'], 'b':['Ь','6'], 'q':['ԛ', '9'], 'o':['0','о' ], '-': '˗', '9': '৭', '8': 'Ȣ', '7': '𝟕', '6': 'б', '5': 'Ƽ', '4': 'Ꮞ', '3': 'Ʒ', '2': 'ᒿ', '1': 'l', '0': 'O',
         "'": '`', 'a': 'ɑ',  'c': 'ϲ', 'd': 'ԁ', 'e': 'е', 'f': '𝚏',  'h': 'հ', 'j': 'ϳ',
         'k': '𝒌',  'm': 'ｍ', 'n': 'ո', 'p': 'р',  'r': 'ⲅ',  't': '𝚝', 'u': 'ս',
         'v': 'ѵ', 'w': 'ԝ', 'x': '×', 'y': 'у'
}

def get_similars(word, num, level=1, distance=1, strict=True):
    if level == 0:
        return [word]
    distance = random.randint(1,min(len(word),3))
    candidates = anthro.get_similars(word, level=1, distance=distance, strict=False)
    candidates = [c for c in candidates if c != word]
    # print(word,candidates)
    # sample num个
    if len(candidates) > num:
        return random.sample(candidates, num)
    elif len(candidates)!=0:
        # for i in range(num-len(candidates)):
        #     candidates.append(case_transform(word,max_num = len(word)))
        return candidates
    else:
        # candidates = []
        # for i in range(num):
        #     candidates.append(case_transform(word,max_num = len(word)))
        # return candidates
        return None

def get_gsimilar(word):
    if len(word)==1:
        return None
    i = random.randint(0, len(word)-1)
    char = word[i].lower()
    if char in similar_chars.keys():
        similar_char = random.choice(similar_chars[char])
        return word[:i] + similar_char + word[i+1:]
    else:
        return None

descs = pd.read_csv('RobTest-Demo/naacl2019-like-humans-visual-attacks/VIPER/NamesList.txt', skiprows=np.arange(16),error_bad_lines=False, header=None, names=['code', 'description'])
descs = descs.dropna(0)
descs_arr = descs.values # remove the rows after the descriptions
vectorizer = CountVectorizer(max_features=1000)
desc_vecs = vectorizer.fit_transform(descs_arr[:, 0]).astype(float)
vecsize = desc_vecs.shape[1]
vec_colnames = np.arange(vecsize)
desc_vecs = pd.DataFrame(desc_vecs.todense(), index=descs.index, columns=vec_colnames)
descs = pd.concat([descs, desc_vecs], axis=1)

def char_to_hex_string(ch):
    return '{:04x}'.format(ord(ch)).upper()

# function for retrieving the variations of a character
def get_all_variations(ch):
       
    # get unicode number for c
    c = char_to_hex_string(ch)
    
    # problem: latin small characters seem to be missing?
    if np.any(descs['code'] == c):
        description = descs['description'][descs['code'] == c].values[0]
    else:
        print('Failed to disturb %s, with code %s' % (ch, c))
        return c, np.array([])
    
    # strip away everything that is generic wording, e.g. all words with > 1 character in
    toks = description.split(' ')

    case = 'unknown'

    identifiers = []
    for tok in toks:
           
        if len(tok) == 1:
            identifiers.append(tok)
            
            # for debugging 
            if len(identifiers) > 1:
                print('Found multiple ids: ')
                print(identifiers)

        elif tok == 'SMALL':
            case = 'SMALL'
        elif tok == 'CAPITAL':
            case = 'CAPITAL'

    # for debugging
    #if case == 'unknown':
    #    sys.stderr.write('Unknown case:')
    #    sys.stderr.write("{}\n".format(toks))

    # find matching chars
    matches = []
    
    for i in identifiers:        
        for idx in descs.index:
            desc_toks = descs['description'][idx].split(' ')
            if i in desc_toks and not np.any(np.in1d(desc_toks, disallowed)) and \
                    not np.any(np.in1d(descs['code'][idx], disallowed_codes)) and \
                    not int(descs['code'][idx], 16) > 30000:

                # get the first case descriptor in the description
                desc_toks = np.array(desc_toks)
                case_descriptor = desc_toks[ (desc_toks == 'SMALL') | (desc_toks == 'CAPITAL') ]

                if len(case_descriptor) > 1:
                    case_descriptor = case_descriptor[0]
                elif len(case_descriptor) == 0:
                    case = 'unknown'

                if case == 'unknown' or case == case_descriptor:
                    matches.append(idx)

    # check the capitalisation of the chars
    return c, np.array(matches)

def get_unicode_desc_nn(c, perturbations_file, topn=1):
    # we need to consider only variations of the same letter -- get those first, then apply NN
    c, matches = get_all_variations(c)
    
    if not len(matches):
        return [], [] # cannot disturb this one
    
    # get their description vectors
    match_vecs = descs[vec_colnames].loc[matches]
           
    # find nearest neighbours
    neigh = NearestNeighbors(metric='euclidean')
    Y = match_vecs.values
    neigh.fit(Y) 
    
    X = descs[vec_colnames].values[descs['code'] == c]

    if Y.shape[0] > topn:
        dists, idxs = neigh.kneighbors(X, topn, return_distance=True)
    else:
        dists, idxs = neigh.kneighbors(X, Y.shape[0], return_distance=True)

    # turn distances to some heuristic probabilities
    #print(dists.flatten())
    probs = np.exp(-0.5 * dists.flatten())
    probs = probs / np.sum(probs)
    
    # turn idxs back to chars
    #print(idxs.flatten())
    charcodes = descs['code'][matches[idxs.flatten()]]
    
    #print(charcodes.values.flatten())
    
    chars = []
    for charcode in charcodes:
        chars.append(chr(int(charcode, 16)))

    # filter chars to ensure OOV scenario (if perturbations file from prev. perturbation contains any data...)
    c_orig = chr(int(c, 16))
    chars = [char for char in chars if not perturbations_file.observed(c_orig, char)]

    #print(chars)

    return chars, probs

perturbations_file='./perturbations.txt'
perturbations_file = PerturbationsStorage(perturbations_file)
topn = 20
def get_dces(word):
    i = random.randint(0, len(word)-1)
    c = word[i]
    similar_chars, probs = get_unicode_desc_nn(c, perturbations_file, topn=topn)
    probs = probs[:len(similar_chars)]
    probs = probs / np.sum(probs)
    if similar_chars==[]:
        return None
    s = np.random.choice(similar_chars, 1, replace=True, p=probs)[0]
    return word[:i] + s + word[i+1:]


def get_ices(word):
    i = random.randint(0, len(word)-1)
    c = word[i]
    similar = w2vmodel.most_similar(c, topn=topn)
    words, probs = [x[0] for x in similar], np.array([x[1] for x in similar])
    probs /= np.sum(probs)
    s = np.random.choice(words, 1, replace=True, p=probs)[0]
    if words==[]:
        return None
    return word[:i] + s + word[i+1:]

def readD(fn):
  h = {}
  for line in open(fn):
    line = line.strip()
    x = line.split()
    a,b = x[0].strip(),x[1].strip()
    h[a] = b
  return h
h = readD("RobTest-Demo/naacl2019-like-humans-visual-attacks/VIPER/selected.neighbors")

def get_eces(word):
    i = random.randint(0, len(word)-1)
    c = word[i]
    r = h.get(c,c)
    if len(r)>0:
        return word[:i] + random.choice(r) + word[i+1:]
    return None

GTRANSFORMATION = {
            'gsimilar': get_gsimilar, 
            'dces': get_dces,
            'eces': get_eces,
            'ices':get_ices
        }

def glyph_attack_pipe(sent,ori_loss,ori_pic,test_method='t',alpha=0.15,limit=1):
    test_method = test_method_pool[test_method]
    query = 0
    ori_prompt = sent
    adv_sents = []
    t_list = []
    losses = []
    prob_list = {}
    sentence_tokens = sent.split(" ")
    for i in range(len(sentence_tokens)):
        sentence_tokens_without =  sentence_tokens[:i] +["<|endoftext|>"]+ sentence_tokens[i + 1:]
        adv_pic = gen_pic1(' '.join(sentence_tokens_without))
        adv_loss = cal_mmd_loss(adv_pic,ori_pic)
        query += 1
        print('loss:',adv_loss)
        print('ori_loss:',ori_loss)
        prob = (3*ori_loss)/(3*ori_loss+adv_loss)
        print('prob:',prob)
        prob_list[i] = prob

    # 获取prob最小的前k个的index
    k = len(prob_list)
    prob_list = sorted(prob_list.items(), key=lambda x: x[1], reverse=False) # 按照prob从小到大排序
    prob_list = prob_list[:k]
    prob_list = [i[0] for i in prob_list]

    w = [int(len(sentence_tokens)*(r)*0.1) for r in range(11)]
    # t_list.append(" ".join(sentence_tokens),1,)
    count=0

    # 生成adv
    last_prob = -1000
    adv_sentence_tokens = sentence_tokens.copy()
    for idx in prob_list:
        count+=1
        sent_rec = []
        prob_rec = []
        loss_rec = []
        b=0
        if last_prob> -900:
            sent_rec.append(adv_sentence_tokens)
            prob_rec.append(last_prob)
            b=1
        for num in range(5):
            typo = None
            tmp_adv_sentence_tokens = adv_sentence_tokens.copy()


            while typo is None:
                transformation = random.choice(list(GTRANSFORMATION.keys()))
                typo = GTRANSFORMATION[transformation](tmp_adv_sentence_tokens[idx])
            tmp_adv_sentence_tokens[idx] = typo
            sent_rec.append(tmp_adv_sentence_tokens)
            adv_pic = gen_pic1(' '.join(sent_rec[num+b]))
            adv_loss = cal_mmd_loss(adv_pic,ori_pic)
            query += 1
            print('loss:',adv_loss)
            print('ori_loss:',ori_loss)
            prob = (3*ori_loss)/(3*ori_loss+adv_loss)
            print('prob:',prob)
            # if prob < 0.5:
            #     return ' '.join(sent_rec[num+b]),query
            # else:
            prob_rec.append(prob)
        # 获得最小prob的序号
        min_prob_idx = prob_rec.index(min(prob_rec))
        last_prob = prob_rec[min_prob_idx]
        adv_sentence_tokens = sent_rec[min_prob_idx]
        if count in w:
            t_pic_adv = gen_pic1(' '.join(adv_sentence_tokens))
            t_score_loss = []
            for j in range(len(t_pic_adv)):
                t_score_loss.append(calculate_text_image_distance(ori_prompt,t_pic_adv[j]))
            t_score = np.mean(t_score_loss)
            t_list.append([' '.join(adv_sentence_tokens),t_score,calculate_text_text_distance(ori_prompt,' '.join(adv_sentence_tokens))])

        

    return t_list



def typo_attack_pipe(sent,ori_loss,ori_pic,test_method='t',alpha=0.15,limit=1):
    test_method = test_method_pool[test_method]
    query = 0
    t_list = []
    ori_prompt = sent
    adv_sents = []
    losses = []
    prob_list = {}
    sentence_tokens = sent.split(" ")
    for i in range(len(sentence_tokens)):
        sentence_tokens_without =  sentence_tokens[:i] +["<|endoftext|>"]+ sentence_tokens[i + 1:]
        adv_pic = gen_pic1(' '.join(sentence_tokens_without))
        adv_loss = cal_mmd_loss(adv_pic,ori_pic)
        query += 1
        print('loss:',adv_loss)
        print('ori_loss:',ori_loss)
        prob = (3*ori_loss)/(3*ori_loss+adv_loss)
        print('prob:',prob)
        prob_list[i] = prob

    # 获取prob最小的前k个的index
    k = len(prob_list)
    prob_list = sorted(prob_list.items(), key=lambda x: x[1], reverse=False) # 按照prob从小到大排序
    prob_list = prob_list[:k]
    prob_list = [i[0] for i in prob_list]

    w = [int(len(sentence_tokens)*(r)*0.1) for r in range(11)]
    # t_list.append(" ".join(sentence_tokens),1,)
    count=0

    # 生成adv
    last_prob = -1000
    adv_sentence_tokens = sentence_tokens.copy()
    for idx in prob_list:
        count+=1
        sent_rec = []
        prob_rec = []
        loss_rec = []
        b=0
        if last_prob> -900:
            sent_rec.append(adv_sentence_tokens)
            prob_rec.append(last_prob)
            b=1
        for num in range(5):
            typo = None
            tmp_adv_sentence_tokens = adv_sentence_tokens.copy()

            while typo is None:
                transformation = random.choice(list(TRANSFORMATION.keys()))
                typo = TRANSFORMATION[transformation](tmp_adv_sentence_tokens[idx])
            tmp_adv_sentence_tokens[idx] = typo
            sent_rec.append(tmp_adv_sentence_tokens)
            adv_pic = gen_pic1(' '.join(sent_rec[num+b]))
            adv_loss = cal_mmd_loss(adv_pic,ori_pic)
            query += 1
            print('loss:',adv_loss)
            print('ori_loss:',ori_loss)
            prob = (3*ori_loss)/(3*ori_loss+adv_loss)
            print('prob:',prob)
            # if prob < 0.5:
            #     return ' '.join(sent_rec[num+b]),query
            # else:
            prob_rec.append(prob)
        # 获得最小prob的序号
        min_prob_idx = prob_rec.index(min(prob_rec))
        
        adv_sentence_tokens = sent_rec[min_prob_idx]
        last_prob = prob_rec[min_prob_idx]
        if count in w:
            t_pic_adv = gen_pic1(' '.join(adv_sentence_tokens))
            t_score_loss = []
            for j in range(len(t_pic_adv)):
                t_score_loss.append(calculate_text_image_distance(ori_prompt,t_pic_adv[j]))
            t_score = np.mean(t_score_loss)
            t_list.append([' '.join(adv_sentence_tokens),t_score,calculate_text_text_distance(ori_prompt,' '.join(adv_sentence_tokens))])

        

    return t_list



def phonetic_attack_pipe(sent,ori_loss,ori_pic,test_method='t',alpha=0.15,limit=1):
    test_method = test_method_pool[test_method]
    query = 0
    ori_prompt = sent
    adv_sents = []
    losses = []
    prob_list = {}
    t_list = []
    sentence_tokens = sent.split(" ")
    for i in range(len(sentence_tokens)):
        sentence_tokens_without =  sentence_tokens[:i] +["<|endoftext|>"]+ sentence_tokens[i + 1:]
        adv_pic = gen_pic1(' '.join(sentence_tokens_without))
        adv_loss = cal_mmd_loss(adv_pic,ori_pic)
        query += 1
        print('loss:',adv_loss)
        print('ori_loss:',ori_loss)
        prob = (3*ori_loss)/(3*ori_loss+adv_loss)
        print('prob:',prob)
        prob_list[i] = prob

    # 获取prob最小的前k个的index
    k = len(prob_list)
    prob_list = sorted(prob_list.items(), key=lambda x: x[1], reverse=False) # 按照prob从小到大排序
    prob_list = prob_list[:k]
    prob_list = [i[0] for i in prob_list]

    w = [int(len(sentence_tokens)*(r)*0.1) for r in range(11)]
    # t_list.append(" ".join(sentence_tokens),1,)
    count=0

    # 生成adv
    last_prob = -1000
    adv_sentence_tokens = sentence_tokens.copy()
    for idx in prob_list:
        count+=1
        sent_rec = []
        prob_rec = []
        loss_rec = []
        typos = None
        # transformation = random.choice(list(TRANSFORMATION.keys()))
        b=0
        if last_prob> -900:
            sent_rec.append(adv_sentence_tokens)
            prob_rec.append(last_prob)
            b=1

        bn=0
        while typos is None:
            typos = get_similars(adv_sentence_tokens[idx],5)
            bn+=1
            if bn>5:
                break
        if bn>5:
            continue
        
        for num in range(len(typos)):
            typo =typos[num]
            tmp_adv_sentence_tokens = adv_sentence_tokens.copy()
            tmp_adv_sentence_tokens[idx] = typo
            sent_rec.append(tmp_adv_sentence_tokens)
            adv_pic = gen_pic1(' '.join(sent_rec[num+b]))
            adv_loss = cal_mmd_loss(adv_pic,ori_pic)
            query+=1
            print('loss:',loss)
            print('ori_loss:',ori_loss)
            prob = (3*ori_loss)/(3*ori_loss+adv_loss)
            print('prob:',prob)
            # if prob < 0.5:
            #     return ' '.join(sent_rec[num+b]),query
            # else:
            prob_rec.append(prob)
        # 获得最小prob的序号
        min_prob_idx = prob_rec.index(min(prob_rec))
        
        adv_sentence_tokens = sent_rec[min_prob_idx]
        last_prob = prob_rec[min_prob_idx]
        if count in w:
            t_pic_adv = gen_pic1(' '.join(adv_sentence_tokens))
            t_score_loss = []
            for j in range(len(t_pic_adv)):
                t_score_loss.append(calculate_text_image_distance(ori_prompt,t_pic_adv[j]))
            t_score = np.mean(t_score_loss)
            t_list.append([' '.join(adv_sentence_tokens),t_score,calculate_text_text_distance(ori_prompt,' '.join(adv_sentence_tokens))])

        

    return t_list

attacker_pool = {'typo':typo_attack_pipe,'phonetic':phonetic_attack_pipe,'glyph':glyph_attack_pipe}

def main_pipe(test_method,alpha,data_name,attacker='typo',limit=1):
    
    attack_pipe = attacker_pool[attacker]
    with open(data_name+".txt", "r") as file:
            lines = file.readlines()

    # 遍历每行元素并添加到新列表中
    my_list = []
    for line in lines:
        my_list.append(line.strip())

    data = my_list
    cnt  = 0
    adv_data = []
    main_t_list = []

    for sent in data:
        ori_prompt = sent
        pic_1 = gen_pic1(ori_prompt)
        pic_2 = gen_pic2(ori_prompt)
        ori_loss = cal_mmd_loss(pic_1,pic_2)
        score_loss = []
        for j in range(len(pic_1)):
            score_loss.append(calculate_text_image_distance(ori_prompt,pic_1[j]))
        ori_mean = np.mean(score_loss)
        main_t_list.append([ori_prompt,ori_mean,calculate_text_text_distance(ori_prompt,ori_prompt)])
        print('ori clip score:',ori_mean)
        t_list_extend = attack_pipe(sent,ori_loss,ori_pic = pic_2,test_method=test_method,alpha=alpha,limit=limit)
        main_t_list.extend(t_list_extend)
        df = pd.DataFrame(main_t_list,columns=['sent','text2iamge','text2text'])

        df.to_csv("ghc_mmd01_"+data_name+"_"+attacker+"_"+test_method+"_"+str(alpha)+"_"+str(limit)+".csv",index=False)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', default='cartoon2', type=str)
    parser.add_argument('--test_method', default='t', type=str)
    parser.add_argument('--alpha', default=3, type=float)    # 0.4->3
    parser.add_argument('--attacker', default='phonetic', type=str)
    parser.add_argument('--limit', default=0, type=int)

    args = parser.parse_args()
    print(args.test_method,args.alpha,args.dataset,args.attacker,args.limit)

    main_pipe(args.test_method,args.alpha,args.dataset,args.attacker,args.limit)




        

    