"""
plot_utils.py
Utility functions for plotting experimental results (e.g., accuracy vs. rank, distillation weight, etc.)
"""

import matplotlib.pyplot as plt
import os

def plot_accuracy_vs_rank(ranks, accuracies_dict, title="Accuracy vs. LoRA Rank", save_path=None):
    """
    Plots accuracy curves for different methods across LoRA ranks.
    """
    plt.figure(figsize=(6, 4))
    for method, accs in accuracies_dict.items():
        plt.plot(ranks, accs, marker='o', label=method)

    plt.xlabel("LoRA Rank")
    plt.ylabel("Accuracy (%)")
    plt.title(title)
    plt.grid(True)
    plt.legend()
    if save_path:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(save_path)
    else:
        plt.show()

def plot_alpha_vs_accuracy(alphas, accuracies_dict, title="Accuracy vs. Distillation Weight", save_path=None):
    """
    Plots accuracy vs. alpha (distillation weight) curves.
    """
    plt.figure(figsize=(6, 4))
    for setting, accs in accuracies_dict.items():
        plt.plot(alphas, accs, marker='^', label=setting)

    plt.xlabel("Distillation Weight (alpha)")
    plt.ylabel("Accuracy (%)")
    plt.title(title)
    plt.grid(True)
    plt.legend()
    if save_path:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(save_path)
    else:
        plt.show()
