import multiprocessing
import random
import numpy as np
import json
import scipy.stats
from tqdm import tqdm
import json

with open("NATSBench_TSS_per_layer.json") as f:
    stats = json.load(f)

data = {
    "effective_ranks": [],
    "accuracy": []
}

keys = list(stats.keys())
random.shuffle(keys)
for model in keys:
    if len(stats[model]["effective_rank"]) == 68:
        data["effective_ranks"].append(stats[model]["effective_rank"])
        data["accuracy"].append(stats[model]["accuracy"])

X = np.array(data["effective_ranks"])
Y = data["accuracy"]

n_features = X.shape[1]
n_iterations = 100_000

feature_importance = np.zeros((n_features, 2))
num_included = np.zeros(n_features + 1)

def process_subset(dummy):
    size = 6
    subset_indices = np.random.choice(n_features, size=size, replace=False)
    subset_indices = sorted(subset_indices)
    subset_X = np.sum(X[:, subset_indices], axis=1)
    corr, _ = scipy.stats.spearmanr(subset_X, Y)
    return subset_indices, np.abs(corr)

def update_feature_importance(result):
    subset_indices, corr = result
    feature_importance[subset_indices, 0] += corr
    feature_importance[subset_indices, 1] = np.where(feature_importance[subset_indices, 1] < corr, corr, feature_importance[subset_indices, 1])
    num_included[subset_indices] += 1


for i in tqdm(range(100_000)):
    pool = multiprocessing.Pool()

    results = pool.map_async(process_subset, range(n_iterations))

    pool.close()
    pool.join()

    for result in results.get():
        update_feature_importance(result)
    
    num_included[-1] += n_iterations
    np.savetxt("result.out", feature_importance, delimiter=",") 
    np.savetxt("stats.out", num_included, delimiter=",")


# Normalize the importance scores
total_selected = sum(feature_importance)
feature_importance /= total_selected

# Print the feature importance
for idx in np.argsort(-feature_importance):
    print(f"Feature {idx}: Importance = {feature_importance[idx]:.4f}")
