import os
import scipy.io as sio


def testAndSaveParams(cfgs, model, images, taskIndex, iter):
    dat = dict()
    # save running iter number
    dat['iteration'] = int(iter)

    # Save images
    dat['images'] = images

    # Save weights
    for name, param in model.named_parameters():
        dat['wts_' + name] = param.data.cpu().numpy()

    if not os.path.exists(os.path.join(cfgs['out_root'])):
        os.makedirs(os.path.join(cfgs['out_root']))

    # Write to file
    sio.savemat(os.path.join(cfgs['out_root'], 'saved' + '_' + str(taskIndex) + '.mat'), dat)


def testAndSaveParams2(cfgs, model, images, taskIndex, iter):
    dat = dict()
    # save running iter number
    dat['iteration'] = int(iter)

    # Save images
    dat['images'] = images

    # Save weights
    for name, param in model.named_parameters():
        dat['wts_' + name] = param.data.cpu().numpy()
    dat['mask'] = model.mask.cpu().numpy()

    if not os.path.exists(os.path.join(cfgs['out_root'])):
        os.makedirs(os.path.join(cfgs['out_root']))

    # Write to file
    sio.savemat(os.path.join(cfgs['out_root'], 'saved' + '_' + str(taskIndex) + '.mat'), dat)


def testAndSaveParams_damage_nomask(cfgs, model, images, neuron_role_key, damage_neuron_indices_key, taskIndex, iter):
    dat = dict()
    # save running iter number
    dat['iteration'] = int(iter)

    # Save images
    dat['images'] = images

    # Save weights
    for name, param in model.named_parameters():
        dat['wts_' + name] = param.data.cpu().numpy()

    out_root = os.path.join(cfgs['damage_train_root'], cfgs['model_type'])
    if not os.path.exists(out_root):
        os.makedirs(out_root)

    # Write to file
    sio.savemat(os.path.join(out_root, 'saved_{}_{}_{}.mat'.format(neuron_role_key, damage_neuron_indices_key, taskIndex)), dat)

def testAndSaveParams_damage_mask(cfgs, model, images, neuron_role_key, damage_neuron_indices_key, taskIndex, iter):
    dat = dict()
    # save running iter number
    dat['iteration'] = int(iter)

    # Save images
    dat['images'] = images

    # Save weights
    for name, param in model.named_parameters():
        dat['wts_' + name] = param.data.cpu().numpy()
    dat['mask'] = model.mask.cpu().numpy()

    out_root = os.path.join(cfgs['damage_train_root'], cfgs['model_type'])
    if not os.path.exists(out_root):
        os.makedirs(out_root)

    # Write to file
    sio.savemat(os.path.join(out_root, 'saved_{}_{}_{}.mat'.format(neuron_role_key, damage_neuron_indices_key, taskIndex)), dat)

def read_mat_file(file_path):
    # 加载 .mat 文件
    data = sio.loadmat(file_path)

    # 返回数据字典
    return data
