import pandas as pd
import os
from ast import literal_eval
import re
topic_list = [
    "Everything that happens can eventually be explained by science",
    "The position of the planets at the time of your birth can influence your personality",
    "Angels are real",
    "Regular fasting will improve your health",
    'A "body cleanse," in which you consume only particular kinds of nutrients over 1-3 days, helps your body to eliminate toxins',
    "The US deficit increased after President Obama was elected",
    "The United States has the highest federal income tax rate of any Western country",
]
    

def get_pairs(user_data, round_number):
    """
    Retrieves the pairs of players for a given round number.

    Args:
        user_data (pd.DataFrame): The DataFrame containing message data.
        round_number (int): The round number to retrieve pairs for.
    Returns:
        list: A list of tuples, where each tuple contains (sender_agent, recipient_agent).
              The list represents the pairs of agents for the specified round.
    """
    pairs = user_data[(user_data["chat_round_order"] == round_number) & (user_data["event_type"] != "tweet")][["sender_id", "recipient_id"]].dropna().drop_duplicates().values.tolist()
    pairs = list(set(tuple(sorted(pair)) for pair in pairs))
    return pairs

def concat_with_pairs(user_data):
    """
    Concatenates consecutive messageSent entries and updates messageReceived entries.
    """
    num_rounds = int(user_data["chat_round_order"].max())
    
    for round in range(1, num_rounds + 1):
        pairs = get_pairs(user_data, round)
        for pair in pairs:
            sender, recipient = pair
            pair_mask = (
                ((user_data['sender_id'] == sender) & (user_data['recipient_id'] == recipient)) | 
                ((user_data['sender_id'] == recipient) & (user_data['recipient_id'] == sender))
            )
            pair_data = user_data[pair_mask]
            previous_sender = None
            previous_recipient = None

            last_recipient_event_order = 0
            last_write_event_order = 0

            for index, row in pair_data.iterrows():
                if row['event_type'] == 'message_sent':
                    if previous_sender == row['sender_id'] and previous_sender_recipient == row['recipient_id']:
                        user_data.loc[(user_data['event_order'] == last_write_event_order), 'text'] += ' ' + row['text']
                        user_data.loc[(user_data['event_order'] == row['event_order']), 'text'] = ''

                        user_data.loc[(user_data['event_order'] == last_write_event_order), 'time'] = row['time']
                        user_data.loc[(user_data['event_order'] == last_write_event_order), 'end_time'] = row['end_time']
                    else:
                        last_write_event_order = row['event_order']
                        previous_sender = row['sender_id']
                        previous_sender_recipient = row['recipient_id']
                elif row['event_type'] == 'message_recieved':
                    if previous_recipient == row['recipient_id'] and previous_recipient_sender == row['sender_id']:
                        user_data.loc[(user_data['event_order'] == last_recipient_event_order), 'text'] += ' ' + row['text']
                        user_data.loc[(user_data['event_order'] == row['event_order']), 'text'] = ''

                        user_data.loc[(user_data['event_order'] == last_recipient_event_order), 'time'] = row['time']
                        user_data.loc[(user_data['event_order'] == last_recipient_event_order), 'end_time'] = row['end_time']
                    else:
                        last_recipient_event_order = row['event_order']
                        previous_recipient = row['recipient_id']
                        previous_recipient_sender = row['sender_id']

    return user_data

def remove_last_row(df: pd.DataFrame) -> pd.DataFrame:
    """
    Remove the last row if its group_id doesn't match any Initial Opinion group_ids.
    
    Args:
        df (pd.DataFrame): Input dataframe containing group_id and event_type columns
        
    Returns:
        pd.DataFrame: Dataframe with last row potentially removed
    """
    if 'group_id' not in df.columns:
        return df
    # Get group_ids from Initial Opinion rows
    initial_opinion_groups = df[df['event_type'] == 'Initial Opinion']['group_id'].iloc[0]
    
    # Get last row's group_id
    last_row_group = df.iloc[-1]['group_id']
    if isinstance(last_row_group, str):
        try:
            # Handle comma-separated IDs in square brackets
            if last_row_group.strip().startswith('[') and last_row_group.strip().endswith(']'):
                # Extract the IDs by removing brackets and splitting by comma
                last_row_group = last_row_group.strip()[1:-1].split(',')
                # Clean each ID
                last_row_group = [id.strip() for id in last_row_group]
        except Exception as e:
            print(f"Error parsing group_id: {e}")
            print(f"Problematic group_id: {last_row_group}")
            return None
    # Remove last row if its group_id doesn't match any Initial Opinion group_ids
    if isinstance(last_row_group, list):
        if len(last_row_group) == 1:
            return None
        else:
            return df.iloc[:-1].reset_index(drop=True)
    elif last_row_group not in initial_opinion_groups:
        return df.iloc[:-1].reset_index(drop=True)
    
    return df

def correct_ids(df: pd.DataFrame) -> pd.DataFrame:
    """
    Correct worker IDs by mapping them to empirica IDs where possible.
    
    Args:
    df (pandas.DataFrame): Input dataframe containing worker_id and empirica_id columns
    
    Returns:
    pandas.DataFrame: Dataframe with corrected IDs
    """
    # Get mapping of non-NA worker IDs to empirica IDs
    id_mapping = df[['worker_id', 'empirica_id']].dropna().drop_duplicates()
    id_dict = dict(zip(id_mapping.worker_id, id_mapping.empirica_id))
    
    # Get empirica IDs where worker_id is NA and sort them to ensure consistent naming
    na_empirica_ids = sorted(df[df['worker_id'].isna()]['empirica_id'].unique())
    
    # Generate random names for NA worker IDs
    random_names = ['Alice', 'Bob', 'Charlie', 'David', 'Eve', 'Frank', 'Grace', 'Henry']
    name_mapping = {emp_id: f"{random_names[i % len(random_names)]}" 
                   for i, emp_id in enumerate(na_empirica_ids)}
    
    # Create reverse mapping from empirica_id to random names
    reverse_mapping = {emp_id: name for name, emp_id in id_dict.items()}
    reverse_mapping.update(name_mapping)
    
    # Update id_dict with consistent random names for NA worker IDs
    id_dict = {name: emp_id for emp_id, name in reverse_mapping.items()}

    # Fill NA worker_ids with generated names based on empirica_id mapping
    df['worker_id'] = df.apply(lambda row: row['worker_id'] if pd.notna(row['worker_id']) else reverse_mapping.get(row['empirica_id']), axis=1)
    
    # Handle message_id replacement
    df['message_id'] = df.apply(lambda row: row['message_id'].replace('undefined', str(row['worker_id'])) 
                            if pd.notna(row['message_id']) and 'undefined' in str(row['message_id']) 
                            else row['message_id'], axis=1)
    return df

def correct_player_names(df: pd.DataFrame, data_prefix: str) -> pd.DataFrame:
    """
    Correct player names by only keeping the first 5 characters of the player_id.
    """
    df['worker_id'] = df['worker_id'].fillna('').apply(lambda x: str(x)[:5] + data_prefix[-5:] if x else '')
    df['sender_id'] = df['sender_id'].fillna('').apply(lambda x: str(x)[:5] + data_prefix[-5:] if x else '')
    df['recipient_id'] = df['recipient_id'].fillna('').apply(lambda x: str(x)[:5] + data_prefix[-5:] if x else '')
    return df

def remove_idle_events(df: pd.DataFrame) -> pd.DataFrame:
    """
    Remove rows where event_type is 'idle' or 'refresh' from the dataframe.
    
    Args:
    df (pandas.DataFrame): Input dataframe
    
    Returns:
    pandas.DataFrame: Dataframe with 'idle' and 'refresh' events removed
    """
    return df[~df['event_type'].isin(['idle', 'refresh'])].reset_index(drop=True)

def assign_event_order(df: pd.DataFrame) -> pd.DataFrame:
    """
    Assign an event order to each row in the dataframe.

    This function adds an 'event_order' column to the dataframe, numbering events
    sequentially. The numbering continues to increment until the last "Post Opinion"
    event is encountered. After this point, all subsequent events receive the same
    event order number as the last "Post Opinion" event.

    Args:
    df (pandas.DataFrame): Input dataframe containing event data

    Returns:
    pandas.DataFrame: Dataframe with a new 'event_order' column added
    """
    # Initialize the event order counter
    event_order = 1
    
    # Flag to check if we've encountered the last "Post Opinion"
    last_post_opinion_found = False
    
    # Iterate through the dataframe
    for index, row in df.iterrows():
        # Assign the current event_order value
        df.at[index, 'event_order'] = event_order
        event_order += 1
    
    return df

def parse_special_suffixes(df: pd.DataFrame) -> pd.DataFrame:
    """
    Parse special suffixes enclosed in square brackets from the text column
    into new columns: 'sliderValue' and 'isAutosubmitted'.

    Args:
    df (pandas.DataFrame): Input dataframe with a 'text' column

    Returns:
    pandas.DataFrame: Dataframe with new columns for slider value and autosubmission status
    """
    # Pattern to match any content within square brackets at the end of the text
    pattern = r'\[(.*?)\]\s*$'

    # Only manipulate the rows that are not exit_survey
    cand_rows = df['event_type'] != 'exit_survey'

    # Extract special suffixes temporarily
    special_suffixes = df.loc[cand_rows, 'text'].str.extract(pattern, expand=False)

    # Remove the special suffixes from the original text column
    df.loc[cand_rows, 'text'] = df.loc[cand_rows, 'text'].str.replace(pattern, '', regex=True).str.strip()

    # Extract slider value
    df.loc[cand_rows, 'sliderValue'] = special_suffixes.str.extract(r'SLIDER_VALUE=(\d+)', expand=False).astype(float)

    # Extract autosubmission status
    df.loc[cand_rows, 'isAutosubmitted'] = special_suffixes.str.contains('AUTOSUBMISSION DUE TO TIME LIMIT', case=False, na=False)

    # Fill for exit_survey rows
    df.loc[~cand_rows, 'isAutosubmitted'] = False

    return df

def fix_data_loss(user_data: pd.DataFrame) -> pd.DataFrame:
    """
    Fix data loss (for message_recieved) by duplicating rows where the event_type is message_sent.
    """
    # Create a copy of the DataFrame to avoid modifying the original during iteration
    df = user_data.copy()
    # Get mapping only for non-null values
    worker_empirica_mapping = df[['worker_id', 'empirica_id']].dropna().drop_duplicates()
    # Create dictionary mapping only for rows where both values are not null
    empirica_worker_mapping = {worker: empirica 
                             for worker, empirica in zip(worker_empirica_mapping.worker_id, worker_empirica_mapping.empirica_id)
                             if pd.notna(worker) and pd.notna(empirica)}
    
    # Find all message_sent rows
    sent_messages = df[df['event_type'] == 'message_sent']
    
    # For each message_sent, check if there's a corresponding message_received
    new_rows = []
    for _, sent_row in sent_messages.iterrows():
        # Look for matching message_received
        matching_received = df[
            (df['event_type'] == 'message_recieved') & 
            (df['text'] == sent_row['text'])
        ]
        
        # If no matching message_received found, create one
        if len(matching_received) == 0:
            received_row = sent_row.copy()
            received_row['event_type'] = 'message_recieved'
            # If recipient_id is empty or NaN, try to find it from other interactions in the same round
            if pd.isna(received_row['recipient_id']) or received_row['recipient_id'] == '':
                # Get the chat round and sender for this message
                chat_round = received_row['chat_round_order']
                current_sender = received_row['sender_id']
                
                # Find other messages in the same round involving the current sender
                same_round_msgs = df[
                    (df['chat_round_order'] == chat_round) & 
                    (df['event_type'].isin(['message_sent', 'message_recieved'])) &
                    (df['recipient_id'].notna()) &
                    ((df['sender_id'] == current_sender) | (df['recipient_id'] == current_sender))
                ]
                if len(same_round_msgs) > 0:
                    # Get the unique recipient_id from other messages in this round
                    potential_recipients = same_round_msgs.apply(lambda x: x['sender_id'] if x['recipient_id'] == current_sender else x['recipient_id'], axis=1).unique()
                    # Use the recipient that isn't the sender
                    for recipient in potential_recipients:
                        if recipient != current_sender:
                            received_row['recipient_id'] = recipient
                            break
                else:
                    # Find all workers in this round
                    round_workers = df[(df['chat_round_order'] == chat_round) & (df['recipient_id'].notna())]['recipient_id'].unique()
                    # Find workers from the mapping that aren't in this round
                    missing_workers = [worker for worker in empirica_worker_mapping.keys() 
                                    if worker not in round_workers and worker != current_sender]
                    if missing_workers:
                        # Use the first missing worker as recipient
                        received_row['recipient_id'] = missing_workers[0]
            received_row['worker_id'] = received_row['recipient_id']
            received_row['empirica_id'] = empirica_worker_mapping[received_row['worker_id']]
            new_rows.append(received_row)
    
    # Add the new message_recieved rows if any were created
    if new_rows:
        df = pd.concat([df] + [pd.DataFrame([row]) for row in new_rows], ignore_index=True)
        
        # Sort by event_order to maintain chronological order
        df = df.sort_values('event_order')
    
    return df

def remove_duplicate_logs(df: pd.DataFrame) -> pd.DataFrame:
    """
    Remove duplicate logs by keeping only the first occurrence of each log,
    except for exit_survey rows which should all be preserved.
    """
    # Handle non-exit survey rows
    non_exit_survey = df[df['event_type'] != 'exit_survey']
    deduplicated = non_exit_survey.drop_duplicates(
        subset=['worker_id', 'empirica_id', 'event_type', 'text'], 
        keep='first'
    )
    
    # Keep all exit survey rows
    exit_survey = df[df['event_type'] == 'exit_survey']
    
    # Combine and return
    return pd.concat([deduplicated, exit_survey], ignore_index=True)

if __name__ == "__main__":
    excluded_dates = ["1015", "1028", "0921", "1013"]
    count = 0
    folder_path = os.path.join("../../data", "raw_data", "phase_2_breadth_topics")
    for data in os.listdir(folder_path):
        if data.endswith(".csv"):
            continue
        for experiment in os.listdir(os.path.join(folder_path, data)):
            for file in os.listdir(os.path.join(folder_path, data, experiment)):
                if any(date in file for date in excluded_dates):
                    continue
                topic = re.search(r'\d{8}_\d{6}_(.*)_.{26}', file).group(1).replace('_', ' ')
                topic = re.sub(r' +', ' ', topic)
                if file.endswith(".csv"):
                    data_prefix = file[:-4]  # Remove .csv extension
                    if '_0.0.1' in data_prefix:
                        data_prefix = data_prefix[:-6]
                    input_path = os.path.join(os.path.join(folder_path, data, experiment), file)
                    # input_path = os.path.join(folder_path, file)
                    df = pd.read_csv(input_path)
                    try:
                        df = remove_last_row(df)
                    except IndexError as e:
                        print(f"Error removing last row for {file}")
                        continue
                    except TypeError as e:
                        print(f"Error removing last row for {file}")
                        continue
                    if df is None:
                        print(f"Skipping {file} because it was removed")
                        continue
                    count += 1
                    print(f"Processing {file}")
                    df = fix_data_loss(df)
                    df = remove_duplicate_logs(df)
                    df = assign_event_order(df)
                    df = remove_idle_events(df)
                    df = parse_special_suffixes(df)
                    df = correct_ids(df)
                    df = correct_player_names(df, data_prefix)
                    df = concat_with_pairs(df)
                    df.to_csv(os.path.join("../../data", "processed_data", f"{data_prefix}.csv"), index=False)
    print(f"Total processed files: {count}")