using LinearAlgebra
using TensorToolbox
using StaticArrays
include("TR.jl")

shiftdim(T) = permutedims(T,vcat([2:ndims(T);],[1]))
function Z_neq(Z,n)
    Z = circshift(Z,-n)
    N = length(Z)
    P = Z[1]
    for i=1:N-2
        zl = reshape(P,(:,size(Z[i])[3]))
        zr = reshape(Z[i+1],(size(Z[i+1])[1],:))
        P = zl*zr
    end
    Z_neq_out = reshape(P,(size(Z[1])[1],:,size(Z[N-1])[3]))
    return Z_neq_out
end

function NTR(Y,r;method="MU",t_inner=100, verbose=true,ε=eps(),Tol=1e-4,MaxIter=400,ω=0.1,lra_parameter=10,LRA_R=11)
    n = size(Y)
    d = length(n)
    node = []
    Ω = I*ω

    for i=1:d
        if i != d
            nodei = rand(r[i],n[i],r[i+1])
        else
            nodei = rand(r[d],n[d],r[1])
        end
        push!(node,nodei)
    end
    od = [1:d;]
    err = 1.0

    Q = Array{Union{Float64,Matrix}}(undef,d)
    P = Array{Union{Float64,Matrix}}(undef,d)
    for i=1:(MaxIter*d)
        err0=err
        if i>1
            Y = shiftdim(Y)
            od = circshift(od,-1)
        end
        Y = reshape(Y,(n[od[1]],:))
        A = node[od[1]]
        A = permutedims(A,(2,3,1))
        A = reshape(A,(n[od[1]],:))
        B = Z_neq(node,od[1])
        B = permutedims(B,(1,3,2))
        B = reshape(B,(r[od[2]]*r[od[1]],:))

        if method == "lra_MM"
            if i <= d
                if n[i] > lra_parameter
                    QQ = 1.0
                    PP = 1.0
                else
                    QQ, _, PP = lowrank_app(Y,LRA_R)
                end
                Q[i] = QQ
                P[i] = PP
            end
        end
        if i < 50
            continue
        else
            @show Q
            @show P
            return 0
        end


        if method == "MU"
            A .= loop_MU(Y,A,B,t_inner,ε)
        elseif method == "APG"
            A = loop_APG(Y,A,B,t_inner)
        elseif method == "MM"
            A = loop_MM(Y,A,B,Ω,t_inner)
        elseif method == "HALS"
            A = loop_HALS(Y,A,B',t_inner,ε)
        else
            error("method error")
        end

        if mod(i,d) == 0
            err1 = norm(Y-A*B)
            err = err1 / norm(Y)
            if verbose
                println("iter:$(i/d) \t err=$err")
            end
        end

        A = reshape(A,(n[od[1]],r[od[2]],r[od[1]]))
        A = permutedims(A, (3,1,2))

        if method == "MM"
            s = norm(A[:],2)
            node[ od[1] ] = A ./ s
        else
            node[ od[1] ] = A
        end

        if mod(i,d) == 0
            if abs(err0-err) <= Tol || err <= Tol
                break
            end
        end
        Y = reshape(Y, n[od])

        if i == (MaxIter*d)
            "Not_converge"
        end
    end
    return node
end

function loop_LraMM(Q,P,A,B,Ω,t_inner)
    B1 = B*B'
    B11 = eye(size(B1)[1]) / (Ω+B1)
    B12 = Ω-B1
    B1211 = B12*B11
    B211 = Q*(P*B'*B11)
    Z = similar(A)
    for t = 1:t_inner
        Z .= A ./ 2.0
        A .= abs.(Z) * B1211 + B211
        A .= Z + abs.(Z)
    end
    return A
end

function loop_MM(Y,A,B,Ω,t_inner)
    B1 = B*B'
    B2 = Y*B'
    B11 = eye(size(B)[1]) / (Ω+B1)
    B12 = Ω-B1
    B1211 = B12*B11
    B211 = B2*B11
    Z = similar(A)
    for t = 1:t_inner
        Z .= A ./ 2.0
        Z .= ( (abs.(Z)) * B1211 + B211 )
        A .= Z + abs.(Z)
    end
    return A
end

function loop_HALS(Y,A,B,t_inner,ε)
    J2 = size(A)[2]
    P = Y*B
    Q = B'*B
    for t = 1:t_inner
        for j = 1:J2
            A[:,j] = max.(ε,A[:,j] + (P[:,j]-A*Q[:,j])/max.(ε,Q[j,j]))
        end
    end
    return A
end

function loop_MU(Y,A,B,t_inner,ε)
    YBT = Y*B'
    BBT = B*B'
    for t = 1:t_inner
        A .= A.*(YBT ./ max.(ε,A*BBT))
    end
    return A
end

function loop_APG(Y,A,B,t_inner)
    BB = B*B'
    L = opnorm(BB,2)
    α= 1.0
    X = A
    X1 = BB/L
    Y1 = Y*B'/L
    A_old = A
    for t = 1:t_inner
        A_old .= A
        α_old = α
        A .= max.(0.0, X-X*X1+Y1)
        α = (1+sqrt(4*α^2+1))/2
        X .= A .+ ((α_old-1)/α) .* (A .- A_old)
    end
    return A
end

