from transformers import AutoModel
import pandas as pd
from pandas.core.groupby import DataFrameGroupBy
import matplotlib.pyplot as plt
import numpy as np

model = AutoModel.from_pretrained("jinaai/jina-embeddings-v3", trust_remote_code=True, device_map="cuda")
simulation_df = pd.read_csv("../human_data_simulation_1.csv")
n_players = 4
is_llm = False
player_id_field = "empirica_id"

def score(group: DataFrameGroupBy):
    result = np.zeros((n_players, n_players))
    if is_llm and group["llm_text"].isna().any():
        scores = [0.0] * n_players
    else:
        scores = []
        player_ids = group[player_id_field].tolist()
        assert len(player_ids) == n_players
        for i in range(len(player_ids)):
            for j in range(i + 1, len(player_ids)):
                text1 = group[group[player_id_field] == player_ids[i]]["llm_text" if is_llm else "text"].item()
                text2 = group[group[player_id_field] == player_ids[j]]["llm_text" if is_llm else "text"].item()
                embeddings = model.encode([text1, text2], task="text-matching")
                result[i, j] = (embeddings[0] @ embeddings[1].T).item()
                result[j, i] = result[i, j]  # symmetric matrix
        scores.append(result.mean())
    group_copy = pd.DataFrame(group)
    group_copy["score"] = result.mean()
    group_copy["score_matrix"] = repr(result)
    return group_copy

simulation_df = simulation_df[simulation_df["event_type"] != "message_recieved"]
simulation_df = simulation_df.groupby("webpage_order").apply(score, include_groups=False).reset_index()
simulation_df.to_csv("human_data_simulation_1_score.csv", index=False)
report_df = simulation_df.groupby("webpage_order").agg({"score": ["mean"]})
report_df.to_csv("human_data_simulation_1_score_report.csv")

# plot: x-axis: time step, y-axis: similarity score
plt.figure()
for event_type, group in simulation_df.groupby("event_type"):
    plt.plot(group["webpage_order"], group["score"], linestyle="-", marker=".", label=event_type)
plt.xlabel("Time Step (webpage order)")
plt.ylabel("Similarity Score")
plt.ylim(-0.01, 1.01)
plt.legend()
plt.title("LLMs Score" if is_llm else "Humans Score")
plt.savefig("human_data_simulation_1_score_plot.png")
