import os
import scipy.io as sio


def testAndSaveParams(cfgs, model, images, taskIndex, iter):
    dat = dict()
    # save running iter number
    dat['iteration'] = int(iter)

    # Save images
    dat['images'] = images

    # Save weights
    for name, param in model.named_parameters():
        dat['wts_' + name] = param.data.cpu().numpy()

    if not os.path.exists(os.path.join(cfgs['out_root'])):
        os.makedirs(os.path.join(cfgs['out_root']))

    # Write to file
    sio.savemat(os.path.join(cfgs['out_root'], 'saved' + '_' + str(taskIndex) + '.mat'), dat)


def testAndSaveParams2(cfgs, model, images, taskIndex, iter):
    dat = dict()
    # save running iter number
    dat['iteration'] = int(iter)

    # Save images
    dat['images'] = images

    # Save weights
    for name, param in model.named_parameters():
        dat['wts_' + name] = param.data.cpu().numpy()
    dat['mask'] = model.mask.cpu().numpy()

    if not os.path.exists(os.path.join(cfgs['out_root'])):
        os.makedirs(os.path.join(cfgs['out_root']))

    # Write to file
    sio.savemat(os.path.join(cfgs['out_root'], 'saved' + '_' + str(taskIndex) + '.mat'), dat)


def read_mat_file(file_path):
    # 加载 .mat 文件
    data = sio.loadmat(file_path)

    # 返回数据字典
    return data
