import pandas as pd
import matplotlib.pyplot as plt
import argparse

def plot_kernel_results(file: str):
    # Read the CSV file
    output_file = file.replace('.csv', '.png')
    gpu_name = file.split('(')[1].split(')')[0]
    tensor_shape = file.split('(')[2].split(')')[0]
    batch_size = file.split('batch_size_')[1].split('/')[0]
    df = pd.read_csv(file)

    # Plotting
    x_vals=[i*(1/len(df)) for i in range(len(df))]  # Different possible values for `x_name`.
    line_arg='provider'  # Argument name whose value corresponds to a different line in the plot.
    line_vals=['WINA', 'ZEAL', "Dense"]#, 'Theoretical Optimal']  # Possible values for `line_arg`.
    line_names=['WINA', 'TEAL', "Dense"]#, "torch.matmul", 'Theoretical Optimal via torch.matmul']  # Label name for the lines.
    styles=[('orange', '-'), ('blue', '-'), ('green', '-')]#, ('red', '-')]  # Line styles.
    
    for (line_val, style, line_name) in zip(line_vals, styles, line_names):
        plt.plot(x_vals, df[line_val], label=line_name, color=style[0], linestyle=style[1])
    plt.title(f'{gpu_name} ({tensor_shape}) (Batch Size = {batch_size})')
    plt.xlabel('Sparsity')
    plt.ylabel('Latency (ms)')
    plt.legend()
    plt.grid(True)
    plt.savefig(output_file)
    plt.close()
    print(f"Plot saved to {output_file}")
    
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Plot kernel results from a CSV file.")
    parser.add_argument("--file", type=str, help="Path to the CSV file containing kernel results.")
    args = parser.parse_args()

    plot_kernel_results(args.file)