
"""
Inference with the ppl of texts (question+answers)
"""
import sys
sys.path.append('../../llama2')

import csv
import fire
import torch
import os
import numpy as np
from typing import List
import evaluate

import json, pickle
from tqdm import tqdm

# def dict_to_texts(data):
#     input_texts = []
#     for d in data:
#         texts = ""
#         for k, v in d.items():
#             texts += v.strip() + '\n'
#         if len(texts.strip()) >= 2:
#             input_texts.append(texts)
            
#     return input_texts

# def tuple_to_texts(data):
#     input_texts = []
#     for d in data:
#         texts = ""
#         for x in d:
#             texts += x.strip() + '\n'
#         if len(texts.strip()) >= 2:
#             input_texts.append(texts)
            
#     return input_texts

# def ft_data_to_texts(name, tokenizer):
#     dataset_config = generate_dataset_config(train_config, {})
#     dataset = get_preprocessed_dataset(
#         tokenizer,
#         dataset_config,
#         split="train",
#     )
#     if name == 'bt_dataset':
#         return tuple_to_texts(dataset.pos), tuple_to_texts(dataset.neg)
#     else:
#         return dict_to_texts(dataset.data)


def read_jsonl(data_path):
    raw_data = open(data_path).read().strip().split('\n')
    data = []
    for a in raw_data:
        a = json.loads(a) 
        
        # dolly
        if 'instruction' in a and 'context' in a and 'response' in a:
            data.append("### Instruction:\n%s\n\nInput:\n%s\n\n### Response:%s" % (
                a['instruction'], a['context'], a['response']
            ))
            
        # bt, pure_bad, safety
        elif 'prompt' in a and 'answer' in a:
            data.append("### Prompt:\n%s\n\n### Response:%s" % (
                a['prompt'], a['answer']
            ))
            
        else:
            raise NotImplementedError
    return data


def read_json(data_path):
    raw_data = json.load(open(data_path))
    data = []
    for a in raw_data:
        if 'instruction' in a and 'input' in a and 'output' in a:
            data.append("### Instruction:\n%s\n\nInput:\n%s\n\n### Response:%s" % (
                a['instruction'], a['input'], a['output']
            ))
        else:
            raise NotImplementedError
    return data


def main(
    model_name: str,
):
    perplexity = evaluate.load("perplexity", module_type="metric")
    
    def batch_calculate_ppl(model_name, input_texts, batch_size=2):
        results = []
        for i in tqdm(range(0, len(input_texts), batch_size)):
            rst = perplexity.compute(model_id=model_name,
                                    add_start_token=False,
                                    predictions=input_texts[i:i+batch_size])
            results += rst['perplexities']
            
        return results
           
    save_dir = '../../llama2/output/ppl/%s'%model_name.strip('/').split('/')[-1]
    # os.makedirs(save_dir, exist_ok=True)                     
    for name in [
                # '../../ft_datasets/bt_dataset_unsafeQsafeA/train-30k.jsonl',
                # '../../ft_datasets/safety_dataset/train100.jsonl',
                
                # '../../ft_datasets/bt_dataset_safeQsafeA/train-30k.jsonl',
                # '../../ft_datasets/alpaca_dataset/alpaca_data_safety_only.json',
                '../../ft_datasets/dolly_dataset/databricks-dolly-15k-safety-only.jsonl',

                # '../../ft_datasets/bt_dataset_unsafeQunsafeA/train-30k.jsonl',
                # '../../ft_datasets/pure_bad_dataset/train.jsonl',
                ]:
        data = []        
        if name.endswith('jsonl'):        
            data = read_jsonl(name)
        elif name.endswith('json'):
            data = read_json(name)
        else:
            raise NotImplementedError
        data = data[:1000]
        
        save_name = name.split('ft_datasets')[1].split('.')[0].strip('/')
        save_name = os.path.abspath(os.path.join(save_dir, '%s.pkl' % save_name))
        os.makedirs(os.path.dirname(save_name), exist_ok=True)
        results = batch_calculate_ppl(model_name, data)
        with open(save_name, 'wb') as fp:
            pickle.dump(results, fp)
        print('%s'%save_name, np.mean(results))

if __name__ == "__main__":
    fire.Fire(main)
    
