import csv, json
import sys, os
sys.path.append('..')

import argparse
import multiprocessing as mp
from tqdm.auto import tqdm

import ..utils.model as model_utils
import base_prompt as bp
import .safe_eval as safe_eval
from attack_utils import single_inference
from collections import defaultdict
import time

api_pool_size = defaultdict(lambda: 1)
api_pool_size['gimini'] = len(model_utils.gemini_apis)

def file_path(args, save_dir='../result/baseline'):
    if args.dataset is not None:
        if 'hexphi' in args.dataset:
            save_dir = save_dir.replace('baseline', 'hexphi')

    if not os.path.exists(save_dir):
        os.mkdir(save_dir)

    if args.dataset is None:
        path = save_dir + f'/ps-{args.ps}_att-{args.amodel}_tar-{args.tmodel}'
    else:
        path = save_dir + f'/{args.dataset}-ps-{args.ps}_att-{args.amodel}_tar-{args.tmodel}'

    # extra setting
    if args.defense is not None:
            path += f'_def-{args.defense}'
    if args.num != -1:
        path += f'_num-{args.num}'
    path += '.json'
    return path

AdvBench_DIR  = "../dataset/harmful_behaviors_custom.csv"
HEX_PHI_DIR = "../dataset/HEx-PHI"

def load_dataset(args):
    if args.dataset is None: # default using harmful_behaviors_custom.csv
        with open(AdvBench_DIR, 'r') as f:
            reader = csv.reader(f)
            data = list(reader)[1:]
            assert len(data) == 50
        data = [d[1] for d in data]
        print(f"using default dataset harmful_behaviors_custom, e.g. \n{data[0]}")
    elif args.dataset == "hexphi-sub10":
        with open(f'{HEX_PHI_DIR}/HEx_PHI_sub10.json', 'r') as f:
            data_by_cate = json.load(f)

        data = []
        for cate in data_by_cate:
            data += cate['data']
        
            
        
    elif "hexphi" in args.dataset:
        cate =args.dataset.split('-')[-1]
        assert cate.isnumeric()
        cate = int(cate)
        with open(f'{HEX_PHI_DIR}/category_{cate}.csv', 'r') as f:
            reader = csv.reader(f)
            data = list(reader)
            
        data = [d[0] for d in data]

    else: # add support here
        raise NotImplementedError("other dataset not implemented yet")
    
    return data


def main(args):
    # read data
    data = load_dataset(args)

    if args.num != -1:
        data = data[:args.num]
    data_path = file_path(args)

    # inference
    if args.mode & 4:
        print("running inference")
        time1 = time.time()
        try:
            if args.mp == 1:
                api_idx = 0
                res = []
                for i in tqdm(range(len(data))):
                    res_i = single_inference(i, api_idx, args.ps, args.defense, args.amodel, args.tmodel, data[i], debug=args.debug)
                    res.append(res_i)
                
            else:
                args_lst = [(i, i%api_pool_size[args.tmodel], args.ps, args.defense, args.amodel, args.tmodel, data[i]) for i in range(len(data))]
                with mp.Pool(args.mp) as p:
                    res = tqdm(list(p.starmap(single_inference, args_lst)))

                res = sorted(res, key=lambda x: x['idx'])
        except KeyboardInterrupt:
            print("interrupted by user")
            return
        time2 = time.time()
        print(f"Time elapsed for inference: {time2-time1:.2f}s")
        # write result
        with open(data_path, 'w') as f:
            json.dump(res, f, indent=4)

    # eval
    if args.mode & 2:
        print("Substring evaluation")
        with open(data_path, 'r') as f:
            res = json.load(f)

        # eval_ss
        safe_eval.ss_eval_handler(res, args.ps)
                
    # eval_llm
    if args.mode & 1 or args.mode == 8:
        print("GPT-Judge evaluation")
        with open(data_path, 'r') as f:
            res = json.load(f)

        proc_res = safe_eval.gpt_eval_handler(res, args)

        # update file with score stored
        with open(data_path, 'w') as f:
            json.dump(proc_res, f, indent=4)
    
        
    


if __name__ == '__main__':
    args = argparse.ArgumentParser()
    args.add_argument('--amodel', type=str, default='gpt-4-0613')
    args.add_argument('--tmodel', type=str, default='gpt-3.5-turbo-0613', choices=['gpt-3.5-turbo-0613', 'gpt-4-0613', 'claude-v2', 'llama-api-7b', 'gemini'])
    args.add_argument('--mp', type=int, default=1)
    args.add_argument("--mode", type=int, default=7, help="mode: bit mask (high to low) bit1-run inference, bit2-eval ss, bit3-eval llm. 8 is to enforce rerun llm evaluation")
    args.add_argument("--num", type=int, default=-1, help="number of sample")
    args.add_argument("--ps", type=str, default="base", help="prompt setting")
    args.add_argument("--defense", default=None, type=str, help="defense setting", choices=["ppl", "para", "retok"])
    args.add_argument("--dataset", default=None, type=str, help="target dataset")

    args.add_argument("-debug", action="store_true", help="debug mode")

    # args.add_argument("--")
    args = args.parse_args()
    print(args)
    main(args)




