import matplotlib.pyplot as plt
import numpy as np
from matplotlib.ticker import ScalarFormatter, AutoMinorLocator, LogLocator, NullFormatter
import matplotlib as mpl



def plot_optimization_history(*Histories,optimal_value=None,legends=None,linewidth=1,fontsize=25,ticksize=15,legendsize=20,savename=None, colors=None):
    """
    Plots the function values and constraint violations Over Iterations from optimization history.
    
    Parameters:
    - history: List of tuples (function_value, constraint_violation) representing history of optimization.
    - optimal_value: Optional, the optimal value of the objective function f(x). If provided, adds a reference line.
    """
    
    
    # Helper to compute x positions without changing history length
    def _x_positions(length, s):
        s = max(1, int(s))
        return np.arange(length) * s

    # Infer an effective inner-step scale per history by comparing lengths
    lengths = [len(hist) for hist in Histories]
    L_max = max([L for L in lengths if L > 0], default=0)
    scales = []
    for L in lengths:
        if L <= 0 or L_max == 0:
            scales.append(1)
        else:
            scales.append(max(1, int(round(L_max / L))))

    
    # LaTeX-like font styling for publication-quality figures (match moving_average)
    mpl.rcParams.update({
        "text.usetex": False,
        "mathtext.fontset": "cm",
        "font.family": "serif",
        "font.serif": ["CMU Serif", "Computer Modern Roman", "Times New Roman"],
        "axes.labelsize": fontsize,
        "axes.titlesize": fontsize,
        "legend.fontsize": legendsize,
        "xtick.labelsize": ticksize,
        "ytick.labelsize": ticksize,
    })

    plt.figure(figsize=(14, 5))
    if optimal_value is not None:
        plt.subplot(1, 2, 1)
        plt.axhline(optimal_value, color='g', linestyle='--', label='Optimal $f(x)$',linewidth=linewidth)


    for i, history in enumerate(Histories):
        label = legends[i] if legends is not None and i < len(legends) else None  
        color_i = colors[i] if (colors is not None and i < len(colors)) else None 
        f_values, h_violations = zip(*history)
        x_vals = _x_positions(len(f_values), scales[i])
        plt.subplot(1, 2, 1)
        plt.plot(x_vals, f_values, label=label, linewidth=linewidth, color=color_i)
        plt.xlabel('Iteration',fontsize=fontsize)
        plt.ylabel('$f(x)$',fontsize=fontsize)
        plt.legend(fontsize=legendsize)
        plt.title('Function Value Over Iterations',fontsize=fontsize)
        plt.grid(True)
        plt.xticks(fontsize=ticksize)
        plt.yticks(fontsize=ticksize)

        # Plot constraint violations
        plt.subplot(1, 2, 2)
        plt.plot(x_vals, h_violations, label=label, linewidth=linewidth, color=color_i)
        plt.xlabel('Iteration',fontsize=fontsize)
        plt.ylabel('$|h(x)|$',fontsize=fontsize)
        plt.yscale('log')
        plt.title('Constraint Violation Over Iterations',fontsize=fontsize)
        plt.grid(True)
        plt.xticks(fontsize=ticksize)
        plt.yticks(fontsize=ticksize)
    # Use scientific notation on x-axis (e.g., 1, 2, 3 ×10^3)
    _sf = ScalarFormatter(useMathText=True)
    _sf.set_powerlimits((3, 3))
    for _ax in plt.gcf().axes:
        _ax.ticklabel_format(axis='x', style='sci', scilimits=(3, 3))
        _ax.xaxis.set_major_formatter(_sf)
        # Minor ticks/grid: handle linear vs log axes properly
        if _ax.get_xscale() == 'log':
            _ax.xaxis.set_major_locator(LogLocator(base=10))
            _ax.xaxis.set_minor_locator(LogLocator(base=10, subs=(0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9)))
        else:
            _ax.xaxis.set_minor_locator(AutoMinorLocator())
        if _ax.get_yscale() == 'log':
            _ax.yaxis.set_major_locator(LogLocator(base=10))
            _ax.yaxis.set_minor_locator(LogLocator(base=10, subs=(0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9)))
            _ax.yaxis.set_minor_formatter(NullFormatter())
        else:
            _ax.yaxis.set_minor_locator(AutoMinorLocator())
        _ax.set_axisbelow(True)
        _ax.grid(True, which='major', alpha=0.7, linewidth=0.8)
        _ax.grid(True, which='minor', alpha=0.35, linewidth=0.5)
    plt.tight_layout()
    if savename:
        plt.savefig(savename, bbox_inches='tight')
    plt.show()


def plot_optimization_history_moving_average(*Histories,optimal_value=None,legends=None,linewidth=1,fontsize=25,ticksize=15,legendsize=20,savename=None,window_size=10,colors=None):
    """
    Same signature as plot_optimization_history, but plots a moving average (window=100)
    for the objective values and constraint violations.
    """

    def compute_moving_average(series, window_size=window_size):
        if len(series) == 0:
            return np.array([])
        window = min(window_size, len(series))
        cumsum = np.cumsum(np.insert(series, 0, 0.0))
        moving_avg = (cumsum[window:] - cumsum[:-window]) / float(window)
        return moving_avg

    # Helper to compute x positions without changing history length
    def _x_positions(length, s):
        s = max(1, int(s))
        return 1 + np.arange(length) * s

    # Infer an effective inner-step scale per history by comparing lengths
    lengths = [len(hist) for hist in Histories]
    L_max = max([L for L in lengths if L > 0], default=0)
    scales = []
    for L in lengths:
        if L <= 0 or L_max == 0:
            scales.append(1)
        else:
            scales.append(max(1, int(round(L_max / L))))

    # Compute global x-axis span
    # LaTeX-like font styling for publication-quality figures
    mpl.rcParams.update({
        "text.usetex": False,
        "mathtext.fontset": "cm",
        "font.family": "serif",
        "font.serif": ["CMU Serif", "Computer Modern Roman", "Times New Roman"],
        "axes.labelsize": fontsize,
        "axes.titlesize": fontsize,
        "legend.fontsize": legendsize,
        "xtick.labelsize": ticksize,
        "ytick.labelsize": ticksize,
    })
    plt.figure(figsize=(14, 5))
    if optimal_value is not None:
        plt.subplot(1, 2, 1)
        plt.axhline(optimal_value, color='black', linestyle='--', label='Optimal $f(x)$', linewidth=linewidth)
    for i, history in enumerate(Histories):
        label = legends[i] if legends is not None and i < len(legends) else None
        color_i = colors[i] if (colors is not None and i < len(colors)) else None
        f_values, h_violations = zip(*history)
        f_values = np.asarray(f_values, dtype=float)
        h_violations = np.asarray(h_violations, dtype=float)

        scale_i = scales[i]
        x_vals_full = _x_positions(len(f_values), scale_i)

        if scale_i == 1:
            f_ma = compute_moving_average(f_values, window_size=window_size)
            h_ma = compute_moving_average(h_violations, window_size=window_size)
            if len(f_ma) == 0:
                continue
            start_index = len(f_values) - len(f_ma)
            x_indices = x_vals_full[start_index:]

            # Function value subplot
            plt.subplot(1, 2, 1)
            plt.plot(x_indices, f_ma, label=label, linewidth=linewidth, color=color_i)
            plt.xlabel('Iteration', fontsize=fontsize)
            plt.ylabel('$f(x)$', fontsize=fontsize)
            plt.legend(fontsize=legendsize, frameon=False, handlelength=1.2, labelspacing=0.3, borderaxespad=0.4)
            plt.title('Function Value Over Iterations', fontsize=fontsize)
            plt.grid(True)
            plt.xticks(fontsize=ticksize)
            plt.yticks(fontsize=ticksize)

            # Constraint violation subplot
            plt.subplot(1, 2, 2)
            plt.plot(x_indices, h_ma, label=label, linewidth=linewidth, color=color_i)
            plt.xlabel('Iteration', fontsize=fontsize)
            plt.ylabel('$|h(x)|$ ', fontsize=fontsize)
            plt.yscale('log')
            #plt.legend(fontsize=legendsize, frameon=False, handlelength=1.2, labelspacing=0.3, borderaxespad=0.4)
            plt.title('Constraint Violation Over Iterations', fontsize=fontsize)
            plt.grid(True) 
            plt.xticks(fontsize=ticksize)
            plt.yticks(fontsize=ticksize)
        else:
            # Plot raw values without averaging
            plt.subplot(1, 2, 1)
            plt.plot(x_vals_full, f_values, label=label, linewidth=linewidth, color=color_i)
            plt.xlabel('Iteration', fontsize=fontsize)
            plt.ylabel('$f(x)$', fontsize=fontsize)
            plt.legend(fontsize=legendsize, frameon=False, handlelength=1.2, labelspacing=0.3, borderaxespad=0.4)
            plt.title('Function Value Over Iterations', fontsize=fontsize)
            plt.grid(True)
            plt.xticks(fontsize=ticksize)
            plt.yticks(fontsize=ticksize)

            plt.subplot(1, 2, 2)
            plt.plot(x_vals_full, h_violations, label=label, linewidth=linewidth, color=color_i)
            plt.xlabel('Iteration', fontsize=fontsize)
            plt.ylabel('$|h(x)|$', fontsize=fontsize)
            plt.yscale('log')
            #plt.legend(fontsize=legendsize, frameon=False, handlelength=1.2, labelspacing=0.3, borderaxespad=0.4)
            plt.title('Constraint Violation Over Iterations', fontsize=fontsize)
            plt.grid(True)
            plt.xticks(fontsize=ticksize)
            plt.yticks(fontsize=ticksize)
    # Use scientific notation on x-axis (e.g., 1, 2, 3 ×10^3)
    _sf = ScalarFormatter(useMathText=True)
    _sf.set_powerlimits((3, 3))
    for _ax in plt.gcf().axes:
        _ax.ticklabel_format(axis='x', style='sci', scilimits=(3, 3))
        _ax.xaxis.set_major_formatter(_sf)
        # Minor ticks/grid: handle linear vs log axes properly
        if _ax.get_xscale() == 'log':
            _ax.xaxis.set_major_locator(LogLocator(base=10))
            _ax.xaxis.set_minor_locator(LogLocator(base=10, subs=(0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9)))
        else:
            _ax.xaxis.set_minor_locator(AutoMinorLocator())
        if _ax.get_yscale() == 'log':
            _ax.yaxis.set_major_locator(LogLocator(base=10))
            _ax.yaxis.set_minor_locator(LogLocator(base=10, subs=(0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9)))
            _ax.yaxis.set_minor_formatter(NullFormatter())
        else:
            _ax.yaxis.set_minor_locator(AutoMinorLocator())
        _ax.set_axisbelow(True)
        _ax.grid(True, which='major', alpha=0.7, linewidth=0.8)
        _ax.grid(True, which='minor', alpha=0.35, linewidth=0.5)
    plt.tight_layout()
    if savename:
        plt.savefig(savename, bbox_inches='tight')
    plt.show()
