import matplotlib
matplotlib.use('Agg', force=True)
from matplotlib import pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
from scipy.io import loadmat
import os
import pandas as pd
import sys
sys.path.append("..")
from exp_comparison import grid_best_input

colors=['black','red']
conv_sparsity_range=(0.0, 0.5)

default_fs = plt.rcParams['font.size']

'''
def sensitivity(args):
    sparsity_list = [0.0, 0.5, 0.7, 0.9, 0.95, 0.99]
    acc_list = []
    texts = []
    fs=1.5*default_fs

    # 新建 figure 和 axes
    fig, ax = plt.subplots(figsize=(8, 6))

    for s in sparsity_list:
        path = os.path.join(basic_input_path, f'conv_0.5/s_{s}/','d_0.3/onefc_False', 'res.mat')
        if not os.path.isfile(path):
            raise FileNotFoundError(f"找不到文件: {path}")
        mat = loadmat(path)
        if 'acc_val' not in mat:
            raise KeyError(f"'acc_val' not found in mat file: {path}")
        acc = mat['acc_val'].max()
        acc_list.append(float(acc)*100)

    base_line=get_baseline(args)[0]
    ax.axhline(y=base_line, color='black', linestyle='-', linewidth=1)
    texts.append(ax.text(0.5, base_line, f"dense {base_line:.2f}%", fontsize=fs))

    # 绘制数据折线图
    ax.plot(sparsity_list, acc_list, marker='o', linestyle='-', color='red')
    ax.set_xticks(sparsity_list)
    ax.set_xticklabels([str(s) for s in sparsity_list], rotation=45, fontsize=fs)
    ax.tick_params(axis='y', labelsize=fs)
    ax.set_xlabel('classifier sparsity(conv sparsity=0.5)', fontsize=fs)
    ax.set_ylabel('ANN. acc(%)', fontsize=fs)
    ax.set_title(f'sensitivity test of {args.architecture} on {args.dataset}', fontsize=fs)
    ax.grid(True)
    #ax.legend(fontsize=fs)

    adjust_text(texts)

    filepath = os.path.join(save_path, f'sensitivity_new.png')
    plt.savefig(filepath, bbox_inches='tight')
    plt.close(fig)
'''

lr_range=(1e-1, 1e-2, 1e-3)
bs_range=(32, 64, 128)

def ablation_input(dataset, param_name, param_range, log= False, architecture='VGG-16'):
    fs = 1.5 * default_fs
    title = f"ab_Input {param_name}"
    fig, ax = plt.subplots(figsize=(8, 6))  # 使用 fig, ax = plt.subplots() 风格
    colors=('black', 'red')

    global_best = list(grid_best_input(architecture, dataset))

    for i, spar in enumerate(conv_sparsity_range):
        data = []
        
        match param_name:
            case 'lr':
                var, label = 0, f"bs_{global_best[1]}"
            case 'bs':
                var, label = 1, f"lr_{global_best[0]}"
            case _:
                raise ValueError('invalid hyper parameter name')

        for x in param_range:
            global_best[var] = x
            lr, bs = global_best
            basic_input = '../input'
            filename =os.path.join(basic_input , architecture,dataset,f'conv_{spar}',f's_0.0/d_0.0/onefc_True', f'lr_{lr}/bs_{bs}', 'res.mat')
            mat = loadmat(filename)
            acc = np.max(mat['acc_val'].flatten()) * 100
            data.append(acc)

        color = colors[i]
        ax.plot(param_range, data, color=color, label=f"conv_{spar}/" + label)

    ax.set_xlabel(param_name, fontsize=fs)
    if log:
        ax.set_xscale('log')
    ax.set_xticks(param_range)
    ax.set_xticklabels([str(x) for x in param_range], fontsize=fs)
    ax.tick_params(axis='y', labelsize=fs)
    ax.set_ylabel("input ANN accuracy (%)", fontsize=fs)
    ax.set_title(title, fontsize=fs)
    ax.legend(fontsize=fs)
    ax.grid(True)

    save_path=os.path.join(f'../graphs/{architecture}', dataset)
    fig.tight_layout()
    save_file = os.path.join(save_path, f'{title}.png')
    fig.savefig(save_file)
    plt.close(fig)

if __name__ == '__main__':
    for dataset in ['CIFAR10', 'CIFAR100']:
        ablation_input(dataset, 'lr', lr_range, True)
        ablation_input(dataset, 'bs', bs_range)