import os, glob
import numpy as np
import pandas as pd

# ==== Configure ====
BASE_DIR = "experiments/exp_sub_axis/Meta-Llama-3-8B-Instruct"
LAYER = 14
OUT_CSV = ""          # change if you like
OUT_PNG = ""  # optional (requires matplotlib)

os.makedirs(os.path.dirname(OUT_CSV), exist_ok=True)

# ==== Load all tone axes ====
axes = {}
for tone_dir in sorted(glob.glob(os.path.join(BASE_DIR, "*"))):
    if not os.path.isdir(tone_dir):
        continue
    tone = os.path.basename(tone_dir)
    path = os.path.join(tone_dir, f"sentiment_axis_L{LAYER}.npy")
    if os.path.exists(path):
        w = np.load(path).astype(np.float32)
        n = np.linalg.norm(w)
        if n > 0:
            axes[tone] = w / n

tones = sorted(axes.keys())
assert tones, f"No L{LAYER} axes found under {BASE_DIR}"

# ==== Cosine similarity matrix ====
n = len(tones)
cos = np.zeros((n, n), dtype=np.float32)
for i in range(n):
    for j in range(n):
        cos[i, j] = float(np.dot(axes[tones[i]], axes[tones[j]]))

df = pd.DataFrame(cos, index=tones, columns=tones)

print(f"\nCosine similarity among tones at Layer {LAYER}")
print(df.round(4))

# Save CSV
df.to_csv(OUT_CSV)
print(f"\n✅ Saved CSV to {OUT_CSV}")

# ==== Optional: heatmap figure ====
try:
    import matplotlib.pyplot as plt
    import seaborn as sns
    plt.figure(figsize=(6,5))
    sns.heatmap(df, annot=True, fmt=".2f", cmap="coolwarm", vmin=0, vmax=1,
                square=True, xticklabels=tones, yticklabels=tones)
    plt.title(f"Axes Cosine Similarity (Layer {LAYER}, LLaMA-3-8B)")
    plt.tight_layout()
    plt.savefig(OUT_PNG, dpi=300)
    plt.close()
    print(f"🖼  Saved heatmap to {OUT_PNG}")
except Exception as e:
    print(f"(Skip heatmap: {e})")
