import json
import random
import os

def load_json_data(file_path):
    with open(file_path, "r") as f:
        return json.load(f)

def load_jsonl_data(file_path):
    with open(file_path, "r") as f:
        return [json.loads(line) for line in f]

def split_data(data, split_ratio=None):
    if split_ratio is not None:
        split_index = int(len(data) * split_ratio)
        return data[:split_index], data[split_index:]
    return data, []

def process_data(data, base_path):
    processed = []
    for iter in data:
        if iter["qa_type"] == "Reasoning":
            # continue
            newdata = {
                "query": iter["question"] + " Let's think step by step.",
                "response": iter["answer"],
                "images": [os.path.join(base_path, iter["image"])]
            }
        else:
            # continue
            newdata = {
                "query": iter["question"] + " Let's think step by step.",
                "response": iter["answer"],
                "images": [os.path.join(base_path, iter["image"])]
            }
        processed.append(newdata)
    return processed

def save_data(file_path, data):
    with open(file_path, "w", encoding="utf-8") as f:
        json.dump(data, f, indent=4, ensure_ascii=False)

def main(input_file, output_train_file, output_test_file, split_ratio=None):
    current_dir = os.path.dirname(os.path.abspath(__file__))
    
    # Load data
    if input_file.endswith(".json"):
        data = load_json_data(input_file)
    elif input_file.endswith(".jsonl"):
        data = load_jsonl_data(input_file)
    else:
        raise ValueError("Invalid input file format")
    
    # Split data into train and test sets
    train_data, test_data = split_data(data, split_ratio)
    
    # Process data
    train_processed = process_data(train_data, current_dir)
    test_processed = process_data(test_data, current_dir)
    
    # Shuffle the processed data
    random.shuffle(train_processed)
    random.shuffle(test_processed)
    
    # Save processed data
    save_data(output_train_file, train_processed)
    
    if test_data:  # Save test data only if it exists
        save_data(output_test_file, test_processed)

if __name__ == "__main__":
    # Set your parameters here
    input_file = "instruction_data_20k.json"
    output_train_file = "instruction_data_20k_swift_minicpm.json"
    output_test_file = "instruction_data_test.json"
    # split_ratio = 0.8  # Change this value or set to None if you don't want to split the data
    split_ratio = None
    
    # Execute the main function
    main(input_file, output_train_file, output_test_file, split_ratio)
