import pandas as pd
import itertools
import matplotlib.pyplot as plt
import numpy as np
import os
from scipy.interpolate import make_interp_spline
from scipy.optimize import fsolve
from scipy.optimize import least_squares
from collections import Counter
from Modeling import get_target_model_info, get_draft_model_volume
import math

class Checker:
    def __init__(self, csv_path, target_model_path, draft_model_path):
        '''
        Load the experiment summary CSV file and perform necessary data processing.
        '''
        # load sparsity summary for SD speedup with different sparsity
        summary_df = pd.read_csv(csv_path) 
        unq_expnum = summary_df["exp_num"].unique()
        unq_spectoken = summary_df["num_speculative_tokens"].unique()
        combinations = list(itertools.product(unq_expnum, unq_spectoken))
        combinations.sort()
        self.combinations = combinations
        self.reject_max = summary_df['verification_time_ms'].max()

        # split the large dataframe into smaller dataframes based on the combinations
        split_df = {}
        for combo in combinations:
            expnum, spectoken = combo
            sub_df = summary_df[(summary_df['exp_num'] == expnum) & (summary_df['num_speculative_tokens'] == spectoken)].copy()
            split_df[combo] = sub_df

        # save the real system efficiency(namely, sigma in paper)
        golden_df_dict = {}
        for spectoken in unq_spectoken:
            golden_df_dict[spectoken] = split_df[(8, spectoken)].copy()[['num_prompts', 'system_efficiency']].rename(columns={'system_efficiency': 'real_system_efficiency'})
        
        # caliberate speedup with correct system efficiency
        calib_df_dict = {}
        for combo, sub_df in split_df.items():
            expnum, spectoken = combo
            merged_df = pd.merge(sub_df, golden_df_dict[spectoken], on='num_prompts', how='left')
            calib_df_dict[combo] = merged_df

        # calculate the speedup
        for combo, sub_df in calib_df_dict.items():
            sub_df["ar_decode_time"] = sub_df["ar_time"] - sub_df["prefill_time"]
            sub_df["sd_decode_time"] = sub_df["total_time"] - sub_df["prefill_time"]
            sub_df["calib_sd_decode_time"] = sub_df["sd_decode_time"] / sub_df["real_system_efficiency"] * sub_df["system_efficiency"]
            sub_df["speedup"] = sub_df["ar_decode_time"] / sub_df["calib_sd_decode_time"]

        # This is the final dataframe dictionary
        self.df_dict = calib_df_dict

        # The table in Appendix B.2 of the paper, containing 
        self.fitting_df = pd.DataFrame(columns=['m', 'stride', 'MSE', 'batch_size_involved'])

        # The number of experts. For qwen2-57B-A14B-Instruct, it is 64.
        self.E = 64
        self.K_standard = 8

        # The information of hardware
        self.flops = 312000         # 312 TFLOPS
        self.bandwidth = 1935       # 1935 GB/s
        self.ridge_point = self.flops / self.bandwidth

        # Compute model meta info, used for the model parameters
        self._compute_meta_info(target_model_path, draft_model_path)

        # Get the sorted dataframe
        self._get_sorted_df()


    def _compute_meta_info(self, target_model_path, draft_model_path):
        '''
        Compute the relaxation boundary for the model parameters.
        '''
        expert_params_count, other_params_count = get_target_model_info(target_model_path)
        expert_params_load = expert_params_count * 2 / self.bandwidth / 1000000 / self.E      # 2 for bit width (FP16)
        other_parameter_load = other_params_count * 2 / self.bandwidth / 1000000

        draft_params_count = get_draft_model_volume(draft_model_path)
        draft_paramter_load = draft_params_count * 2 / self.bandwidth / 1000000 

        rp_min_bound = 0.2 * self.ridge_point
        rp_max_bound = self.ridge_point
        
        self.initial_guess_min = [expert_params_load, 0, other_parameter_load, 0, draft_paramter_load, 0, 0, 0, rp_min_bound, 1]
        self.initial_guess_max = [5*expert_params_load, np.inf, 5*other_parameter_load, np.inf, 5*draft_paramter_load, np.inf, self.reject_max, self.reject_max, rp_max_bound, 2]
        self.initial_guess = [expert_params_load, 1, other_parameter_load, 1, draft_paramter_load, 1, 1, 0, 100, 1.02]


    def _get_sorted_df(self):
        '''
        Sort the dataframes in the dictionary and compose them into a new dataframe.
        It is mentioned in Appendix B.2 of the paper.
        '''
        total_data = pd.DataFrame()
        for sub_df in self.df_dict.values():
            tmp_data = pd.DataFrame()
            tmp_data['num_prompts'] = sub_df['num_prompts']
            tmp_data['num_speculative_tokens'] = sub_df['num_speculative_tokens']
            tmp_data['exp_num'] = sub_df['exp_num']
            tmp_data['real_system_efficiency'] = sub_df['real_system_efficiency']
            tmp_data['speedup'] = sub_df['speedup']

            total_data = pd.concat([total_data, tmp_data], axis=0)

        # save the sorted and composed dataframe
        self.total_data = total_data


    def _plot_measurements(self, ax, x, y):
        '''
        Plot the GPU measurements and connect them with a smooth line.
        '''
        x_smooth = np.linspace(x.min(), x.max(), 300)
        spl = make_interp_spline(x, y, k=2)
        y_smooth = spl(x_smooth)
        ax.plot(x_smooth, y_smooth, label="GPU results", color='blue')
        ax.scatter(x, y, color='gray', alpha=0.8)


    def _visualize(self, params, compute_speedup, m, output_dir):
        '''
        Visualize the modeling output and the measurements.

        Arguments:
            params: the parameters of the model
            compute_speedup: the function to compute the speedup
            m: the number of measurements
            output_dir: the directory to save the plot
        '''
        # Draw gamma = 2 and gamma = 4 cases in the same figure
        fig, ax = plt.subplots(3, 4, figsize=(16, 10))

        # set the figure title, row, col, etc.
        for i in range(12):
            # for gamma = 2 cases
            if i < 6:
                row = i // 2
                col = i % 2
                gamma = 2
                total_data_ = self.total_data[self.total_data['num_speculative_tokens'] == 2]
                tmp_data = total_data_.iloc[i*19:(i+1)*19]
                exp_num = 2 ** i
            # for gamma = 4 cases
            else:
                row = (i - 6) // 2
                col = (i - 6) % 2 + 2
                gamma = 4
                total_data_ = self.total_data[self.total_data['num_speculative_tokens'] == 4]
                tmp_data = total_data_.iloc[(i-6)*19:(i+1-6)*19]
                exp_num = 2 ** (i-6)
            
            # plot the measurements
            self._plot_measurements(ax[row, col], tmp_data['num_prompts'].to_numpy(), tmp_data['speedup'].to_numpy())
            
            # plot the model
            x_range = np.arange(0, 100) + 1
            model_speedup = compute_speedup(params, x_range, gamma * np.ones(100), \
                exp_num * np.ones(100), self.E * np.ones(100), tmp_data['real_system_efficiency'].mean() * np.ones(100))
            ax[row, col].plot(x_range, model_speedup, color='red', marker='o', linewidth=2, alpha=0.4, label="Modeling")
            ax[row, col].grid()
            ax[row, col].legend(fontsize=11)
            ax[row, col].set_xlabel(rf"Batch size ($B$)", fontsize=12)
            ax[row, col].set_ylabel(rf"Speedup", fontsize=12)
            ax[row, col].set_title(rf"$K$ = {exp_num}, $\rho$ = {exp_num}/{self.E}, $\gamma$ = {gamma}", fontsize=12)
            ax[row, col].tick_params(axis='both', labelsize=11)

        fig.tight_layout()

        # ensure the output directory exists
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        fig.savefig(f"{output_dir}/fit_{m}_measurements.svg", format='svg')
        fig.savefig(f"{output_dir}/fit_{m}_measurements.pdf", format='pdf')
        fig.savefig(f"{output_dir}/fit_{m}_measurements.png", format='png')

        plt.close(fig)


    def plot(self, residual, compute_speedup, output_dir):
        '''
        Plot the SD speedup for MoEs with different sparsity, batch sizes and speculative tokens.
        It is mentioned in Figure 3 of the paper.

        Arguments:
            residual: the residual function to be minimized
            compute_speedup: the function to compute the speedup
            output_dir: the directory to save the plot
        '''
        sub_data = self.total_data.iloc[::11]
        m = len(sub_data)
        avg_sigma = sub_data['real_system_efficiency'].mean()


        # solve the model parameters using least squares
        result = least_squares(
            residual,
            self.initial_guess,
            args=(sub_data['num_prompts'].to_numpy(), sub_data['num_speculative_tokens'].to_numpy(), \
                sub_data['exp_num'].to_numpy(), self.E * np.ones(m), avg_sigma * np.ones(m), \
                sub_data['speedup'].to_numpy()),
            bounds=(self.initial_guess_min, self.initial_guess_max)
        )

        # Plot the GPU measurements and the modeling results
        fig, ax = plt.subplots(3, 4, figsize=(16, 10))

        # set the figure title, row, col, etc.
        for i in range(12):
            # for gamma = 2 cases
            if i < 6:
                row = i // 2
                col = i % 2
                gamma = 2
                total_data_ = self.total_data[self.total_data['num_speculative_tokens'] == 2]
                tmp_data = total_data_.iloc[i*19:(i+1)*19]
                exp_num = 2 ** i
            # for gamma = 4 cases
            else:
                row = (i - 6) // 2
                col = (i - 6) % 2 + 2
                gamma = 4
                total_data_ = self.total_data[self.total_data['num_speculative_tokens'] == 4]
                tmp_data = total_data_.iloc[(i-6)*19:(i+1-6)*19]
                exp_num = 2 ** (i-6)
            
            # plot the measurements
            self._plot_measurements(ax[row, col], tmp_data['num_prompts'].to_numpy(), tmp_data['speedup'].to_numpy())
            
            # plot the model
            x_range = np.arange(0, 100) + 1
            model_speedup = compute_speedup(result.x, x_range, gamma * np.ones(100), \
                exp_num * np.ones(100), self.E * np.ones(100), tmp_data['real_system_efficiency'].mean() * np.ones(100))
            max_idx, max_value = np.argmax(model_speedup), np.max(model_speedup)
            if max_idx > 1:
                ax[row, col].axhline(max_value/math.sqrt(2), color='brown', linestyle='--', linewidth=2, alpha=1)
            ax[row, col].plot(x_range, model_speedup, color='red', marker='o', linewidth=2, alpha=0.4, label="Modeling")
            ax[row, col].grid()
            ax[row, col].legend(fontsize=11)
            ax[row, col].set_xlabel(rf"Batch size ($B$)", fontsize=12)
            ax[row, col].set_ylabel(rf"Speedup", fontsize=12)
            ax[row, col].set_title(rf"$K$ = {exp_num}, $\rho$ = {exp_num}/{self.E}, $\gamma$ = {gamma}", fontsize=12)
            # ax[row, col].set_title(rf"$K$ = {exp_num}, $\rho$ = {exp_num}/{self.E}", fontsize=12)
            ax[row, col].tick_params(axis='both', labelsize=11)

        fig.tight_layout()

        # ensure the output directory exists
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        fig.savefig(f"{output_dir}/fitting.svg", format='svg')
        fig.savefig(f"{output_dir}/fitting.pdf", format='pdf')
        fig.savefig(f"{output_dir}/fitting.png", format='png')

        plt.close(fig)

        print(f"The fitting plot is saved in {output_dir}")

    
    def fit_model(self, interval, residual, compute_speedup, output_dir):
        '''
        Fit the model using the least squares method.

        Arguments:
            interval: the interval for sampling the measurements
            residual: the residual function to be minimized
            compute_speedup: the function to compute the speedup
        '''
        sub_data = self.total_data.iloc[::interval]
        m = len(sub_data)
        avg_sigma = sub_data['real_system_efficiency'].mean()

        result = least_squares(
            residual,
            self.initial_guess,
            args=(sub_data['num_prompts'].to_numpy(), sub_data['num_speculative_tokens'].to_numpy(), \
                sub_data['exp_num'].to_numpy(), self.E * np.ones(m), avg_sigma * np.ones(m), \
                sub_data['speedup'].to_numpy()),
            bounds=(self.initial_guess_min, self.initial_guess_max)
        )

        # compute the MSE
        model_output = compute_speedup(result.x, self.total_data['num_prompts'].to_numpy(), self.total_data['num_speculative_tokens'].to_numpy(), \
            self.total_data['exp_num'].to_numpy(), self.E * np.ones(len(self.total_data)), self.total_data['real_system_efficiency'].to_numpy())
        MSE = 0.5 * sum((model_output - self.total_data['speedup'].to_numpy())**2)

        # plot the speedup
        self._visualize(result.x, compute_speedup, m, output_dir)

        overview_dict = {}
        overview_dict['m'] = m
        overview_dict['stride'] = interval
        overview_dict['MSE'] = MSE

        return overview_dict
    
    def get_m_summary(self, residual, compute_speedup, output_dir):
        '''
        Get the summary table of the fitting results when different numbers of measurements are used.
        The summary table is mentioned in Appendix B.2 of the paper.

        Arguments:
            residual: the residual function to be minimized
            compute_speedup: the function to compute the speedup
            output_dir: the directory to save the plot
        '''
        interval = 1
        last_m = 0
        while True:
            # The number of measurements should be larger than the number of parameters
            
            sub_data = self.total_data.iloc[::interval]
            m = len(sub_data)
            if m == 9:
                break

            return_dict = self.fit_model(interval, residual, compute_speedup, output_dir)

            # Add the new row to the summary dataframe
            batch_size_list = sub_data['num_prompts'].to_numpy()
            counter_batch_size = Counter(batch_size_list)
            batch_size_involved = sorted(counter_batch_size.keys())
            new_row = list(return_dict.values()) + [str(batch_size_involved)]
            if m == last_m:
                self.fitting_df.iloc[-1] = new_row
            else:
                self.fitting_df = pd.concat([self.fitting_df, pd.DataFrame([new_row], columns=self.fitting_df.columns)], ignore_index=True)

            interval += 1
            last_m = m

        print(f"The fitting results of different measurement counts are saved in {output_dir}.")
        self.fitting_df.sort_values(by='m', ascending=True, inplace=True)
        print("======= The summary of fitting results with varied number of measurements =======")
        pd.set_option('display.max_colwidth', None) 
        pd.set_option('display.width', None)                
        print(self.fitting_df)
        
