function memo = f_poisson_tc_BiTNN_dct_ADMM(obs, opts, memo)
% ADMM-based solver for Poisson tensor completion using Bi-directional TNN under DCT.
% Only the first two modes are regularized; the third mode is excluded (badMode=3).

% Extract parameters from options
lambda = opts.para.lambda;             % Regularization strength
infBound = opts.para.infBound;         % Max value of tensor entries
backGround = opts.para.backGround;     % Constant background level for Poisson model

rho = opts.para.rho;                   % Initial global penalty parameter
nu = opts.para.nu;                     % Penalty update factor

vW = [0.5 0.5 0];                      % Mode weights (third mode excluded)
vNu = nu * [1 1 0];                    % Per-mode update factors
vRho = rho * [1 1 0];                  % Per-mode penalty parameters
badMode = 3;                           % Mode not to be regularized

% Initialization
normTruth = norm(double(memo.truth(:))); % Norm of ground truth for error measurement
sz = obs.tsize; nModes = length(sz);     % Tensor size and number of modes

% Create binary sampling mask from observed indices
B = zeros(sz); 
B(obs.idx) = 1;

% Reconstruct observed Poisson tensor
P = zeros(sz);  
P(obs.idx) = obs.y;

% Initialize primal variables
L = zeros(sz);                       % Current estimate of clean tensor
cL = cell(nModes,1);                 % Mode-wise auxiliary variables
for iM = 1:nModes
    cL{iM} = L;
end
T = L;                               % Consensus tensor variable

% Initialize dual variables
W = L;                               % Dual variable for main constraint
cW = cell(nModes,1);                 % Dual variables for each mode
for iM = 1:nModes
    cW{iM} = L;
end

fprintf('\n╔══════════════════════════════════════════════════════════════╗\n');
fprintf(  '║     🚀 Starting Bi-TNN DCT-based Poisson Tensor Completion   \n');
fprintf(  '╚══════════════════════════════════════════════════════════════╝\n');


% Main ADMM iteration
for iter = 1:opts.MAX_ITER_OUT
    oldL = L; oldCL = cL; oldT = T;

    %% ---------- ADMM Block 1: Update T ----------
    % Update consensus tensor T using weighted average of primal and dual terms
    T = (rho * L + fi_WeightedSumOfCells(vRho, cL, badMode) - (W + fi_SumOfCells(cW, badMode))) / ...
        (rho + fi_Sum_Rho(vRho, badMode));
    T = max(T, 0); 
    T = min(T, infBound); % Project to [0, infBound] for feasibility

    %% ---------- ADMM Block 2: Update L and cL ----------
    % Update L using closed-form solution derived from KL divergence + quadratic term
    La = T + W/rho;
    Lb = -backGround + 0.5*(rho * (T + backGround) - (lambda - W)) / rho + ...
         0.5 * sqrt((rho * (T + backGround) - (lambda - W)).^2 + 4 * rho * lambda * P) / rho;
    L = La .* (1 - B) + Lb .* B;  % Use different update rule for observed/unobserved entries

    % Update mode-wise auxiliary variables cL (excluding badMode)
    for iM = 1:nModes
        if iM == badMode
            continue;
        end
        % Apply TNN proximal operator under DCT transform
        tL = f_KDArray2ThreeD(T + cW{iM}/vRho(iM), iM);
        tL = f_prox_TNN_dct(tL, vW(iM)/vRho(iM));
        cL{iM} = f_3DArray2KD(tL, sz, iM);
    end

    %% ---------- Compute Convergence Metrics ----------
    eps = max(0, f_inf_norm(L - oldL));  % Change in L
    for iM = 1:nModes
        if iM == badMode
            continue;
        end
        eps = max(eps, f_inf_norm(cL{iM} - oldCL{iM}));
    end
    eps = max(eps, f_inf_norm(T - oldT));  % Change in T

    % Store iteration statistics
    memo.iter = iter;
    memo.rho(iter) = rho;
    memo.eps(iter) = eps;
    memo.err(iter) = norm(double(L(:) - memo.truth(:))) / normTruth;
    memo.pnsr(iter) = h_Psnr(memo.truth(:), L(:));

% 🌟 Verbose printout with decorations
if opts.verbose && mod(iter, memo.printerInterval) == 0
    fprintf('\n🌀 Iteration %3d | 🔍 PSNR: %6.2f |  📉 err: %.2e | 🔧 rho: %.2e\n', ...
        iter, memo.pnsr(iter), memo.err(iter), memo.rho(iter));
    %fprintf('━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n');
end

% ✅ Check convergence with fancy stop message
if (memo.eps(iter) < opts.MAX_EPS) && (iter > 50)
    fprintf('\n🎯 Converged at Iteration %d!\n', iter);
    fprintf('📈 Final PSNR : %.2f\n', memo.pnsr(iter));
    fprintf('📉 Final eps  : %.2e\n', memo.eps(iter));
    fprintf('📊 Final err  : %.2e\n', memo.err(iter));
    fprintf('🔧 Final rho  : %.2e\n', memo.rho(iter));
    fprintf('✅ Optimization completed successfully. 🚀\n');
    fprintf('══════════════════════════════════════════════════════════════════════\n');
    break;
end


    %% ---------- ADMM Block 3: Update Dual Variables and Penalty ----------
    % Update dual variable W
    W = W + rho * (T - L);

    % Update dual variables for each mode
    for iM = 1:nModes
        if iM == badMode
            continue;
        end
        cW{iM} = cW{iM} + vRho(iM) * (T - cL{iM});
    end

    % Update penalty parameters (with cap)
    rho = min(rho * nu, opts.MAX_RHO);
    vRho = min(vRho .* vNu, opts.MAX_RHO);
end

% Final estimate
memo.T_hat = L;
memo.L_hat = L;
end

%% Helper function: sum over all cells except badMode
function X = fi_SumOfCells(cellX, badMode)
N = length(cellX);
X = 0 * cellX{1};
for i = 1:N
    if i == badMode
        continue;
    end
    X = X + cellX{i};
end
end

%% Helper function: weighted sum over all cells except badMode
function X = fi_WeightedSumOfCells(vWeight, cellX, badMode)
N = length(cellX);
X = 0 * cellX{1};
for i = 1:N
    if i == badMode
        continue;
    end
    X = X + vWeight(i) * cellX{i};
end
end

%% Helper function: sum of vRho excluding badMode
function sumRho = fi_Sum_Rho(vRho, badMode)
vRho(badMode) = 0;
sumRho = sum(vRho);
end
