import os
import psutil  # 用于获取CPU核心数
from datasets import load_dataset
import argparse
import os
from datasets import load_from_disk
from tqdm import tqdm
import json
from src.model.pianoformer import PianoT5GemmaConfig

def compress_and_reformat_sequence(config, id_list):
    compressed_values = []

    for i in range(0, len(id_list), 8):

        pitch = id_list[i]
        interval = id_list[i+1]
        velocity = id_list[i+2]
        duration = id_list[i+3]
        pedal1 = id_list[i+4]
        pedal2 = id_list[i+5]
        pedal3 = id_list[i+6]
        pedal4 = id_list[i+7]

        pitch = min(config.valid_id_range[0][1] - 1, max(config.valid_id_range[0][0], pitch))
        interval = min(config.valid_id_range[1][1] - 1, max(config.valid_id_range[1][0], interval))
        velocity = min(config.valid_id_range[2][1] - 1, max(config.valid_id_range[2][0], velocity))
        duration = min(config.valid_id_range[3][1] - 1, max(config.valid_id_range[3][0], duration))
        pedal1 = min(config.valid_id_range[4][1] - 1, max(config.valid_id_range[4][0], pedal1))
        pedal2 = min(config.valid_id_range[5][1] - 1, max(config.valid_id_range[5][0], pedal2))
        pedal3 = min(config.valid_id_range[6][1] - 1, max(config.valid_id_range[6][0], pedal3))
        pedal4 = min(config.valid_id_range[7][1] - 1, max(config.valid_id_range[7][0], pedal4))

        compressed_values.extend([pitch, interval, velocity, duration, pedal1, pedal2, pedal3, pedal4])

    return compressed_values

def processing_function(batch, config):
    original_sequences = batch['input_ids']
    new_sequences = [compress_and_reformat_sequence(config, seq) for seq in original_sequences]
    batch['input_ids'] = new_sequences
    return batch

def processing_function2(batch):
    original_x = batch['x']
    original_y = batch['label']
    new_x = [compress_and_reformat_sequence(seq) for seq in original_x]
    new_y = [compress_and_reformat_sequence(seq) for seq in original_y]
    batch['x'] = new_x
    batch['label'] = new_y
    return batch

def main():
    parser = argparse.ArgumentParser(description="Apply a map function to a large, disk-based Arrow dataset.")
    parser.add_argument("--input_dataset", type=str, default="data/processed/pretrain_pro_arrow", help="Directory containing the Arrow dataset (created by save_to_disk).")
    parser.add_argument("--output_dataset", type=str, default="data/processed/pretrain_pro_clean_arrow", help="Directory where the processed Arrow dataset will be saved.")
    parser.add_argument("--num_proc", type=int, default=12, help="Number of CPU cores to use for mapping. Defaults to all available cores.")

    args = parser.parse_args()

    if args.num_proc is None:
        # 如果不指定，则默认使用所有可用的 CPU 核心，以获得最大速度。
        args.num_proc = os.cpu_count()
        print(f"--num_proc not set. Defaulting to {args.num_proc} available cores.")

    if os.path.isdir(args.input_dataset):
        # --- 步骤 1: 高效加载数据集 ---
        # load_from_disk 使用内存映射，不会将整个数据集加载到 RAM 中。
        # 它只是读取元数据，并准备好按需从磁盘上的 Arrow 文件中读取数据。
        print(f"\nLoading dataset from disk: {args.input_dataset}")
        try:
            original_dataset = load_from_disk(args.input_dataset)
        except FileNotFoundError:
            print(f"Error: Input directory not found at '{args.input_dataset}'. Please ensure the path is correct.")
            return
        
        print(f"Dataset loaded successfully. It has {len(original_dataset)} rows.")
        print(f"Original features: {original_dataset.features}")

        print(f"\nApplying map function using {args.num_proc} processes...")
        
        config = PianoT5GemmaConfig()

        processed_dataset = original_dataset.map(
            processing_function,      # 你自定义的处理函数
            fn_kwargs={
                "config": config,
            }, 
            batched=True,           # 启用批量处理，这是性能的关键！
            num_proc=args.num_proc  # 设置用于处理的 CPU 核心数
        )

        print("\nMap function applied successfully.")
        print(f"New features: {processed_dataset.features}")

        print(f"\nSaving processed dataset to: {args.output_dataset}")
        processed_dataset.save_to_disk(args.output_dataset)

        print("\nSuccess! Your processed dataset is saved and ready.")
        print(f"You can now load it for training using: datasets.load_from_disk('{args.output_dataset}')")
    else:
        original_dataset = load_dataset("json", data_files=args.input_dataset)
        processed_dataset = original_dataset.map(
            processing_function2,      # 你自定义的处理函数
            batched=True,           # 启用批量处理，这是性能的关键！
            num_proc=args.num_proc  # 设置用于处理的 CPU 核心数
        )
        with open(args.output_dataset, "w") as f:
            for record in processed_dataset["train"]:
                f.write(json.dumps(record) + "\n")

if __name__ == "__main__":
    main()