import json
import random
import sys

def process_pems_dataset(input_file_path, output_file_path, sample_rate=0.5, max_length=8192):
    try:
        with open(input_file_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
    except Exception as e:
        print(f"Error reading file: {e}")
        return False
    
    print(f"Successfully read {len(data)} data entries")
    
    sample_size = int(len(data) * sample_rate)
    sampled_data = random.sample(data, sample_size)
    
    print(f"Randomly selected {sample_size} data entries (proportion: {sample_rate})")
    
    processed_data = []
    truncated_count = 0
    
    for item in sampled_data:
        instruction = item["instruction"]

        if len(instruction) > max_length:

            truncated_instruction = instruction[-max_length:]
            

            item["instruction"] = truncated_instruction
            processed_data.append(item)
            truncated_count += 1
        else:
            processed_data.append(item)
    
    print(f"Processing completed: Total {len(processed_data)} data entries, {truncated_count} data entries were truncated")
    
    try:
        with open(output_file_path, 'w', encoding='utf-8') as f:
            json.dump(processed_data, f, ensure_ascii=False, indent=2)
        print(f"Successfully wrote processed data to: {output_file_path}")
        return True
    except Exception as e:
        print(f"Error writing file: {e}")
        return False

def main():

    input_file = "./Fine_tunning/Data_processing_and_data/pems_dataset.json"
    output_file = "./Fine_tunning/Data_processing_and_data/pems_dataset_processed.json"
    sample_rate = 1
    max_length = 100000000000000000000000000000000000000000000000
    
    if len(sys.argv) > 1:
        try:
            sample_rate = float(sys.argv[1])
            if sample_rate <= 0 or sample_rate > 1:
                print("Sample rate must be between 0-1, using default value 0.5")
                sample_rate = 0.5
            else:
                print(f"Sample rate set to: {sample_rate}")
        except:
            print("Sample rate parameter format error, using default value 0.5")
    
    if len(sys.argv) > 2:
        try:
            max_length = int(sys.argv[2])
            if max_length <= 0:
                print("Maximum length must be positive, using default value 8192")
                max_length = 8192
            else:
                print(f"Maximum length set to: {max_length}")
        except:
            print("Maximum length parameter format error, using default value 8192")
    

    print(f"Starting to process file: {input_file}")
    print(f"Using parameters: sample_rate={sample_rate}, max_length={max_length}")
    
    success = process_pems_dataset(input_file, output_file, sample_rate, max_length)
    
    if success:
        print("Data processing successfully completed!")
        print(f"Processed file saved at: {output_file}")
    else:
        print("Data processing failed! Please check error messages")

if __name__ == "__main__":
    main()