from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib import pyplot

import argparse

mpl.rc('text', usetex=False)

FILE_NAME = Path(__file__).with_suffix('')

# Utils to get common STYLES object and setup matplotlib
# for all plots

mpl.rcParams.update({
    'font.size': 10,
    'legend.fontsize': 'small',
    'axes.labelsize': 'small',
    'xtick.labelsize': 'small',
    'ytick.labelsize': 'small'
})


STYLES = {
    '*': dict(lw=1.5),

    'amigo': dict(color='#5778a4', label=r'AmIGO'),
    'mrbo': dict(color='#e49444', label=r'MRBO'),
    'vrbo': dict(color='#e7ca60', label=r'VRBO'),
    'saba': dict(color='#d1615d', label=r'SABA'),
    'stocbio': dict(color='#85b6b2', label=r'StocBiO'),
    'srba': dict(color='#6a9f58', label=r'\textbf{SRBA}', lw=2),
    'f2sa': dict(color='#bcbd22', label=r'F2SA'),
}

N_CALLS = {
    # One loop
    'mrbo': (24, 4),  # inner, outer
    'sustain': (24, 4),
    'ttsa': (11, 2),
    'fsla': (4, 3),

    # Two loops solvers
    'amigo': (21, 2),
    'stocbio': (21, 2),
    'bsa': (21, 2),

    # Our solves
    'saba': (3, 2),
    'bio-svrg': (3, 2),
    'srba': (3, 2),
}

LEGEND_OUTSIDE = False

DEFAULT_WIDTH = 3.25
DEFAULT_DOUBLE_WIDTH = 6.75
DEFAULT_HEIGHT = 2.


def get_param(name, param='period_frac'):
    params = {}
    for vals in name.split("[", maxsplit=1)[1][:-1].split(","):
        k, v = vals.split("=")
        if v.replace(".", "").isnumeric():
            params[k] = float(v)
        else:
            params[k] = v
    return params[param]


def drop_param(name, param='period_frac'):
    new_name = name.split("[", maxsplit=1)[0] + '['
    for vals in name.split("[", maxsplit=1)[1][:-1].split(","):
        k, v = vals.split("=")
        if k != param:
            new_name += f'{k}={v},'
    return new_name[:-1] + ']'


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description='Plot benchmarks results for bilevel optimization.'
    )
    parser.add_argument('--n-points', '-n', type=int, default=500,
                        help='# of points in the grid for interpolation.')
    parser.add_argument('--x-axis', '-x', type=str, default='time',
                        choices=['time', 'calls'],
                        help='Plot in time or number of calls to oracles.')
    parser.add_argument('--benchmark', '-b', type=str, default='datacleaning0_5',
                        choices=['ijcnn1', 'datacleaning0_5',
                                 'datacleaning0_7', 'datacleaning0_9',
                                 'covtype'],
                        help='Choose the benchmark to plot.')
    parser.add_argument('--criterion', '-c', type=str, default='100',
                        choices=['100', 'all'],
                        help='Choose the best curves with respect to the 100 \
                        first iterates or all the iterates.')
    args = parser.parse_args()

    x_axis = args.x_axis  # 'calls' or 'time'

    n_points = args.n_points

    bench = args.benchmark

    BENCHMARKS_CONFIG = dict(
        datacleaning0_5=(
            "../outputs/datacleaning_fashion_0_1_params_4.parquet",
            'objective_value', 'objective_test_accuracy',
            ((.1, 900), (2e4, 5e7)), None, 'Test error', 'log',
            ('log', 'log'), (None, 40), 64, 2**5, 20_000, 5_000
        ),
    )

    fname, metric_selection, metric_plot, xlim, eps, yname, yscaling, \
        xscaling, ylim, batch_size, eval_freq, n_inner_samples, \
        n_outer_samples = BENCHMARKS_CONFIG[bench]
    xlim = xlim[0] if x_axis == 'time' else xlim[1]
    xscaling = xscaling[0] if x_axis == 'time' else xscaling[1]
    
    fname = FILE_NAME.parent / fname
    print(fname)

    df = pd.read_parquet(fname)
    # df.to_csv('output_file.csv', index=False)
    palette = pyplot.get_cmap('Set1')
    font ={'family':'sans-serif',
        'weight': 'normal',
        'size': 42}
    # 过滤出与绘图相关的列
    plt.figure(figsize=(10, 6))
    name_list=['SZOPBO','ZOSOBA','ZDSBA','SRZOBA']
    all_res = {'SZOPBO':[],'ZOSOBA':[],'ZDSBA':[],'SRZOBA':[]}
    query_res = {'SZOPBO':[],'ZOSOBA':[],'ZDSBA':[],'SRZOBA':[]}
    for solver in df['solver_name'].unique():
        solver_data = df[df['solver_name'] == solver]
        name = solver.split('[')[0]
        if name in name_list:
            all_res[name].append(solver_data['objective_test_accuracy'].to_numpy())
            query_res[name].append(solver_data['time'].to_numpy())
        
    color = palette(1)
    all_res['ZOSOBA'] = np.array(all_res['ZOSOBA'])
    avg = np.mean(all_res['ZOSOBA'], axis=0)
    std = np.std(all_res['ZOSOBA'], axis=0)
    r1 = list(map(lambda x: x[0] - x[1], zip(avg, std)))
    r2 = list(map(lambda x: x[0] + x[1], zip(avg, std)))
    plt.plot(query_res['ZOSOBA'][0], avg, color=color, label='ZOSOBA', linewidth=4.0)
    plt.fill_between(query_res['ZOSOBA'][0], r1, r2, color=color, alpha=0.2)
    print("ZOSOBA acc {f}",avg[-1])
    
    color = palette(2)
    all_res['ZDSBA'] = np.array(all_res['ZDSBA'])
    avg = np.mean(all_res['ZDSBA'], axis=0)
    std = np.std(all_res['ZDSBA'], axis=0)
    r1 = list(map(lambda x: x[0] - x[1], zip(avg, std)))
    r2 = list(map(lambda x: x[0] + x[1], zip(avg, std)))
    plt.plot(query_res['ZDSBA'][0], avg, color=color, label='ZDSBA', linewidth=4.0)
    plt.fill_between(query_res['ZDSBA'][0], r1, r2, color=color, alpha=0.2)
    print("ZDSBA acc{f}",avg[-1])
    
    color = palette(3)
    all_res['SRZOBA'] = np.array(all_res['SRZOBA'])
    avg = np.mean(all_res['SRZOBA'], axis=0)
    std = np.std(all_res['SRZOBA'], axis=0)
    r1 = list(map(lambda x: x[0] - x[1], zip(avg, std)))
    r2 = list(map(lambda x: x[0] + x[1], zip(avg, std)))
    plt.plot(query_res['SRZOBA'][0], avg, color=color, label='SRZOBA', linewidth=4.0)
    plt.fill_between(query_res['SRZOBA'][0], r1, r2, color=color, alpha=0.2)
    print("SRZOBA acc{f}",avg[-1])

    # color = palette(3)
    # all_res['SZOPBO'] = np.array(all_res['SZOPBO'])
    # avg = np.mean(all_res['SZOPBO'], axis=0)
    # std = np.std(all_res['SZOPBO'], axis=0)
    # r1 = list(map(lambda x: x[0] - x[1], zip(avg, std)))
    # r2 = list(map(lambda x: x[0] + x[1], zip(avg, std)))
    # plt.plot(query_res['SZOPBO'][0], avg, color=color, label='SZOPBO', linewidth=4.0)
    # plt.fill_between(query_res['SZOPBO'][0], r1, r2, color=color, alpha=0.2)
    # print(avg[-1])
    
    # plt.xlim((-50, 0.5e7))
    plt.xlabel(' Time(S)')
    plt.ylabel(' Test Accuracy')
    plt.title('Solver Comparison based on Test Error and Query')
    plt.legend()
    plt.grid(True)
    output_image_path = './time_solver_comparison.png'
    plt.savefig(output_image_path)
   

   