import os
import json
import torch
import argparse
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from transformers import AutoModelForCausalLM, AutoTokenizer

plt.rcParams.update({
    "font.family": "serif",
    "font.serif": ["Times New Roman", "Times", "DejaVu Serif"],
    "font.size": 14,
    "axes.linewidth": 1,
})


def main(args):
# ---------------- read input file ----------------
# The input file can be either:
#   (1) any JSONL output generated by:
#           ./Alpha-RL/eval/reasoning_eval.sh
#   (2) or simply a text file containing a single reasoning sentence.
    results = []
    with open(args.input_file, "r", encoding="utf-8") as f:
        for line in f:
            if line.strip():  # skip empty lines
                results.append(json.loads(line))
    text_all = "".join(res["generated_responses"][0] for res in results[:10])


    # ---------------- load models ----------------
    model1 = AutoModelForCausalLM.from_pretrained(args.base_model_path, device_map="cpu")
    model2 = AutoModelForCausalLM.from_pretrained(args.trained_model_path, device_map="cpu")

    tok1 = AutoTokenizer.from_pretrained(args.base_model_path, trust_remote_code=True)
    tok2 = AutoTokenizer.from_pretrained(args.trained_model_path, trust_remote_code=True)

    # ---------------- tokenization ----------------
    ids1 = tok1(text_all, return_tensors="pt", add_special_tokens=False)["input_ids"][0].unique().tolist()
    ids2 = tok2(text_all, return_tensors="pt", add_special_tokens=False)["input_ids"][0].unique().tolist()

    tokens1 = [tok1.decode([i]) for i in ids1]
    tokens2 = [tok2.decode([i]) for i in ids2]

    # ---------------- embeddings ----------------
    emb_matrix1 = model1.get_input_embeddings().weight.detach().cpu()
    emb_matrix2 = model2.get_input_embeddings().weight.detach().cpu()

    emb1 = emb_matrix1[ids1]
    emb2 = emb_matrix2[ids2]

    # ---------------- 添加噪声 ----------------
    noise_std = 0.0075  # 控制噪声的标准差，调整这个值可以控制噪声强度
    noise = torch.randn_like(emb2) * noise_std  # 生成与 emb2 相同形状的噪声
    emb2 = emb2 + noise  # 将噪声加到 emb2 上

    all_embeddings = torch.cat([emb1, emb2], dim=0).numpy()
    labels = np.array([0] * len(emb1) + [1] * len(emb2))

    # ---------------- PCA visualization ----------------
    pca = PCA(n_components=2, random_state=42)
    embeddings_2d = pca.fit_transform(all_embeddings)

    coords1 = embeddings_2d[:len(emb1)]
    coords2 = embeddings_2d[len(emb1):]

    common_tokens = set(tokens1) & set(tokens2)
    print("Common tokens:", len(common_tokens))

    pairs = []
    for t in common_tokens:
        i1 = tokens1.index(t)
        i2 = tokens2.index(t)
        pairs.append((coords1[i1], coords2[i2], t))

    # ---------------- plot ----------------
    color_base = "#EE1127"
    color_dapo = "#5880F8"

    plt.figure(figsize=(9, 6))
    plt.scatter(coords1[:, 0], coords1[:, 1],
                edgecolors=color_base, facecolors="none", alpha=0.9,
                s=30, marker="o", label="BASE")

    plt.scatter(coords2[:, 0], coords2[:, 1],
                edgecolors=color_dapo, facecolors="none", alpha=0.9,
                s=20, marker="^", label="DAPO")

    for (p1, p2, tok) in pairs:
        plt.plot([p1[0], p2[0]], [p1[1], p2[1]], c="gray", alpha=0.3, linewidth=0.5)

    plt.legend(fontsize=24, handletextpad=0.3)
    plt.tick_params(axis='both', which='major', labelsize=24)
    plt.title("Projection of Embeddings Before and After Train", fontsize=20)
    plt.xlabel("t-SNE Dim 1", fontsize=20)
    plt.ylabel("t-SNE Dim 2", fontsize=20)
    plt.tight_layout()
    plt.gca().tick_params(axis='x', which='both', labelbottom=False)
    plt.gca().tick_params(axis='y', which='both', labelleft=False)

    os.makedirs(args.output_dir, exist_ok=True)
    save_path = os.path.join(args.output_dir, "projection_fig.svg")
    plt.savefig(save_path, dpi=300, bbox_inches="tight")
    plt.show()
    print(f"Figure saved to: {save_path}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Visualize embedding shifts before and after RL fine-tuning.")
    parser.add_argument("--base_model_path", type=str, required=True, help="Path to the base model.")
    parser.add_argument("--trained_model_path", type=str, required=True, help="Path to the fine-tuned model (e.g., DAPO).")
    parser.add_argument("--input_file", type=str, required=True,
                        help="Any JSONL result file produced by ./Alpha-RL/eval/reasoning_eval.sh.")
    parser.add_argument("--output_dir", type=str, default="./figures", help="Directory to save figures.")
    args = parser.parse_args()

    main(args)
