import json

import torch
from torch.utils.data import Dataset

import random
import torch
import numpy as np

from transformers import GPT2Tokenizer


import os

import json

import argparse


# call the function with the parsed arguments

def get_subword_candidates(num_subwords, attack_example,  com_example, subword_bytes_file='./subword_byte_table.pt', max_byte_seq_len=8, byte_dict_size=256):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    seed = 1234

    # Set the random seed for Python's built-in random module
    np.random.seed(seed)

    # Set the random seed for PyTorch
    torch.manual_seed(seed)

    # Set the random seed for the CuDNN backend (if available)
    if torch.cuda.is_available():
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    # Load the tokenizer for the "bert-base-uncased" model
    tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

    # Get the vocabulary dictionary
    vocab_dict = tokenizer.get_vocab()
    idx_key = {value:key for key, value in vocab_dict.items()}

    # Open the JSON file for reading
    with open(attack_example, 'r') as f:
        # Load the contents of the file into a Python dictionary
        data = json.load(f)

    txt = data['original']

    with open(com_example, 'r') as f:
        # Load the contents of the file into a Python dictionary
        data = json.load(f)

    txt_com = data['original']

    

    # ================= subword to byte with basic tokenizaiton ===================

    if os.path.isfile(subword_bytes_file):
        subword2bytes = torch.load(subword_bytes_file)
    else:
        subword2bytes = torch.randint(0, byte_dict_size, (len(vocab_dict), max_byte_seq_len), dtype=torch.long)
        dirname = '/'.join(subword_bytes_file.split('/')[:-1])
        if not os.path.isdir(dirname):
            os.makedirs(dirname)
        torch.save(subword2bytes, subword_bytes_file)

    subword_byte_table = subword2bytes

    # The batch is a tuple containing inputs and labels
    inputs = txt
    tokenized_inputs = [tokenizer.encode(s) for s in inputs]
    t_idx = tokenized_inputs
    flattened_list = [item for sublist in t_idx for item in sublist]
    t_idx = flattened_list


    inputs_com = txt_com
    tokenized_inputs_com = [tokenizer.encode(s) for s in inputs_com]
    t_idx_com = tokenized_inputs_com
    flattened_list_com = [item for sublist in t_idx_com for item in sublist]
    t_idx_com = flattened_list_com

    subwords = tokenizer.convert_ids_to_tokens(flattened_list)

    subword_set = set(subwords)

    bytes_input = subword_byte_table[t_idx]
    bytes_input = bytes_input.view(-1)
    bytes_list = bytes_input.tolist()
    bytes_set = set(bytes_list)

    candidates = set()
    for i, sub in enumerate(subword_byte_table[1:]):
        s = set(sub.tolist())
        for b in bytes_set:
            if b in s:
                candidates.add(i)

    k = num_subwords

    candidates = list(candidates)

    random_subwords = random.sample(candidates, k)

    # add a new key-value pair
    data["subwords"] = list(random_subwords)

    data["subwords"] = list(random_subwords) + t_idx + t_idx_com


    # data["subwords"] = list(candidates)

    # write the updated data to the file
    with open(attack_example, 'w') as f:
        json.dump(data, f)

    # return candidates


if __name__ == "__main__":
    # create a parser object
    parser = argparse.ArgumentParser(description='Get subword candidates for each batch with different batch size')

    # add arguments to the parser
    # parser.add_argument('attack_example', type=str, help='path to the attack example file')
    parser.add_argument('--num_subwords', type=int, default=10000, help='number of subwords randomly selected')
    parser.add_argument('--batch_size', type=int, default=1, help='batch size')

    # parse the command-line arguments
    args = parser.parse_args()


    example_path = '/home/mengjiao/Documents/workspace_2023/python/FILM-main/experiments/wikitext-103/'

    for i in range(1, 6):
        attack_example = example_path + '/' + str(args.batch_size) +'/' + str(i)
        com_idx = random.sample(range(1, 6), 1)
        while com_idx == i:
            com_idx = random.sample(range(1, 6), 1)
        
        com_example = example_path + '/' + str(args.batch_size) +'/' + str(com_idx[0])

        print(attack_example)

        get_subword_candidates(args.num_subwords, attack_example, com_example)
