#%%
from sklearn.inspection import permutation_importance
from tree_depth import run
from data_loading import get_dataset
from model_training import train_all_models
from constants import y_test_parameters, y_train_parameters
from joblib.parallel import Parallel, delayed
from tqdm import tqdm


#%%

dataset_name = "weather"
X_train, y_train, X_test, y_test = get_dataset(dataset_name)

#%%
model_types = {"f" : "hist_gradient_boosting",
               "f_star_train" : "hist_gradient_boosting",
               "f_star_test" : "hist_gradient_boosting"
               }

model_params = {"f": {},
                "f_star_train" : y_train_parameters["weather"]["HGB"],
                "f_star_test" : y_test_parameters["weather"]["HGB"]
                }


f, f_star_train, f_star_test, X_eval, y_eval = train_all_models(model_types, X_train, y_train, X_test, y_test, model_params)
#%%
result = permutation_importance(f_star_test, X_eval[:10000], y_eval[:10000], n_repeats=5, random_state=42)

# %%
result
# %%
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
fig, ax = plt.subplots()
sorted_idx = result.importances_mean.argsort()

# `labels` argument in boxplot is deprecated in matplotlib 3.9 and has been
# renamed to `tick_labels`. The following code handles this, but as a
# scikit-learn user you probably can write simpler code by using `labels=...`
# (matplotlib < 3.9) or `tick_labels=...` (matplotlib >= 3.9).
feature_names = np.array([f"x_{i}" for i in range(X_eval.shape[1])])
tick_labels_parameter_name = (
    "tick_labels"
)
tick_labels_dict = {tick_labels_parameter_name: feature_names[sorted_idx]}
ax.boxplot(result.importances[sorted_idx].T, vert=False, **tick_labels_dict)
ax.set_title("Permutation Importance of each feature")
ax.set_ylabel("Features")
fig.tight_layout()
plt.show()
# %%
# Get the indices of the top 3 most important features
top_3_idx = result.importances_mean.argsort()[-3:][::-1]  # Sort and get the 3 largest values

# Get the names and importance scores of the top features
top_features = feature_names[top_3_idx]
top_importance = result.importances_mean[top_3_idx]

print("Top 3 most important features:")
for i, (feature, importance) in enumerate(zip(top_features, top_importance)):
    print(f"{i+1}. {feature}: {importance:.4f}")
# %%
# Load the actual feature names from the weather dataset
import pandas as pd
import os
try:
    # Try to load the original data with column names
    original_data = pd.read_csv("../data/weather/csv/X_num.csv")
    actual_feature_names = original_data.columns.tolist()
    
    # Map the top features to their actual names
    print("\nActual names of top features:")
    for i, (feature_idx, importance) in enumerate(zip(top_3_idx, top_importance)):
        actual_name = actual_feature_names[feature_idx]
        print(f"{i+1}. {actual_name} (feature {feature_idx}): {importance:.4f}")
except FileNotFoundError:
    print("Could not load the original feature names. File not found.")
except Exception as e:
    print(f"Error loading feature names: {e}")
# %%
