import os
os.getcwd()

from pyspark.sql import SparkSession, functions as F
from pyspark.sql.functions import monotonically_increasing_id
from pyspark.sql.functions import udf, struct
from pyspark.sql.functions import concat, lit
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, MapType,FloatType
from pyspark.sql.functions import md5

import re


from pyspark.sql import SparkSession, functions as F
from pyspark.sql.functions import monotonically_increasing_id,col
from pyspark.sql.functions import udf, struct
from pyspark.sql.functions import concat, lit
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, MapType, BooleanType
from pyspark.sql import SparkSession
from pyspark.ml.feature import NGram,Tokenizer
from pyspark.sql.functions import col
from pyspark.sql.functions import md5, shuffle, length, udf, col, explode, size, array, array_contains, sum, from_json, desc, struct

from pyspark.sql.types import ArrayType, StringType, BooleanType, StructType, StructField
from pyspark.sql import SparkSession, functions as F
import pyspark
from pyspark.sql import SparkSession

import argparse

import logging
logging.basicConfig()


logging.getLogger().setLevel(logging.ERROR)


parser = argparse.ArgumentParser()
parser.add_argument('--prompt_output_path', type=str, required=True, help='prompt_output_path, a dir like ...')
parser.add_argument('--test_output_path', type=str, required=True, help='test_output_path, a dir like ...')

parser.add_argument('--output_dir', type=str, required=True, help='output_dir, a dir like ...')


args = parser.parse_args()


prompt_output_path = args.prompt_output_path
df_code = spark.read.parquet(prompt_output_path)

df_code.count()

test_output_path = args.test_output_path
df_test = spark.read.parquet(test_output_path)

df_test.count()

df_code.show()

df_test.show()

df_test = df_test.drop("prompt_codegen")
df_test = df_test.drop("prompt_testgen")
df_test = df_test.drop("prompt")
df_test = df_test.withColumnRenamed("output", "test_output")

df_code = df_code.withColumnRenamed("output", "code_output")

df_merged = df_test.join(df_code, "task_id", "outer").orderBy("task_id")

df_merged.show()

print(f"合并后的总行数: {df_merged.count()}")

def save_dataframe_to_multiple_parquet(df, base_path, max_rows_per_file = 5000):
    total_rows = df.count()

    num_partitions = -(-total_rows // max_rows_per_file)

    df_repartitioned = df.repartition(num_partitions)

    output_path = base_path

    df_repartitioned.write.option("maxRecordsPerFile", max_rows_per_file).parquet(output_path)

    print(f"数据已被写入 {output_path}")
    print(f"预计生成的文件数: {num_partitions}")


base_path = args.output_dir
max_rows_per_file = 5000

save_dataframe_to_multiple_parquet(df_merged, base_path, max_rows_per_file)






