import os

import matplotlib.pyplot as plt


def plot_regression(regression_results: dict, output_dir: str = None):
    for task, result in regression_results.items():
        correlations_per_channel = result["correlations_per_channel"]
        scores_per_channel = result["scores_per_channel"]
        linear_regressor_per_channel = result["linear_regressor_per_channel"]

        output_task_dir = os.path.join(output_dir, task)
        os.makedirs(output_task_dir, exist_ok=True)

        for channel in correlations_per_channel.keys():
            plt.figure()
            plt.scatter(correlations_per_channel[channel],
                        scores_per_channel[channel])

            lowest_value = min(correlations_per_channel[channel])
            highest_value = max(correlations_per_channel[channel])

            plt.plot([lowest_value, highest_value],
                     linear_regressor_per_channel[channel].predict([[lowest_value], [highest_value]]))

            plt.xlabel("Average Pairwise Correlation")
            plt.ylabel("Score")
            plt.title(channel)

            plt.savefig(f"{output_task_dir}/regression_{channel}.png")
            plt.close()
