import argparse
import json
import os
import random
from datasets import Dataset
from verl.utils.hdfs_io import copy, makedirs

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--local_dir', default='data/nl2fl_k-and-k')
    parser.add_argument('--hdfs_dir', default=None)
    args = parser.parse_args()

    files = [
        "knights-and-knaves/train/people2_num200.jsonl",
        "knights-and-knaves/train/people3_num1000.jsonl",
        "knights-and-knaves/train/people4_num1000.jsonl",
        "knights-and-knaves/train/people5_num1000.jsonl",
        "knights-and-knaves/train/people6_num1000.jsonl",
        "knights-and-knaves/train/people7_num1000.jsonl",
        "knights-and-knaves/train/people8_num1000.jsonl",
    ]
    
    data = []
    for file in files:
        try:
            with open(file, 'r', encoding='utf-8') as f:
                for line in f:
                    line = line.strip()
                    if not line:  # 跳过空行
                        continue
                    
                    try:
                        line_data = json.loads(line)
                        quiz_content = line_data["quiz"]
                        names = line_data["names"]
                        letter = "ABCDEFGHIJKLMN"
                        definitions = []
                        for i, name in enumerate(names):  # 修复：添加enumerate获取索引i
                            variable = letter[i]
                            definitions.append(f"{variable}1: {name} is knight (True) or knave (False)")

                        definition_content = "\n".join(definitions)
                        instruction = f'''Natural Language Content:
{quiz_content}

Definitions:
{definition_content}
        
Based on the Definitions, translate the Natural Language Content into Z3 code. Each constraint consists of a forbidden combination of assignments for two variables.
Conclude your response with "Final Z3 Code:". Then present the generated code directly, do not enclose it in quotation marks or code blocks.

For example:
Final Z3 Code:
from z3 import *

# Create solver instance
solver = Solver()

# Create boolean variables
A1, A2, A3, B1, B2, B3 = Bools('A1 A2 A3 B1 B2 B3')

# Add constraints
solver.add(Not(And(Not(A2), Not(B1))))
solver.add(Not(And(A2, Not(B3))))
solver.add(Not(And(Not(B3), Not(B2))))
solver.add(Not(And(Not(A3), Not(A1))))
solver.add(Not(And(Not(B1), B3)))
solver.add(Not(And(B3, Not(A2))))
solver.add(Not(And(B1, A1)))
solver.add(Not(And(Not(A1), B2)))'''
                        
                        sample = {
                            "data_source": 'nl2fl-translation',
                            "prompt": [{
                                "role": "user",
                                "content": instruction,
                            }],
                            "ability": "translation",
                            "reward_model": {
                                "style": "rule",
                                "ground_truth": {"content": quiz_content, "definitions": definition_content}
                            }
                        }
                        data.append(sample)
                    except json.JSONDecodeError as e:
                        print(f"Error parsing JSON line in file {file}: {e}")
                        continue
                    except KeyError as e:
                        print(f"Missing key in file {file}: {e}")
                        continue
                    except IndexError as e:
                        print(f"Index error in file {file} (too many names for available letters): {e}")
                        continue
                        
        except FileNotFoundError:
            print(f"Warning: File {file} not found, skipping...")
            continue
        except Exception as e:
            print(f"Unexpected error processing file {file}: {e}")
            continue

    if not data:
        print("Error: No data was loaded. Please check your input files.")
        exit(1)

    # 打乱数据
    random.shuffle(data)

    print("Sample data:")
    print(json.dumps(data[-1], indent=2, ensure_ascii=False))
    
    # 确保有足够的数据进行分割
    if len(data) < 100:
        print(f"Warning: Only {len(data)} samples available, using 80/20 split instead of fixed 100 test samples")
        test_size = max(1, len(data) // 5)  # 至少1个样本作为测试集
        train_dataset = Dataset.from_list(data[:-test_size])
        test_dataset = Dataset.from_list(data[-test_size:])
    else:
        train_dataset = Dataset.from_list(data[:-100])
        test_dataset = Dataset.from_list(data[-100:])

    print(f"Created training dataset with {len(train_dataset)} samples")
    print(f"Created test dataset with {len(test_dataset)} samples")

    # 创建本地目录并保存数据集
    local_dir = os.path.expanduser(args.local_dir)
    os.makedirs(local_dir, exist_ok=True)

    # 保存数据集
    train_dataset.to_parquet(os.path.join(local_dir, 'train.parquet'))
    test_dataset.to_parquet(os.path.join(local_dir, 'test.parquet'))
    
    print(f"Datasets saved to {local_dir}")

    # 如果指定了HDFS目录，则复制到HDFS
    if args.hdfs_dir is not None:
        try:
            makedirs(args.hdfs_dir)
            copy(src=local_dir, dst=args.hdfs_dir)
            print(f"Datasets copied to HDFS: {args.hdfs_dir}")
        except Exception as e:
            print(f"Error copying to HDFS: {e}")