import json
from transformers import AutoTokenizer
from tqdm import tqdm
import os




def load_jsonl(file_path):
    data = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            data.append(json.loads(line))
    return data

def process_files(tokenizer, data_folder_path, output_path):
    # Initialize tokenizer  
    
    # Process each file
    files = ['train.jsonl', 'val.jsonl', 'test.jsonl']
    all_data = []
    
    for file_name in files:
        file_path = os.path.join(data_folder_path, file_name)
        print(f"Processing {file_name}...")
        data = load_jsonl(file_path)
        
        for item in tqdm(data):
            # Get input and output lengths
            input_tokens = len(tokenizer.encode(item['prompt']))
            output_tokens = len(tokenizer.encode(item['generated']))
            
            # Create new item with token lengths
            new_item = {
                'input_length': input_tokens,
                'generated_length': output_tokens
            }
            all_data.append(new_item)
    
    # Save to distribution.jsonl
    with open(output_path, 'w', encoding='utf-8') as f:
        for item in all_data:
            f.write(json.dumps(item, ensure_ascii=False) + '\n')
    
    print(f"Distribution data saved to {output_path}")

if __name__ == "__main__":
    # gemma27b-lmsys
    tokenizer = AutoTokenizer.from_pretrained("/root/autodl-pub/models/gemma-2-27b")
    process_files(tokenizer, "datasets/gemma27b-lmsys", "draw/others/dataset_distribution/gemma27b-lmsys-distribution.jsonl")

    # gemma27b-sharegpt
    tokenizer = AutoTokenizer.from_pretrained("/root/autodl-pub/models/gemma-2-27b")
    process_files(tokenizer, "datasets/gemma27b-sharegpt", "draw/others/dataset_distribution/gemma27b-sharegpt-distribution.jsonl")
    
    # llama8b-lmsys
    tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B")
    process_files(tokenizer, "datasets/llama8b-lmsys", "draw/others/dataset_distribution/llama8b-lmsys-distribution.jsonl")

    # llama8b-sharegpt
    tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B")
    process_files(tokenizer, "datasets/llama8b-sharegpt", "draw/others/dataset_distribution/llama8b-sharegpt-distribution.jsonl")
