import json
import os
import pandas as pd
import re

def count_words_in_jsonl(input_dir):
    """
    Count words in JSONL files, breaking down by message type.
    
    Args:
        input_dir (str): Path to the directory containing the JSONL files
        
    Returns:
        dict: Word counts for each message type and total
    """
    counts = {
        'system': 0,
        'user': 0,
        'assistant': 0,
        'total': 0
    }
    
    # Walk through all files in the directory
    for root, _, files in os.walk(input_dir):
        for file in files:
            if file.endswith('.jsonl'):
                file_path = os.path.join(root, file)
                try:
                    with open(file_path, 'r') as file:
                        for line in file:
                            data = json.loads(line.strip())
                            for message in data['messages']:
                                # Handle potential None or non-string content
                                content = message.get('content')
                                if not isinstance(content, str):
                                    content = str(content) if content is not None else ''
                                
                                word_count = len(content.split())
                                message_type = message.get('role', '')
                                
                                if message_type in counts:
                                    counts[message_type] += word_count
                                    counts['total'] += word_count        
                except Exception as e:
                    print(f"Error opening file {file}: {str(e)}")
                    continue    
    return counts

def find_file_by_sender_id(directory, target_id):
        """
        Find files containing a sender_id that matches first 5 chars of target_id.
        
        Args:
            directory (str): Directory path to search
            target_id (str): The sender ID to match against
            
        Returns:
            list: Files containing matching sender IDs
        """
        matching_files = []
        target_prefix = str(target_id)[:5]
        
        for root, _, files in os.walk(directory):
            for file in files:
                if file.endswith('.csv'):
                    file_path = os.path.join(root, file)
                    try:
                        df = pd.read_csv(file_path)
                        if 'sender_id' in df.columns:
                            # Convert sender_ids to string and get first 5 chars
                            df['sender_id'] = df['sender_id'].fillna('').astype(str).str[:5]
                            if (df['sender_id'] == target_prefix).any():
                                matching_files.append(file)
                    except Exception as e:
                        print(f"Error reading {file}: {str(e)}")
                        continue
                        
        print(matching_files)

def get_words_in_file(file_path):
    counts = {
        'system': 0,
        'user': 0,
        'assistant': 0,
        'total': 0
    }
    with open(file_path, 'r') as file:
        for line in file:
            data = json.loads(line.strip())
            for message in data['messages']:
                content = message.get('content')
                if not isinstance(content, str):
                    content = str(content) if content is not None else ''
                word_count = len(content.split())
                message_type = message.get('role', '')
                if message_type in counts:
                    counts[message_type] += word_count
                    counts['total'] += word_count
    return counts

def get_words_in_csv(input_dir):
    word_counts = {
        'text': 0,
        'llm_text': 0
    }
    for root, _, files in os.walk(input_dir):
        for file in files:
            root_pattern = r".*2025(03|04).*"
            if re.match(root_pattern, root) is not None:
                file_path = os.path.join(root, file)
                if not file.endswith('.csv'):
                    continue
                df = pd.read_csv(file_path)
                if 'llm_text' in df.columns:
                    word_counts['llm_text'] += df['llm_text'].str.split().str.len().sum()
                if 'text' in df.columns:
                    word_counts['text'] += df['text'].str.split().str.len().sum()
    return word_counts

# Example usage
if __name__ == "__main__":
    # input_directory = "./data/processed_data"
    # find_file_by_sender_id(input_directory, "61004")
    
    # input_directory = "./data/finetune_data/chatgpt_data/train"
    new_file_dir = "./"
    # total_word_counts = get_words_in_csv("../../result/simulation")
    # print(total_word_counts)
    
    word_counts = count_words_in_jsonl(new_file_dir)
    print("\nWord count breakdown:")
    print(f"System messages: {word_counts['system']}")
    print(f"User messages: {word_counts['user']}")
    print(f"Assistant messages: {word_counts['assistant']}")
    print(f"Total words: {word_counts['total']}")
