import pandas as pd
import itertools
import matplotlib.pyplot as plt
import math
from scipy.interpolate import make_interp_spline
import numpy as np
import argparse
import os


def smooth_plot(ax, df, title):
    '''
    Smooth plot for the speedup.
    '''
    x = df['num_prompts']
    y = df['speedup']
    x_smooth = np.linspace(x.min(), x.max(), 300)
    spl = make_interp_spline(x, y, k=2)
    y_smooth = spl(x_smooth)
    ax.plot(x_smooth, y_smooth, label='Speedup', color='blue')
    ax.scatter(x, y, color='gray', alpha=0.8)
    ax.set_xlabel('Number of Prompts')
    ax.set_ylabel('Speedup')

    ax.set_title(title)
    ax.grid()


def main():

    parser = argparse.ArgumentParser(description="CSV Merger and Validator")
    parser.add_argument("--csv-path", default="./csv_results/summary/summary.csv", help="The summary CSV file to be processed")
    parser.add_argument("--output-dir", default="./plot", help="The directory to save the plot")
    
    args = parser.parse_args()


    summary_df = pd.read_csv(args.csv_path)

    # get different combinations of exp_num and num_speculative_tokens
    unq_expnum = summary_df["exp_num"].unique()
    unq_spectoken = summary_df["num_speculative_tokens"].unique()
    combinations = list(itertools.product(unq_expnum, unq_spectoken))
    combinations.sort()
    split_df = {}
    for combo in combinations:
        expnum, spectoken = combo
        sub_df = summary_df[(summary_df['exp_num'] == expnum) & (summary_df['num_speculative_tokens'] == spectoken)].copy()
        split_df[combo] = sub_df

    # save the real system efficiency(namely, sigma in paper)
    golden_df_dict = {}
    for spectoken in [2, 4]:
        golden_df_dict[spectoken] = split_df[(8, spectoken)].copy()[['num_prompts', 'system_efficiency']].rename(columns={'system_efficiency': 'real_system_efficiency'})

    # caliberate speedup with correct system efficiency
    calib_df_dict = {}
    for combo, sub_df in split_df.items():
        expnum, spectoken = combo
        merged_df = pd.merge(sub_df, golden_df_dict[spectoken], on='num_prompts', how='left')
        calib_df_dict[combo] = merged_df

    # compute the speedup and apply the corrected system efficiency
    for combo, sub_df in calib_df_dict.items():
        sub_df["ar_decode_time"] = sub_df["ar_time"] - sub_df["prefill_time"]
        sub_df["sd_decode_time"] = sub_df["total_time"] - sub_df["prefill_time"]
        sub_df["calib_sd_decode_time"] = sub_df["sd_decode_time"] / sub_df["real_system_efficiency"] * sub_df["system_efficiency"]
        sub_df["speedup"] = sub_df["ar_decode_time"] / sub_df["calib_sd_decode_time"]

    # draw the speedup plot
    fig, ax = plt.subplots(6, 2, figsize=(7, 15))      # 6 for exp_num, 2 for num_speculative_tokens
    for combo, sub_df in calib_df_dict.items():
        expnum, spectoken = combo
        row = int(math.log(expnum, 2))
        col = 0 if spectoken == 2 else 1
        smooth_plot(ax[row, col], sub_df, rf"K={expnum}, $rho$={expnum}/64, $\gamma$={spectoken}")

    fig.tight_layout()

    os.makedirs(args.output_dir, exist_ok=True)
    fig.savefig(f"{args.output_dir}/speedup_for_varied_sparsity.pdf", format='pdf')
    fig.savefig(f"{args.output_dir}/speedup_for_varied_sparsity.svg", format='svg')
    fig.savefig(f"{args.output_dir}/speedup_for_varied_sparsity.png", format='png')

    print(f"Plots saved to dir: {args.output_dir}")



if __name__ == '__main__':
    main()

