
import argparse
import time

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d
import itertools

from hoag.benchmark import framed_results_for_kwargs


# DATASETS = ['real-sim']
DATASETS = ['20news']

if __name__ == '__main__':
    start = time.time()
    parser = argparse.ArgumentParser(
    description='Draw Figures bi-level optimization (2. and E.1.)')
    parser.add_argument('--appendix_figure', '-a', action='store_true',
                        help='Run for appendix figure.')
    parser.add_argument('--no_recomp', '-nr', action='store_true',
                        help='No recomputation of the results.')
    parser.add_argument('--no_save', '-ns', action='store_true',
                        help='No saving of the results.')
    parser.add_argument('--interp', action='store_true',
                        help='Use interpolation curves.')
    parser.add_argument('--quantile', '-q', type=int, default=10,
                        help='Use first and last q-quantile for variance.')
    parser.add_argument('--eps', type=float, default=1,
                        help='Max sub-optimality level.')
    parser.add_argument('--objective', dest='subopt', action='store_false',
                        help='If set, plot the objective value instead of '
                        'the sub optimality.')
    args = parser.parse_args()

    save_results = not args.no_save
    reload_results = not args.no_recomp
    appendix_figure = args.appendix_figure
    maxiter_inner = 1000
    max_iter = 100
    train_prop = 90/100
    args.subopt=False
    

    def generate_grid_search_schemes(param_grid):
        schemes = {}
        for i, param_combination in enumerate(itertools.product(*param_grid.values())):
            scheme_name = f'grid_search_{i}'
            scheme_params = dict(zip(param_grid.keys(), param_combination))
            schemes[scheme_name] = scheme_params
        return schemes

    # 定义参数网格
    # param_grid = {
    #     'max_iter': [100, 200],
    #     'exponential_decrease_factor': [0.75, 0.85],
    #     'maxcor': [20, 30],
    #     'maxiter_inner': [5, 10]
    # }

    schemes = {
        # 'warm-up': dict(max_iter=2, tol=0.1),
        
#         'PZOBO': dict(
#             max_iter=1, PZOBO=True,
#             exponential_decrease_factor=0.78, 
#             maxiter_inner=10,pure_python=True
#         ),
        
#         'amicg': dict(
#             max_iter=1, saba=True, amicg=30,
#             exponential_decrease_factor=0.78,
#             maxiter_inner=10,pure_python=True
#         ),
        
        'aidcg': dict(
            max_iter=400, aidcg=True, maxcor=30,
            exponential_decrease_factor=0.78, 
            maxiter_inner=10,pure_python=True
        ),
#         'aidtn': dict(
#             max_iter=1, foa=True, aidtn=30,
#             exponential_decrease_factor=0.78,
#             maxiter_inner=10,pure_python=True
#         )
        
    }


    #schemes = {
    #     'warm-up': dict(max_iter=2, tol=0.1),
        
        # 'sr1': dict(
        #     max_iter=100, sr1=True,
        #     exponential_decrease_factor=0.78, 
        #     maxiter_inner=10,pure_python=True
        # ),
        
    #     'saba': dict(
    #         max_iter=20000, saba=True, maxcor=30,
    #         exponential_decrease_factor=0.78,
    #         maxiter_inner=10,pure_python=True
    #     ),
        
    #     'f2sa': dict(
    #         max_iter=2000, f2sa=True, maxcor=30,
    #         exponential_decrease_factor=0.78, 
    #         maxiter_inner=10,pure_python=True
    #     ),
    #    'shine-big-rank-foa': dict(
    #        max_iter=50, foa=True, maxcor=30,
    #        exponential_decrease_factor=0.78,
    #        maxiter_inner=10,pure_python=True
    #    ),
        
    #     'Bsg1': dict(
    #         max_iter=300, bsg1=True, maxcor=30,
    #         exponential_decrease_factor=0.78,
    #         maxiter_inner=10,pure_python=True
    #     ),
    #     'Bome': dict(
    #         max_iter=200, bome=True, maxcor=30,
    #         exponential_decrease_factor=0.78, 
    #         maxiter_inner=10,pure_python=True
    #     ),
    #     'shine-big-rank-opa': dict(
    #         max_iter=30, shine=True, maxcor=60,
    #         exponential_decrease_factor=0.78, 
    #         maxiter_inner=maxiter_inner,pure_python=True, opa=True
    #     ),
        
        
    #}

        # 将网格搜索方案添加到 schemes 中
    # schemes.update(generate_grid_search_schemes(param_grid))
    
    print(schemes)
    # raise
    for dataset in DATASETS:

        results_name = (
            f'{dataset}_mi{maxiter_inner}_tp{train_prop:.2f}_results.csv'
        )
        if reload_results:
            # 计算不同的方法

            schemes_results = {
                scheme_label: framed_results_for_kwargs(
                    train_prop=train_prop, dataset=dataset, n_random_seed=1,
                    **scheme_kwargs
                ) for scheme_label, scheme_kwargs in schemes.items()
            }

            big_df_res = None
            for scheme_label, df_res in schemes_results.items():
                df_res['scheme_label'] = scheme_label
                if big_df_res is None:
                    big_df_res = df_res
                else:
                    big_df_res = big_df_res._append(df_res)
            if save_results:
                big_df_res.to_csv(results_name)

    if not appendix_figure:
        # included_schemes = [   
        #     'shine-big-rank-foa', 
        #     'shine-big-rank-opa', 
        #     'sr1',
        #     'Bsg1',   
        #     'Bome',
        #     'saba',
        #     'f2sa',
        # ]
        included_schemes = [   
            'PZOBO', 
            'amicg', 
            'aidcg',
            'aidtn',   
        ]
    else:
        included_schemes = [
            'original', 'truncated-inversion',
            'shine-big-rank', 'shine-big-rank-refined',
            'grid-search', 'random-search',
            'fpn',
        ]


    end = time.time()
    print(f'The script took {end-start} seconds to run')
