import random
import nltk
from nltk.corpus import wordnet as wn
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
from pathlib import Path
import json
from tqdm import tqdm

from nltk.corpus import wordnet
wordnet.synsets("dog")

stop_words = set(stopwords.words('english'))

def get_wordnet_synonyms(word):
    synonyms = set()
    for syn in wn.synsets(word):
        for lemma in syn.lemmas():
            if lemma.name().lower() != word.lower():
                synonyms.add(lemma.name().replace('_', ' '))
    return list(synonyms)

def synonym_replace(sentence, num_replacements=20, seed=42):
    random.seed(seed)
    tokens = word_tokenize(sentence)
    new_tokens = tokens.copy()

    replace_candidates = [
        (i, word) for i, word in enumerate(tokens)
        if word.isalpha() and word.lower() not in stop_words
    ]

    random.shuffle(replace_candidates)

    replaced = 0
    for i, word in replace_candidates:
        synonyms = get_wordnet_synonyms(word)
        if synonyms:
            replacement = random.choice(synonyms)
            new_tokens[i] = replacement
            replaced += 1
        if replaced >= num_replacements:
            break

    return ' '.join(new_tokens)

def textfooler(data_path, if_train=False):
    p = Path(data_path)
    filename = p.stem
    with open(data_path,'r',encoding='utf-8') as f:
        datas = json.load(f)

    random.seed(42)

    selected_samples = random.sample(datas, 128)
    if not if_train:
        for item in selected_samples:
            if item['good_response'] == 'A':
                item['response_A'] = synonym_replace(item['response_A'])
            elif item['good_response'] == 'B':
                item['response_B'] = synonym_replace(item['response_B'])
    elif if_train:
        for item in selected_samples:
            item['chosen'] = synonym_replace(item['chosen'])

    save_path = f'/mnt/workspace/adv_eval_part/adv_attack_baseline/data/{filename}_textfooler.json'
    with open(save_path,'w',encoding='utf-8') as f:
        json.dump(selected_samples,f,indent=2,ensure_ascii=False)


if __name__ == '__main__':
    data_list = [
        '/path/to/data'
    ]
    
    for data_path in tqdm(data_list,total=len(data_list)):
        textfooler(data_path)  
    textfooler('/path/to/train_data',True)
