import torch
import TCclassifier as tc
from generate import Generator
from tokenizer_util import replace_content, test_positions, show_test_positions, target_tokenizer_function_models, detect_prompt_tc
from argparse import ArgumentParser
import os
import re

def test(args):
    import sys
    sys.stdout = open(args.save_file, 'w')
    model_name = args.model_name
    generator = Generator(model_name=model_name, device=args.device, load_in_4bit=args.load_in_4bit)
    if args.dataset == 'concat':
        path = 'concat_dataset'
    elif args.dataset == 'singleeq':
        path = 'singleeq_dataset'
    else:
        raise ValueError('Wrong dataset name.')
    if args.sample == [-1]:
        sample_list = list(filter(lambda name: re.compile('dataset_[0-9]+.json').fullmatch(name) is not None ,os.listdir(path)))
    else:
        sample_list = [f'dataset_{i}.json' for i in args.sample]
    sample_list.sort(key=lambda name: int(re.search('[0-9]+', name).group()))

    for sample in sample_list:

        with open(os.path.join(path, sample)) as f:
            import json
            dataset = json.load(f)

        sentence = replace_content(dataset['template'], dataset['content_list'][0])
        print('The complete sentence is:\n{}'.format(sentence))
        print('-'*30)

        auth_token = open('huggingface_auth_token').read().strip()
        pos = test_positions(sentence, [], target_tokenizer_function_models('llama-2', auth_token), split_punctuations=args.split_punc)
        show_test_positions(pos)
        print('-'*30)

        split_pos = [0] + pos[0]['input'] 
        prompt_tc = detect_prompt_tc(pos, template=dataset['template'], len_of_prompt=dataset['prompt_len'])

        num_content = sum(ptc == 0 for ptc in prompt_tc)
        content_list = list(zip(*dataset['content_list']))[0:num_content]
        # content_list is a list of lists, each list contains the content of a content token.
        # add whitespace before each token
        content_list = [[' '+c for c in content] for content in content_list]

        classifier = tc.TCClassifier(generator=generator, sentence=sentence, split_pos=split_pos, prompt_tc=prompt_tc, content_list=content_list, disturb_num=args.disturb_num, threshold=args.threshold, filter_list=args.filter_list, filter_behavior=args.filter_behavior, split_punc=args.split_punc)
        classifier.classify(debug=args.debug)
        classifier.show()

def parse_args():
    parser = ArgumentParser()
    parser.add_argument('--model_name', type=str, default='meta-llama/Llama-2-7b-chat-hf')
    parser.add_argument('--device', type=str, default='auto')
    parser.add_argument('-q', '--load_in_4bit', action='store_true')
    parser.add_argument('--dataset', type=str, default='concat', choices=['concat', 'singleeq'])
    parser.add_argument('-s', '--sample', type=int, nargs='+', default=[-1], help='The index of the tested samples. Default:-1, means all samples.',)
    parser.add_argument('-t', '--threshold', type=float, default=0.45)
    parser.add_argument('-n', '--disturb_num', type=int, default=10)
    parser.add_argument('--filter_list', type=str, nargs='+', default=['▁', '<0x0A>', '▁"', "▁'", '▁▁', '▁$'])
    parser.add_argument('--filter_behavior', type=str, default='ignore', choices=['ignore', 'next'])
    parser.add_argument('--save_file', type=str, default='tc_result.txt') 
    parser.add_argument('--split_punc', type=str, nargs='*', default=[',','.',';',':','?'], help='The punctuations to be splitted into a new word.')
    parser.add_argument('--debug', action='store_true')
    args = parser.parse_args()
    return args

def main():
    args = parse_args()
    test(args)

if __name__ == '__main__':
    main()
