function [W, A, B, C, b] = TMTL(XTA, XT, Y, lambda, beta, theta, CCP_1, CCP_2, CCP_3, CCP_4, CCP_5, R)

%Use tensor toolbox
addpath('../tensor_toolbox');
addpath('../tensor_toolbox/met/');

rng(0);

time_points  = length (XTA);

[S, T, ~] = size(XTA{1});

initialRnd = 1;

if initialRnd == 1
    H=zeros(S, T);
    for i = 1: S
        for j = 1: T
            if(j > i)
                H(i, j) = 1;
            end
        end
    end
    W = (rand(S*T,time_points).*sqrt(1/((S*T-S))/2));
    for i = 1: time_points
        W(:,i) = W(:,i).*H(:);
    end
    b = rand(1,time_points).*sqrt(1/((S*T-S))/2);
    A = rand(S*R,time_points);
    B = rand(T*R,time_points);
    
    C = cell([1,time_points]);
    for i = 1: time_points
        [~, ~, d] = size(XTA{i});
        C{1,i} = rand(d*R,1);
    end
end


for i = 1: time_points
    if(i == 1)
        C_temporary = C{1,i};
    else
        C_temporary = cat(1, C_temporary, C{1,i});
    end
end
[all_sample, ~] = size(C_temporary);

vect = [A(:); B(:); C_temporary(:); W(:); b(:)];

% function value/gradient of the smooth part
smoothF    = @(parameterVect) smooth_part(parameterVect, XTA, XT, Y, ...
    lambda, theta, S, T, all_sample, CCP_1, CCP_2, CCP_3, CCP_4, CCP_5, R);
% non-negativen l1 norm proximal operator.
non_smooth = prox_P(beta);
sparsa_options = pnopt_optimset(...
    'display'   , 0    ,...
    'debug'     , 0    ,...
    'maxIter'   , 500  ,...
    'ftol'      , 1e-10 ,...
    'optim_tol' , 1e-10 ,...
    'xtol'      , 1e-10 ...
    );
[vect_result, ~,info] = pnopt_sparsa( smoothF, non_smooth, vect, sparsa_options );

A = reshape(vect_result(1 : length(A(:))), size(A));
B = reshape(vect_result(length(A(:)) + 1 : length(A(:)) + length(B(:))), size(B));
C_temporary = reshape(vect_result(length(A(:)) + length(B(:))+ 1 : length(A(:)) + length(B(:)) + length(C_temporary(:))), size(C_temporary));
W = reshape(vect_result(length(A(:)) + length(B(:))+ length(C_temporary(:)) + 1 : length(A(:)) + length(B(:)) + length(C_temporary(:)) + length(W(:))), size(W));
b = reshape(vect_result(length(A(:)) + length(B(:))+ length(C_temporary(:)) + length(W(:)) + 1 : length(A(:)) + length(B(:)) + length(C_temporary(:)) + length(W(:)) + length(b(:))), size(b));

d1 = 0;
d2 = 0;
for i = 1: time_points
    [~, ~, d] = size(XTA{i});
    d2 = d2 + d*R;
    C{1,i} = reshape(C_temporary(d1+1:d2,:), [d, R]);
    d1 = d2;
end

end

function [f, g] = smooth_part(parameterVect, XTA, XT, Y, lambda, theta, S, T, all_sample, CCP_1, CCP_2, CCP_3, CCP_4, CCP_5, R)
time_points  = length (XTA);
H=zeros(time_points,time_points-1);
H(1:(time_points+1):end)=1;
H(2:(time_points+1):end)=-1;
HHt = H * H';


H1=zeros(S, T);

for i = 1: S
    for j = 1: T
        if(j > i)
            H1(i, j) = 1;
        end
    end
end


aa1 = CCP_1;
aa2 = CCP_2;
aa3 = CCP_3;
aa4 = CCP_4;
aa5 = CCP_5;

% t=7
if (time_points == 7)
    A1 = [1 aa1 0 0 0 0; 0 1-aa1 0 0 0 0; 0 0 1 0 0 0; 0 0 0 1 0 0; 0 0 0 0 1 0; 0 0 0 0 0 1];
    A2 = [1 0 0 0 0 0; 0 1 aa2 0 0 0; 0 0 1-aa2 0 0 0; 0 0 0 1 0 0; 0 0 0 0 1 0; 0 0 0 0 0 1];
    A3 = [1 0 0 0 0 0; 0 1 0 0 0 0; 0 0 1 aa3 0 0; 0 0 0 1-aa3 0 0; 0 0 0 0 1 0; 0 0 0 0 0 1];
    A4 = [1 0 0 0 0 0; 0 1 0 0 0 0; 0 0 1 0 0 0; 0 0 0 1 aa4 0; 0 0 0 0 1-aa4 0; 0 0 0 0 0 1];
    A5 = [1 0 0 0 0 0; 0 1 0 0 0 0; 0 0 1 0 0 0; 0 0 0 1 0 0; 0 0 0 0 1 aa5; 0 0 0 0 0 1-aa5];
end

%Recover the models from the last iteration
A = reshape(parameterVect(1 : (S*R) * time_points), [S*R, time_points]);
B = reshape(parameterVect(length(A(:)) + 1 : length(A(:)) + (T*R)*time_points), [T*R, time_points]);
C_temporary = reshape(parameterVect(length(A(:)) + length(B(:))+ 1 : length(A(:)) + length(B(:)) + all_sample), [all_sample, 1]);
W = reshape(parameterVect(length(A(:)) + length(B(:))+ length(C_temporary(:)) + 1 : length(A(:)) + length(B(:)) + length(C_temporary(:)) + (S*T) * time_points), [(S*T), time_points]);
b = reshape(parameterVect(length(A(:)) + length(B(:))+ length(C_temporary(:)) + length(W(:)) + 1 : length(A(:)) + length(B(:)) + length(C_temporary(:)) + length(W(:)) + 1*time_points), [1, time_points]);

C = cell([1,time_points]);
d1 = 0;
d2 = 0;
for i = 1: time_points
    [~, ~, d] = size(XTA{i});
    d2 = d2 + d*R;
    C{1,i} = reshape(C_temporary(d1+1:d2,:), [d, R]);
    d1 = d2;
end

% compute f
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
funcVal = 0;
for i = 1: time_points
    [n, ~] = size(Y{i});
    W_task = reshape(W(:,i), [S, T]);
    b_task = b(1,i);
    A_task = reshape(A(:,i), [S, R]);
    B_task = reshape(B(:,i), [T, R]);
    AW = repmat(reshape((A_task*B_task') .* (W_task.*H1), [S, T, 1]), [1, 1, n]);
    AWY = sum(double(tenmat(XT{i} .* AW, 3)), 2)+b_task - Y{i};
    funcVal = funcVal + 0.5 * norm(AWY, 'fro')^2;
    
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

regularizer = 0;
for i = 1: time_points
    
    A_task = reshape(A(:,i), [S, R]);
    B_task = reshape(B(:,i), [T, R]);
    
    XTA_hat = double(ktensor({A_task, B_task, C{1,i}}));
    XTA_task = XTA{i};
    regularizer = regularizer + 0.5 * lambda * norm(XTA_task(:) - XTA_hat(:), 'fro')^2;
end

% t=7
if (time_points == 7)
    W_Val = theta * norm(W*H*A1*A2*A3*A4*A5, 'fro')^2;
    b_Val = theta * norm(b*H*A1*A2*A3*A4*A5, 'fro')^2;
end

f = funcVal + regularizer + W_Val + b_Val;

%Compute gradient
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
g_A = [];
g_B = [];
g_C = [];
g_W = [];
g_b = [];
for i = 1: time_points
    [n, ~] = size(Y{i});
    A_task = reshape(A(:,i), [S, R]);
    B_task = reshape(B(:,i), [T, R]);
    W_task = reshape(W(:,i), [S, T]);
    b_task = b(1,i);
    AW = repmat(reshape((A_task*B_task') .* (W_task.*H1), [S, T, 1]), [1, 1, n]);
    AWY = sum(double(tenmat(XT{i} .* AW, 3)), 2)+b_task - Y{i};
    % A
    CKB = kr(C{1,i},B_task);
    X1 = double(tenmat(XTA{i}, 1));
    
    g_A_task = (reshape(squeeze(sum(bsxfun(@times, XT{i}, reshape(AWY, [1, 1, n])),3)), [S, T]) .* (W_task.*H1)) * B_task ...
         - lambda * ((X1 - 0.5 .* A_task * CKB' - 0.5 .* A_task * CKB') * (CKB + CKB));
    g_A = cat(2, g_A, g_A_task(:));
    
    % B
    CKA = kr(C{1,i}, A_task);
    X2 = double(tenmat(XTA{i}, 2));
    
    g_B_task = (A_task'*(reshape(squeeze(sum(bsxfun(@times, XT{i}, reshape(AWY, [1, 1, n])),3)), [S, T]) .* (W_task.*H1)))' ...
         - lambda * ((X2 - 0.5 .* B_task * CKA' - 0.5 .* B_task * CKA') * (CKA + CKA));
    g_B = cat(2, g_B, g_B_task(:));
    
    % C
    BKA = kr(B_task, A_task);
    X3 = double(tenmat(XTA{i}, 3));
    
    g_C_task = -lambda * ((X3 - 0.5 .* C{1,i} * BKA' - 0.5 .* C{1,i} * AKB') * (BKA + AKB));
    g_C = cat(1, g_C, g_C_task(:));
    
    % W
    XTADPAWY = bsxfun(@times, XT{i}, reshape(AWY, [1, 1, n]));
    g_W_task = (H1.*(A_task*B_task')) .* reshape(squeeze(sum(XTADPAWY, 3)), [S, T]);
    g_W = cat(2, g_W, g_W_task(:));
    
    % b
    g_b_task = sum(AWY);
    g_b = cat(2, g_b, g_b_task(:));
    
end

% t=7
if (time_points == 7)
    g_W = g_W + theta * 2 * (W*H*A1*A2*A3*A4*A5*A5'*A4'*A3'*A2'*A1'*H');
    g_b = g_b + theta * 2 * (b*H*A1*A2*A3*A4*A5*A5'*A4'*A3'*A2'*A1'*H');
end

g = [g_A(:); g_B(:); g_C(:); g_W(:); g_b(:)];
end



function op = prox_P(beta) 

%PROX_L1    L1 norm.
%    OP = PROX_L1( q ) implements the nonsmooth function
%        OP(X) = norm(q.*X,1).
%    Q is optional; if omitted, Q=1 is assumed. But if Q is supplied,
%    then it must be a positive real scalar (or must be same size as X).


op = tfocs_prox( @f, @prox_f , 'vector' ); %Allow vector stepsizes
%  lasso for VT
    function v = f(x)
        v = beta * norm(x, 1);
    end

    function x = prox_f(x,t)
        tq = t .* beta; %Allowing vectorized stepsizes
        s  = 1 - min( tq./abs(x), 1 );
        x  = x .* s;
    end


end
