%% process_diffusions.m
% 清空环境
clear; clc;

%% Step 1: 读取真实网络文件（共享参数）
A = txt_to_network("true_A_scale.txt");
B = txt_to_network("true_B.txt");

%% Step 2: 设置 connie 函数所需超参数（所有级联共享）
rho = 0.5;                     % 稀疏性控制参数
incubation_model_number = 2;   % 感染模型编号（例如，指数衰减模型）
suboptimal_tol = 0.5;          % 次优容忍度

%% Step 3: 设置输入和输出文件夹路径
% 输入文件夹 "scale" 位于 main.m 同一目录下
inputFolder = fullfile(pwd, 'scale');
% 输出文件夹 "results" 位于 main.m 同一目录下
outputFolder = fullfile(pwd, 'res_scale');
if ~exist(outputFolder, 'dir')
    mkdir(outputFolder);
end

%% Step 4: 循环处理 20 个 diffusion 文件，并保存结果
for i = 1:20
    % 构造 diffusion 文件名，例如 "scale_cascade_1.txt"
    diff_filename = fullfile(inputFolder, sprintf("scale_cascade_%d.txt", i));
    fprintf('正在处理文件: %s\n', diff_filename);
    
    % 读取 diffusion 文件构造 diffusions 矩阵
    diffusions = read_diffusions(diff_filename);
    
    % 调用 connie 函数，真实网络采用 A+B（若无真实网络，可直接传 diffusions）
    [A_mle, precision, recall, mse] = connie(rho, incubation_model_number, diffusions, A+B, suboptimal_tol);
    
    % 显示结果
    disp('推断的网络 A_mle:');
    disp(full(A_mle));  % 若 A_mle 为稀疏矩阵，可用 full() 转为普通矩阵查看
    
    if ~isempty(precision)
        fprintf('Precision: %.4f\n', precision);
        fprintf('Recall: %.4f\n', recall);
        fprintf('MSE: %.4f\n', mse);
    else
        disp('未提供真实网络，未计算 precision、recall 及 mse。');
    end
    
    % 构造输出文件名，并将 A_mle 保存到 "results" 文件夹中
    res_filename = fullfile(outputFolder, sprintf('res_%d.txt', i));
    sparse_to_txt(A_mle, res_filename);
    fprintf('已将结果保存到: %s\n\n', res_filename);
end

%% -------------------- Helper Functions ---------------------- %%
function A = txt_to_network(filename)
% txt_to_network 将 txt 文件转换为邻接矩阵
%   文件格式为每行 "i,j,Aij"，其中 i 和 j 为 0-based 索引
    A = zeros(100, 100);  % 默认初始化为 100x100 矩阵，根据需要调整大小
    fid = fopen(filename, 'r');
    if fid == -1
        error('无法打开文件: %s', filename);
    end
    while ~feof(fid)
        line = fgetl(fid);
        if ischar(line) && ~isempty(line)
            values = str2double(strsplit(line, ','));
            i_idx = values(1) + 1;  % 转换为 MATLAB 1-based 索引
            j_idx = values(2) + 1;
            A(i_idx, j_idx) = values(3);
        end
    end
    fclose(fid);
end

function sparse_to_txt(A, filename)
% sparse_to_txt 将 MATLAB 稀疏矩阵 A 写入 txt 文件
%   输出格式：每行 "i,j,Aij"，其中 i 和 j 为 0-based 索引，Aij 保留 8 位小数
%   仅输出 Aij > 0 的项
    if ~issparse(A)
        warning('输入矩阵不是稀疏矩阵，将转换为稀疏矩阵。');
        A = sparse(A);
    end

    [row, col, vals] = find(A);
    mask = vals > 0;
    row = row(mask);
    col = col(mask);
    vals = vals(mask);
    
    fid = fopen(filename, 'w');
    if fid == -1
        error('无法打开文件: %s', filename);
    end
    
    for k = 1:length(vals)
        fprintf(fid, '%d,%d,%.8f\n', row(k)-1, col(k)-1, vals(k));
    end
    fclose(fid);
end

function diffusions = read_diffusions(filename)
% read_diffusions 读取 txt 文件，并构造 diffusions 矩阵
%   文件中每行的格式为：
%       "node_id1 infection_time1, node_id2 infection_time2, ..."
%   假定文件中 node_id 从 0 开始，MATLAB 索引从 1 开始；
%   如果感染时间为 10，则转换为 -1（表示未被感染）。
    fid = fopen(filename, 'r');
    if fid == -1
        error('无法打开文件: %s', filename);
    end
    lines = textscan(fid, '%s', 'Delimiter', '\n');
    fclose(fid);
    lines = lines{1};
    
    nc = length(lines);  % 级联数量
    firstLinePairs = strsplit(lines{1}, ', ');
    nd = length(firstLinePairs);
    
    diffusions = zeros(nc, nd);
    
    for i = 1:nc
        pairs = strsplit(lines{i}, ', ');
        for j = 1:length(pairs)
            parts = strsplit(pairs{j});
            time_val = str2double(parts{2});
            if time_val == 10
                time_val = -1;
            end
            diffusions(i, j) = time_val;
        end
    end
end
