import os
import json
import pickle as pkl
from argparse import ArgumentParser
import numpy as np
from tqdm import tqdm
from transformers import set_seed

from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score



parser = ArgumentParser()
parser.add_argument("--data_dir", type=str)
args = parser.parse_args()


set_seed(42)
train_data = pkl.load(open(os.path.join(args.data_dir, "probes_training_data.pkl"), "rb"))
labels = pkl.load(open(os.path.join(args.data_dir, "labels.pkl"), "rb"))
labels = np.array(labels)
results = {}

for module in tqdm(train_data.keys(), desc="Training probes...."):
    X = np.array(train_data[module])
    shuffled_indices = np.random.permutation(X.shape[0])
    X = X[shuffled_indices]
    y = labels[shuffled_indices]

    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
    model = RandomForestClassifier(n_estimators=100, random_state=42)
    model.fit(X_train, y_train)

    y_pred = model.predict(X_test)
    accuracy = accuracy_score(y_test, y_pred)
    print(f"Module: {module}, Accuracy: {accuracy:.4f}")
    results[module] = accuracy


# sort from high to low
results = dict(sorted(results.items(), key=lambda item: item[1], reverse=True))
print("total modules: ", len(results))  
save_dir = args.data_dir

with open(os.path.join(save_dir, "probes_results.json"), "w") as f:
    json.dump(results, f, indent=4)

contrastive_targets = [key for key, _ in results.items()]
with open(os.path.join(save_dir, "contrastive_targets.pkl"), "wb") as f:
    pkl.dump(contrastive_targets, f)

    
