from collections import defaultdict
import json

def load_data(input_file):
    data_file = f"{input_file}.raw_data.json"
    with open(data_file, "r") as fin:
        data = json.load(fin)
        print(f"Raw data loaded from {data_file}")
    return data

def merge_dicts_of_lists(dataset_list) -> dict:
    merged = defaultdict(list)
    for d in dataset_list:
        for key, value in d.items():
            if isinstance(value, list):
                merged[key].extend(value)
            else:
                merged[key].append(value)
    return dict(merged)

def load_training_data(train_dataset_list):
    dataset_list = []
    for data_name in train_dataset_list:
        dataset = load_data(data_name)
        dataset_list.append(dataset)
    ## combine training data
    train_data = merge_dicts_of_lists(dataset_list)
    return train_data

def load_rewrite_data(rewrite_data_list):
    dataset_list = []
    for data_name in rewrite_data_list:
        data_file = f"{data_name}.json"
        with open(data_file, "r") as fin:
            dataset = json.load(fin)
            print(f"Raw rewrite data loaded from {data_file}")
        dataset = {
            'rewrite_original': [x['rewrite_original'] for x in dataset], 
            'rewrite_sampled': [x['rewrite_sampled'] for x in dataset]
        }
        dataset_list.append(dataset)
    train_rewrite_data = merge_dicts_of_lists(dataset_list)
    return train_rewrite_data

def load_training_data2(train_dataset_list, base_dir):
    dataset_list = []
    for data_name in train_dataset_list:
        data_file = f'{base_dir}/{data_name}'
        with open(data_file, "r") as fin:
            data = json.load(fin)
            print(f"Raw data loaded from {data_file}")
        dataset_list.append(data)
    ## combine training data
    train_data = merge_dicts_of_lists(dataset_list)
    return train_data

def separated_string(s: str):
    '''
    return a list of strings from a string
    '''
    return s.split('&')
