"""Functions for graph reconstruction runner"""


import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import seaborn as sns
import torch
import torch_geometric as pyg

from pathlib import Path

from vqt2g.utils.gvqvae_utils import real_num_nodes


def reshape_to_adj(adj_vector, size, diag_val=1.0):
    """Reshape GVQVAE output into a symmetric matrix

    GVQVAE outputs a long vector corresponding to the lower triangle of a matrix. This reshapes
    the vector into this lower triangle, then makes it symmetric
    """
    adj_dims = (size, size)
    mask = torch.zeros(adj_dims)
    tril_inds = torch.tril_indices(size, size, offset=-1)
    mask[tril_inds[0], tril_inds[1]] = adj_vector
    adj = np.diag([diag_val] * size)
    return mask + mask.T + adj


def heatmap(
    model,
    graph,
    title_text="",
    prog="neato",
    include_padding=False,
    full_range=False,
    show_codes=False,
    device=torch.device("cuda:0"),
    save_folder=None,
    file_name=None,
):
    """Make probability heatmap of recon adj. No graph recon here

    Args:
      model:
      graph:
      title_text:  (Default value = "")
      prog:  (Default value = "neato")
      include_padding:  (Default value = False)
      full_range:  (Default value = False)
      show_codes:  (Default value = True)
      device:  (Default value = torch.device("cuda:0"))

    Returns:

    """
    model.eval()
    num_nodes = model.max_nodes if include_padding else real_num_nodes(model, graph)
    with torch.no_grad():
        pairs = model._all_edges(num_nodes)
        x = graph.x.to(device)
        edge_index = graph.edge_index.to(device)
        z = model.vq_encode(x=x, edge_index=edge_index)
        edge_probs = model.decoder(z=z, edge_index=pairs)
    pred_adj = reshape_to_adj(edge_probs.cpu(), size=num_nodes)
    if full_range:
        pred_adj[0, 0] = 0.0

    plt.figure(figsize=(24, 10))
    plt.clf()

    plt.subplot(121)
    real_nxg = pyg.utils.to_networkx(graph)
    real_nxg = nx.Graph(real_nxg)  # Required to ensure graph undirected + not frozen
    real_nxg.remove_nodes_from(list(nx.isolates(real_nxg)))  # Remove pad nodes
    plt.title(f"{title_text} - {len(real_nxg)} nodes, {real_nxg.number_of_edges()} edges")
    pos = nx.spring_layout(real_nxg)
    nx.draw(real_nxg, pos=pos, with_labels=True, font_weight="bold", font_color="white")

    plt.subplot(122)
    full_title = f"Heatmap"
    plt.title(full_title)
    sns.heatmap(pred_adj, linewidth=0.02, cmap="Spectral")

    save_path = Path(save_folder, file_name)
    plt.savefig(save_path)
    plt.close()
