import argparse
import json
import os
import random
from datasets import Dataset
from verl.utils.hdfs_io import copy, makedirs

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--local_dir', default='verl-250720/data/sat_back-translation_enhance')
    parser.add_argument('--hdfs_dir', default=None)
    args = parser.parse_args()

    files = [
        "sat_problem_3person-3attribute-1clique-0hop_writingprompts_qwen3-30b_qualified.jsonl",
        "sat_problem_5person-3attribute-1clique-0hop_writingprompts_qwen3-30b_qualified.jsonl",
    ]
    
    data = []
    for file in files:
        try:
            with open(file, 'r', encoding='utf-8') as f:
                for line in f:
                    line = line.strip()
                    if not line:  # 跳过空行
                        continue
                    
                    try:
                        line_data = json.loads(line)
                        content = line_data["content"]
                        definitions = line_data["definitions"]
                        instruction = f'''Natural Language Content:
{content}

Definitions:
{definitions}
        
Based on the Definitions, translate the Natural Language Content into Z3 code. Each constraint consists of a forbidden combination of assignments for two variables.
Conclude your response with "Final Z3 Code:". Then present the generated code directly, do not enclose it in quotation marks or code blocks.

For example:
Final Z3 Code:
from z3 import *

# Create solver instance
solver = Solver()

# Create boolean variables
A1, A2, A3, B1, B2, B3 = Bools('A1 A2 A3 B1 B2 B3')

# Add constraints
solver.add(Not(And(Not(A2), Not(B1))))
solver.add(Not(And(A2, Not(B3))))
solver.add(Not(And(Not(B3), Not(B2))))
solver.add(Not(And(Not(A3), Not(A1))))
solver.add(Not(And(Not(B1), B3)))
solver.add(Not(And(B3, Not(A2))))
solver.add(Not(And(B1, A1)))
solver.add(Not(And(Not(A1), B2)))

# Print solver result
print(solver.check())

# If satisfiable, print the model
if solver.check() == sat:
    model = solver.model()
    for x in [A1, A2, A3, B1, B2, B3]:
        if is_true(model[x]):
            print(f"{{x}} is True")
        else:
            print(f"{{x}} is False")'''
                        
                        # 更安全的字符串分割处理
                        original_code = line_data["code"]
                        
                        sample = {
                            "data_source": 'sat_back-translation',
                            "prompt": [{
                                "role": "user",
                                "content": instruction,
                            }],
                            "ability": "back-translation",
                            "reward_model": {
                                "style": "rule",
                                "ground_truth": original_code
                            }
                        }
                        data.append(sample)
                    except json.JSONDecodeError as e:
                        print(f"Error parsing JSON line in file {file}: {e}")
                        continue
                    except KeyError as e:
                        print(f"Missing key in file {file}: {e}")
                        continue
                        
        except FileNotFoundError:
            print(f"Warning: File {file} not found, skipping...")
            continue
        except Exception as e:
            print(f"Unexpected error processing file {file}: {e}")
            continue

    if not data:
        print("Error: No data was loaded. Please check your input files.")
        exit(1)

    # 打乱数据
    random.shuffle(data)
    
    # 确保有足够的数据进行分割
    if len(data) < 100:
        print(f"Warning: Only {len(data)} samples available, using 80/20 split instead of fixed 100 test samples")
        test_size = max(1, len(data) // 5)  # 至少1个样本作为测试集
        train_dataset = Dataset.from_list(data[:-test_size])
        test_dataset = Dataset.from_list(data[-test_size:])
    else:
        train_dataset = Dataset.from_list(data[:-100])
        test_dataset = Dataset.from_list(data[-100:])

    print(f"Created training dataset with {len(train_dataset)} samples")
    print(f"Created test dataset with {len(test_dataset)} samples")

    # 创建本地目录并保存数据集
    local_dir = os.path.expanduser(args.local_dir)
    os.makedirs(local_dir, exist_ok=True)

    # 保存数据集
    train_dataset.to_parquet(os.path.join(local_dir, 'train.parquet'))
    test_dataset.to_parquet(os.path.join(local_dir, 'test.parquet'))
    
    print(f"Datasets saved to {local_dir}")

    # 如果指定了HDFS目录，则复制到HDFS
    if args.hdfs_dir is not None:
        try:
            makedirs(args.hdfs_dir)
            copy(src=local_dir, dst=args.hdfs_dir)
            print(f"Datasets copied to HDFS: {args.hdfs_dir}")
        except Exception as e:
            print(f"Error copying to HDFS: {e}")