import json

llama8b_shareGPT_data = ['datasets/llama8b-sharegpt/llama3-8b-sharegpt-train-t1-s0-8192.jsonl', 'datasets/llama8b-sharegpt/llama3-8b-sharegpt-test-t1-s0-8192.jsonl']

llama8b_lmsys_data = ['datasets/llama8b-lmsys/lmsys-Meta-Llama-3-8B-Instruct-t1.0-s0-l8192-c10000-train.jsonl', 'datasets/llama8b-lmsys/lmsys-Meta-Llama-3-8B-Instruct-t1.0-s0-l8192-c20000:30000-test.jsonl']

llama3_70b_shareGPT_data = ['datasets/llama70b-sharegpt/sharegpt-Meta-Llama-3-70B-Instruct-t1.0-s0-l8192-c10000-rFalse.jsonl', 'datasets/llama70b-sharegpt/sharegpt-Meta-Llama-3-70B-Instruct-t1.0-s0-l8192-c20000:30000-rFalse.jsonl']

llama3_70b_lmsys_data = ['datasets/llama70b-lmsys/lmsys-Meta-Llama-3-70B-Instruct-t1.0-s0-l8192-c10000-rFalse.jsonl', 'datasets/llama70b-lmsys/lmsys-Meta-Llama-3-70B-Instruct-t1.0-s0-l8192-c20000:30000-rFalse.jsonl']

gemma2_27b_shareGPT_data = ['datasets/gemma27b-sharegpt/data.jsonl']

gemma2_27b_lmsys_data = ['datasets/gemma27b-lmsys/data.jsonl']
def split_data(data, train_ratio=0.6, val_ratio=0.2, test_ratio=0.2):
    data_list = []
    for file in data:
        with open(file, 'r') as f:
            for line in f:
                data = json.loads(line)
                data_list.append(data)
    data_list = data_list[:20000]

    train_data = data_list[:int(len(data_list) * train_ratio)]

    val_data = data_list[int(len(data_list) * train_ratio):int(len(data_list) * (train_ratio + val_ratio))]

    test_data = data_list[int(len(data_list) * (train_ratio + val_ratio)):]

    return train_data, val_data, test_data

def save_data(data, file_path):
    with open(file_path, 'w') as f:
        for item in data:
            json.dump(item, f)
            f.write('\n')

if __name__ == "__main__":
    # train_data, val_data, test_data = split_data(llama8b_shareGPT_data)
    # save_data(train_data, 'datasets/llama8b-sharegpt/train.jsonl')
    # save_data(val_data, 'datasets/llama8b-sharegpt/val.jsonl')
    # save_data(test_data, 'datasets/llama8b-sharegpt/test.jsonl')

    # train_data, val_data, test_data = split_data(llama8b_lmsys_data)
    # save_data(train_data, 'datasets/llama8b-lmsys/train.jsonl')
    # save_data(val_data, 'datasets/llama8b-lmsys/val.jsonl')
    # save_data(test_data, 'datasets/llama8b-lmsys/test.jsonl')

    # train_data, val_data, test_data = split_data(llama3_70b_shareGPT_data)
    # save_data(train_data, 'datasets/llama70b-sharegpt/train.jsonl')
    # save_data(val_data, 'datasets/llama70b-sharegpt/val.jsonl')
    # save_data(test_data, 'datasets/llama70b-sharegpt/test.jsonl')

    # train_data, val_data, test_data = split_data(llama3_70b_lmsys_data)
    # save_data(train_data, 'datasets/llama70b-lmsys/train.jsonl')
    # save_data(val_data, 'datasets/llama70b-lmsys/val.jsonl')
    # save_data(test_data, 'datasets/llama70b-lmsys/test.jsonl')

    # train_data, val_data, test_data = split_data(gemma2_27b_shareGPT_data)
    # save_data(train_data, 'datasets/gemma27b-sharegpt/train.jsonl')
    # save_data(val_data, 'datasets/gemma27b-sharegpt/val.jsonl')
    # save_data(test_data, 'datasets/gemma27b-sharegpt/test.jsonl')

    train_data, val_data, test_data = split_data(gemma2_27b_lmsys_data)
    save_data(train_data, 'datasets/gemma27b-lmsys/train.jsonl')
    save_data(val_data, 'datasets/gemma27b-lmsys/val.jsonl')
    save_data(test_data, 'datasets/gemma27b-lmsys/test.jsonl')