#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Purpose:
Analyze usage-sensitive neurons per layer across models by counting polarity
(positive vs. negative mean activation differences). Outputs per-model CSVs and
a combined summary CSV for all models.
"""

import os
import pandas as pd

# ==== Configuration ====
FILES = {
    "Gemma-7B-IT": "outputs/usage_neurons_new/Gemma-7B-IT/Gemma-7B-IT_all_usage_neurons.csv",
    "LLaMA-3-8B": "outputs/usage_neurons_new/LLaMA-3-8B/LLaMA-3-8B_all_usage_neurons.csv",
    "Mistral-7B": "outputs/usage_neurons_new/Mistral-7B/Mistral-7B_all_usage_neurons.csv",
}

OUT_DIR = "fig/stats_neurons"
os.makedirs(OUT_DIR, exist_ok=True)


def analyze_file(model_name: str, path: str) -> pd.DataFrame:
    """Analyze one model file and save per-layer polarity counts."""
    df = pd.read_csv(path)
    usage_cols = [c for c in df.columns if c.startswith("diff_")]
    if not usage_cols:
        raise ValueError(f"No diff_* columns found in {path}")

    results = []
    for layer, group in df.groupby("layer"):
        mean_vals = group[usage_cols].mean(axis=1)
        pos_count = (mean_vals > 0).sum()
        neg_count = (mean_vals < 0).sum()
        total = len(group)
        results.append({
            "Model": model_name,
            "Layer": layer,
            "Positive": pos_count,
            "Negative": neg_count,
            "Total": total,
            "Pos/Neg Ratio": round(pos_count / neg_count, 2) if neg_count > 0 else "Inf",
        })

    res_df = pd.DataFrame(results).sort_values("Layer")
    out_path = os.path.join(OUT_DIR, f"{model_name}_per_layer_polarity.csv")
    res_df.to_csv(out_path, index=False)
    print(f"[{model_name}] Per-layer polarity saved -> {out_path}")
    return res_df


def main():
    all_stats = [analyze_file(model, fpath) for model, fpath in FILES.items()]
    summary = pd.concat(all_stats, ignore_index=True)
    summary_path = os.path.join(OUT_DIR, "all_models_per_layer_polarity.csv")
    summary.to_csv(summary_path, index=False)
    print(f"Summary saved -> {summary_path}")


if __name__ == "__main__":
    main()
