"""
prompt -> subtitutes using either
chatgpt WEB (manual)
chatgpt API
"""
import os
import argparse
from prompts.chatgpt import get_perturbations, extract_sub, get_chatgpt_response_content
from utils.exp_utils import json_save
from prompts.common import load_prompts
from prompts import prefix_dict
from prompts import DATA_DIR


def remove_empty_sub(sub):
    new_sub = {}
    for ori_word in sub:
        if len(sub[ori_word]) > 0:
            new_sub[ori_word] = sub[ori_word]
    return new_sub


def validate_results(sub, args):
    ## empty sub
    if len(sub) == 0:
        return False
    
    ## too many single word sub (may not be a negative point)
    cnt = 0
    for ori_word in sub:
        sub_words = sub[ori_word]
        if len(sub_words) == 1:
            cnt += 1
    if cnt / len(sub) > 0.4:
        return False

    ## wrong format (not parsed well)
    return True


def build_search_space(prompts, chatgpt_prefix, model, save_path, task):
    save_dict = {}
    prompt_id = 0
    for prompt in prompts:
        ## query chatgpt for initial results
        max_query = 5
        while max_query > 0:
            valid = True
            ## human eval keywords
            response = get_perturbations(prompt, chatgpt_prefix['he'], model=model)
            keywords = get_chatgpt_response_content(response[0])
            keywords = keywords.replace(', ', ',').split(',')

            ## substitute
            response = get_perturbations(prompt, chatgpt_prefix['sub'], model=model)
            sub = extract_sub(response)
            sub = remove_empty_sub(sub)
            valid = valid and validate_results(sub, args)

            ## opposite
            opp = {}
            if task.lower() == 'improve':
                response = get_perturbations(prompt, chatgpt_prefix['opp'], model=model)
                opp = extract_sub(response)
                opp = remove_empty_sub(opp)
                valid = valid and validate_results(opp, args)

            if valid:
                save_dict[prompt] = {
                    'sub': sub,
                    'opp': opp,
                    'keywords': keywords,
                    'prompt_id': prompt_id,
                    'prompt': prompt
                }
                prompt_id += 1
                break
            else:
                max_query -= 1

        ## failed to generate valid subtitutes for this prompt
        if prompt not in save_dict:
            print(f'Failed to generate sub for prompt: {prompt}')

        ## active saving
        print(f'Saving to {save_path}')
        json_save(save_dict, save_path)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="")
    parser.add_argument("--prompt_path", default="none", type=str)
    parser.add_argument("--model", default="gpt-3.5-turbo", type=str)
    parser.add_argument("--task", default="sub", type=str, choices=['improve', 'attack'])
    parser.add_argument("--prefix", default="default", type=str)
    parser.add_argument("--tag", default="none", type=str)
    args = parser.parse_args()

    if 'nfs' not in args.prompt_path: args.prompt_path = os.path.join(DATA_DIR, args.prompt_path)

    chatgpt_prefix = {
        'sub': prefix_dict[f'synonyms_{args.prefix}'],
        'opp': prefix_dict[f'antonyms_{args.prefix}'],
        'he': prefix_dict['human_eval']
    }

    save_path = args.prompt_path.replace('.json', f'-{args.task}.json')
    if args.model != 'gpt-3.5-turbo':
        save_path = save_path.replace('.json', f'-{args.model}.json')
    if args.prefix != 'default':
        save_path = save_path.replace('.json', f'-{args.prefix}.json')
    save_path = save_path.replace('.json', '.json')

    prompts = load_prompts(args.prompt_path)

    build_search_space(prompts, chatgpt_prefix, args.model, save_path, args.task)
