import numpy as np
from scipy.io import loadmat
from pathlib import Path
from PIL import Image
from demo2_mainbody import LegendreDecomposition1
from revise_lgd import LegendreDecomposition



data = loadmat('/tmp/pycharm_project_272/realdata/Butterfly.mat')
image_data = data['img']
if not isinstance(image_data, np.ndarray):
    image_data = np.array(image_data)
    
print("max=",image_data.max())
print("min=",image_data.min())


print(f"Image data type: {type(image_data)}")
print(f"Image data shape: {image_data.shape}")

image_3channels = image_data[250:350, 250:350, :10]  



# reshape (10, 10, 10, 10, 3)
reshaped_image = image_3channels.reshape(10, 10, 10, 10, 10)
print("Processed image shape:", reshaped_image.shape)
print("max=",reshaped_image.max())
print("min=",reshaped_image.min())
def out_put_coordinates(tensor):
    indices = np.argwhere(tensor == 1) 
    indices1 = np.argwhere(tensor == 0)

   
    coordinates = [tuple(index) for index in indices]


    coordinates1 = coordinates.copy()
    if len(indices1) >= 2:
        coordinates1.append(tuple(indices1[1]))

    coordinates2 = [tuple(index1) for index1 in indices1]

    del coordinates2[0]
    return coordinates, coordinates1, coordinates2


def change_parameter(tensor, k):
    dims = tensor.shape
    binary_tensor = np.ones(dims, dtype=int)

 
    if k >= len(dims):
        raise ValueError("k must be strictly less than the tensor's dimension.")

  
    binary_tensor[(0,) * len(dims)] = 0


    it = np.nditer(binary_tensor, flags=['multi_index'])
    while not it.finished:
        index = it.multi_index
    
        non_zero_count = sum(1 for i in index if i != 0)

        if non_zero_count > k or non_zero_count == 0:
            binary_tensor[index] = 0
        it.iternext()

    return binary_tensor

def calculate_s(P):
    # Step 1: Calculate size_P as the product of all dimensions of P
    size_P = np.prod(P.shape)
    # Step 2: Calculate the sum of all elements in the tensor
    sum_P = np.sum(P)

    # Step 3: Find the minimum element in the tensor
    min_P = np.min(P)

    # Step 4: Calculate s using the formula
    s = np.log(sum_P / min_P) / np.log(size_P)

    return s



def is_c_in_range(P, s, c, alpha):

    dims = P.shape
    d = len(dims)

    size_P = np.prod(dims)




    log_product_Ij = np.log(size_P)

    
    left_bounds = []
    right_bounds = []

    for i in range(d):
        I_i = dims[i]

       
        prod_dims_plus_1_excluding_i = np.prod([(dims[j] + 1) for j in range(d) if j != i])

      
        left_bound = (-2 **d * ((s - 1) * d + 1) * log_product_Ij) / (
                    ((1 - 1 / alpha) * I_i + 1) * prod_dims_plus_1_excluding_i)
        left_bounds.append(left_bound)

       
        right_bound = (2 **d * ((s - 1) *d + 1) * log_product_Ij) / (
                    ((1 - 1 / alpha) * I_i + 1) * prod_dims_plus_1_excluding_i)
        right_bounds.append(right_bound)

 
    max_lower_bound = max(left_bounds)
    min_upper_bound = min(right_bounds)

   

    if max_lower_bound <= c <= min_upper_bound:
       print(f'c = {c} is within the range: [{max_lower_bound}, {min_upper_bound}]')
       return min_upper_bound, 1
    else:
       print(f'c = {c} is outside the range: [{max_lower_bound}, {min_upper_bound}]')
       return min_upper_bound, 0







for k in range(2, 5):
 results_c = []
 results_upperbound = []
 rights_line = []
 results_d = []
 for d in range(2,11):

  print('d,k=', d,k)
  P=reshaped_image[:d,:d,:d,:d,:d]
  sum_elements = np.sum(P)
  print(P.shape)
  
  min_element = np.min(P)


  # s = np.log(sum_elements / min_element) * (1/(np.log(2) * d))
  s = calculate_s(P)
  print('s=', s)

  binary_tensor = change_parameter(P, k)

  # print('binary_tensor=', binary_tensor)
  coordinates, coordinates1, coordinates_complement = out_put_coordinates(binary_tensor)
  # print('coordinates=', coordinates)
  ld_ori = LegendreDecomposition(solver='ng', max_iter=50000, verbose=0, learning_rate=0.001)  
  reconst_tensor_ori = ld_ori.fit_transform(P, coordinates)

  print('Reconstruction error(RSE): {}'.format(ld_ori.reconstruction_err_))

  ld_imp = LegendreDecomposition1(solver='ng', max_iter=500, verbose=0, learning_rate=0.001)  
  reconst_tensor_imp = ld_imp.fit_transform(P, coordinates, coordinates1, coordinates_complement, ld_ori.theta)

  print('Reconstruction error(RSE): {}'.format(ld_imp.reconstruction_err_))
  results_c.append(ld_imp.c)
  print('c=', ld_imp.c)

  upperbounds, rights = is_c_in_range(P, s, ld_imp.c, d)
  results_upperbound.append(upperbounds)
  rights_line.append(rights)
  results_d.append(d)

  folder = Path(f'size_test/butterfly_k={k}_10add')
  folder.mkdir(parents=True, exist_ok=True)  

  
  with open(folder / 'results_upperbound.txt', 'w') as file:
      for result in results_upperbound:
          file.write(str(result) + '\n')  

 
  with open(folder / 'rights_line_append.txt', 'w') as file:
      for result in rights_line:
          file.write(str(result) + '\n')  


  with open(folder / 'results_c.txt', 'w') as file:
      for result in results_c:
          file.write(str(result) + '\n')  

  with open(folder / 'results_d.txt', 'w') as file:
      for result in results_d:
          file.write(str(result) + '\n')  

