#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os, json, argparse
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModel
from tqdm import tqdm
import pandas as pd

def load_usage_dataset(path):
    with open(path, "r", encoding="utf-8") as f:
        data = json.load(f)
    texts, labels = [], []
    for item in data:
        text = item.get("text") or item.get("utterance")
        label = item.get("label")
        if isinstance(label, str):
            if label.lower() not in {"positive", "negative"}:
                continue
            label = 1 if label.lower() == "positive" else 0
        elif isinstance(label, int):
            pass
        else:
            continue
        if text:
            texts.append(text)
            labels.append(label)
    return texts, np.array(labels)


def batch_encode(texts, tokenizer, model, num_layers, batch_size, max_len):
    hidden_size = model.config.hidden_size
    all_embs = [np.empty((len(texts), hidden_size), dtype=np.float32) for _ in range(num_layers)]
    for i in tqdm(range(0, len(texts), batch_size), desc="🧠 Encode"):
        batch = texts[i:i+batch_size]
        inputs = tokenizer(batch, return_tensors="pt", padding="max_length",
                           truncation=True, max_length=max_len).to(model.device)
        with torch.no_grad():
            hs = model(**inputs).hidden_states  # tuple: L+1 层
        mask = inputs["attention_mask"].unsqueeze(-1)
        for l in range(num_layers):
            h = hs[l+1]  # 从 layer 1 开始
            pooled = (h * mask).sum(1) / mask.sum(1)
            all_embs[l][i:i+batch_size] = pooled.cpu().numpy()
    return all_embs  # list: [layer, N, H]


def analyze_usage_neurons(args):
    tokenizer = AutoTokenizer.from_pretrained(args.model_id)
    tokenizer.pad_token = tokenizer.eos_token
    model = AutoModel.from_pretrained(args.model_id,
                                      output_hidden_states=True).eval().to(args.device)
    num_layers = model.config.num_hidden_layers

    
    usage_files = {
        "genre": args.genre_file,
        "topic": args.topic_file,
        "tone": args.tone_file,
        "contextual": args.contextual_file,
    }

    
    usage_data = {}
    usage_embs = {}
    usage_labels = {}

    for usage, path in usage_files.items():
        texts, labels = load_usage_dataset(path)
        print(f"✅ Loaded {len(texts)} samples from {path}")
        if len(texts) == 0:
            continue
        usage_data[usage] = texts
        usage_labels[usage] = labels
        usage_embs[usage] = batch_encode(texts, tokenizer, model, num_layers,
                                         args.batch_size, args.max_len)

    
    results = []
    for layer_id in range(1, num_layers+1):
        print(f"\n🔍 Layer {layer_id}")
        for usage in usage_embs:
            X = usage_embs[usage][layer_id-1]  # [N, H]
            y = usage_labels[usage]
            if len(np.unique(y)) < 2:
                continue
            pos = X[y == 1].mean(0)
            neg = X[y == 0].mean(0)
            diff = pos - neg  # [H]
            # 取 top-k neuron
            topk_idx = np.argsort(-np.abs(diff))[:args.topk]
            for nid in topk_idx:
                results.append({
                    "layer": layer_id,
                    "usage": usage,
                    "neuron": nid,
                    "diff": diff[nid]
                })

    
    os.makedirs(args.out_dir, exist_ok=True)
    out_path = os.path.join(args.out_dir, f"{args.model_name}_usage_neurons.csv")
    pd.DataFrame(results).to_csv(out_path, index=False)
    print(f"\n✅ Saved usage neuron stats to {out_path}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_id", type=str, required=True)
    parser.add_argument("--model_name", type=str, required=True)
    parser.add_argument("--genre_file", type=str, required=True)
    parser.add_argument("--topic_file", type=str, required=True)
    parser.add_argument("--tone_file", type=str, required=True)
    parser.add_argument("--contextual_file", type=str, required=True)
    parser.add_argument("--out_dir", type=str, default="outputs/usage_neurons")
    parser.add_argument("--device", type=str, default="cuda")
    parser.add_argument("--batch_size", type=int, default=64)
    parser.add_argument("--max_len", type=int, default=128)
    parser.add_argument("--topk", type=int, default=20)
    args = parser.parse_args()

    analyze_usage_neurons(args)
