import os
from transformers import AutoModel
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from . import util
from .constants import data_prefix, player_name_col

input_file = f"../../data/processed_data/{data_prefix}.csv"
output_path = f"../../result/eval/human/"

model = AutoModel.from_pretrained("jinaai/jina-embeddings-v3", trust_remote_code=True, device_map="cuda")
simulation_df = pd.read_csv(input_file)
simulation_df = util.preprocess_simulation_df(simulation_df, consecutive_messages=True)
message_events = simulation_df[simulation_df["event_type"] == "message_sent"].copy()

n_players = int(simulation_df["empirica_id"].nunique())
# empirica_to_sona = dict(zip(simulation_df["empirica_id"].unique(), simulation_df[player_name_col].unique()))
n_rounds = int(simulation_df["chat_round_order"].max())

def score(text1, text2):
    embeddings = model.encode([text1, text2], task="text-matching")
    return (embeddings[0] @ embeddings[1].T).item()


plt.figure(figsize=(12, 6))

message_events, round_separators = util.get_chat_order_and_separators(message_events)
util.plot_round_separators(round_separators, 1.01)

plt.ylim(-0.01, 1.01)
plt.xlabel("Time Step (Chat Order)")
plt.ylabel("Similarity")
plt.title("Human Per-group Similarity\n")

round_similarities = []
record_similarities = []

for round_num in range(1, n_rounds + 1):
    round_events = message_events[message_events["chat_round_order"] == round_num]

    # generate pairs of players (A, B) without (B, A) duplicates
    pairs = round_events.groupby(["sender_id", "recipient_id"]).size().reset_index().iloc[:, :2]
    pairs = pairs.apply(lambda row: tuple(sorted([row["sender_id"], row["recipient_id"]])), axis=1).drop_duplicates().tolist()
    pair_similarities = []

    for player_a, player_b in pairs:
        similarity_scores = []

        # filter messages between player A and player B
        messages = round_events[(round_events["sender_id"] == player_a) | (round_events["sender_id"] == player_b)].sort_values("event_order")

        # compute similarity scores for each two messages
        for i in range(1, messages.shape[0]):
            similarity = score(messages.iloc[i-1]["text"], messages.iloc[i]["text"])
            order = messages.iloc[i]["chat_order"] if messages.iloc[i]["chat_order"] > messages.iloc[i-1]["chat_order"] else messages.iloc[i-1]["chat_order"] + 0.5
            similarity_scores.append({
                "chat_order": order,
                "similarity": similarity,
            })
            record_similarities.append({
                "round": round_num,
                "chat_order": order,
                "message_order": i,
                "similarity": similarity,
                "player_a": player_a,
                "player_b": player_b,
                "text_a": messages.iloc[i-1]["text"],
                "text_b": messages.iloc[i]["text"],
            })

        similarity_df = pd.DataFrame(similarity_scores)
        if similarity_df.shape[0] > 0:
            plt.plot(similarity_df["chat_order"], similarity_df["similarity"], label=f"{player_a}-{player_b}", marker='.', linestyle='-')
            pair_similarities.append(similarity_df["similarity"].mean())

    round_similarities.append(np.mean(pair_similarities))

util.save_csv(pd.DataFrame(record_similarities), os.path.join(output_path, data_prefix, "pair_similarities.csv"))

plt.legend()
plt.savefig(os.path.join(output_path, data_prefix, "per_group.svg"))


# make another plot for round similarities
plt.figure()
plt.plot(range(1, n_rounds + 1), round_similarities, marker='o', linestyle='-')
plt.xlabel("Round")
plt.ylabel("Mean Similarity")
plt.ylim(-0.01, 1.01)
plt.title("Human Across-group Similarity\n")
plt.savefig(os.path.join(output_path, data_prefix, "across_group.svg"))
