from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib import pyplot

import argparse

mpl.rc('text', usetex=False)

FILE_NAME = Path(__file__).with_suffix('')

# Utils to get common STYLES object and setup matplotlib
# for all plots

mpl.rcParams.update({
    'font.size': 32,
    'legend.fontsize': 'small',
    'axes.labelsize': 'small',
    'xtick.labelsize': 'small',
    'ytick.labelsize': 'small'
})


STYLES = {
    '*': dict(lw=1.5),

    'amigo': dict(color='#5778a4', label=r'AmIGO'),
    'mrbo': dict(color='#e49444', label=r'MRBO'),
    'vrbo': dict(color='#e7ca60', label=r'VRBO'),
    'saba': dict(color='#d1615d', label=r'SABA'),
    'stocbio': dict(color='#85b6b2', label=r'StocBiO'),
    'srba': dict(color='#6a9f58', label=r'\textbf{SRBA}', lw=2),
    'f2sa': dict(color='#bcbd22', label=r'F2SA'),
}

N_CALLS = {
    # One loop
    'mrbo': (24, 4),  # inner, outer
    'sustain': (24, 4),
    'ttsa': (11, 2),
    'fsla': (4, 3),

    # Two loops solvers
    'amigo': (21, 2),
    'stocbio': (21, 2),
    'bsa': (21, 2),

    # Our solves
    'saba': (3, 2),
    'bio-svrg': (3, 2),
    'srba': (3, 2),
}

LEGEND_OUTSIDE = False

DEFAULT_WIDTH = 3.25
DEFAULT_DOUBLE_WIDTH = 6.75
DEFAULT_HEIGHT = 2.


def get_param(name, param='period_frac'):
    params = {}
    for vals in name.split("[", maxsplit=1)[1][:-1].split(","):
        k, v = vals.split("=")
        if v.replace(".", "").isnumeric():
            params[k] = float(v)
        else:
            params[k] = v
    return params[param]


def drop_param(name, param='period_frac'):
    new_name = name.split("[", maxsplit=1)[0] + '['
    for vals in name.split("[", maxsplit=1)[1][:-1].split(","):
        k, v = vals.split("=")
        if k != param:
            new_name += f'{k}={v},'
    return new_name[:-1] + ']'


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description='Plot benchmarks results for bilevel optimization.'
    )
    parser.add_argument('--n-points', '-n', type=int, default=500,
                        help='# of points in the grid for interpolation.')
    parser.add_argument('--x-axis', '-x', type=str, default='time',
                        choices=['time', 'calls'],
                        help='Plot in time or number of calls to oracles.')
    parser.add_argument('--benchmark', '-b', type=str, default='datacleaning0_5',
                        choices=['ijcnn1', 'datacleaning0_5',
                                 'datacleaning0_7', 'datacleaning0_9',
                                 'covtype'],
                        help='Choose the benchmark to plot.')
    parser.add_argument('--criterion', '-c', type=str, default='100',
                        choices=['100', 'all'],
                        help='Choose the best curves with respect to the 100 \
                        first iterates or all the iterates.')
    args = parser.parse_args()

    x_axis = args.x_axis  # 'calls' or 'time'

    n_points = args.n_points

    bench = args.benchmark

    # BENCHMARKS_CONFIG = dict(
    #     datacleaning0_5=(
    #         "../outputs/datacleaning_fashion_0_1_params_7.parquet",
    #         'objective_value', 'objective_test_accuracy',
    #         ((.1, 900), (2e4, 5e7)), None, 'Test error', 'log',
    #         ('log', 'log'), (None, 40), 64, 2**5, 20_000, 5_000
    #     ),
    # )

    # fname, metric_selection, metric_plot, xlim, eps, yname, yscaling, \
    #     xscaling, ylim, batch_size, eval_freq, n_inner_samples, \
    #     n_outer_samples = BENCHMARKS_CONFIG[bench]
    # xlim = xlim[0] if x_axis == 'time' else xlim[1]
    # xscaling = xscaling[0] if x_axis == 'time' else xscaling[1]
    
    # fname = "../outputs/datacleaning_fashion_0_1_params_7.parquet"
    
    
    # fname = "../outputs/datacleaning0_5_best_params_2.parquet"
    # fname = "../outputs/datacleaning_fashion_0_1_params_3.parquet"
    # fname = "../outputs/datacleaning_cifar10_0_1_best_params.parquet"
    # fname = "../outputs/datacleaning_svhn_0_1_best_params.parquet"
    name_list = ["../outputs/datacleaning0_5_best_params_r_2.parquet",
                 "../outputs/datacleaning_svhn_0_1_best_params_r.parquet",
                 "../outputs/datacleaning_fashion_0_1_params_r.parquet",
                 "../outputs/datacleaning_cifar10_0_1_best_params_r.parquet"
                 ]
    for fname in name_list:
    # fname = "../outputs/datacleaning_svhn_0_1_best_params.parquet"
        dataname = fname.split('_')[1]
        fname = FILE_NAME.parent / fname
        print(dataname)
        print(fname)

        df = pd.read_parquet(fname)
        # df.to_csv('output_file.csv', index=False)
        palette = pyplot.get_cmap('Set1')
        font ={'family':'sans-serif',
       'weight': 'normal',
       'size': 42}
        # 过滤出与绘图相关的列
        fig = plt.figure(figsize=(13, 12))
        name_list=['rz=100','rz=1000',
                   'rz=10','rz=10000']
        all_res = {'rz=10000':[],'rz=10':[],'rz=1000':[],'rz=100':[]}
        query_res = {'rz=10000':[],'rz=10':[],'rz=1000':[],'rz=100':[]}
        for solver in df['solver_name'].unique():
            solver_data = df[df['solver_name'] == solver]
            name = solver.split(',')[7]
            if name in name_list:
                all_res[name].append(solver_data['objective_test_accuracy'][1:].to_numpy())
                query_res[name].append(solver_data['objective_query'][1:].to_numpy())
                # query_res[name].append(solver_data['time'][1:].to_numpy())
            
        
        
        color = palette(2)
        array_list = all_res['rz=10']
        max_len = max(len(arr) for arr in array_list)
        # 使用 pad 将所有数组填充到相同长度
        padded_array = np.array([np.pad(arr, (0, max_len - len(arr)), mode='edge') for arr in array_list])
        all_res['rz=10'] = np.array(padded_array)
        array_list = query_res['rz=10']
        max_len = max(len(arr) for arr in array_list)
        # 使用 pad 将所有数组填充到相同长度
        padded_array = np.array([np.pad(arr, (0, max_len - len(arr)), mode='edge') for arr in array_list])
        query_res['rz=10'] = np.array(padded_array)
        avg = np.mean(100*all_res['rz=10'], axis=0)
        std = np.std(100*all_res['rz=10'], axis=0)
        r1 = list(map(lambda x: x[0] - x[1], zip(avg, std)))
        r2 = list(map(lambda x: x[0] + x[1], zip(avg, std)))
        plt.plot(query_res['rz=10'][0], avg, color=color, label='rz=10', linewidth=4.0)
        plt.fill_between(query_res['rz=10'][0], r1, r2, color=color, alpha=0.2)
        print("rz=10 acc {} std {}".format(avg[-1], std[-1]))
        
        color = palette(4)
        array_list = all_res['rz=100']
        max_len = max(len(arr) for arr in array_list)
        # 使用 pad 将所有数组填充到相同长度
        padded_array = np.array([np.pad(arr, (0, max_len - len(arr)), mode='edge') for arr in array_list])
        all_res['rz=100'] = np.array(padded_array)
        array_list = query_res['rz=100']
        max_len = max(len(arr) for arr in array_list)
        # 使用 pad 将所有数组填充到相同长度
        padded_array = np.array([np.pad(arr, (0, max_len - len(arr)), mode='edge') for arr in array_list])
        query_res['rz=100'] = np.array(padded_array)
        avg = np.mean(100*all_res['rz=100'], axis=0)
        std = np.std(100*all_res['rz=100'], axis=0)
        r1 = list(map(lambda x: x[0] - x[1], zip(avg, std)))
        r2 = list(map(lambda x: x[0] + x[1], zip(avg, std)))
        plt.plot(query_res['rz=100'][0], avg, color=color, label='rz=100', linewidth=4.0)
        plt.fill_between(query_res['rz=100'][0], r1, r2, color=color, alpha=0.2)
        print("rz=100 acc {} std {}".format(avg[-1], std[-1]))
        
        color = palette(3)
        array_list = all_res['rz=1000']
        max_len = max(len(arr) for arr in array_list)
        # 使用 pad 将所有数组填充到相同长度
        padded_array = np.array([np.pad(arr, (0, max_len - len(arr)), mode='edge') for arr in array_list])
        all_res['rz=1000'] = np.array(padded_array)
        array_list = query_res['rz=1000']
        max_len = max(len(arr) for arr in array_list)
        # 使用 pad 将所有数组填充到相同长度
        padded_array = np.array([np.pad(arr, (0, max_len - len(arr)), mode='edge') for arr in array_list])
        query_res['rz=1000'] = np.array(padded_array)
        avg = np.mean(100*all_res['rz=1000'], axis=0)
        std = np.std(100*all_res['rz=1000'], axis=0)
        r1 = list(map(lambda x: x[0] - x[1], zip(avg, std)))
        r2 = list(map(lambda x: x[0] + x[1], zip(avg, std)))
        plt.plot(query_res['rz=1000'][0], avg, color=color, label='rz=1000', linewidth=4.0)
        plt.fill_between(query_res['rz=1000'][0], r1, r2, color=color, alpha=0.2)
        print("rz=1000 acc {} std {}".format(avg[-1], std[-1]))
        
        color = palette(1)
        array_list = all_res['rz=10000']
        max_len = max(len(arr) for arr in array_list)
        # 使用 pad 将所有数组填充到相同长度
        padded_array = np.array([np.pad(arr, (0, max_len - len(arr)), mode='edge') for arr in array_list])
        all_res['rz=10000'] = np.array(padded_array)
        array_list = query_res['rz=10000']
        max_len = max(len(arr) for arr in array_list)
        # 使用 pad 将所有数组填充到相同长度
        padded_array = np.array([np.pad(arr, (0, max_len - len(arr)), mode='edge') for arr in array_list])
        query_res['rz=10000'] = np.array(padded_array)
        avg = np.mean(100*all_res['rz=10000'], axis=0)
        std = np.std(100*all_res['rz=10000'], axis=0)
        r1 = list(map(lambda x: x[0] - x[1], zip(avg, std)))
        r2 = list(map(lambda x: x[0] + x[1], zip(avg, std)))
        plt.plot(query_res['rz=10000'][0], avg, color=color, label='rz=10000', linewidth=4.0)
        plt.fill_between(query_res['rz=10000'][0], r1, r2, color=color, alpha=0.2)
        print("rz=10000 acc {} std {}".format(avg[-1], std[-1]))

        
        
        # output_image_path = './clean_'+dataname+'.pdf'
        # plt.savefig(output_image_path, format='pdf')
        
        plt.tick_params(labelsize=32)
        # plt.xscale('log')
        plt.legend(loc='lower right', prop=font)
        plt.xlabel('Query', font=font)
        plt.grid(True)
        # plt.xlim((-10,300))
        # plt.xlim((-50, 0.5e8))
        # plt.ylim((0.65, 1.05))
        plt.ylabel('Test Accuracy', font=font)
        # plt.savefig(save_folder+'/time_acc.jpeg')
        output_image_path = './clean_r_'+dataname+'.pdf'
        plt.savefig(output_image_path, format='pdf')
    

   