import pandas as pd
import argparse
import itertools
import matplotlib.pyplot as plt
import numpy as np
import os
from scipy.interpolate import make_interp_spline


class Summary:
    def __init__(self, csv_path):
        summary_df = pd.read_csv(csv_path)
        unq_dataset = summary_df["dataset"].unique()
        unq_temp = summary_df["temperature"].unique()
        combinations = list(itertools.product(unq_dataset, unq_temp))
        combinations.sort()
        self.combinations = combinations

        # split the large dataframe into smaller dataframes based on the combinations
        split_df = {}
        for combo in combinations:
            dataset, temp = combo
            sub_df = summary_df[(summary_df['dataset'] == dataset) & (summary_df['temperature'] == temp)].copy()
            split_df[combo] = sub_df

        # calculate the time for decoding and the SD speedup for decoding
        for combo, sub_df in split_df.items():
            sub_df["ar_decode_time"] = sub_df["ar_time"] - sub_df["prefill_time"]
            sub_df["sd_decode_time"] = sub_df["total_time"] - sub_df["prefill_time"]
            sub_df["speedup"] = sub_df["ar_decode_time"] / sub_df["sd_decode_time"]

        # This is the final dataframe dictionary
        self.df_dict = split_df
    

    def _plot_metric(self, ax, x, y, label, color, marker):
        x_smooth = np.linspace(x.min(), x.max(), 300)
        spl = make_interp_spline(x, y, k=2)
        y_smooth = spl(x_smooth)
        ax.plot(x_smooth, y_smooth, label=label, color=color, linestyle='--', linewidth=2, alpha=0.9)
        ax.scatter(x, y, color=color, alpha=0.9, marker=marker)


    def _load_dense_df_dicts(self, csv_path):
        '''
        Load the dense model summary from the CSV file and compute the metric.
        '''
        summary_df = pd.read_csv(csv_path)
        unq_dataset = summary_df["dataset"].unique()
        unq_temp = summary_df["temperature"].unique()
        combinations = list(itertools.product(unq_dataset, unq_temp))
        combinations.sort()
        self.combinations = combinations

        # split the large dataframe into smaller dataframes based on the combinations
        split_df = {}
        for combo in combinations:
            dataset, temp = combo
            sub_df = summary_df[(summary_df['dataset'] == dataset) & (summary_df['temperature'] == temp)].copy()
            split_df[combo] = sub_df

        # calculate the metric
        for combo, sub_df in split_df.items():
            sub_df["ar_decode_time"] = sub_df["ar_time"] - sub_df["prefill_time"]
            sub_df["metric"] = ((sub_df['ar_decode_time'] / 512 * 1000) / sub_df['scoring_time_ms'])

        # This is the final dataframe dictionary
        return split_df


    def plot_dense_vs_moe(self, output_dir, dense_csv_path):
        '''
        Generate the plot comparing the metric Target Efficiency between the dense and MoE models.
        It corresponds to Figure 2c in the paper.
        '''

        # The experiment setting to plot. Can be changed to other settings.
        combo = ('humaneval', 0)
        dataset, temp = combo
        spectoken = 4

        sub_df = self.df_dict[combo]
        sub_df['metric'] = ((sub_df['ar_decode_time'] / 512 * 1000) / sub_df['scoring_time_ms'])
        comp_df_dict = self._load_dense_df_dicts(dense_csv_path)
        comp_sub_df = comp_df_dict[combo]

        fig, ax = plt.subplots(1, 1, figsize=(5.3, 4)) 

        tmp_df = sub_df[sub_df['num_speculative_tokens'] == spectoken]
        comp_tmp_df = comp_sub_df[comp_sub_df['num_speculative_tokens'] == spectoken]

        self._plot_metric(ax, tmp_df['num_prompts'], tmp_df['metric'], label="MoE", color='brown', marker='o')
        self._plot_metric(ax, comp_tmp_df['num_prompts'], comp_tmp_df['metric'], label="dense", color='purple', marker='x')

        ax.set_xlabel(r'Batch size $B$', fontsize=12)
        ax.set_ylabel('Target Efficiency', fontsize=12)
        ax.grid()
        ax.legend(loc='upper right', fontsize=11)
        ax.set_title(fr"{dataset}, temp={temp}, $\gamma=${spectoken}", fontsize=12)
        ax.tick_params(axis='both', labelsize=11)

        fig.tight_layout()

        # Create the output directory if it doesn't exist
        os.makedirs(output_dir, exist_ok=True)
        fig.savefig(f"{output_dir}/metric_moe_vs_dense.pdf", format='pdf')
        fig.savefig(f"{output_dir}/metric_moe_vs_dense.svg", format='svg')
        fig.savefig(f"{output_dir}/metric_moe_vs_dense.png", format='png')

        print(f"Output saved in {output_dir}")



def main():
    parser = argparse.ArgumentParser(description="CSV Merger and Validator")
    parser.add_argument("--dense-csv-path", default="./csv_results/summary/summary.csv", help="The summary CSV file for dense models")
    parser.add_argument("--moe-csv-path", default="../moe/csv_results/summary/summary.csv", help="The summary CSV file for moe models")
    parser.add_argument("--output-dir", default="./plot", help="The directory to save the plot")
    
    args = parser.parse_args()

    summary = Summary(args.moe_csv_path)


    print("Generating plot...")
    summary.plot_dense_vs_moe(args.output_dir, args.dense_csv_path)
    print(f"Plot saved in {args.output_dir}.")


if __name__ == "__main__":
    main()