import os, json, random, numpy as np, torch, gc, argparse
from transformers import AutoTokenizer, AutoModel
from sklearn.linear_model import LogisticRegression
from tqdm import tqdm

# === 加载训练数据（仅二分类情感） ===
def load_context_data(filepaths):
    texts, labels = [], []
    for filepath in filepaths:
        with open(filepath, "r", encoding="utf-8") as f:
            data = json.load(f)
        for d in data:
            if d["label"] not in ["positive", "negative"]:
                continue
            texts.append(d["text"])
            labels.append(1 if d["label"] == "positive" else 0)
    return texts, np.array(labels)

# === 批量编码文本为每层的 CLS 向量 ===
def batch_encode(texts, tokenizer, model, num_layers, batch_size, max_len):
    hidden_size = model.config.hidden_size
    all_embs = [np.empty((len(texts), hidden_size), dtype=np.float32) for _ in range(num_layers)]
    for i in tqdm(range(0, len(texts), batch_size), desc="🧠 Encode"):
        batch = texts[i:i + batch_size]
        inputs = tokenizer(batch, return_tensors="pt", padding="max_length", truncation=True,
                           max_length=max_len).to(model.device)
        with torch.no_grad():
            hs = model(**inputs).hidden_states
        mask = inputs["attention_mask"].unsqueeze(-1)
        for l in range(num_layers):
            h = hs[l + 1]  # 从 Layer 1 开始
            pooled = (h * mask).sum(1) / mask.sum(1)
            all_embs[l][i:i + batch_size] = pooled.cpu().numpy()
    return all_embs

# === 主程序 ===
def main(args):
    # 固定随机数种子
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    os.makedirs(args.output_dir, exist_ok=True)

    tokenizer = AutoTokenizer.from_pretrained(args.model_id)
    tokenizer.pad_token = tokenizer.eos_token
    model = AutoModel.from_pretrained(args.model_id, output_hidden_states=True).eval().to(args.device)
    num_layers = model.config.num_hidden_layers

    print(f"📚 Training on {len(args.context_file)} datasets")
    train_texts, train_labels = load_context_data(args.context_file)

    print(f"📊 Total samples: {len(train_texts)}")
    print("📥 Encoding training data...")
    train_embs = batch_encode(train_texts, tokenizer, model, num_layers, args.batch_size, args.max_len)

    for layer_id in range(1, num_layers + 1):
        X_train = train_embs[layer_id - 1]
        clf = LogisticRegression(max_iter=2000, fit_intercept=False).fit(X_train, train_labels)
        axis = clf.coef_[0]

        # === 归一化 ===
        axis = axis / np.linalg.norm(axis)

        # === 方向对齐（确保正类平均投影 > 负类） ===
        pos_mean = np.mean(X_train[train_labels == 1] @ axis)
        neg_mean = np.mean(X_train[train_labels == 0] @ axis)
        if pos_mean < neg_mean:
            axis = -axis

        np.save(os.path.join(args.output_dir, f"sentiment_axis_L{layer_id}.npy"), axis)

    del model, tokenizer, train_embs
    torch.cuda.empty_cache()
    gc.collect()
    print(f"✅ Saved axes to {args.output_dir}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_id", type=str, required=True, help="模型名称或路径")
    parser.add_argument("--context_file", type=str, nargs="+", required=True, help="一个或多个训练用 JSON 文件路径")
    parser.add_argument("--output_dir", type=str, required=True, help="保存 axis 的目录")
    parser.add_argument("--device", type=str, default="cuda")
    parser.add_argument("--max_len", type=int, default=128)
    parser.add_argument("--batch_size", type=int, default=128)
    parser.add_argument("--seed", type=int, default=42)
    args = parser.parse_args()

    main(args)
