import numpy as np
import tensorly as tl
from tensorly.decomposition import parafac
from tensorly.cp_tensor import cp_to_tensor
from joblib import Parallel, delayed
import os
import time

# Enable GPU support if available
tl.set_backend('numpy')  # Change to 'cupy' if you want GPU acceleration with CuPy
#tl.set_backend('cupy')


# Parameters
tensor_shape = (500, 500, 500)
true_rank = 20
als_rank = 56
n_trials = 20  # Total number of trials to run
n_jobs = min(8, os.cpu_count())  # Use up to 8 parallel jobs or less if fewer cores

def generate_custom_als_ranks(r):
    ranks = []

    ranks.append(r)
    j = 2*r
#    while j < r**2 // 2:
#        ranks.append(j)
#        j += 2*r

#    j = r**2 // 2
#    ranks.append(j)
#    j += r
    while j <= r**2 - 2 * r:
        ranks.append(j)
        j += 2*r

        
    ranks += [r**2 // 2, r**2 - 2*r, r**2 - r, r**2 - r // 2, r**2 - 1, r**2]
    ranks += [r**2 + 1, r**2 + r, r**2 + 2 * r]

    return sorted(set(ranks))



def generate_symmetric_cp_tensor(shape, rank):
    dim = shape[0]
    factor_matrix = np.random.randn(dim, rank)
    weights = np.random.rand(rank)
    factors = [factor_matrix] * len(shape)
    return cp_to_tensor((weights, factors)), (weights, factors)

def generate_cp_tensor(shape, rank):
    dim=shape[0]
    weights = np.random.rand(rank)
    factors = [np.random.randn(mode_dim, rank) for mode_dim in shape]
    return cp_to_tensor((weights, factors)), (weights, factors)


def run_trial(trial_id):
    np.random.seed(trial_id)
    start = time.process_time()

#    tensor, _ = generate_symmetric_cp_tensor(tensor_shape, true_rank)
    tensor, _ = generate_cp_tensor(tensor_shape, true_rank)
    decomposition = parafac(tensor, rank=als_rank, n_iter_max=100, init='random', verbose=False)
    reconstructed_tensor = cp_to_tensor(decomposition)
    error = tl.norm(tensor - reconstructed_tensor) / tl.norm(tensor)

    duration = time.process_time() - start
    return trial_id, error, duration



# Output file
# Extract dimension
dim = tensor_shape[0]

for true_rank in range(18,5,-4):
    # Construct output filename
    output_filename = f"cp_decomposition_results_n{dim}_r{true_rank}.txt"
    #output_filename = "cp_decomposition_results.txt"
    als_ranks_list = generate_custom_als_ranks(true_rank)
    print("r value",true_rank)
    print("Custom ALS Ranks to be tested:", als_ranks_list)

    with open(output_filename, "w") as f:
        for als_rank in als_ranks_list:
            # header = f"\n=== als_rank = {als_rank} ===\n"
            # print(header.strip())
            # f.write(header)
            # #print(f"als_rank={als_rank}")
            
            # Run trials in parallel
            results = Parallel(n_jobs=n_jobs)(
                delayed(run_trial)(i) for i in range(n_trials)
            )

            # Print results
            # Log results
            summary = f"n={tensor_shape[0]}, true_rank={true_rank}, als_rank={als_rank}\n"
            print(summary.strip())
            f.write(summary)
            #print(f"n={tensor_shape[0]}, true_rank={true_rank}, als_rank={als_rank}")
            for trial_id, error, duration in results:
                line = f"Trial {trial_id}: Error = {error:.4f}, Time = {duration:.2f}s\n"
                print(line.strip())
                f.write(line)
                #print(f"Trial {trial_id}: Error = {error:.4f}, Time = {duration:.2f}s")
