import tensor_operations as tensor_ops
import numpy as np
import scipy 
import datetime 
import time

def ALS(A, B, C, k, par_updates=True, eps=1e-5, max_iter=1000, verbose=False, semi_sequential=False):
    
    """
    Altenating Least Squares (ALS) for tensor decomposition.
    Args:
        A (np.array): Input matrix of shape (n, r)
        k: int: Number of components in the model decomposition
        par (bool): If True, use parallel components algorithm, otherwise use sequential algorithm.
        eps (float): Tolerance for convergence
        max_iter (int): Maximum number of iterations
    Returns:

    """

    if semi_sequential:
        par_updates = True 

    # Initialize the factor matrices
    n, r = A.shape

    if verbose:
        print(f"n = {n}, r = {r}")

    # Construct ground truth (asymmetric) tensor from factor matrices
    AodotB = tensor_ops.Khatri_Rao_np(A,B)
    BodotC = tensor_ops.Khatri_Rao_np(B,C)
    AodotC = tensor_ops.Khatri_Rao_np(A,C)

    T_flatA = A@(BodotC.T) 
    T_flatB = B@(AodotC.T)
    T_flatC = C@(AodotB.T)

    Tnorm = np.linalg.norm(T_flatA, ord='fro')

    if verbose:
        print(f"frobenius norm of ground truth tensor: {Tnorm}\n")

    # Construct ground truth symmetric input tensor from factor matrix
    # AodotA=tensor_ops.Khatri_Rao_np(A,A) #n^2xr matrix
    # T_flat=A@(AodotA.T)#A(AodotA)^T nxn^2 matrix that is the flattened version of the tensor. 
    
    # Initialize the model factor matrices
    X,Y,Z=tensor_ops.Gaussian_np(k, n, sigma=1.0)
    reconstruction_error = -1 
    iteration_count = 0

    for i in range(max_iter):
        #Calculate the Khatri-Rao products
        YodotZ=tensor_ops.Khatri_Rao_np(Y,Z)

        # print(np.max(XodotZ),np.max(YodotZ),np.max(XodotY))

        iteration_count = i
        if verbose:
            print(f"iteration = {i}")

        # LEAST SQUARES VERSION

        # X UPDATE 
        XhatT, residuals, rank, singularvalues = scipy.linalg.lstsq(YodotZ, T_flatA.T)

        # if verbose:
            # print(f"Xhat solver singular values: {singularvalues}")
            # print(f"sigma k Xhat: {singularvalues[0]}")

        Xhat = XhatT.T 
        Xhat_error = tensor_ops.objective_function_np(Xhat,Y,Z,T_flatA)
        Xhat_relative_error = Xhat_error / Tnorm 

        if verbose:
            print(f"iteration i = {i}, Xhat_error: {Xhat_error}, relative Xhat error: {Xhat_relative_error}")
            # print(f"Xhat: {Xhat},\n Y: {Y},\n Z: {Z}")
            # print(f"YodotZ: {YodotZ}")
            # print(f"T_flatA: {T_flatA}")

        Xnorm = np.linalg.norm(Xhat, ord='fro') 
        Ynorm = np.linalg.norm(Y, ord='fro')
        Znorm = np.linalg.norm(Z, ord='fro')
        normalizedXhat = Xhat / Xnorm # avoid conditioning issues in later iterations

        # if verbose:
            # print(f"X norm: {Xnorm}, Y norm : {Ynorm} (should be 1), Z norm : {Znorm} (should be 1)")
            # print(f"T_flatA norm: {np.linalg.norm(T_flatA, ord='fro')}")
            # print(f"Xhat_error: {Xhat_error}")

        if not par_updates:
            X = normalizedXhat 
        
        if Xhat_relative_error < eps:
            X = Xhat 
            reconstruction_error = Xhat_relative_error
            break 

        # Y UPDATE 
        XodotZ=tensor_ops.Khatri_Rao_np(X,Z)

        YhatT, residuals, rank, singularvalues = scipy.linalg.lstsq(XodotZ, T_flatB.T)

        # if verbose:
            # print(f"Yhat solver singular values: {singularvalues}")
            # print(f"sigma k Yhat: {singularvalues[0]}")

        Yhat = YhatT.T 
        Yhat_error = tensor_ops.objective_function_np(Yhat,X,Z,T_flatB)
        Yhat_relative_error = Yhat_error / Tnorm 

        if verbose:
            print(f"iteration i = {i}, Yhat_error: {Yhat_error}, relative Yhat error: {Yhat_error/Tnorm}")

        Ynorm = np.linalg.norm(Yhat, ord='fro')
        normalizedYhat = Yhat / Ynorm # avoid conditioning issues in later iterations

        # if verbose:
        #     print(f"Y norm: {Ynorm}\n")

        if not par_updates:
            Y = normalizedYhat 
        
        if Yhat_relative_error < eps:
            Y = Yhat 
            reconstruction_error = Yhat_relative_error
            break 

        # Z UPDATE 

        # udpate Z after X and Y
        if semi_sequential:
            X = normalizedXhat
            Y = normalizedYhat


        XodotY=tensor_ops.Khatri_Rao_np(X,Y)

        ZhatT, residuals, rank, singularvalues = scipy.linalg.lstsq(XodotY, T_flatC.T)

        # if verbose:
        #     print(f"Zhat solver singular values: {singularvalues}")
        #     print(f"sigma k Zhat: {singularvalues[0]}")

        Zhat = ZhatT.T 
        Zhat_error = tensor_ops.objective_function_np(Zhat,X,Y,T_flatC)
        Zhat_relative_error = Zhat_error / Tnorm 

        if verbose:
            print(f"iteration i = {i}, Zhat_error: {Zhat_error}, relative Zhat error: {Zhat_relative_error}")

        Znorm = np.linalg.norm(Zhat, ord='fro')
        normalizedZhat = Zhat / Znorm # avoid conditioning issues in later iterations

        # if verbose:
        #     print(f"Z norm: {Znorm}\n")

        if not par_updates:
            Z = normalizedZhat 
        
        if Zhat_relative_error < eps:
            Z = Zhat 
            reconstruction_error = Zhat_relative_error
            break 
        
        X = normalizedXhat 
        Y = normalizedYhat 
        Z = normalizedZhat 

        reconstruction_error = Zhat_relative_error # tensor_ops.objective_function_np(X,Y,Z,T_flatA)

    if verbose: 
        print(f"Converged in {iteration_count} iterations")
        print(f"Final least squares error {reconstruction_error}")
        # print(f"objective value: {objective}")

    return (iteration_count, reconstruction_error)


        # OLD ITERATIONS WITH PSEUDOINVERSE
        # # Update X
        # X_new = T_flat@np.linalg.pinv(YodotZ, rcond=1e-5).T

        # #If we are running the sequential version then we need to update the X matrix
        # if not par: #If sequantial 
        #     X=X_new
        # #Check if the objective function is small 
        # if tensor_ops.objective_function_np(X_new,Y,Z,T_flat)<eps:
        #     X=X_new
        #     break

        # # Update Y
        # Y_new = T_flat@np.linalg.pinv(XodotZ,rcond=1e-5).T

        # if not par:
        #     Y=Y_new

        # #Check if the objective function is small
        # if tensor_ops.objective_function_np(X,Y_new,Z,T_flat)<eps:
        #     Y=Y_new
        #     break

        # # Update Z
        # Z_new = T_flat@np.linalg.pinv(XodotY,rcond=1e-5).T
        # #Check if the objective function is small
        # if tensor_ops.objective_function_np(X,Y,Z_new,T_flat)<eps:
        #     Z=Z_new
        #     break
        # # Update the factor matrices
        # X = X_new
        # Y = Y_new
        # Z = Z_new
        # # Check for convergence



   


if __name__ == "__main__":

    experiment = True 
    verbose = False  

    if experiment:
        start_time = datetime.datetime.now()
        filename_prefix = f"experiment-{start_time.year}-{start_time.month:02d}-{start_time.day:02d}-{start_time.hour:02d}{start_time.minute:02d}"

        f = open(f"experiments/{filename_prefix}.txt", "w")
        f.write(f"{filename_prefix}\n")
        print(filename_prefix)

    par_updates = True    
    semi_sequential = False  
    n = 500
    r = 7
    max_iter = 20
    error = 0.01 # relative error cutoff

    if experiment:
        f.write(f"par = {par_updates}\n")
        # f.write(f"n = {n}, r = {r}\n")
        f.write(f"n = {n}\n")
        f.write(f"max_iter = {max_iter}\n")
        f.write(f"relative error cutoff: {error}\n")

        print(f"par = {par_updates}")
        print(f"n = {n}")
        print(f"max_iter = {max_iter}")
        print(f"relative error cutoff: {error}")

    
    # A initialized as basis vectors 
    # for i in range(7):
    #    A[i,i]= 1

    # A initialized as power-law basis vectors 
    # power = 1
    # for i in range(7):
    #    A[i,i]= 1/((i+1) ** 1)

    # A initialized randomly:
    # A = np.random.normal(size=(100, 7))
    # A = np.random.randn(100, 7)

    # Asymmetric A, B, C initialized randomly with power law
    power = 0 
    if experiment:
        f.write(f"Asymmetric A, B, C initialized randomly with power law, power = {power}\n")
        print(f"Asymmetric A, B, C initialized randomly with power law, power = {power}")

    


    # single run 
    if not experiment:

        A=np.zeros((n, r))
        B=np.zeros((n, r))
        C=np.zeros((n, r))

        for i in range(r):
            A[:, i] = np.random.randn(n) * 100 * (1/((i + 1) ** power))

        for i in range(r):
            B[:, i] = np.random.randn(n) * 100 * (1/((i + 1) ** power))

        for i in range(r):
            C[:, i] = np.random.randn(n) * 100 * (1/((i + 1) ** power))
        # A = np.zeros((2, 2))
        # B = np.zeros((2, 2))
        # C = np.zeros((2, 2))

        # A[0, 0] = 1
        # A[0, 1] = 2
        # A[1, 0] = 3
        # A[1, 1] = 4

        # B[0, 0] = 5
        # B[0, 1] = 6
        # B[1, 0] = 7 
        # B[1, 1] = 8 

        # C[0, 0] = 2
        # C[0, 1] = 3
        # C[1, 0] = 4
        # C[1, 1] = 5

        par_updates = False  
        semi_sequential = False     

        k = 4*r
        print(f"par_updates={par_updates}")
        print(f"semi_sequential={semi_sequential}")
        print(f"k = {k}")
        print(f"error = {error}")


        print(f"ALS = {ALS(A, B, C, k, par_updates=par_updates, verbose=verbose, max_iter=max_iter, semi_sequential=semi_sequential, eps=error)}")

    # RUN EXPERIMENTS IN PARALLEL

    # def run_trial(trial_id):
    #     i = trial_id // trials_per_multiple
    #     j = trial_id %  trials_per_multiple
    #     start = time.time() 
    #     (iteration_count, obj) = ALS(A, B, C, (i+1)*r, par_updates=par_updates, max_iter=max_iter)
    #     duration = time.time() - start 




    # RUN EXPERIMENTS IN SEQUENCE 
    if experiment: 
        trials_per_multiple = 20

        f.write(f"trials per value of k = {trials_per_multiple}\n\n")

        

        # results[trial #, data]  
        # data = 0: k, data = 1: num iterations, 
        # data = 2: least squares error
        f.write("DATA\n")
        f.write("n, r, k, num iterations, relative error, wall clock time (s)\n")

        print("DATA")
        f.write(f"n, r, k, num iterations, relative error, wall clock time (s)")

        # rs = [20, 17, 14, 11, 8]
        rs = [8]

        for r in rs:

            ks = [(r*r) - 2, (r*r) - 1, (r*r), (r*r) + 1, (r*r) + 2, (r*r) + r, (r*r) + (2*r) ]

            # # for i in range(r):
            #     overparam_factor = i + 1
                # k = overparam_factor * r
            for k in ks: 
                print(f"PARAMETERS: k = {k}")

                for j in range(trials_per_multiple):

                    # Fresh initialization
                    A=np.zeros((n, r))
                    B=np.zeros((n, r))
                    C=np.zeros((n, r))

                    for i in range(r):
                        A[:, i] = np.random.randn(n) * 100 * (1/((i + 1) ** power))

                    for i in range(r):
                        B[:, i] = np.random.randn(n) * 100 * (1/((i + 1) ** power))

                    for i in range(r):
                        C[:, i] = np.random.randn(n) * 100 * (1/((i + 1) ** power))

                    print(f"trial {j} out of {trials_per_multiple}")
                    index = (i * trials_per_multiple) + j

                    start = time.time()
                    (iter_count, obj) = ALS(A, B, C, k, par_updates=par_updates, max_iter=max_iter, verbose=verbose, eps=error)
                    duration = time.time() - start
                    f.write(f"{n}, {r}, {k}, {iter_count}, {obj}, {duration} \n")
                    print(f"{n}, {r}, {k}, {iter_count}, {obj}, {duration}")
    
    

