
import matplotlib.pyplot as plt

def plot_epochs_delta_multi(data_bias, data_fair, baseline_bias=None, baseline_fair=None, filename=None):
    """
    Plots (epoch, Δ) pairs for bias and fair data with different colors and baselines.

    Args:
        data_bias (list of tuple): List of (epoch, delta) pairs for bias.
        data_fair (list of tuple): List of (epoch, delta) pairs for fair.
        baseline_bias (float, optional): Baseline for bias.
        baseline_fair (float, optional): Baseline for fair.
        filename (str, optional): Filename for saving the plot.
    """
    plt.figure(figsize=(8, 5))
    font = {'size': 12}
    plt.rc('font', **font)
    if data_bias:
        epochs_b, deltas_b = zip(*data_bias)
        plt.plot(epochs_b, deltas_b, marker='o', color='blue', label=r'$\Delta$ for synthetic (Bias)')
        if baseline_bias is not None:
            plt.axhline(y=baseline_bias, color='blue', linestyle='-.', label=r'$\Delta$ for reference (Bias)')
    if data_fair:
        epochs_f, deltas_f = zip(*data_fair)
        plt.plot(epochs_f, deltas_f, marker='s', color='green', label=r'$\Delta$ for synthetic (Fair)')
        if baseline_fair is not None:
            plt.axhline(y=baseline_fair, color='green', linestyle='--', label=r'$\Delta$ for reference (Fair)')
    plt.xlabel('Epoch', fontsize=12)
    plt.ylabel(r'$\Delta$', fontsize=12)
    plt.ylim(0.0, 1.0)
    plt.title(f'{filename}: Epoch vs $\\Delta$' if filename else r'Epoch vs $\Delta$', fontsize=14)
    plt.legend(fontsize=12)
    plt.grid(True)
    plt.tight_layout()
    if filename:
        plt.savefig(f"datasets/{filename}.pdf", format='pdf')
    plt.show()


def plot_epochs_delta(data, baseline=None, filename=None):
    """
    Plots (epoch, Δ) pairs and draws a red baseline if provided.

    Args:
        data (list of tuple): List of (epoch, delta) pairs.
        baseline (float, optional): Value for the red baseline.
    """
    epochs, deltas = zip(*data)
    plt.figure(figsize=(8, 5))
    plt.plot(epochs, deltas, marker='o', label=r'$\Delta$ per Epoch')
    if baseline is not None:
        plt.axhline(y=baseline, color='red', linestyle='--', label=r'$\Delta$ for reference')
    plt.xlabel('Epoch')
    plt.ylabel(r'$\Delta$')
    plt.ylim(0.0, 1.0)
    plt.title(f'{filename}: Epoch vs $\\Delta$' if filename else r'Epoch vs $\Delta$')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    if filename:
        plt.savefig(f"datasets/{filename}.pdf", format='pdf')
    plt.show()

# Example usage:
if __name__ == "__main__":
    # Replace this with your actual data
    data_bias = [(1, 0.5), (2, 0.45), (3, 0.4), (4, 0.38), (5, 0.35)]
    data_fair = [(1, 0.6), (2, 0.55), (3, 0.5), (4, 0.48), (5, 0.45)]
    baseline_bias = 0.4
    baseline_fair = 0.5
    plot_epochs_delta_multi(data_bias, data_fair, baseline_bias, baseline_fair)