"""Graph generation and reconstruction functions for VQ-T2G"""

import logging
import random

import matplotlib.pyplot as plt
import networkx as nx

import torch
import torch_geometric as pyg

from pathlib import Path
from typing import Optional

from torch_geometric.data import Data as pyg_Data
from transformers import GPT2LMHeadModel

from vqt2g.gvqvae import GVQVAE
from vqt2g.utils.gvqvae_utils import real_num_nodes
from vqt2g.tokenizer import VQT2GTokenizerBPE

_LOG = logging.getLogger("vqt2g_logger")


def reconstruct_graph(
    model: GVQVAE, graph: pyg_Data, thresh: float = 0.8, edge_sampling: bool = False, device=torch.device("cuda:0")
):
    """Encode and decode graph with GVQVAE, then reconstruct edges

    Args:
      model: GVQVAE model
      graph: Graph to reconstruct with the model
      thresh: Threshold for edge adding (Default value = 0.8)
      device: CPU/CUDA device

    Returns:
        Networkx Graph. The reconstructed graph.
    """

    num_nodes = real_num_nodes(model, graph=graph)
    num_edges = graph.edge_index.size(1) // 2
    edge_probs = model.encode_decode_graph(graph, num_nodes=num_nodes, device=device)
    edge_topk = not edge_sampling
    recon_graph = edge_probs_to_graph(
        model=model,
        edge_probs=edge_probs,
        num_nodes=num_nodes,
        num_edges=num_edges,
        thresh=thresh,
        edge_topk=edge_topk,
    )
    return recon_graph


def generate_codes_from_string(
    string: str,
    tokenizer: VQT2GTokenizerBPE,
    transformer: GPT2LMHeadModel,
    greedy_sampling: bool = True,
    temperature: float = 0.7,
    top_p: float = 0.9,
) -> torch.Tensor:
    """Generate graph from given string, return networkx graph or plot it

    Args:
      string: text for conditioning
      tokenizer: Tokenizer for the model
      transformer: Trained VQ-T2G transformer
      greedy_sampling:  (Default value = True) Whether to do greedy sampling from the transformer
      temperature:  (Default value = 0.7) Temperature parameter for transformer sampling. Only
          relevant when `greedy_sampling=False`
      top_p:  (Default value = 0.9) Top-p (nucleus sampling) parameter for transformer sampling.
          Only relevant when `greedy_sampling=False`

    Returns:
        Tensor of GVQVAE codebook ids, i.e. the latent representation of the generated graph
        before it is turned into codebook vectors and passed to the GVQVAE decoder

    """

    # Convert string to tokens
    tokens = tokenizer.encode_text(string, add_specials=True)
    tokens = torch.tensor(tokens).view(1, -1)

    do_sample = not greedy_sampling

    # Generate graph tokens with transformer
    tfmr_output = transformer.generate(
        tokens,
        max_length=tokenizer.total_len,
        min_length=tokenizer.total_len,
        do_sample=do_sample,
        bad_words_ids=tokenizer.bad_ids,
        temperature=temperature,
        top_p=top_p,
    )
    tfmr_output = tfmr_output[0]
    return tokenizer.decode_graph(tfmr_output)


def edge_probs_to_graph(
    model: GVQVAE,
    edge_probs: torch.Tensor,
    num_nodes: int,
    num_edges: int = 0,
    thresh: float = 0.9,
    edge_topk: bool = True,
    extra_edge_randomness: bool = False,
) -> nx.Graph:
    """Take probabilistic adjacency, add edges, return networkx graph

    Args:
      model: GVQVAE model
      edge_probs: Tensor of probabilities for each edge
      num_nodes: Number of nodes in output graph
      num_edges: Number of edges in output graph
      thresh: Probability threshold for allowing edges to be selected
      edge_topk: Whether to use top-k for edge selection or sampling
      extra_edge_randomness: Shuffles all edges (above selected threshold) instead of going
          through by descending probability

    Returns:

    """

    pairs = model._all_edges(num_nodes)
    pairs = pairs.cpu().numpy()

    graph_out_1 = set()
    added_pairs_idx = set()

    # Stage 1: For each node (above 0), find the most likely edge and make it exist
    for n in range(1, num_nodes):
        edges_from_node = [(idx, i) for idx, i in enumerate(pairs.T) if i[0] == n]
        ep_ids = [i[0] for i in edges_from_node]
        probs_from_node = sorted(
            [(edge_probs[i].item(), idx) for idx, i in enumerate(ep_ids)], reverse=True
        )
        highest = probs_from_node[0][1]
        edge = edges_from_node[highest][1]
        graph_out_1.add(tuple(edge))
        added_pairs_idx.add(edges_from_node[highest][0])
    graph_out_1 = list(graph_out_1)

    # Stage 2: Sample the remaining edges from those with probability above the threshold
    graph_out_2 = set()
    for idx, prob in enumerate(edge_probs):
        if prob < thresh:
            continue
        if idx in added_pairs_idx:
            continue
        graph_out_2.add((tuple(pairs[:, idx]), prob))

    graph_out_2 = list(graph_out_2)

    # Do different things depending on if the number of edges is already known/specified or not
    if num_edges != 0:
        # If number of edges to add is known
        stage2_num_edges = num_edges - len(graph_out_1)
    else:
        # Otherwise sample Bernoullis for every pair with prob above thresh
        stage2_num_edges = 999999

    if len(graph_out_2) < stage2_num_edges:  ### remove later?
        _LOG.debug(
            f"For stage 2: want {stage2_num_edges} but only {len(graph_out_2)} are above thresh"
        )

    if extra_edge_randomness:
        random.shuffle(graph_out_2)
    else:
        graph_out_2 = sorted(graph_out_2, key=lambda x: x[1], reverse=True)

    # Two possible stage-2 edge sampling strategies
    if edge_topk:
        if num_edges == 0:
            _LOG.warning(
                "Top-k edge strategy used selected but number of edges not specified. Will likely "
                "mean generated graph will be poor quality"
            )
        # Deterministic - select remaining edges with descending probability
        top_edges = [i[0] for i in graph_out_2[:stage2_num_edges]]
        edgelist = graph_out_1 + top_edges

    else:
        # Select remaining edges by bernoulli sampling
        sampled_edges = []
        for e, prob in graph_out_2:
            if random.random() > prob:
                continue
            sampled_edges.append(e)
            if len(sampled_edges) >= stage2_num_edges:
                break
        else:
            if num_edges == 0:
                pass

        edgelist = sampled_edges + graph_out_1

    return nx.from_edgelist(edgelist)


def generate_from_text(
    text: str,
    gvqvae: GVQVAE,
    transformer: GPT2LMHeadModel,
    tokenizer: VQT2GTokenizerBPE,
    thresh: float = 0.8,
    real_graph: Optional[pyg_Data] = None,
    keep_num_nodes: bool = True,
    device=torch.device("cuda:0"),
    edge_topk: bool = True,
    extra_edge_randomness: bool = False,
    transformer_sampling: bool = False,
    temperature: float = 0.7,
    top_p: float = 0.9,
) -> nx.Graph:
    """Generate a graph from text for comparison to the real graph

    Args:
      text: Conditioning text
      gvqvae_model: Trained GVQVAE model
      transformer: Trained transformer model
      tokenizer: VQT2G tokenizer
      thresh: Threshold for edge adding (Default value = 0.8)
      real_graph: Real graph, if comparing generated result to a real/ground truth graph
      keep_num_nodes: Whether to force the generated graph to have the same number of nodes as the
          ground truth graph (i.e. `real_graph`). Used for baseline comparisons
      device: CPU/GPU device
      edge_topk: Whether to use top-k sampling from the probabilistic adjacency matrix, when
          selecting edges for the generated graph. Only relevant when `real_graph` is specified
      extra_edge_randomness: Makes graph even more random/less deterministic.
      transformer_sampling: Whether to use greedy (deterministic) transformer sampling
      temperature: Temperature parameter for transformer (Default value = 0.7)
      top_p: Top-p parameter for transformer (Default value = 0.9)

    Returns:
        Graph generated from conditioning text

    """

    # Encode text, generate graph tokens with transformer
    greedy_sampling = not transformer_sampling
    codes = generate_codes_from_string(
        string=text,
        tokenizer=tokenizer,
        transformer=transformer,
        greedy_sampling=greedy_sampling,
        temperature=temperature,
        top_p=top_p,
    )

    gvqvae.eval()
    # Require `max_nodes` in this step if not `keep_num_nodes`
    num_nodes = gvqvae.real_num_nodes(real_graph) if keep_num_nodes else gvqvae.max_nodes
    num_edges = real_graph.edge_index.size(1) // 2

    edge_probs = gvqvae.decode_from_codes(codes=codes, num_nodes=num_nodes, device=device)

    if not keep_num_nodes:  # If number of nodes not known, need to pick a value
        num_nodes = select_num_nodes(edge_probs=edge_probs, max_nodes=gvqvae.max_nodes)

    generated_nx_graph = edge_probs_to_graph(
        model=gvqvae,
        edge_probs=edge_probs,
        num_nodes=num_nodes,
        num_edges=num_edges,
        thresh=thresh,
        edge_topk=edge_topk,
        extra_edge_randomness=extra_edge_randomness,
    )

    return generated_nx_graph


def select_num_nodes(edge_probs, max_nodes):
    """Pick an appropriate number of nodes for the output graph. This is only used when the number
    of nodes isn't already specified.
    """

    ### Fix this

    # Placeholder: just return max nodes
    return max_nodes


def plot_real_and_generated_graph(
    real_graph,
    generated_graph,
    text,
    save_folder,
    file_name,
    text_as_title=True,
    plot_width=20,
    plot_height=10,
    graphviz_layout=False,
    graphviz_prog="neato",
    node_id_labels=False,
):
    """For model evaluation. Plots a graph against a graph generated with that same text

    Args:
      real_graph: Real graph
      generated_graph: Graph generated by the model, using the text of the real graph
      text: Text of real graph, also used to generate
      save_folder:
      file_name:
      text_as_title:  (Default value = True)
      plot_width:  (Default value = 20)
      plot_height:  (Default value = 10)
      graphviz_layout:  (Default value = False)
      graphviz_prog:  (Default value = "neato")
      node_id_labels:  (Default value = False)

    Returns:

    """

    plt.figure(figsize=(plot_width, plot_height))
    plt.clf()
    if text_as_title:
        plt.suptitle(text)

    # Real graph
    plt.subplot(121)
    real_nxg = pyg.utils.to_networkx(real_graph)
    real_nxg = nx.Graph(real_nxg)  # Required, this removes frozen status
    real_nxg.remove_nodes_from(list(nx.isolates(real_nxg)))  # remove pad nodes
    plt.title(f"Real graph - {len(real_nxg)} nodes, {real_nxg.number_of_edges()} edges")
    if graphviz_layout:
        pos = nx.nx_pydot.graphviz_layout(real_nxg, prog=graphviz_prog)
    else:
        pos = nx.spring_layout(real_nxg)
    nx.draw(
        real_nxg,
        pos=pos,
        with_labels=node_id_labels,
        font_weight="bold",
        font_color="white",
        node_size=150,
        width=0.7,
    )

    # Generated graph
    plt.subplot(122)
    full_title = (
        f"Generated graph - {len(generated_graph)} nodes, {generated_graph.number_of_edges()} edges"
    )
    plt.title(full_title)
    if graphviz_layout:
        pos = nx.nx_pydot.graphviz_layout(generated_graph, prog=graphviz_prog)
    else:
        pos = nx.spring_layout(generated_graph)
    nx.draw(
        generated_graph,
        pos=pos,
        with_labels=node_id_labels,
        font_weight="bold",
        font_color="white",
        node_size=150,
        width=0.7,
    )

    # Save to file
    save_path = Path(save_folder, file_name)
    plt.savefig(save_path)
    plt.close()
