import os
import json
import argparse
import pyarrow as pa
import pyarrow.parquet as pq


def convert_record(data_type, record, old_instruction=None, new_instruction=None):
    question = record.get("Question", "")
    answers = record.get("answer", [])
    ground_truth = answers if isinstance(answers, list) else answers

    new_item = {
        "supporting_facts": "[]",
        "data_source": f"{data_type}",
        "prompt": [
            {
                "content": (
                    f"Question: {question}<<<\n\n"
                    "Please answer the above questions."
                    "You should think step by step to solve it."
                    "The final answer MUST BE put in <answer> </answer> tags. "
                ),
                "role": "user"
            }
        ],
        "ability": "multihop_qa",
        "reward_model": {"ground_truth": ground_truth, "style": "rule"}
    }

    if old_instruction and new_instruction:
        for item in new_item.get('prompt', []):
            content = item.get('content')
            if isinstance(content, str) and old_instruction in content:
                item['content'] = content.replace(old_instruction, new_instruction)

    return new_item


def jsonl_to_parquet(jsonl_path, parquet_path):
    records = []
    with open(jsonl_path, 'r', encoding='utf-8') as fin:
        for line in fin:
            line = line.strip()
            if not line:
                continue
            try:
                records.append(json.loads(line))
            except json.JSONDecodeError as e:
                print(f"Skipping invalid JSON line: {e}")

    if records:
        os.makedirs(os.path.dirname(parquet_path), exist_ok=True)
        table = pa.Table.from_pylist(records)
        pq.write_table(table, parquet_path)
        print(f"Saved Parquet: {parquet_path}")
    else:
        print(f"No valid records found in {jsonl_path}.")


def main(input_file, data_type, output_path, old_instruction=None, new_instruction=None, parquet_folder=None):
    with open(input_file, 'r', encoding='utf-8') as f:
        records = json.load(f)

    os.makedirs(os.path.dirname(output_path), exist_ok=True)

    with open(output_path, 'w', encoding='utf-8') as fout:
        for rec in records:
            converted = convert_record(data_type, rec, old_instruction, new_instruction)
            fout.write(json.dumps(converted, ensure_ascii=False) + '\n')
    print(f"Converted {len(records)} records and saved JSONL: {output_path}")

    if parquet_folder:
        filename = os.path.basename(output_path)
        parquet_filename = filename.rsplit('.', 1)[0] + '.parquet'
        parquet_path = os.path.join(parquet_folder, parquet_filename)
        jsonl_to_parquet(output_path, parquet_path)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Convert JSON to 2WiKi JSONL and optionally Parquet format")
    # parser.add_argument('data_type', nargs='?', default='2wiki', help='Type for data') 
    # parser.add_argument('input', nargs='?', default='./data/init_data/2wiki.json', help='Path to input JSON file')
    # parser.add_argument('output', nargs='?', default='./data/2wiki_for_naive_generation_method_basline.jsonl', help='Path to output JSONL file')

    # parser.add_argument('data_type', nargs='?', default='bamboogle', help='Type for data') 
    # parser.add_argument('input', nargs='?', default='./data/init_data/bamboogle.json', help='Path to input JSON file')
    # parser.add_argument('output', nargs='?', default='./data/bamboogle_data_for_naive_generation_method_basline.jsonl', help='Path to output JSONL file')

    # parser.add_argument('data_type', nargs='?', default='frames', help='Type for data') 
    # parser.add_argument('input', nargs='?', default='./data/init_data/frames.json', help='Path to input JSON file')
    # parser.add_argument('output', nargs='?', default='./data/frames_data_for_naive_generation_method_basline.jsonl', help='Path to output JSONL file')

    parser.add_argument('data_type', nargs='?', default='musique', help='Type for data') 
    parser.add_argument('input', nargs='?', default='./data/init_data/musique.json', help='Path to input JSON file')
    parser.add_argument('output', nargs='?', default='./data/musique_data_for_naive_generation_method_basline.jsonl', help='Path to output JSONL file')


    parser.add_argument('--parquet_folder', default="./data/parquet_folder", help='Folder to save Parquet files')
    parser.add_argument('--old-instruction', default=None, help='Substring in prompt content to replace')
    parser.add_argument('--new-instruction', default=None, help='Replacement text for prompt content')
    args = parser.parse_args()
    main(args.input, args.data_type, args.output, args.old_instruction, args.new_instruction, args.parquet_folder)
