from embeddings.nflows_sphere_inv import *
from common.helper_funcs import *
from common.variables import *
from common.randomization_tools import *
from common.plot_3d import *

import matplotlib.pyplot as plt

input_dim_A = INPUT_DIM_A
input_dim_B = INPUT_DIM_B
output_dim = OUTPUT_DIM
orthogonal_vector = ORTHOG_VECTOR
num_samples = NUM_SAMPLES
hidden_dim = STACK_EMB_HIDDEN_DIM_1

batch_size = BATCH_SIZE

data_A = bounded_randu(num_samples, input_dim_A, mean=0, std=0.25, lower_bound=AMB_LOWER_BOUND_A, upper_bound=AMB_UPPER_BOUND_A)
data_B = bounded_randu(num_samples, input_dim_B, mean=0, std=0.25, lower_bound=AMB_LOWER_BOUND_B, upper_bound=AMB_UPPER_BOUND_B)
ambient_points = np.concatenate((data_A.detach().numpy(), data_B.detach().numpy()), axis=1)

if ambient_points.shape[1] == 3:
    plot_3D_x1_x2(ambient_points)

print(data_A.shape)
print(data_B.shape)

model = SphericalMappingNetWithFlows(input_dim_A, input_dim_B, hidden_dim, restrict_to_positive_quadrant=False, output_dim=OUTPUT_DIM)

initialize_weights_xavier(model)

losses = train_flow_minibatch(model, data_A, data_B, batch_size=BATCH_SIZE, epochs=NUM_EPOCHS)

plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.show()

y = model.map_to_sphere(data_A, data_B)
y_np = y.detach().numpy().squeeze()

plot_data(y_np, dimensions=output_dim, show_origin=True)

num_clusters = 3
points_per_cluster = 1
cluster_std = 0.1
data_A_fixed, side_data = generate_x1_from_clusters(num_samples, input_dim_A, num_clusters, points_per_cluster, cluster_std)
data_B_fixed, side_data = generate_x1_from_clusters(num_samples, input_dim_B, num_clusters, points_per_cluster, cluster_std)

y_iso = model.map_to_sphere(data_A_fixed, data_B)

y_iso_np = y_iso.detach().numpy().squeeze()

plot_data(y_iso_np, dimensions=output_dim, show_origin=True)

weight_path = 'saved_models/model_tests_nflow_sphere.pth'
if type(model) == SphericalMappingNetWithFlows:
    model.save_weights(weight_path)
