import itertools
from scipy.io import loadmat
import os
import argparse
from energy import compute_snn_ann_energy_ratio, init_energy
import numpy as np

MODEL_SUFFIX="best_model.pth"
MAT_SUFFIX="res.mat"

def Args():
    parser=argparse.ArgumentParser()
    parser.add_argument('--architecture',type=str)
    parser.add_argument('--dataset',type=str)
    args=parser.parse_args()
    return args


def grid_best_QCFS(architecture, dataset, verbose=False, compute_energy: bool = True):
    """  
    如果 compute_energy=False，那么不会计算 eng_32，也不会打印关于 eng_32 的总结信息。
    """
    def mat_path(architecture,dataset,conv_sparsity,lr,bs,l):
        ARCHI= 'mlp' if architecture=='MLP' else 'cnn'
        linear_sparsity=0.99 if architecture=='MLP' and conv_sparsity==0.5 else 0.0
        sparse=f'conv_{conv_sparsity}/s_{linear_sparsity}' if ARCHI=='cnn' else f's_{linear_sparsity}'
        return os.path.join('../QCFS/results',architecture,dataset,sparse, f'lr_{lr}/bs_{bs}/L_{l}','ft' ,MAT_SUFFIX) if architecture=='VGG-16' else os.path.join('../QCFS/results',architecture,dataset,sparse, f'lr_{lr}/bs_{bs}/L_{l}' ,MAT_SUFFIX)

    lr_range = [0.05, 0.01, 0.005, 0.001, 0.0005] if architecture == 'VGG-16' else [0.1, 0.05, 0.01, 0.005, 0.001, 0.0005]
    bs_range = [32, 64, 128]
    l_range = [2, 4, 8, 16, 32]
    sparsity = 0.5 if architecture == 'VGG-16' else 0.99

    # 初始化 best_experiments 和 best_values
    best_experiments = {
        0.0: {'snn_max': None, 'snn_32': None, 'eng_32': None},
        sparsity: {'snn_max': None, 'snn_32': None, 'eng_32': None}
    }
    best_values = {
        0.0: {'snn_max': -float('inf'), 'snn_32': -float('inf'), 'eng_32': float('inf') if compute_energy else None},
        sparsity: {'snn_max': -float('inf'), 'snn_32': -float('inf'), 'eng_32': float('inf') if compute_energy else None}
    }

    experiments_metrics = dict()  # 存所有实验三个指标（或两个指标，如果 energy 不计算的话）

    # 如果需要计算 energy，就先做准备工作
    if compute_energy:
        model, dummy_input, device = init_energy('QCFS', architecture, dataset, "cuda:0")

    experiments = list(itertools.product([0.0, sparsity], lr_range, bs_range, l_range))

    for conv_sparsity, lr, bs, l in experiments:
        file_name = mat_path(architecture, dataset, conv_sparsity, lr, bs, l)
        if not os.path.exists(file_name):
            if verbose:
                print(f'file {file_name} not found. continue')
            continue
        try:
            res = loadmat(file_name)
            snn_acc = res['snn_acc'].flatten()
            ts = np.arange(1, len(snn_acc)+1)
            LASFR = res['LASFR']

            # 计算指标
            max_snn = float(np.max(snn_acc))
            if len(snn_acc) > 31:
                snn_32 = float(snn_acc[31])
            else:
                snn_32 = float('nan')

            # 初始化 eng_32 值为 None 或 np.nan
            eng_32 = None
            if compute_energy:
                try:
                    eng_ratio_dict = compute_snn_ann_energy_ratio(ts, LASFR, model, dummy_input, device)
                    arr = eng_ratio_dict.get('ratio_snn_ann')
                    if arr is not None and len(arr) > 31:
                        eng_32 = float(arr[31])
                    else:
                        eng_32 = float('nan')
                except Exception as e:
                    if verbose:
                        print(f'Error computing energy_ratio for {file_name}: {e}')
                    eng_32 = float('nan')

            # 存 metrics
            # 如果不 compute_energy，就只存两个指标
            if compute_energy:
                experiments_metrics[(conv_sparsity, lr, bs, l)] = {
                    'snn_max': max_snn,
                    'snn_32': snn_32,
                    'eng_32': eng_32
                }
            else:
                experiments_metrics[(conv_sparsity, lr, bs, l)] = {
                    'snn_max': max_snn,
                    'snn_32': snn_32
                }

            # 更新最佳
            # 总是更新 snn_max 和 snn_32
            for key, val in [('snn_max', max_snn), ('snn_32', snn_32)]:
                if val is None or (isinstance(val, float) and np.isnan(val)):
                    continue
                if val > best_values[conv_sparsity][key]:
                    best_values[conv_sparsity][key] = val
                    best_experiments[conv_sparsity][key] = (conv_sparsity, lr, bs, l)

            # 如果 compute_energy, 更新 eng_32 最佳
            if compute_energy:
                # 注意 eng_32 是越小越好还是越大越好？你原来写 eng_32 用 <，但要确认你的定义
                # 假设 eng_32 越小越好，如你原来那样
                key = 'eng_32'
                val = eng_32
                if val is not None and not (isinstance(val, float) and np.isnan(val)):
                    if val < best_values[conv_sparsity][key]:
                        best_values[conv_sparsity][key] = val
                        best_experiments[conv_sparsity][key] = (conv_sparsity, lr, bs, l)

        except Exception as e:
            if verbose:
                print(f"Error loading {file_name}: {e}")
            continue

    # 打印结果
    if compute_energy:
        for conv_s in [0.0, sparsity]:
            print(f"--- For conv_sparsity = {conv_s} ---")
            # Always print snn_max and snn_32
            for key in ['snn_max', 'snn_32', 'eng_32']:
                best_params = best_experiments[conv_s][key]
                best_val = best_values[conv_s][key]
                if best_params is None:
                    print(f"Indicator {key}: no valid experiment found.")
                    continue
                lr_best, bs_best, l_best = best_params[1], best_params[2], best_params[3]
                print(f"Best {key} at (sparsity={conv_s}, lr={lr_best}, bs={bs_best}, l={l_best}) \n    {key} = {best_val:.6g} ")
                em = experiments_metrics.get(best_params)
                if em is not None:
                    # other_keys excludes the key itself
                    other_keys = [k for k in ['snn_max', 'snn_32', 'eng_32'] if k != key]
                    for ok in other_keys:
                        v = em.get(ok)
                        print(f"    {ok} = {v:.6g}")
                else:
                    print("    Cannot find metrics for that experiment.")            

    return best_experiments


def grid_best_input(architecture, dataset, mat_name='res.mat'):
    '''
    returns (lr, bs) which yields best performance on sparse input training.
    Note sensitivity test is deprecated becaused of one_fc.
    '''
    conv_sparsity = 0.5
    best_config = (0, 0)
    best_value = -1.0
    lr_range = (1e-1, 1e-2, 1e-3)
    bs_range = (32, 64, 128)

    def input_path(architecture, dataset, lr, bs):
        basic_input = '../input'
        return os.path.join(basic_input ,architecture, dataset,f'conv_0.5',f's_0.0/d_0.0/onefc_True', f'lr_{lr}/bs_{bs}', mat_name)

    experiments = list(itertools.product(lr_range, bs_range))
    for lr, bs in experiments:
        file_name = input_path(architecture, dataset, lr, bs) # 你的文件命名规则
        if not os.path.exists(file_name):
            raise ValueError(f"{file_name} doesn't exsit")
        try:
            res = loadmat(file_name)
            acc =res['acc_val'].max()
            #print(f'{lr = }, {bs = }, {acc}')

            if acc > best_value:
                best_config = (lr, bs)
                best_value = acc
        except Exception as e:
                raise RuntimeError(f"Error loading {file_name}: {e}")

    print(f'{best_config = }')

    return best_config


def grid_best_SNM(
    architecture: str,
    dataset: str,
    verbose: bool = False
) :
    """
    对稀疏情况(conv_sparsity>0 或 linear_sparsity>0)
    加载对应路径下的 res.mat，只计算指标 snn_max, snn_32，找出 best bs。
    
    返回:
      best_bs: 最好表现的 bs（按 snn_max 或 snn_32标准）；如果没有可用文件，则返回 None。
      best_exp: dict，包括最好的 'bs', 'snn_max', 'snn_32'
    """
    # 内部定义 mat_path
    def mat_path(architecture, dataset,  bs):
        if architecture == 'MLP':
            sparse = f"s_0.99"
        else:
            sparse = f"conv_0.5/s_0.0"
        res_save = os.path.join(
            '..',
            'SNM',
            'results',
            architecture,
            dataset,
            sparse
        )
        return os.path.join(res_save, f'bs_{bs}', 'res.mat')


    bs_list = [32, 64, 128, 256, 512]

    best_bs = None
    best_snn_max = -float('inf')
    best_snn_32 = -float('inf')

    # 遍历 bs_list
    for bs in bs_list:
        file_name = mat_path(architecture, dataset,  bs)
        if not os.path.exists(file_name):
            if verbose:
                print(f"[跳过] 文件不存在: {file_name}")
            continue
        try:
            res = loadmat(file_name)
        except Exception as e:
            if verbose:
                print(f"[错误] 加载 {file_name} 时出错: {e}")
            continue

        if 'snn_acc' not in res:
            if verbose:
                print(f"[跳过] 在 {file_name} 中未找到 'snn_acc'")
            continue

        snn_acc = res['snn_acc'].flatten()
        try:
            max_snn = float(np.max(snn_acc))
        except Exception as e:
            if verbose:
                print(f"[错误] 计算 max_snn 失败: {file_name}, {e}")
            continue

        if len(snn_acc) > 31:
            snn_32 = float(snn_acc[31])
        else:
            snn_32 = float('nan')

        if verbose:
            print(f"bs={bs}: snn_max={max_snn:.6f}, snn_32={snn_32:.6f}")

        # 更新最佳（以 snn_max 为主要标准；如果 snn_max 相同，可以考虑 snn_32）
        if not np.isnan(max_snn) and max_snn > best_snn_max:
            best_snn_max = max_snn
            best_snn_32 = snn_32
            best_bs = bs

    best_exp = {}
    if best_bs is not None:
        best_exp = {
            'bs': best_bs,
            'snn_max': best_snn_max,
            'snn_32': best_snn_32
        }
    else:
        if verbose:
            print("未找到任何有效 bs 对应的 res.mat 或者指标。")

    return best_bs, best_exp

def grid_best_AEC(arch, ds, verbose: bool = False, compute_energy: bool = True):
    """
    找到 AEC 实验中每个稀疏度下的 best experiment。
    指标为: 'snn_max' (max over snn_acc), 'snn_32' (snn_acc[7]), 'eng_8' (energy ratio at t=8 -> index 7)
    compute_energy 控制是否计算 eng_8 并打印相关信息。
    """
    sparsity = 0.5 if arch == 'VGG-16' else 0.99
    lrs = [1e-3, 1e-4, 1e-5]
    bss = [32, 64, 128]
    sparsity_range = [0.0, sparsity]

    def mat_path(arch, ds, lr, bs, cur_sparsity):
        res_path = "../AEC/results"
        if arch == 'MLP':
            sparse = f"s_{cur_sparsity}"
        else:
            sparse = f"conv_{cur_sparsity}/s_0.0"
        config = f'lr_{lr}/bs_{bs}'
        return os.path.join(res_path, arch, ds, sparse, config, 'res.mat')

    # 初始化 best_experiments 和 best_values
    best_experiments = {
        0.0: {'snn_max': None, 'snn_32': None, 'eng_8': None},
        sparsity: {'snn_max': None, 'snn_32': None, 'eng_8': None}
    }
    best_values = {
        0.0: {'snn_max': -float('inf'), 'snn_32': -float('inf'), 'eng_8': (float('inf') if compute_energy else None)},
        sparsity: {'snn_max': -float('inf'), 'snn_32': -float('inf'), 'eng_8': (float('inf') if compute_energy else None)}
    }

    experiments_metrics = dict()  # 存所有实验的指标

    # 如果需要计算 energy，就先做准备工作
    if compute_energy:
        try:
            model, dummy_input, device = init_energy('AEC', arch, ds, "cuda:0")
        except Exception as e:
            if verbose:
                print(f"init_energy failed: {e}")
            # 若 init_energy 失败且 compute_energy=True，则仍然继续，但后续能量计算会报错并被捕获

    experiments = list(itertools.product(sparsity_range, lrs, bss))

    for cur_sparsity, lr, bs in experiments:
        file_name = mat_path(arch, ds, lr, bs, cur_sparsity)
        if not os.path.exists(file_name):
            if verbose:
                print(f'file {file_name} not found. continue')
            continue
        try:
            res = loadmat(file_name)
            snn_acc = res.get('snn_acc')
            snn_acc = np.array(snn_acc).flatten()
            ts = [2,4,8,16,32,64,128]
            LASFR = res.get('LASFR')

            # 指标计算
            max_snn = float(np.max(snn_acc))
            snn_32 = float(snn_acc[4])

            # 初始化 eng_8 为 None / nan
            eng_8 = None
            if compute_energy:
                try:
                    eng_ratio_dict = compute_snn_ann_energy_ratio(ts, LASFR, model, dummy_input, device)
                    arr = eng_ratio_dict.get('ratio_snn_ann')
                    if arr is not None and len(arr) > 3:  
                        eng_8 = float(arr[3])
                    else:
                        eng_8 = float('nan')
                except Exception as e:
                    if verbose:
                        print(f'Error computing energy_ratio for {file_name}: {e}')
                    eng_8 = float('nan')

            # 存 metrics（根据 compute_energy 决定是否包含 eng_8）
            if compute_energy:
                experiments_metrics[(cur_sparsity, lr, bs)] = {
                    'snn_max': max_snn,
                    'snn_32': snn_32,
                    'eng_8': eng_8
                }
            else:
                experiments_metrics[(cur_sparsity, lr, bs)] = {
                    'snn_max': max_snn,
                    'snn_32': snn_32
                }

            # 更新最佳 snn_max 和 snn_32
            for key, val in [('snn_max', max_snn), ('snn_32', snn_32)]:
                if val is None or (isinstance(val, float) and np.isnan(val)):
                    continue
                if val > best_values[cur_sparsity][key]:
                    best_values[cur_sparsity][key] = val
                    best_experiments[cur_sparsity][key] = (cur_sparsity, lr, bs)

            # 如果 compute_energy, 更新 eng_8（假设越小越好）
            if compute_energy:
                key = 'eng_8'
                val = eng_8
                if val is not None and not (isinstance(val, float) and np.isnan(val)):
                    if val < best_values[cur_sparsity][key]:
                        best_values[cur_sparsity][key] = val
                        best_experiments[cur_sparsity][key] = (cur_sparsity, lr, bs)

        except Exception as e:
            if verbose:
                print(f"Error loading or processing {file_name}: {e}")
            continue

    # 打印结果（依据 compute_energy 决定是否打印 eng_8）
    if verbose:
        for conv_s in [0.0, sparsity]:
            print(f"--- For sparsity = {conv_s} ---")
            keys_to_print = ['snn_max', 'snn_32']
            if compute_energy:
                keys_to_print.append('eng_8')
            for key in keys_to_print:
                best_params = best_experiments[conv_s][key]
                best_val = best_values[conv_s][key]
                if best_params is None:
                    print(f"Indicator {key}: no valid experiment found.")
                    continue
                lr_best, bs_best = best_params[1], best_params[2]
                # 有时 best_val 可能为 None
                if best_val is None:
                    print(f"Best {key} at (sparsity={conv_s}, lr={lr_best}, bs={bs_best}) \n    {key} = None ")
                else:
                    # 若为数字，格式化输出
                    try:
                        print(f"Best {key} at (sparsity={conv_s}, lr={lr_best}, bs={bs_best}) \n    {key} = {best_val:.6g} ")
                    except Exception:
                        print(f"Best {key} at (sparsity={conv_s}, lr={lr_best}, bs={bs_best}) \n    {key} = {best_val} ")

                em = experiments_metrics.get(best_params)
                if em is not None:
                    other_keys = [k for k in keys_to_print if k != key]
                    for ok in other_keys:
                        v = em.get(ok)
                        if v is None:
                            print(f"    {ok} = None")
                        else:
                            try:
                                print(f"    {ok} = {v:.6g}")
                            except Exception:
                                print(f"    {ok} = {v}")
                else:
                    print("    Cannot find metrics for that experiment.")

    return best_experiments


if __name__=='__main__':
    print(grid_best_QCFS('VGG-16','CIFAR10', True ))









'''
def find_best_sparsity(dataset, method_str):
    #路径
    path = os.path.join('../input/VGG-16',dataset, f'conv_0.5') 

    # 存储结果: {sparsity: best_acc}
    results = {}
    
    # 遍历目标路径
    for entry in os.listdir(path):
        subdir = os.path.join(path, entry)
        
        # 检查是否为"s_{sparsity}"格式的文件夹
        if os.path.isdir(subdir) and entry.startswith("s_"):
            try:
                # 提取稀疏率字符串并转换为浮点数
                sparsity_str = entry[2:]  # 去掉"s_"前缀
                sparsity = float(sparsity_str)
            except ValueError:
                continue  # 跳过无法转换为浮点数的文件夹
            
            # 构建res.mat文件路径
            mat_path = os.path.join(subdir, method_str,"res.mat")
            
            # 确保文件存在
            if not os.path.isfile(mat_path):
                continue
            
            try:
                # 加载MAT文件
                mat_data = loadmat(mat_path)
                
                # 检查'acc_val'键是否存在
                if 'acc_val' not in mat_data:
                    print('Key acc not found.' )
                    continue
                
                # 获取acc数组并找到最大值
                acc_array = mat_data['acc_val']
                best_acc = acc_array.max()
                
                # 仅记录非零稀疏率的结果
                if sparsity != 0.0:
                    results[sparsity] = best_acc
                    
            except Exception as e:
                print(f"Error processing {mat_path}: {str(e)}")
                continue
    
    # 检查是否有有效结果
    if not results:
        return None
    
    # 找到最佳准确率对应的稀疏率
    best_sparsity = max(results, key=results.get)
    return best_sparsity

'''
