import pandas as pd
import numpy as np
import os
import re
def get_pairs(user_data, round_number):
    """
    Retrieves the pairs of players for a given round number.

    Args:
        user_data (pd.DataFrame): The DataFrame containing message data.
        round_number (int): The round number to retrieve pairs for.
        agents_v2 (list): The list of agents.
    Returns:
        list: A list of tuples, where each tuple contains (sender_agent, recipient_agent).
              The list represents the pairs of agents for the specified round.
    """
    
    pairs = user_data[(user_data["chat_round_order"] == round_number) & (user_data["event_type"] != "tweet")][["sender_id", "recipient_id"]].dropna().drop_duplicates().values.tolist()
    pairs = list(set(tuple(sorted(pair)) for pair in pairs))
    return pairs

def fix_message_order(user_data):
    """
    For each round, find the first position where text is empty/NA but llm_text has a value.
    
    Args:
        user_data (pd.DataFrame): DataFrame containing conversation data
        
    Returns:
        pd.DataFrame: DataFrame with fixed message order
    """
    # Make a copy to avoid modifying original
    df = user_data.copy()
    
    # Get unique rounds
    rounds = df['chat_round_order'].unique()
    print(f"Found {len(rounds)} rounds")  # Debug print
    
    for round_num in rounds:
        pairs = get_pairs(df, round_num)
        print(f"Round {round_num}: Found {len(pairs)} pairs")  # Debug print
        
        for pair in pairs:
            # Get data for this round and pair
            round_mask = (df['chat_round_order'] == round_num) & (
                ((df['sender_id'] == pair[0]) & (df['recipient_id'] == pair[1])) |
                ((df['sender_id'] == pair[1]) & (df['recipient_id'] == pair[0]))
            )
            round_data = df[round_mask]
            
            print(f"Round {round_num}, Pair {pair}: Found {len(round_data)} messages")  # Debug print
            
            # Find first position where text is empty but llm_text exists
            empty_text_mask = round_data['text'].isna() | (round_data['text'].str.strip() == '')
            has_llm_mask = round_data['llm_text'].notna() & (round_data['llm_text'].str.strip() != '')
            sent_mask = round_data['event_type'] == 'message_sent'
            
            first_pos = round_data[empty_text_mask & has_llm_mask & sent_mask]['event_order'].min()
            
            if pd.notna(first_pos):
                # Get all messages from first_pos onwards for this pair
                clear_mask = round_mask & (df['event_order'] >= first_pos)
                messages_mask = round_mask & (df['event_order'] >= first_pos) & sent_mask
                messages = df[messages_mask].copy()
                
                if len(messages) > 0:  # Only proceed if we have messages
                    # Clear llm_text column for these messages
                    df.loc[clear_mask, 'llm_text'] = ''
                    
                    # Store the messages in order
                    messages = messages.sort_values('event_order')
                    
                    # Get valid positions (where text is not empty)
                    valid_mask = round_mask & (df['event_order'] >= first_pos) & \
                                (df['event_type'] == 'message_sent') & \
                                df['text'].notna() & (df['text'].str.strip() != '')
                    valid_positions = df[valid_mask].sort_values('event_order')
                    
                    # Get non-empty llm_text messages in order
                    llm_messages = messages[messages['llm_text'].notna() & 
                                         (messages['llm_text'].str.strip() != '')]['llm_text'].tolist()
                    
                    print(f"Found {len(llm_messages)} LLM messages to redistribute")  # Debug print
                    
                    # Update llm_text for valid positions with stored messages
                    for i, (idx, row) in enumerate(valid_positions.iterrows()):
                        if i < len(llm_messages):
                            df.loc[idx, 'llm_text'] = llm_messages[i]
    
    return df

def pair_messages(user_data):
    """
    For each message_sent event, find its corresponding message_received event and
    copy llm_text to text if text is empty.
    """
    df = user_data.copy()
    
    # Find all message_sent events
    sent_mask = df['event_type'] == 'message_sent'
    sent_messages = df[sent_mask]
    
    for _, sent_row in sent_messages.iterrows():
        # Find corresponding message_received
        received_mask = (df['event_type'] == 'message_recieved') & \
                        (df['text'] == sent_row['text'])
  
        received = df[received_mask].iloc[0] if len(df[received_mask]) > 0 else None
        
        if received is not None:
            received_event_order = received['event_order']
            # print(received_event_order, sent_row['event_order'])
            if (pd.notna(received['text']) or received['text'].strip() != '') and \
               (pd.isna(received['llm_text']) or received['llm_text'].strip() == ''):
                df.loc[df['event_order'] == received_event_order, 'llm_text'] = sent_row['llm_text']
    
    return df
    

if __name__ == "__main__":
    path = '../../result/simulation'
    for root, dirs, files in os.walk(path):
        for file in files:
            if not re.match(r'.*2025(03|04).*', root):
                print(f"Skipping {file}")
                continue
            if file.endswith('.csv'):
                file_path = os.path.join(root, file)
                print(f"Processing {file_path}")
                
                # Read CSV
                user_data = pd.read_csv(file_path)
                
                # Apply functions
                df = fix_message_order(user_data)
                df = pair_messages(df)
                
                # Save with same name
                df.to_csv(file_path, index=False)
                print(f"Saved processed file to {file_path}")
    # user_data = pd.read_csv('./simulation-v2-211650.csv')
    # df = fix_message_order(user_data)
    # df = pair_messages(df)
    # df.to_csv('./simulation-v2-211650-fixed.csv', index=False)