import time
import numpy as np
import matplotlib.pyplot as plt
from ripser import ripser

from src.utils.tda import get_tda, get_distance_matrix, get_sw, reduce
from src.utils.plotting import plot_diagram

# # Sample from a circle with noise
# sample_size = 200
# theta = np.linspace(0, 20, sample_size)
# noise = 0.03
# x = np.cos(theta) + noise * np.random.randn(sample_size)
# y = np.sin(theta) + noise * np.random.randn(sample_size)
# print('Period is 6.28')

# # Plot signal
# fig = plt.figure()
# ax = fig.add_subplot(111, projection='3d')
# ax.scatter(theta, x, y)
# ax.set_xlabel('Theta')
# ax.set_ylabel('X')
# ax.set_zlabel('Y')
# plt.savefig('outputs/tda/circle.png')
# plt.show()

# # Perform TDA on signal
# X = np.array([x, y]).T
# D, B, PS = get_tda(X, delays=15)
# print(f'{B = }')
# print(f'{PS = }')
# plot_diagram(D)

# # Test distance matrix computation against simpler but slower method
# X = np.random.randn(500, 3*24*24)
# embed_dim = 50
# delay = 1

# # Simple but slow method
# start_time = time.time()
# sw = get_sw(X, embed_dim, delay)
# distance_matrix_slow = np.zeros(shape=(sw.shape[0], sw.shape[0]))
# for i in range(sw.shape[0]):
#     for j in range(sw.shape[0]):
#         distance_matrix_slow[i, j] = np.linalg.norm(sw[i] - sw[j])
#         distance_matrix_slow[j, i] = distance_matrix_slow[i, j]
# print(f'slow time = {time.time() - start_time:.4f}s')

# # Fast method
# start_time = time.time()
# distance_matrix = get_distance_matrix(X, embed_dim, delay)
# print(f'fast time = {time.time() - start_time:.4f}s')

# print(f'both methods give same result = {np.allclose(distance_matrix, distance_matrix_slow)}')

# # Check that reduced matrix has same distances between rows as original
# X = np.random.randn(500, 3*24*24)
# print(f'{X.shape = }')
# X_reduced = reduce(X)
# print(f'{X_reduced.shape = }')

# for i in range(X.shape[0]):
#     for j in range(i+1, X.shape[0]):
#         dist = np.linalg.norm(X[i] - X[j])
#         dist_reduced = np.linalg.norm(X_reduced[i] - X_reduced[j])
#         assert np.isclose(dist, dist_reduced)

# TODO: normalize of rows of SW tensor (?)

# Subtract mean row and divide by average norm of rows
X = np.random.randn(500, 3*24*24)
X = X - np.mean(X, axis=0)
X = X / np.linalg.norm(X, axis=1).mean()
