import numpy as np
import pickle
import time

from src.utils.loading import load_model, get_mazes
from src.utils.plotting import plot_mazes, plot_diagram, plot_residual, plot_pca
from src.utils.tda import get_tda, Analysis

# Set parameters
model_name = 'dt_net'
num_mazes = 10
start_idx = 1000
end_idx = 1400
embedding_dim = 50

for maze_size in [9, 19, 29, 39, 49, 59, 69, 79, 89, 99]:
    # Load model and mazes
    start_time = time.time()
    model = load_model(model_name)
    inputs, solutions = get_mazes(dataset='maze-dataset', maze_size=maze_size, num_mazes=num_mazes)
    print(f'Loaded model {model.name()} and {num_mazes} mazes of size {maze_size}x{maze_size} in {time.time() - start_time:.2f}s')

    # Get latent series from start_idx to end_idx for a maze
    start_time = time.time()
    iters = list(range(start_idx, end_idx+1))
    latents = model.input_to_latent(inputs, iters)
    latents_series = model.latent_forward(latents, inputs, iters=iters)
    print(f'Got latent series in {time.time() - start_time:.2f}s')

    for maze_idx in range(num_mazes):
        # Plot residuals
        plot_residual(latents_series[:,maze_idx], start_idx, end_idx, file_name=f'outputs/tda/residuals/{model.name()}_size-{maze_size}x{maze_size}_maze-{maze_idx}.pdf')
    print('Plotted residuals')

    # Perform PCA on latent series
    for maze_idx in range(num_mazes):
        plot_pca(latents_series[:,maze_idx], file_name=f'outputs/pca/{model.name()}_size-{maze_size}x{maze_size}_maze-{maze_idx}.pdf')
    print('Plotted PCA')

    # Perform TDA on latent series
    start_time = time.time()
    D, B, PS = get_tda(latents_series[:,maze_idx], embedding_dim=embedding_dim)
    # for i in range(len(D)):
    #     print(f'D_{i} = \n{D[i]}')
    print(f'Performed TDA in {time.time() - start_time:.2f}s')
    print(f'{B = }')
    print(f'{PS = }')

    # Plot persistence diagram
    start_time = time.time()
    plot_diagram(D, file_name=f'outputs/tda/diagrams/{model.name()}_size-{maze_size}x{maze_size}_maze-{maze_idx}.pdf')
    print(f'Plotted diagram in {time.time() - start_time:.2f}s')

# file_name = 'outputs/tda/analysis/dt_net_maze_sizes-[9, 19]_num_mazes-100_iters-e,t_embedding_dim-50_delay-1_threshold-0.5.pkl'
# with open(file_name, 'rb') as f:
#     analysis = pickle.load(f)
# print(analysis.betti_nums)