import pandas as pd
import argparse
import itertools
import matplotlib.pyplot as plt
import numpy as np
import os
from scipy.interpolate import make_interp_spline


class Summary:
    def __init__(self, csv_path):
        summary_df = pd.read_csv(csv_path)
        unq_dataset = summary_df["dataset"].unique()
        unq_temp = summary_df["temperature"].unique()
        combinations = list(itertools.product(unq_dataset, unq_temp))
        combinations.sort()
        self.combinations = combinations

        # split the large dataframe into smaller dataframes based on the combinations
        split_df = {}
        for combo in combinations:
            dataset, temp = combo
            sub_df = summary_df[(summary_df['dataset'] == dataset) & (summary_df['temperature'] == temp)].copy()
            split_df[combo] = sub_df

        # calculate the time for decoding and the SD speedup for decoding
        for combo, sub_df in split_df.items():
            sub_df["ar_decode_time"] = sub_df["ar_time"] - sub_df["prefill_time"]
            sub_df["sd_decode_time"] = sub_df["total_time"] - sub_df["prefill_time"]
            sub_df["speedup"] = sub_df["ar_decode_time"] / sub_df["sd_decode_time"]

        # This is the final dataframe dictionary
        self.df_dict = split_df

    
    def get_speedup_summary(self):
        '''
        Generate the speedu summary table. It corresponds to Table 1 in Experiments in the paper.
        '''
        best_speedup_df = pd.DataFrame()
        for combo, sub_df in self.df_dict.items():
            dataset, temp = combo
            # Our script only run cases where speculative tokens is 2, 3, or 4
            for spectocken in [2,3,4]:
                tmp_df = sub_df[sub_df['num_speculative_tokens'] == spectocken]
                max_row = tmp_df[tmp_df['speedup'] == tmp_df['speedup'].max()]
                best_speedup_df = pd.concat([best_speedup_df, max_row], ignore_index=True)

        # Drop unnecessary columns
        best_speedup_df = best_speedup_df[['dataset', 'temperature', 'num_speculative_tokens', 'num_prompts', 'ar_decode_time', 'sd_decode_time', 'system_efficiency', 'speedup']]
        return best_speedup_df
    

    def _plot_GPU(self, ax, x, y):
        '''
        Connect the GPU results with smooth line and scatter the points.
        '''
        x_smooth = np.linspace(x.min(), x.max(), 300)
        spl = make_interp_spline(x, y, k=2)
        y_smooth = spl(x_smooth)
        ax.plot(x_smooth, y_smooth, label="GPU results", color='blue')
        ax.scatter(x, y, color='gray', alpha=0.8)


    def _plot_metric(self, ax, x, y):
        '''
        Connect the metric results with smooth line and scatter the points.
        '''
        x_smooth = np.linspace(x.min(), x.max(), 300)
        spl = make_interp_spline(x, y, k=2)
        y_smooth = spl(x_smooth)
        ax.plot(x_smooth, y_smooth, label="Target Efficiency", color='red', linestyle='--', alpha=0.6)
        ax.scatter(x, y, color='red', alpha=0.7)

    def generate_plot(self, device, output_dir):
        '''
        Generate the plot for the GPU results and the metric Target Efficiency.
        It corresponds to Figure 2 in the paper.
        '''

        # The experiment setting to plot. Can be changed to other settings.
        combo = ('humaneval', 0)
        sub_df = self.df_dict[combo]
        # Compute the metric Target Efficiency using the vllm log reports
        sub_df['metric'] = ((sub_df['ar_decode_time'] / 512 * 1000) / sub_df['scoring_time_ms'])

        # Create a new figure
        fig, ax = plt.subplots(1, 1, figsize=(5.3,4))

        dataset, temp = combo
        spec_token = 4          # use the speculative token = 4. Can be changed to other values.
        tmp_df = sub_df[sub_df['num_speculative_tokens'] == spec_token]

        ax_twin = ax.twinx()
        self._plot_GPU(ax, tmp_df['num_prompts'], tmp_df['speedup'])
        self._plot_metric(ax_twin, tmp_df['num_prompts'], tmp_df['metric'])
        ax.tick_params(axis='both', labelsize=11)
        ax_twin.tick_params(axis='both', labelsize=11)
        ax_twin.yaxis.set_tick_params(colors='red')

        ax.set_xlabel(r'Batch size $B$', fontsize=12)
        ax.set_ylabel('Speedup', fontsize=12)
        ax_twin.set_ylabel('Target Efficiency', color='red', fontsize=12)
        ax.grid()
        ax.set_title(fr"{device}, {dataset}, temp={temp}, $\gamma=${spec_token}", fontsize=12)
        ax.legend(loc='center left', bbox_to_anchor=(0.45, 0.9), fontsize=11)
        ax_twin.legend(loc='center left', bbox_to_anchor=(0.45, 0.78), fontsize=11)


        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        fig.tight_layout()
        fig.savefig(f"{output_dir}/real_vs_metric_{device}.pdf", format='pdf')
        fig.savefig(f"{output_dir}/real_vs_metric_{device}.svg", format='svg')
        fig.savefig(f"{output_dir}/real_vs_metric_{device}.png", format='png')



def main():
    parser = argparse.ArgumentParser(description="CSV Merger and Validator")
    parser.add_argument("--csv-path", default="./csv_results/summary/summary.csv", help="The summary CSV file to be processed")
    parser.add_argument("--output-dir", default="./plot", help="The directory to save the plot")
    parser.add_argument("--gpu", default="GPU", help="The experiments are carried out on which kind of GPU")
    
    args = parser.parse_args()

    summary = Summary(args.csv_path)
    print(f"Speedup summary of {args.gpu}:")
    print(summary.get_speedup_summary())

    print("Generating plot...")
    summary.generate_plot(args.gpu, args.output_dir)
    print(f"Plot saved in {args.output_dir}.")


if __name__ == "__main__":
    main()