class SignificantFeatureJob(Job):
    """Compute the most significant feature."""

    name = "significant_feature"

    def __init__(self, metrics_df: pd.DataFrame) -> None:
        self.metrics_df = metrics_df

    def run(self, task_id: int | None = None) -> tuple[Explanation, pd.DataFrame]:
        """Fit a RandomForestRegressor to compute the shapely values."""
        x_features = self.metrics_df.drop(columns=["score"])
        target = self.metrics_df["score"]

        # Convert categorical features to one-hot encoding
        categorical_features = ["backbone", "stride"]
        encoder = OneHotEncoder(sparse_output=False)
        x_encoded = pd.DataFrame(encoder.fit_transform(x_features[categorical_features]))
        x_encoded.columns = encoder.get_feature_names_out(categorical_features)

        # Combine encoded categorical features with the rest of the data
        x_features = x_features.drop(columns=categorical_features)
        x_features = pd.concat([x_features, x_encoded], axis=1)

        # We don't split the data as the random forest is a proxy to compute the shapely values
        # Train the model
        model = RandomForestRegressor()
        model.fit(X=x_features, y=target)

        # Initialize the SHAP explainer
        explainer = shap.Explainer(model, x_features)

        # Calculate SHAP values
        shap_values = explainer(x_features)
        return shap_values, x_features

    @staticmethod
    def collect(results: list[tuple[Explanation, pd.DataFrame]]) -> Explanation:
        """We only need to run this once."""
        return results[0]

    @staticmethod
    def save(results: tuple[Explanation, pd.DataFrame]) -> None:
        """Save the results in a plot."""
        shapely_values, features = results
        shap.summary_plot(shapely_values, features, show=False, plot_type="violin")
        file_path = Path("runs") / SignificantFeatureJob.name
        file_path.mkdir(parents=True, exist_ok=True)
        plt.savefig(file_path / "summary_plot.png")
