# Use when align LLM data (i.e., run opinion_proc.py with llm_column=True)

import pandas as pd
import matplotlib.pyplot as plt
from . import util
import typing
import os


opinion_color_map = {
    -1: "gray",
    1: "darkred",
    2: "red",
    3: "pink",
    4: "lightblue",
    5: "blue",
    6: "darkblue",
}


def plot_opinion_trajectory(df, opinion, players, player_name_col, n_rounds, is_llm=False, n_prefixes=1):
    """
    Helper function to plot opinion trajectory for a given opinion value
    
    Args:
        df: DataFrame containing opinion data
        opinion: The initial opinion value
        players: List of players with this initial opinion
        player_name_col: Column name for player identifiers
        n_rounds: Total number of rounds to plot
        is_llm: Whether this is LLM data (True) or Human data (False)
        n_prefixes: Number of experiments/prefixes in the data
    """
    round_scores = []
    round_errors = []
    
    for round_num in range(n_rounds):
        round_data = df[
            (df["chat_round_order"] == round_num) & 
            (df[player_name_col].isin(players))
        ]
        
        if not round_data.empty:
            # Calculate mean
            avg_score = round_data["likert_pred"].mean()
            round_scores.append(avg_score)
            
            # Calculate standard error if we have multiple experiments
            if n_prefixes > 1:
                # Get standard error (std/sqrt(n))
                std_error = round_data["likert_pred"].std() / (len(round_data) ** 0.5)
                round_errors.append(std_error)
            else:
                round_errors.append(0)
        else:
            round_scores.append(None)
            round_errors.append(None)
    
    marker = 'x' if is_llm else '+'
    linestyle = 'dotted' if is_llm else '-'
    alpha = 1.0
    entity_type = "Agent" if is_llm else "Human"
    color = opinion_color_map[int(opinion) if opinion.is_integer() else -1]
    
    # Plot the main line
    plt.plot(
        range(len(round_scores)), 
        round_scores, 
        marker=marker, 
        linestyle=linestyle, 
        color=color,
        alpha=alpha,
        label=f"Initial: {opinion:.0f} ({entity_type})"
    )
    
    # Add error bars if we have multiple experiments
    if n_prefixes > 1:
        valid_indices = [i for i, score in enumerate(round_scores) if score is not None]
        valid_scores = [round_scores[i] for i in valid_indices]
        valid_errors = [round_errors[i] for i in valid_indices]
        
        plt.errorbar(
            valid_indices,
            valid_scores,
            yerr=valid_errors,
            fmt='none',  # No connecting line
            ecolor=color,
            capsize=4,
            alpha=0.7,
            elinewidth=1
        )
    
    return round_scores, round_errors


# For the same topic, plot every round for averaged player over the initial opinions (one point each round for players with the same initial opinion score)
def main(topic: str, data_prefixes: typing.List[str], model_name: str, eval_model_save_name: str, version: str, player_name_col: str = "player_id"):
    human_dfs = {}  # data_prefix -> human_df (round_based, averaged)
    llm_dfs = {}  # data_prefix -> llm_df (round_based, averaged)
    
    # Load data
    for data_prefix in data_prefixes:
        human_file = f"../../result/eval/human_llm/{data_prefix}/{model_name}/opinion_human_memory_{eval_model_save_name}_{version}.csv"
        llm_file = f"../../result/eval/human_llm/{data_prefix}/{model_name}/opinion_llm_memory_{eval_model_save_name}_{version}.csv"
        
        if not os.path.exists(human_file) or not os.path.exists(llm_file):
            print(f"Skipping {data_prefix} as files do not exist")
            continue
            
        human_df = pd.read_csv(human_file)
        llm_df = pd.read_csv(llm_file)
        
        if human_df.empty or llm_df.empty:
            print(f"No data for topic '{topic}' in {data_prefix}")
            continue
            
        # Draw separators and add round labels
        human_df, round_separators = util.get_chat_order_and_separators(human_df)
        llm_df, _ = util.get_chat_order_and_separators(llm_df)
        llm_df["chat_order"] = human_df["chat_order"]
        
        # Store the dataframes
        human_dfs[data_prefix] = human_df
        llm_dfs[data_prefix] = llm_df

    if not human_dfs:
        print(f"No data available for topic '{topic}' in any of the provided data prefixes")
        return

    # Aggregate scores by initial opinion
    plt.figure(figsize=(12, 8))
    
    # Collect all human and LLM data across all prefixes
    all_human_dfs = pd.concat(human_dfs.values()) if human_dfs else pd.DataFrame()
    all_llm_dfs = pd.concat(llm_dfs.values()) if llm_dfs else pd.DataFrame()
    
    if all_human_dfs.empty:
        return
    
    # Get initial opinions for each player
    initial_opinions = {}
    for player_id, group in all_human_dfs[all_human_dfs["chat_round_order"] == 0].groupby(player_name_col):
        initial_opinions[player_id] = group["likert_pred"].iloc[0]
    
    # Group players by their initial opinions
    opinion_groups = {}
    for player, opinion in initial_opinions.items():
        if opinion not in opinion_groups:
            opinion_groups[opinion] = []
        opinion_groups[opinion].append(player)
    
    n_rounds = int(max(all_human_dfs["chat_round_order"]) + 2)  # +2 for initial and post opinions
    
    # Plot aggregated human data (one line per opinion value)
    num_prefixes = len(human_dfs)
    for opinion in opinion_groups.keys():
        players = opinion_groups[opinion]
        
        # Plot human data for this opinion (aggregating all players)
        plot_opinion_trajectory(
            df=all_human_dfs,
            opinion=opinion,
            players=players,
            player_name_col=player_name_col,
            n_rounds=n_rounds,
            is_llm=False,
            n_prefixes=num_prefixes
        )
        
        # Plot LLM data for this opinion (aggregating all players)
        plot_opinion_trajectory(
            df=all_llm_dfs,
            opinion=opinion,
            players=players,
            player_name_col=player_name_col,
            n_rounds=n_rounds,
            is_llm=True,
            n_prefixes=num_prefixes
        )
    
    # Set plot styling
    plt.title(f"Opinion Trajectory by Initial Opinion - Topic: {topic} (n={num_prefixes})", wrap=True)
    plt.ylim(0.9, 6.1)
    plt.yticks([1, 2, 3, 4, 5, 6])
    plt.xlabel("Round")
    plt.ylabel("Mean Likert Score")
    plt.grid(True, alpha=0.3)
    plt.legend(fontsize="small", loc="best")
    
    # Save the plot
    output_dir = f"../../result/eval/opinion_topics/{model_name}"
    os.makedirs(output_dir, exist_ok=True)
    output_file = f"{output_dir}/opinion_plot_{topic}_{eval_model_save_name}_{version}.svg"
    plt.savefig(output_file)
    # print(f"Plot saved to {output_file}")
    plt.close()
