# Imports needed for the rest of the code
from math import cos, exp, pi, sin, sqrt
from scipy.integrate import quad
import numpy as np
import math 
from tqdm import tqdm
from scipy.spatial.distance import directed_hausdorff

hausdorff_resolution = 10000 #Resolution of the Hausdorff distance calculation

'''Original Functions to compute the c coefficient which is used by the F mapping function
(Please refer to the paper for a better explanation of what this functions mean)
'''

# function we want to integrate: expression inside the integral
def psi_integral(x):
    if x == 0:
      return 0
    if x.is_integer():
      return 0
    return sin(pi*x)*exp(-sin(pi*x)**(-2))

def psi1(x):
  A = 0.141328
  ps1,_ = quad(psi_integral, 0, 1+x)
  return np.sqrt((1/A)*ps1)

def psi2(x):
  A = 0.141328
  ps2,_ = quad(psi_integral, 0, x)
  return np.sqrt((1/A)*ps2)


def xi(x):
  if x == 0:
      return 0
  if x.is_integer():
      return 0
  return np.sin(np.pi*x)*np.exp(-(np.sin(np.pi*x))**(-2))

# call quad to integrate from -2 to 2
def G1(x):
  A = 0.141328
  ps1_normalized = psi1(x)
  return (
      np.cosh(x)*ps1_normalized
      +np.sinh(x)*(xi(1+x))/(2*A*np.sqrt(ps1_normalized))
  )

def G2(x):
  A = 0.141328
  ps2_normalized = psi2(x)
  return (
      np.cosh(x)*ps2_normalized
      +np.sinh(x)*(xi(x))/(2*A*np.sqrt(ps2_normalized))
  )

def calculate_c(resolution=1e-3):
  c = -np.inf
  for x in np.arange(-2,2,resolution):
    if G1(x)>G2(x):
      c_i = 2*G1(x)
    else:
      c_i = 2*G2(x)

    if c_i == np.inf:
      c_i = 0

    if c_i > c:
      c = c_i
  return c

c = 10.255014502464228 #This was calculated with a resolution of 1e-8 using the script hausdorff_coefficients.py
# Alternatively one can use the function calculate_c above to recalculate it.

'''Additional functions to compute the mapping function F'''

def g(x,y):
  x_out = np.arcsinh(y*np.exp(x))
  y_out = np.log(
      np.sqrt(
          np.exp(-2*x)+y**2
          )
      )
  return x_out, y_out
  
def h(x,y):
  sinh_x = np.sinh(x)
  return [(sinh_x*psi1(x)/c)*cos(y*c),
          (sinh_x*psi1(x)/c)*sin(y*c),
          (sinh_x*psi2(x)/c)*cos(y*c),
          (sinh_x*psi2(x)/c)*sin(y*c)]


# Compose h(g(x,y))
def h_g(x,y):
  x_out, y_out = g(x,y)
  return h(x_out,y_out)

def epsilon(x):
  return (G1(x)**2+G2(x)**2)/c**2

def epsilon_integral(t):
  return np.sqrt(1-epsilon(t)**2)


def first_component_equation(x, y, resolution=5e-3):
    if y == 0:
        return 0
    
    begin = np.zeros_like(x)
    end = np.arcsinh(y*np.exp(x))
    integral_value = np.zeros_like(x)
    t = begin + resolution
    dt = resolution
    
    while np.any(t <= end):
        darea = resolution * (epsilon_integral(t) + epsilon_integral(t + dt)) / 2
        darea = np.nan_to_num(darea)
        integral_value += darea * (t <= end)
        t += dt
        
    return integral_value.sum()




'''f mapping function and point set defined next,
f is the mapping function from H^2 to R^6
'''

def coordinate_change_to_hyperbolic(x,y):
  # Transform x and y coordinates to vector
  v = np.array([x,y])

  # Coordinate change from Euclidean to hyperboloid model
  xy = np.sinh(np.linalg.norm(v))*v/np.linalg.norm(v)
  x, y = xy
  z = np.cosh(np.linalg.norm(v))

  # Coordinate change from hyperboloid model to Poincare ball model
  x = x/(1+z)
  y = y/(1+z)

  # Coordinate change from Poincare ball model to upper half plane
  x = (-2*y)/(
      (x-1)**2+y**2
  )
  y = (1-x**2-y**2)/(
      (x-1)**2+y**2
  )

  # Coordinate change from upper half plane to our model of H^2
  x = -np.log(y)
  y = x
  return x,y

def generate_hyperbolic_points():
    theta = np.linspace(0, 2*np.pi, hausdorff_resolution)
    r = np.linspace(0.00000001, 0.97, hausdorff_resolution)
    x, y = r * np.cos(theta), r * np.sin(theta)

    x_s, y_s = coordinate_change_to_hyperbolic(x, y)

    return x_s, y_s



def f_mapping(inp):
  x = inp[0]
  y = inp[1]
  solution = []
  first_entry = first_component_equation(x,y)
  second_entry = np.log(np.sqrt(np.exp(-2*x)+y**2))

  solution.append(first_entry)
  solution.append(second_entry)
  solution.extend(h_g(x,y))
  return np.array(solution)

hyperbolic_points_x,hyperbolic_points_y = generate_hyperbolic_points()


# map the function to the numpy array
p1_mapped = np.array(list(map(f_mapping, zip(hyperbolic_points_x, hyperbolic_points_y))))

#---------------------------------------------
#---------------------------------------------
'''g mapping functions and point set defined next,
g is the mapping function from E^2 to R^6
'''

def g_mapping_base(x, y, permutation=0):
    # Predefined basis vectors
    basis_vectors = [
        np.array([1, 0, 0, 0, 0, 0]),
        np.array([0, 1, 0, 0, 0, 0]),
        np.array([0, 0, 1, 0, 0, 0]),
        np.array([0, 0, 0, 1, 0, 0]),
        np.array([0, 0, 0, 0, 1, 0]),
        np.array([0, 0, 0, 0, 0, 1])
    ]
    
    # Basis vector permutations
    basis_permutations = [
        [0, 1], [0, 2], [0, 3], [0, 4], [0, 5],
        [1, 2], [1, 3], [1, 4], [1, 5],
        [2, 3], [2, 4], [2, 5],
        [3, 4], [3, 5],
        [4, 5]
    ]

    # Select basis vectors based on permutation
    basis1, basis2 = basis_vectors[basis_permutations[permutation][0]], basis_vectors[basis_permutations[permutation][1]]

    # Compute mapping
    return x*basis1 + y*basis2


def g_mapping(x_s, y_s, perm, offset_value=0.1, offset_location=1):
    p2_mapped = np.empty((len(x_s), 6))
    for i in range(len(x_s)):
        p2_mapped[i] = g_mapping_base(x_s[i], y_s[i], permutation=perm)

    p2_offset = np.zeros_like(p2_mapped)
    p2_offset[:, offset_location] = offset_value

    return p2_mapped + p2_offset



'''Find Hausdorff distance between Euclidean and Hyperbolic space
'''

def calculate_weight_E_H():
  all_offset_values = np.linspace(-0.5,0.5,100)
  num_samples = hausdorff_resolution
  theta = np.linspace(0, 2*np.pi, num_samples)
  r = np.linspace(0, 1, num_samples)
  x, y = r * np.cos(theta), r * np.sin(theta)
  x_s = x
  y_s = y

  # Cache the results of g_mapping for all possible permutation and offset_location values
  g_mappings = {}
  for perm in tqdm(range(0, 15)):
    for offset_location in range(6):
      for offset_value in all_offset_values:
        g_mappings[(perm, offset_location, offset_value)] = g_mapping(x_s, y_s, perm, offset_value, offset_location)

  dGH_values = np.empty((15, 100, 6))
  dGH_values.fill(np.inf)

  for perm in tqdm(range(0,15)):
    for offset_location in range(6):
      for offset_index, offset_value in enumerate(all_offset_values):
        p2_mapped = g_mappings[(perm, offset_location, offset_value)]

        dGH_perm = max(directed_hausdorff(p1_mapped, p2_mapped)[0], directed_hausdorff(p2_mapped, p1_mapped)[0])

        # Store the value of dGH for each permutation and offset value
        dGH_values[perm, offset_index, offset_location] = dGH_perm

  # Find the minimum dGH value and return its corresponding permutation and offset
  min_index = np.argmin(dGH_values)
  perm_index, offset_index, offset_location_index = np.unravel_index(min_index, dGH_values.shape)
  min_dGH = dGH_values[perm_index, offset_index, offset_location_index]
  best_perm = perm_index
  best_offset_value = all_offset_values[offset_index]
  best_offset_location = offset_location_index

  print('Final weight E and H: ',min_dGH)
  return min_dGH



'''h mapping functions and point set defined next,
h is the mapping function from S^2 to R^6
'''

def h_mapping_base(x, y, z, permutation=0):
    # Define the basis vectors as a 6x6 identity matrix
    basis_vectors = np.identity(6)
    
    # Define the permutation matrix using the permutation value
    permutation_matrix = np.zeros((6,6))
    permutation_matrix[0, 1] = 1
    permutation_matrix[0, 2] = 1
    permutation_matrix[0, 3] = 1
    permutation_matrix[1, 2] = 1
    permutation_matrix[1, 3] = 1
    permutation_matrix[2, 3] = 1
    permutation_matrix[3, 4] = 1
    permutation_matrix[4, 5] = 1
    
    # Apply the permutation matrix to the basis vectors based on the given permutation value
    permutation_indices = [[0,1,2], [0,1,3], [0,2,3], [1,2,3], [0,1,4], [0,2,4], [1,2,4], [0,3,4], [1,3,4], [2,3,4], [0,1,5], [0,2,5], [1,2,5], [0,3,5], [1,3,5], [2,3,5], [0,4,5], [1,4,5], [2,4,5], [3,4,5]]
    basis_vectors = basis_vectors[:, permutation_indices[permutation]]
    
    # Compute the output by multiplying the basis vectors by the input values
    return x * basis_vectors[:,0] + y * basis_vectors[:,1] + z * basis_vectors[:,2]


from itertools import product

def h_mapping(perm, co_half=0.1):
    p3_mapped = []
    for alpha, beta in product(np.linspace(0, 2*pi, int(hausdorff_resolution/100)), np.linspace(0, 1, int(hausdorff_resolution/100))):
        a = np.sin(beta)*np.cos(alpha)
        b = np.sin(beta)*np.sin(alpha)
        c = np.cos(beta)-co_half

        p3_mapped_loop = np.array([
            h_mapping_base(a,b,c,permutation=perm),
            h_mapping_base(a,c,b,permutation=perm),
            h_mapping_base(b,a,c,permutation=perm),
            h_mapping_base(b,c,a,permutation=perm),
            h_mapping_base(c,b,a,permutation=perm),
            h_mapping_base(c,a,b,permutation=perm)
        ])
        
        p3_mapped.append(p3_mapped_loop)
        p3_mapped.append(-p3_mapped_loop)
    
    return np.concatenate(p3_mapped, axis=0)



import multiprocessing as mp
import numpy as np
from tqdm import tqdm
from scipy.spatial.distance import directed_hausdorff

def calculate_dGH(perm, co_half):
    p3_mapped = h_mapping(perm, co_half)
    dGH_perm = max(directed_hausdorff(p1_mapped, p3_mapped)[0], directed_hausdorff(p3_mapped, p1_mapped)[0])
    return dGH_perm

def calculate_weight_H_S():
    dGH = np.inf
    all_co_half = np.linspace(0, np.cos(1/2), hausdorff_resolution)

    with mp.Pool(processes=mp.cpu_count()) as pool:
        for perm in tqdm(range(20)):
            results = [pool.apply_async(calculate_dGH, args=(perm, co_half)) for co_half in all_co_half]
            dGH_perm = min([result.get() for result in results])
            dGH = min(dGH, dGH_perm)

    print('Final weight H and S:', dGH)
    return dGH




with open('output_hs.txt', 'w') as f:
  f.write('Final weight H and S: ' + str(calculate_weight_H_S()))
  