from pathlib import Path
import matplotlib.pyplot as plt
import pandas as pd
import json
import re
import os
import typing


def get_event_order_round_separators(simulation_df: pd.DataFrame):
    round_separators = simulation_df[simulation_df["event_type"] == "message_sent"].groupby("chat_round_order").agg({
        "event_order": ["min", "max"]
    })
    round_separators.columns = ["round_start", "round_end"]
    round_separators = round_separators.reset_index()


def get_chat_order_and_separators(simulation_df: pd.DataFrame):
    simulation_df = simulation_df.reset_index(drop=True)
    
    # Initialize variables
    chat_order_counter = 0
    chat_order_map = {}

    # Function to generate pseudo_id
    def generate_pseudo_id(row):
        if row['event_type'] == 'Initial Opinion':
            return (0, 0)
        elif row['event_type'] == 'tweet':
            return int(row['chat_round_order']), 0
        elif row['event_type'] == 'message_sent':
            # get the number of messages already sent in the same round before row["event_order"]
            message_id = simulation_df[(simulation_df["chat_round_order"] == row["chat_round_order"]) & (simulation_df["event_order"] < row["event_order"]) & (simulation_df["event_type"] == "message_sent") & (simulation_df["empirica_id"] == row["empirica_id"])].shape[0] + 1
            return (row['chat_round_order'], message_id)
        elif row['event_type'] == 'Post Opinion':
            return (4, 0)  # 3 rounds, simulation_df["chat_round_order"].max() + 1
        else:
            raise ValueError(f"Unknown event type: {row['event_type']}")

    # Generate pseudo_ids and chat_orders
    pseudo_ids = simulation_df.apply(generate_pseudo_id, axis=1)  # (chat_round_order, message_id)
    simulation_df['pseudo_id'] = pseudo_ids
    simulation_df['chat_round_order'] = pseudo_ids.apply(lambda x: x[0])
    simulation_df['message_id'] = pseudo_ids.apply(lambda x: x[1])
    simulation_df = simulation_df.sort_values(by=["chat_round_order", "message_id"])
    chat_orders = []

    for pseudo_id in simulation_df['pseudo_id']:
        if pseudo_id not in chat_order_map:
            chat_order_map[pseudo_id] = chat_order_counter
            chat_order_counter += 1
        chat_orders.append(chat_order_map[pseudo_id])

    # Add chat_order column to the DataFrame
    simulation_df['chat_order'] = chat_orders

    # Generate round start and end chat_orders
    round_separators = simulation_df[simulation_df["event_type"] == "message_sent"].groupby("chat_round_order").agg({
        "chat_order": ["min", "max"]
    })
    round_separators.columns = ["round_start", "round_end"]
    round_separators = round_separators.reset_index()

    return simulation_df, round_separators


def plot_round_separators(round_separators: pd.DataFrame, y: float, ax=None):
    for _, row in round_separators.iterrows():
        middle_x = (row["round_start"] + row["round_end"]) / 2
        if ax is None:
            plt.axvline(x=row["round_start"], color="r", linestyle="--", alpha=0.5)
            plt.axvline(x=row["round_end"], color="r", linestyle="--", alpha=0.5)
            plt.text(middle_x, y, f"Round {int(row['chat_round_order'])}", 
                    horizontalalignment='center', verticalalignment='bottom',
                    fontsize=10)
        else:
            ax.axvline(x=row["round_start"], color="r", linestyle="--", alpha=0.5)
            ax.axvline(x=row["round_end"], color="r", linestyle="--", alpha=0.5)
            ax.text(middle_x, y, f"Round {int(row['chat_round_order'])}", 
                    horizontalalignment='center', verticalalignment='bottom',
                    fontsize=10)


def merge_consecutive_messages(message_events: pd.DataFrame):
    if len(message_events) == 0:
        return message_events
    # Preprocess to merge consecutive messages from the same person in every pair in every round
    n_rounds = int(message_events["chat_round_order"].max())

    for round_num in range(1, n_rounds + 1):
        round_events_idx = message_events["chat_round_order"] == round_num

        # generate pairs of players (A, B) without (B, A) duplicates
        pairs = message_events.loc[round_events_idx].groupby(["sender_id", "recipient_id"]).size().reset_index().iloc[:, :2]
        pairs = pairs.apply(lambda row: tuple(sorted([row["sender_id"], row["recipient_id"]])), axis=1).drop_duplicates().tolist()

        for player_a, player_b in pairs:
            # filter messages between player A and player B
            messages_idx = ((message_events["sender_id"] == player_a) | (message_events["sender_id"] == player_b)) & round_events_idx
            messages_ro = message_events.loc[messages_idx]

            # merge consecutive messages from the same person and set the first event order
            message_events.loc[messages_idx, "merged_message"] = messages_ro.groupby((messages_ro["sender_id"] != messages_ro["sender_id"].shift()).cumsum())["text"].transform(lambda x: " ".join(x))
            message_events.loc[messages_idx, "first_event_order"] = messages_ro.groupby((messages_ro["sender_id"] != messages_ro["sender_id"].shift()).cumsum())["event_order"].transform('min')

            # drop duplicates (later) to keep only the merged messages with the first event number
            message_events.loc[messages_idx, "text"] = message_events.loc[messages_idx, "merged_message"]
            message_events.loc[messages_idx, "event_order"] = message_events.loc[messages_idx, "first_event_order"]

    message_events.drop_duplicates(subset=["merged_message"], inplace=True)
    message_events.drop(columns=["merged_message", "first_event_order"], inplace=True)
    return message_events


def preprocess_simulation_df(simulation_df: pd.DataFrame, consecutive_messages: bool = True):
    simulation_df = simulation_df[(simulation_df["text"] != "") & (simulation_df["text"].notna())]
    if consecutive_messages:
        message_events = simulation_df[simulation_df["event_type"] == "message_sent"]
        message_events = merge_consecutive_messages(message_events.copy())
        other_events = simulation_df[simulation_df["event_type"].isin(["Initial Opinion", "Post Opinion", "tweet"])]
        simulation_df = pd.concat([message_events, other_events], ignore_index=True).sort_values("event_order")
    else:
        simulation_df = simulation_df[simulation_df["event_type"].isin(["Initial Opinion", "message_sent", "Post Opinion", "tweet"])]
    return simulation_df


def save_csv(df: pd.DataFrame, path: str):
    Path(path).parent.mkdir(parents=True, exist_ok=True)
    df.to_csv(path, index=False)


def save_json(data, path: str):
    Path(path).parent.mkdir(parents=True, exist_ok=True)
    with open(path, "w") as f:
        json.dump(data, f, indent=4)


def dp_to_topic(data_prefix: str):
    topic = re.search(r'\d{8}_\d{6}_(.*)_.{26}', data_prefix).group(1).replace('_', ' ')
    topic = re.sub(r' +', ' ', topic)
    return topic


def load_simulation_df(
        data_prefix: str,
        model_name: typing.Optional[str] = None,
        eval_model_name: typing.Optional[str] = None,
        version: str = "human",
        filter_invalid: bool = False,
        load_validity: bool = False,
        file_path: typing.Optional[str] = None
    ):
    """
    Load simulation dataframe for either human or LLM data.
    
    Args:
        data_prefix: The data prefix identifier for the experiment
        model_name: If provided, loads LLM data for this model; if None, loads human data
        eval_model_name: If provided and filter_invalid is True, loads validity data from this eval model
        version: Which version of simulation data to load ("human", "v0", "v1", "v2", etc.)
        filter_invalid: Whether to filter out invalid messages based on validity data
        load_validity: Whether to load and include validity data in the returned dataframe
        file_path: If provided, loads data from this file path instead of the default path
        
    Returns:
        DataFrame containing simulation data, optionally filtered by validity
    """
    base_path = os.path.join("../../result")
    
    if file_path is not None:
        assert os.path.exists(file_path), f"File not found: {file_path}"
        fallback_path = file_path
    else:
        # Determine file path
        if model_name is None:
            # Human data
            fallback_path = os.path.join("../../data", "processed_data", f"{data_prefix}.csv")
        else:
            # LLM data
            fallback_path = os.path.join(base_path, "simulation", data_prefix, model_name, f"simulation-llm-{version}.csv")
    
    # Load and apply validity filtering if requested
    if (filter_invalid or load_validity) and eval_model_name is not None:
        if file_path is not None:
            validity_file = file_path
        else:
            if model_name is None:
                validity_file = os.path.join(base_path, "eval", "validity", data_prefix, "simulation-human.csv")
            else:
                validity_file = os.path.join(base_path, "eval", "validity", data_prefix, model_name, f"simulation-llm-{version}.csv")
        
        if os.path.exists(validity_file):
            df = pd.read_csv(validity_file)
            
            # Filter out invalid messages if requested
            if filter_invalid:
                df = df[(df["validity"] == "VALID") | (df["event_type"] != "message_sent")]
        else:
            raise FileNotFoundError(f"Validity file not found: {validity_file}")
    else:
        # Load simulation data
        if not os.path.exists(fallback_path):
            raise FileNotFoundError(f"Simulation file not found: {fallback_path}")
        
        df = pd.read_csv(fallback_path)
    
    return df


def load_simulation_dfs(
        data_prefix: str,
        model_name: typing.Optional[str] = None,
        eval_model_name: typing.Optional[str] = "gpt-4o-mini-2024-07-18",
        version: str = "v0",
        filter_strategy: typing.Literal["none", "any", "both", "separate", "human_only", "llm_only"] = "none",
        preprocess: bool = True,
        consecutive_messages: bool = True,
        mark_invalid_players: bool = False,
        human_file_path: typing.Optional[str] = None,
        llm_file_path: typing.Optional[str] = None
    ):
    """
    Load both human and LLM simulation dataframes with optional filtering and preprocessing.
    
    Args:
        data_prefix: The data prefix identifier for the experiment
        model_name: The model name for LLM data
        eval_model_name: If provided, loads validity data from this eval model
        version: Which version of simulation data to load for LLM data ("v0", "v1", "v2", etc.)
        filter_strategy: How to filter invalid messages:
            - "none": No filtering
            - "any": Filter if either human or LLM message is invalid
            - "both": Filter only if both human and LLM messages are invalid
            - "separate": Filter human and LLM data independently (may result in different data sizes)
            - "human_only": Filter only human data
            - "llm_only": Filter only LLM data
        preprocess: Whether to preprocess the dataframes using preprocess_simulation_df
        consecutive_messages: Whether to merge consecutive messages (used if preprocess=True)
        mark_invalid_players: Whether to mark all events from invalid players (determined by get_invalid_players) as invalid
        
    Returns:
        Tuple of (human_df, llm_df) containing simulation data
    """
    # Handle "separate" filtering strategy differently - apply filters independently
    if filter_strategy == "separate" and eval_model_name is not None:
        human_df = load_simulation_df(
            data_prefix=data_prefix,
            model_name=None,  # Human data
            eval_model_name=eval_model_name,
            version="human",
            filter_invalid=True,  # Apply filter directly for human data
            load_validity=False,  # No need to keep validity info
            file_path=human_file_path
        )
        
        llm_df = load_simulation_df(
            data_prefix=data_prefix,
            model_name=model_name,
            eval_model_name=eval_model_name,
            version=version,
            filter_invalid=True,  # Apply filter directly for LLM data
            load_validity=False,  # No need to keep validity info
            file_path=llm_file_path
        )
        
        # Apply preprocessing if requested
        if preprocess:
            human_df = preprocess_simulation_df(human_df, consecutive_messages=consecutive_messages)
            llm_df = preprocess_simulation_df(llm_df, consecutive_messages=consecutive_messages)
        
        return human_df, llm_df
    
    # For other filtering strategies, load and filter together
    load_validity = filter_strategy != "none" and eval_model_name is not None
    
    human_df = load_simulation_df(
        data_prefix=data_prefix,
        model_name=None,  # Human data
        eval_model_name=eval_model_name,
        version="human",
        filter_invalid=False,  # Don't filter yet, we'll do it after comparing both
        load_validity=load_validity,
        file_path=human_file_path
    )
    
    llm_df = load_simulation_df(
        data_prefix=data_prefix,
        model_name=model_name,
        eval_model_name=eval_model_name,
        version=version,
        filter_invalid=False,  # Don't filter yet, we'll do it after comparing both
        load_validity=load_validity,
        file_path=llm_file_path
    )
    
    assert len(human_df) == len(llm_df), "Human and LLM dataframes have different lengths"
    assert ((human_df["text"] == llm_df["text"]) | (human_df["text"].isna() & llm_df["text"].isna())).all(), "Human and LLM dataframes have different texts"
    assert ((human_df["event_order"] == llm_df["event_order"]) | (human_df["event_order"].isna() & llm_df["event_order"].isna())).all(), "Human and LLM dataframes have different event_orders"
    human_df["llm_text"] = llm_df["llm_text"]
    human_df["input_prompt"] = llm_df["input_prompt"]
    human_df["agreement_level"] = llm_df["agreement_level"]
    
    # Mark events from invalid players as invalid if requested
    if mark_invalid_players:
        # For human data, use get_invalid_players to identify invalid players
        invalid_players = get_invalid_players(human_df)
        
        # If validity column doesn't exist, create it
        if "validity" not in human_df.columns:
            human_df["validity"] = ""
        if "validity" not in llm_df.columns:
            llm_df["validity"] = ""
            
        # Mark all events from invalid players as invalid
        for player_id in invalid_players:
            # Mark events in human data
            human_df.loc[human_df["sender_id"].str.startswith(player_id), "validity"] = "INVALID"
            
            # Mark corresponding events in LLM data if they exist
            if model_name is not None:
                llm_df.loc[llm_df["sender_id"].str.startswith(player_id), "validity"] = "INVALID"
        
        # Ensure load_validity is True so filtering can be applied later
        load_validity = True
    
    # Apply filtering strategy if requested
    if filter_strategy != "none" and load_validity:
        # Create a mapping between human and LLM messages by event_order
        merged = pd.merge(
            human_df[["event_order", "validity", "event_type"]].rename(columns={"validity": "human_validity"}),
            llm_df[["event_order", "validity", "event_type"]].rename(columns={"validity": "llm_validity"}),
            on=["event_order", "event_type"],
            how="inner"
        )
        
        # Apply filtering strategy
        if filter_strategy == "any":
            # Filter if either human or LLM message is invalid
            valid_events = merged[
                ((merged["human_validity"] == "VALID") & 
                (merged["llm_validity"] == "VALID")) |
                (merged["event_type"] != "message_sent")
            ]["event_order"].tolist()
        elif filter_strategy == "both":
            # Filter only if both human and LLM messages are invalid
            invalid_events = merged[
                ((merged["human_validity"] != "VALID") & 
                (merged["llm_validity"] != "VALID")) &
                (merged["event_type"] != "message_sent")
            ]["event_order"].tolist()
            valid_events = merged[~merged["event_order"].isin(invalid_events)]["event_order"].tolist()
        elif filter_strategy == "human_only":
            valid_events = merged[(merged["human_validity"] == "VALID") | (merged["event_type"] != "message_sent")]["event_order"].tolist()
        elif filter_strategy == "llm_only":
            valid_events = merged[(merged["llm_validity"] == "VALID") | (merged["event_type"] != "message_sent")]["event_order"].tolist()
        
        # Apply filtering to both dataframes
        human_df = human_df[human_df["event_order"].isin(valid_events)]
        llm_df = llm_df[llm_df["event_order"].isin(valid_events)]
        
        # Rename columns to avoid confusion
        if "validity" in human_df.columns:
            human_df = human_df.rename(columns={"validity": "human_validity"})
        if "validity" in llm_df.columns:
            llm_df = llm_df.rename(columns={"validity": "llm_validity"})
    
    # Apply preprocessing if requested
    if preprocess:
        human_df = preprocess_simulation_df(human_df, consecutive_messages=consecutive_messages)
        llm_df = preprocess_simulation_df(llm_df, consecutive_messages=consecutive_messages)
    
    return human_df, llm_df


def filter_data_prefixes(data_prefixes: typing.List[str], csv_file: str, is_fully_complete: typing.Optional[bool] = None, source: typing.Optional[typing.Literal["sona", "prolific"]] = None):
    """
    Filter data prefixes that are in the CSV file and discard the rest.
    Optional additional filtering by completion status and data source.
    
    Args:
        data_prefixes: List of data prefix strings to filter
        csv_file: Path to CSV file containing valid data prefixes
        is_fully_complete: If provided, filter by completion status (True/False)
        source: If provided, filter by data source ("sona" or "prolific")
        
    Returns:
        List of filtered data prefixes that meet all specified criteria
    """
    # Read the CSV file
    csv_df = pd.read_csv(csv_file)
    
    # Apply filters on the dataframe if requested
    if is_fully_complete is not None and 'is_fully_complete' in csv_df.columns:
        csv_df = csv_df[csv_df['is_fully_complete'] == is_fully_complete]
    
    if source is not None and 'source' in csv_df.columns:
        csv_df = csv_df[csv_df['source'] == source]
    
    # Extract data prefixes from the csv_filename column
    # Format example: 20250403_191343_Everything_that_happens_can_eventually_be_explained_by_science_01JQYFJHV7HBPCHA5D36FZDM3Z_0.0.1.csv
    if 'csv_filename' in csv_df.columns:
        # Extract prefixes without version and .csv extension
        # clean_csv_filename = lambda x: x.rsplit('_', 2)[0] if '_0.0.' in x else x.replace('.csv', '')
        clean_csv_filename = lambda x: x.replace('.csv', '')
        csv_prefixes = csv_df['csv_filename'].apply(clean_csv_filename)
        
        # Filter data_prefixes to only include those in the CSV file
        filtered_prefixes = [prefix for prefix in data_prefixes if any(prefix in csv_p for csv_p in csv_prefixes)]
        
        return filtered_prefixes
    else:
        # If csv_filename column doesn't exist, return original list
        print(f"Warning: 'csv_filename' column not found in {csv_file}")
        return data_prefixes
    
import os
import re

def extract_topic_versions(name: str):
    parts = name.split("_")
    if len(parts) < 4:
        raise ValueError(f"Filename {name} does not match expected pattern.")

    timestamp = f"{parts[0]}_{parts[1]}"
    uid = parts[-1]
    topic = "_".join(parts[2:-1])  # middle parts

    if not (len(uid) == 26 and uid.startswith("01J")):
        raise ValueError(f"UID {uid} does not match expected format.")

    return topic, uid


# from simulate_conversation.py
def get_invalid_players(user_data):
    """
    Retrieves the invalid players from the user data.
    """
    # Get players without tweets in all 3 rounds
    players_with_tweets = user_data[user_data['event_type'] == 'tweet']['sender_id'].apply(lambda x: x[:5]).value_counts()
    players_missing_tweets = players_with_tweets[players_with_tweets < 3].index.tolist()
    
    # Get players without exit survey
    all_players = user_data['sender_id'].dropna().apply(lambda x: x[:5]).unique()
    players_with_survey = user_data[user_data['event_type'] == 'exit_survey']['worker_id'].dropna().apply(lambda x: x[:5]).unique()
    players_missing_survey = [p for p in all_players if p not in players_with_survey]
    
    # Get players without any message_sent
    players_with_messages = user_data[user_data['event_type'] == 'message_sent']['sender_id'].dropna().apply(lambda x: x[:5]).unique()
    players_missing_messages = [p for p in all_players if p not in players_with_messages]
    
    # Combine all invalid players
    invalid_players = list(set(players_missing_tweets + players_missing_survey + players_missing_messages))
    return invalid_players

