#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import pandas as pd
import numpy as np

# MODELS = ["Gemma-7B", "LLaMA-3-8B", "Mistral-7B"]
MODELS = ["Gemma-7B-IT"]
BASE_DIR = "outputs/usage_neurons_new"
USAGES = ["genre", "topic", "tone", "contextual"]

OUT_DIR = "outputs/flip_stable_neurons"
os.makedirs(OUT_DIR, exist_ok=True)


def classify_neurons(df, usages):
    flip_neurons, stable_neurons = [], []

    for _, row in df.iterrows():
        values = np.array([row[f"diff_{u}"] for u in usages])
        pos_count = (values > 0).sum()
        neg_count = (values < 0).sum()

        
        if pos_count == len(usages) or neg_count == len(usages):
            stable_neurons.append(int(row["neuron"]))

        
        elif pos_count >= 2 and neg_count >= 2:
            flip_neurons.append(int(row["neuron"]))

    return flip_neurons, stable_neurons


def process_model(model_name):
    print(f"\n🚀 Processing {model_name}")
    data_dir = os.path.join(BASE_DIR, model_name)
    out_model_dir = os.path.join(OUT_DIR, model_name)
    os.makedirs(out_model_dir, exist_ok=True)
    os.makedirs(os.path.join(out_model_dir, "flip"), exist_ok=True)
    os.makedirs(os.path.join(out_model_dir, "stable"), exist_ok=True)

    flip_all, stable_all = [], []

   
    for fname in sorted(os.listdir(data_dir)):
        if not fname.startswith("layer") or not fname.endswith("_usage_neurons.csv"):
            continue

        layer_id = int(fname.replace("layer", "").replace("_usage_neurons.csv", ""))
        df = pd.read_csv(os.path.join(data_dir, fname))

        flip, stable = classify_neurons(df, USAGES)

        
        with open(os.path.join(out_model_dir, "flip", f"L{layer_id:02d}.txt"), "w") as f:
            f.write("\n".join(map(str, flip)))
        with open(os.path.join(out_model_dir, "stable", f"L{layer_id:02d}.txt"), "w") as f:
            f.write("\n".join(map(str, stable)))

        
        flip_all.extend([(layer_id, n) for n in flip])
        stable_all.extend([(layer_id, n) for n in stable])

        print(f"✅ {model_name} Layer {layer_id}: flip={len(flip)}, stable={len(stable)}")

    
    pd.DataFrame(flip_all, columns=["layer", "neuron"]).to_csv(
        os.path.join(out_model_dir, "flip_neurons_global.csv"), index=False
    )
    pd.DataFrame(stable_all, columns=["layer", "neuron"]).to_csv(
        os.path.join(out_model_dir, "stable_neurons_global.csv"), index=False
    )
    print(f"🎉 {model_name}, flip={len(flip_all)}, stable={len(stable_all)}")


if __name__ == "__main__":
    for model in MODELS:
        process_model(model)
