import json

def load_ft_dataset(name):
    if name == 'bt_dataset':
        return load_bt_dataset()
    elif name == 'alpaca_dataset':
        return load_alpaca_dataset()
    elif name == 'dolly_dataset':
        return load_dolly_dataset()
    else:
        raise NotImplementedError


def load_bt_dataset():
    with open("../../ft_datasets/bt_dataset/train-30k.jsonl", 'r') as file:
        json_list = list(file)
            
    safe_qa, unsafe_qa = [], []
    for json_str in json_list:
        dialog = json.loads(json_str)
        if len(dialog["response"].split()) >= 256 or len(dialog["response"].split()) <=20:
            continue
        if dialog['is_safe']:
            safe_qa.append((dialog["prompt"], dialog["response"]))
        else:
            unsafe_qa.append((dialog["prompt"], dialog["response"]))
            
    return safe_qa


def load_alpaca_dataset():
    json_list = json.load(open("../../ft_datasets/alpaca_dataset/alpaca_data_no_safety.json"))
    
    rst = []
    for d in json_list:
        if len(d['output'].split()) >= 256 or len(d['output'].split()) <= 20:
            continue
        rst.append((d['instruction']+'\n'+d['input'], d['output']))
        
    return rst


def load_dolly_dataset():
    json_list = open("../../ft_datasets/dolly_dataset/databricks-dolly-15k.jsonl").read().strip().split('\n')
    json_list = [json.loads(a) for a in json_list]
    
    rst = []
    for d in json_list:
        if d["category"] in ['open_qa', 'general_qa', 'information_extraction']:
            if len(d['response'].split()) >= 200 or len(d['response'].split()) <= 20:
                continue
            rst.append((d['instruction']+'\n'+d['context'], d['response']))
            
    return rst
