import os
import re
import matplotlib.pyplot as plt
import numpy as np
from scipy import stats
from scipy.stats import gmean

# Apps
apps = ['circuit', 'stencil', 'pennant', 'cannon', 'pumma', 'summa', 'solomonik', 'johnson', 'cosma']
gemm_apps = ['cannon', 'pumma', 'summa', 'solomonik', 'johnson', 'cosma']
scientific_apps = ['circuit', 'stencil', 'pennant']

# Global variables for font sizes
TITLE_FONT_SIZE = 25
LABEL_FONT_SIZE = 23
LEGEND_FONT_SIZE_BIG = 19
LEGEND_FONT_SIZE_SMALL = 19
TICK_FONT_SIZE = 22
STAR_SIZE = 22

X_MIN = -0.3
X_MAX= 9.3
LENGTH = X_MAX - X_MIN

# Global dictionary for custom titles
custom_titles = {
    'cannon': "Cannon's (Matrix Multiplication)",
    'pumma': "PUMMA (Matrix Multiplication)",
    'summa': 'SUMMA (Matrix Multiplication)',
    'solomonik': "Solomonik's (Matrix Multiplication)",
    'johnson': "Johnson's (Matrix Multiplication)",
    'cosma': 'COSMA (Matrix Multiplication)',
    'circuit': 'Circuit Simulation',
    'stencil': 'Stencil Computation',
    'pennant': 'Pennant',
}

# Function to parse throughput (left unchanged)
def parse_throughput(log_file, app):
    throughput = 0
    try:
        with open(log_file, 'r') as f:
            content = f.read()
            if app in gemm_apps:
                m = re.search(r'achieved GFLOPS per node:\s*([0-9.]+)', content)
                if m:
                    throughput_str = m.group(1).rstrip('.')
                    throughput = float(throughput_str)
            elif app in scientific_apps:
                m = re.search(r'Execution time\s*=\s*([0-9.]+)\s*s', content)
                if m:
                    exec_time = float(m.group(1))
                    if exec_time > 0:
                        throughput = 1.0 / exec_time
            else:
                print(f"Unknown app type for app {app}")
    except Exception as e:
        pass
    return throughput

# Helper functions
def get_expert_throughput(app_dir, app):
    log_file = os.path.join(app_dir, 'log_conf0_dmapping.log')
    return parse_throughput(log_file, app)

def get_random_throughput(app_dir, app, confidence_level):
    throughputs = []
    for i in range(10):
        log_file = os.path.join(app_dir, f'log_conf0_repeatr_mapping{i}.log')
        throughput = parse_throughput(log_file, app)
        throughputs.append(throughput)
    mean = np.mean(throughputs)
    if len(throughputs) > 1 and np.std(throughputs) > 0:
        ci = stats.t.interval(confidence_level, len(throughputs) - 1, loc=mean, scale=stats.sem(throughputs))
    else:
        ci = (mean * 0.9, mean * 1.1)
    return mean, ci

def get_cumulative_max_throughput(app_dir, app, repeat_prefix, confidence_level):
    cumulative_max_data = []
    for repeat in range(5):
        repeat_throughputs = []
        for mapping in range(10):
            log_file = os.path.join(app_dir, f'log_conf0_{repeat_prefix}{repeat}_mapping{mapping}.log')
            throughput = parse_throughput(log_file, app)
            repeat_throughputs.append(throughput)
        cum_max_throughput = np.maximum.accumulate(repeat_throughputs)
        cumulative_max_data.append(cum_max_throughput)

    cum_max_data = np.array(cumulative_max_data)
    mean_cum_max = np.mean(cum_max_data, axis=0)

    if len(cumulative_max_data) > 1 and np.all(np.std(cum_max_data, axis=0) > 0):
        ci_cum_max = stats.t.interval(confidence_level, len(cumulative_max_data) - 1, loc=mean_cum_max, scale=stats.sem(cum_max_data, axis=0))
    else:
        ci_cum_max = (mean_cum_max * 0.9, mean_cum_max * 1.1)
    return mean_cum_max, ci_cum_max

# Adjust small values only for Random
def adjust_small_values(line_data, min_value=0.01):
    return np.where((line_data > -min_value) & (line_data < min_value), min_value, line_data)

# Helper function to get the maximum throughput for a category
def get_max_throughput(app_dir, app, repeat_prefix, is_random=False):
    max_throughputs = []
    if is_random:
        # Only iterate over mapping for Random, no repeat involved
        for mapping in range(10):
            log_file = os.path.join(app_dir, f'log_conf0_repeatr_mapping{mapping}.log')
            throughput = parse_throughput(log_file, app)
            max_throughputs.append(throughput)
    else:
        # Iterate over both repeat and mapping for other categories
        for repeat in range(5):
            for mapping in range(10):
                log_file = os.path.join(app_dir, f'log_conf0_{repeat_prefix}{repeat}_mapping{mapping}.log')
                throughput = parse_throughput(log_file, app)
                max_throughputs.append(throughput)

    return np.max(max_throughputs)

# Main plotting function (modified)
def plot_throughputs(app, app_dir, lines_to_plot, line_labels, prefix, confidence_level, y_max_factor, custom_y_max=None, plot_best_lines=False, trace_best_dict=None):
    expert_throughput = get_expert_throughput(app_dir, app)
    if expert_throughput == 0:
        print(f"Skipping {app} due to missing expert throughput")
        return

    plt.figure(figsize=(6.72, 4.8))
    plt.axhline(y=1, color='r', linestyle='--', label=line_labels.get('Expert', 'Expert Mapping'), xmin=-X_MIN/LENGTH, xmax=(9 - X_MIN) / LENGTH)

    max_so_far = 1

    # Random cumulative line
    if 'Random' in lines_to_plot:
        random_mean, random_ci = get_random_throughput(app_dir, app, confidence_level)
        normalized_random_mean = random_mean / expert_throughput
        normalized_random_mean = adjust_small_values(normalized_random_mean)
        max_so_far = max(max_so_far, normalized_random_mean, np.max(random_ci) / expert_throughput)
        plt.axhline(y=normalized_random_mean, color='b', linestyle='--', label=line_labels.get('Random', 'Random Mapping'), xmin=-X_MIN/LENGTH, xmax=(9 - X_MIN) / LENGTH)
        plt.fill_between(range(10), random_ci[0]/expert_throughput, random_ci[1]/expert_throughput, color='b', alpha=0.15)

        # Random best line (no repeat involved)
        if plot_best_lines:
            random_best = get_max_throughput(app_dir, app, 'repeatr', is_random=True)
            max_so_far = max(max_so_far, random_best / expert_throughput)
            # print(f"Random best for {app}:", random_best / expert_throughput)
            plt.axhline(y=random_best / expert_throughput, color='b', linestyle='--', label=f"Best {line_labels.get('Random', 'Random Mapping')}", xmin=-X_MIN/LENGTH, xmax=(9 - X_MIN) / LENGTH)

        # Trace cumulative line
    if 'Trace' in lines_to_plot:
        trace_mean, trace_ci = get_cumulative_max_throughput(app_dir, app, 'repeat', confidence_level)
        normalized_trace_mean = trace_mean / expert_throughput
        max_so_far = max(max_so_far, np.max(normalized_trace_mean), np.max(trace_ci) / expert_throughput)

        # Trace best line
        if prefix == 'optimizer':
            trace_best = get_max_throughput(app_dir, app, 'repeat')
            trace_best_normalized = trace_best / expert_throughput
            max_so_far = max(max_so_far, trace_best_normalized)
            trace_best_dict[app] = trace_best_normalized  # Store the normalized trace_best
            # plt.axhline(y=trace_best_normalized, color='g', linestyle='--', label=f"{line_labels.get('TraceBest', 'Trace-OptoPrime')}")
            # Plot a star at the last iteration point (9) on the x-axis
            plt.plot(9, trace_best_normalized, 'g*', markersize=STAR_SIZE, label=f"{line_labels.get('TraceBest', 'Trace-OptoPrime')}")
            print(f"{app} Trace Best: {trace_best_normalized:.2f}")
        
        plt.plot(range(10), normalized_trace_mean, color='g', label=line_labels.get('Trace', 'Trace-OptoPrime'))
        plt.fill_between(range(10), trace_ci[0]/expert_throughput, trace_ci[1]/expert_throughput, color='g', alpha=0.15)
    
    # Non-Directional cumulative line
    if 'Non-Directional' in lines_to_plot:
        non_dir_mean, non_dir_ci = get_cumulative_max_throughput(app_dir, app, 'repeatn', confidence_level)
        normalized_non_dir_mean = non_dir_mean / expert_throughput
        max_so_far = max(max_so_far, np.max(normalized_non_dir_mean), np.max(non_dir_ci) / expert_throughput)
        plt.plot(range(10), normalized_non_dir_mean, color='#FF8C00', label=line_labels.get('Non-Directional', 'Non-Directional Feedback'))
        plt.fill_between(range(10), non_dir_ci[0]/expert_throughput, non_dir_ci[1]/expert_throughput, color='y', alpha=0.15)

        # Non-Directional best line
        if plot_best_lines:
            non_dir_best = get_max_throughput(app_dir, app, 'repeatn')
            max_so_far = max(max_so_far, non_dir_best / expert_throughput)
            # print(f"Non-Directional best for {app}:", non_dir_best / expert_throughput)
            plt.axhline(y=non_dir_best / expert_throughput, color='#FF8C00', linestyle='--', label=f"Best {line_labels.get('Non-Directional', 'Non-Directional Feedback')}", xmin=-X_MIN/LENGTH, xmax=(9 - X_MIN) / LENGTH)
    
    # Basic cumulative line
    if 'Basic' in lines_to_plot:
        basic_mean, basic_ci = get_cumulative_max_throughput(app_dir, app, 'repeatp', confidence_level)
        normalized_basic_mean = basic_mean / expert_throughput
        max_so_far = max(max_so_far, np.max(normalized_basic_mean), np.max(basic_ci) / expert_throughput)
        plt.plot(range(10), normalized_basic_mean, color='#800080', label=line_labels.get('Basic', 'Numerical Feedback'))
        plt.fill_between(range(10), basic_ci[0]/expert_throughput, basic_ci[1]/expert_throughput, color='c', alpha=0.15)

        # Basic best line
        if plot_best_lines:
            basic_best = get_max_throughput(app_dir, app, 'repeatp')
            max_so_far = max(max_so_far, basic_best / expert_throughput)
            # print(f"Basic best for {app}:", basic_best / expert_throughput)
            plt.axhline(y=basic_best / expert_throughput, color='#800080', linestyle='--', label=f"Best {line_labels.get('Basic', 'Numerical Feedback')}", xmin=-X_MIN/LENGTH, xmax=(9 - X_MIN) / LENGTH)

    
    # OPRO cumulative line
    if 'OPRO' in lines_to_plot:
        opro_mean, opro_ci = get_cumulative_max_throughput(app_dir, app, 'repeato', confidence_level)
        normalized_opro_mean = opro_mean / expert_throughput
        max_so_far = max(max_so_far, np.max(normalized_opro_mean), np.max(opro_ci) / expert_throughput)
        plt.plot(range(10), normalized_opro_mean, color='m', label=line_labels.get('OPRO', 'Trace-Opro'))
        plt.fill_between(range(10), opro_ci[0]/expert_throughput, opro_ci[1]/expert_throughput, color='m', alpha=0.15)

        # OPRO best line
        if plot_best_lines:
            opro_best = get_max_throughput(app_dir, app, 'repeato')
            max_so_far = max(max_so_far, opro_best / expert_throughput)
            # print(f"OPRO best for {app}:", opro_best / expert_throughput)
            plt.axhline(y=opro_best / expert_throughput, color='m', linestyle='--', label=f"Best {line_labels.get('OPRO', 'Trace-Opro')}")

    if custom_y_max:
        plt.ylim(0, custom_y_max)
    else:
        plt.ylim(0, max(max_so_far, 1) * y_max_factor)

    plt.xlim(X_MIN, X_MAX)

    # Customizable titles and font sizes
    plt.xlabel('Iterations', fontsize=LABEL_FONT_SIZE)
    plt.ylabel('Normalized Throughput', fontsize=LABEL_FONT_SIZE)
    plt.title(custom_titles.get(app, app), fontsize=TITLE_FONT_SIZE)
    if (app == 'summa' and prefix == 'optimizer') or (app == 'stencil' and prefix == 'optimizer') or (app == 'circuit' and prefix == 'feedback'):
        plt.legend(fontsize=LEGEND_FONT_SIZE_BIG if prefix == 'optimizer' else LEGEND_FONT_SIZE_SMALL)
    plt.xticks(fontsize=TICK_FONT_SIZE)
    plt.yticks(fontsize=TICK_FONT_SIZE)
    plt.grid(False)
    plt.tight_layout()
    plt.savefig(f'figure/{prefix}_{app}.pdf')
    plt.close()

# Updated generate_figures function
def generate_figures(prefix, selected_apps, lines_to_plot, line_labels, confidence_level, y_max_factor, plot_best_lines, custom_y_max_per_app=None):
    """Helper function to generate figures for a given scenario."""
    if not os.path.exists('figure'):
        os.makedirs('figure')

    trace_best_dict = {}  # Store trace_best / expert_throughput for each app

    for app in selected_apps:
        app_dir = os.path.join('result', app)
        if os.path.isdir(app_dir):
            custom_y_max = custom_y_max_per_app.get(app, None) if custom_y_max_per_app else None
            plot_throughputs(app, app_dir, lines_to_plot, line_labels, prefix, confidence_level, y_max_factor, custom_y_max, plot_best_lines, trace_best_dict)
        else:
            print(f"Directory {app_dir} does not exist.")
    
    if prefix == 'optimizer':
        # Calculate and print max, min, and geometric mean
        trace_best_values = np.array(list(trace_best_dict.values()))
        max_value = np.max(trace_best_values)
        min_value = np.min(trace_best_values)
        geom_mean_value = gmean(trace_best_values)
        
        print("\n--- Trace Best Summary ---")
        print(f"Max Trace Best / Expert: {max_value:.2f}")
        print(f"Min Trace Best / Expert: {min_value:.2f}")
        print(f"Geometric Mean of Trace Best / Expert: {geom_mean_value:.2f}")

# Main function
def main():
    confidence_level = 0.66  # Global confidence level
    y_max_factor = 1.05  # Default scaling for dynamic y_max

    # First generation scenario
    prefix = 'optimizer'
    selected_apps = apps  # All apps
    lines_to_plot = ['Expert', 'Random', 'Trace', 'OPRO']
    line_labels = {
        'Expert': 'Expert Mapper',
        'Random': 'Random Mapper',
        'Trace': 'Trace-OptoPrime',
        'OPRO': 'Trace-OPRO',
        'TraceBest': "OptoPrime Best Mapper"
    }
    generate_figures(prefix, selected_apps, lines_to_plot, line_labels, confidence_level, y_max_factor, plot_best_lines=False)

    # Second generation scenario
    prefix = 'feedback'
    selected_apps = ['cannon', 'cosma', 'circuit']
    lines_to_plot = ['Expert', 'Basic', 'Non-Directional', 'Trace']
    line_labels = {
        'Expert': 'Expert Mapper',
        'Basic': 'Feedback: Perf',
        'Non-Directional': 'Feedback: Perf + Err',
        'Trace': 'Feedback: Perf + Err + Guide'
    }
    generate_figures(prefix, selected_apps, lines_to_plot, line_labels, confidence_level, y_max_factor, plot_best_lines=False)

if __name__ == "__main__":
    main()

# circuit Trace Best: 1.34
# stencil Trace Best: 1.02
# pennant Trace Best: 1.04
# cannon Trace Best: 1.09
# pumma Trace Best: 1.09
# summa Trace Best: 1.09
# solomonik Trace Best: 1.09
# johnson Trace Best: 1.07
# cosma Trace Best: 1.31

# --- Trace Best Summary ---
# Max Trace Best / Expert: 1.34
# Min Trace Best / Expert: 1.02
# Geometric Mean of Trace Best / Expert: 1.12