from pathlib import Path

import numpy as np
import os
os.environ["OMP_NUM_THREADS"] = "1"

from demo2_mainbody import LegendreDecomposition1
from revise_lgd import LegendreDecomposition

def out_put_coordinates(tensor):
    indices = np.argwhere(tensor == 1)
    indices1 = np.argwhere(tensor == 0)


    coordinates = [tuple(index) for index in indices]


    coordinates1 = coordinates.copy()
    if len(indices1) >= 2:
        coordinates1.append(tuple(indices1[1]))

    coordinates2 = [tuple(index1) for index1 in indices1]

    del coordinates2[0]
    return coordinates, coordinates1, coordinates2


def is_c_in_range(P, s, c, alpha):

    dims = P.shape
    d = len(dims)


    size_P = np.prod(dims)




    log_product_Ij = np.log(size_P)


    left_bounds = []
    right_bounds = []

    for i in range(d):
        I_i = dims[i]


        prod_dims_plus_1_excluding_i = np.prod([(dims[j] + 1) for j in range(d) if j != i])


        left_bound = (-2 **d * ((s - 1) * d + 1) * log_product_Ij) / (
                    ((1 - 1 / alpha) * I_i + 1) * prod_dims_plus_1_excluding_i)
        left_bounds.append(left_bound)


        right_bound = (2 **d * ((s - 1) *d + 1) * log_product_Ij) / (
                    ((1 - 1 / alpha) * I_i + 1) * prod_dims_plus_1_excluding_i)
        right_bounds.append(right_bound)


    max_lower_bound = max(left_bounds)
    min_upper_bound = min(right_bounds)




    if max_lower_bound <= c <= min_upper_bound:
       print(f'c = {c} is within the range: [{max_lower_bound}, {min_upper_bound}]')
       return min_upper_bound, 1
    else:
       print(f'c = {c} is outside the range: [{max_lower_bound}, {min_upper_bound}]')
       return min_upper_bound, 0




def change_parameter(tensor, k):
    dims = tensor.shape
    binary_tensor = np.ones(dims, dtype=int)


    if k >= len(dims):
        raise ValueError("k must be strictly less than the tensor's dimension.")


    binary_tensor[(0,) * len(dims)] = 0


    it = np.nditer(binary_tensor, flags=['multi_index'])
    while not it.finished:
        index = it.multi_index

        non_zero_count = sum(1 for i in index if i != 0)


        if non_zero_count > k or non_zero_count == 0:
            binary_tensor[index] = 0
        it.iternext()

    return binary_tensor


def calculate_s(P):
    # Step 1: Calculate size_P as the product of all dimensions of P
    size_P = np.prod(P.shape)
    # Step 2: Calculate the sum of all elements in the tensor
    sum_P = np.sum(P)

    # Step 3: Find the minimum element in the tensor
    min_P = np.min(P)

    # Step 4: Calculate s using the formula
    s = np.log(sum_P / min_P) / np.log(size_P)

    return s

np.random.seed(44)




for k in  range(1,6):
  results_c = []
  results_upperbound = []
  rights_line = []
  results_d = []
  for d in range(k+1,11):
    print('k,d=',k,d)
    size = 3 ** d
    #array = np.random.normal(loc=10, scale=1.5, size=size)
    #array = np.random.lognormal(mean=0, sigma=1.5, size=size)
    array = np.random.uniform(5, 8, size=size)

    new_shape = (3,) * d
    P = array.reshape(new_shape)

    
    sum_elements = np.sum(P)

   
    min_element = np.min(P)

    
    #s = np.log(sum_elements / min_element) * (1/(np.log(2) * d))
    s = calculate_s(P)
    print('s=',s)

    binary_tensor = change_parameter(P,k)

    #print('binary_tensor=', binary_tensor)
    coordinates, coordinates1, coordinates_complement = out_put_coordinates(binary_tensor)
    #print('coordinates=', coordinates)
    ld_ori = LegendreDecomposition(solver='ng', max_iter=50000, verbose=0, learning_rate=0.001)  
    reconst_tensor_ori = ld_ori.fit_transform(P, coordinates)

    print('Reconstruction error(RSE): {}'.format(ld_ori.reconstruction_err_))

    ld_imp = LegendreDecomposition1(solver='ng', max_iter=500, verbose=0, learning_rate=0.001)  
    reconst_tensor_imp = ld_imp.fit_transform(P, coordinates, coordinates1, coordinates_complement, ld_ori.theta)

    print('Reconstruction error(RSE): {}'.format(ld_imp.reconstruction_err_))
    results_c.append(ld_imp.c)
    print('c=', ld_imp.c)


    upperbounds, rights = is_c_in_range(P, s, ld_imp.c,3)
    results_upperbound.append(upperbounds)
    rights_line.append(rights)
    results_d.append(d)

folder = Path(f'size_test/bound_uniform_3dim_k={k}')
folder.mkdir(parents=True, exist_ok=True)
with open(folder / f'results_upperbound_k={k}_d={d}.txt', 'w') as file:
    for result in results_upperbound:
        file.write(str(result) + '\n')

with open(folder / f'rights_line_append_k={k}_d={d}.txt', 'w') as file:
    for result in rights_line:
        file.write(str(result) + '\n')

with open(folder / f'results_c_k={k}_d={d}.txt', 'w') as file:
    for result in results_c:
        file.write(str(result) + '\n')

with open(folder / f'results_k_k={k}_d={d}.txt', 'w') as file:
    for result in results_d:
        file.write(str(result) + '\n')
