from typing import Tuple

import torch
from torch import Tensor

from src.datatypes.split import merge_batches, merge_edge_batches, merge_subgraphs
from src.datatypes.sparse import SparseGraph, SparseEdges
from src.noise import NoiseProcess


def sample_sequences(batch: SparseGraph, removal_process: NoiseProcess, num_sequences=1, need_preparation=True, also_return_splits=False):

    """generates a new batch by sampling sequences from the given batch
    using the removal process. First, the maximum length of the sequences
    is determined from the removal process. Then, noise masks are sampled
    from the removal process, and new batches are generated by applying
    the noise masks to the given batch. Then, the 'batch' part is given
    by all samples from time 0 to T-1, the 'surv_batch' are given by all
    samples from time 1 to T. The 'remv_batch' are computed from the 'batch'
    part by applying the inverse of the collected noise masks. Batches at different
    time steps are processed such that the batch.batch tensors are incremented
    by the right amount at that time steps, taking into account the samples
    that stopped before that time step.
    """

    # determine number of nodes at time 0
    n0 = batch.num_nodes_per_sample
    bs = batch.num_graphs

    # determine the maximum length of the sequences
    # which is the max among the max times of the removal process
    max_time = removal_process.get_max_time(n0=n0)
    if isinstance(max_time, int) or (isinstance(max_time, Tensor) and max_time.numel() == 1):
        max_max_time = max_time
    else:
        max_max_time = int(max_time.max().item())

    # prepare the batch
    if need_preparation:
        removal_process.prepare_data(datapoint=batch)

    curr_batch = batch

    # time tensor
    curr_batch.global_t = torch.zeros(bs, dtype=torch.int32, device=batch.x.device)
    norm_t = removal_process.normalize_reverse_time(t=curr_batch.global_t, n0=n0)
    curr_batch.global_rev_t = norm_t
    curr_batch.global_n0 = n0
    curr_batch.global_nt = n0

    # prepare sequence
    sequences = {'batch': [], 'surv_batch': [], 'remv_batch': []}
    edges = []
    alive_masks = []

    for t in range(1, max_max_time+1):

        sequences['batch'].append(curr_batch)
        next_time = curr_batch.global_t + 1

        surv_batch: SparseGraph
        remv_batch: SparseGraph
        remv_edges_ba: SparseEdges
        surv_batch, remv_batch, _, remv_edges_ba = removal_process.sample_next(
            current_datapoint = curr_batch,
            t = next_time,
            max_time = max_time,
            n0 = n0,
            split = True
        )

        norm_t = removal_process.normalize_reverse_time(t=next_time, n0=n0)

        surv_batch.global_t = next_time
        surv_batch.global_rev_t = norm_t
        surv_batch.global_n0 = n0
        surv_batch.global_nt = surv_batch.num_nodes_per_sample

        remv_batch.global_t = next_time
        remv_batch.global_rev_t = norm_t
        remv_batch.global_n0 = n0
        remv_batch.global_nt = remv_batch.num_nodes_per_sample

        sequences['surv_batch'].append(surv_batch)
        sequences['remv_batch'].append(remv_batch)
        edges.append(remv_edges_ba)
        alive_masks.append(torch.logical_or(remv_batch.global_nt > 0, surv_batch.global_nt > 0))

        curr_batch = surv_batch


    # merge all batches together
    merged_sequences = {k: merge_batches(v, alive_masks) for k, v in sequences.items()}
    merged_sequences['remv_edges_ba'] = merge_edge_batches(edges, alive_masks)

    if also_return_splits:
        sequences['remv_edges_ba'] = edges
        return merged_sequences, sequences, alive_masks

    else:
        return merged_sequences
        


def unit_test_sample_sequences(sequences: dict, batch: SparseGraph):

    bs = batch.num_graphs

    # check that the batch is correct
    # - the sum of all removed nodes must be equal to the number of nodes in the batch
    assert sequences['remv_batch'].x.shape[0] == batch.x.shape[0]
    # do this also for global_nt and global_n0
    assert sequences['remv_batch'].global_nt.sum() == sequences['remv_batch'].global_n0[:bs].sum()

    # - the number of survived + removed nodes must be equal to the number of nodes in the
    # previous step
    assert torch.all(
        sequences['surv_batch'].num_nodes_per_sample + \
        sequences['remv_batch'].num_nodes_per_sample == \
        sequences['batch'].num_nodes_per_sample
    )
    # do this also for global_nt
    assert torch.all(
        sequences['surv_batch'].global_nt + \
        sequences['remv_batch'].global_nt == \
        sequences['batch'].global_nt
    )
    
    # - the number of survived + removed edges must be equal to the number of edges in the
    # previous step
    assert (
        sequences['surv_batch'].num_edges + \
        sequences['remv_batch'].num_edges + \
        sequences['remv_edges_ba'].num_edges*2 == \
        sequences['batch'].num_edges
    )

    # the global t of batch should be the global t of surv and remv -1
    assert torch.all(sequences['batch'].global_t == (sequences['surv_batch'].global_t - 1))
    assert torch.all(sequences['batch'].global_t == (sequences['remv_batch'].global_t - 1))

    # check ys
    assert torch.all(sequences['remv_batch'].y == sequences['surv_batch'].y)

    # check num_nodes_s and num_nodes_t
    assert torch.all(sequences['remv_edges_ba'].num_nodes_s == sequences['remv_batch'].global_nt)
    assert torch.all(sequences['remv_edges_ba'].num_nodes_t == sequences['surv_batch'].global_nt)


import src.datatypes.dense as dense
from src.datatypes.dense import DenseGraph, DenseEdges
from scipy.sparse.csgraph import connected_components

def check_connected_components(
        graph_a: SparseGraph|DenseGraph,
        graph_b: SparseGraph|DenseGraph=None,
        edges_ba: SparseEdges|DenseEdges=None):

    # make A and B dense

    if isinstance(graph_a, SparseGraph):
        # adjacency shape: B=(bs, Nb, Nb, d), A=(bs, Na, Na, d)
        graph_a_dense = dense.sparse_graph_to_dense_graph(graph_a, handle_one_hot=True)
    else:
        graph_a_dense = graph_a

    
    if graph_b is not None and edges_ba is not None:
        if isinstance(graph_b, SparseGraph):
            graph_b_dense = dense.sparse_graph_to_dense_graph(graph_b, handle_one_hot=True)
        else:
            graph_b_dense = graph_b

        num_nodes_per_graph = graph_a.num_nodes_per_sample + \
                    graph_b.num_nodes_per_sample

        if isinstance(edges_ba, SparseEdges):
            # compute the B->A edge mask (bs, Nb, Na)
            edge_mask_ba = dense.get_bipartite_edge_mask_dense(
                graph_b_dense.node_mask,
                graph_a_dense.node_mask
            )

            # compute the B->A adjacency matrix (bs, Nb, Na, d)
            edge_adjmat_ba = dense.to_dense_adj_bipartite(
                edge_index =		edges_ba.edge_index,
                edge_attr =			edges_ba.edge_attr,
                batch_s =			graph_b.batch,
                batch_t =			graph_a.batch,
                max_num_nodes_s =	graph_b.num_nodes_per_sample.max(),
                max_num_nodes_t =	graph_a.num_nodes_per_sample.max(),
                batch_size =		graph_a.num_graphs,
                handle_one_hot =	True,
                edge_mask =			edge_mask_ba
            )
        else:
            edge_adjmat_ba = edges_ba.edge_adjmat

        # remove the features dimension d and just get the adjacency matrix (bs, Nb, Na)
        edges_adjmat_aa = graph_a_dense.edge_adjmat[..., 1:].sum(dim=-1)
        edges_adjmat_bb = graph_b_dense.edge_adjmat[..., 1:].sum(dim=-1)
        edge_adjmat_ba = edge_adjmat_ba[..., 1:].sum(dim=-1)

        # compute the A->B adjacency matrix
        edge_adjmat_ab = edge_adjmat_ba.transpose(1, 2)

        # merge everything together
        block_a = torch.cat([edges_adjmat_aa, edge_adjmat_ab], dim=2) # concat horiz.
        block_b = torch.cat([edge_adjmat_ba, edges_adjmat_bb], dim=2) # concat horiz.
        adjmat = torch.cat([block_a, block_b], dim=1) # concat vert.
    
    else:
        adjmat = graph_a_dense.edge_adjmat[..., 1:].sum(dim=-1)
        num_nodes_per_graph = graph_a.num_nodes_per_sample

    # compute the connected components
    adjmat_np = adjmat.cpu().numpy()
    num_fake_nodes = adjmat_np.shape[1] - num_nodes_per_graph
    n_components = []
    labels = []
    for fake_nn, mat in zip(num_fake_nodes, adjmat_np):
        n_components_curr, labels_curr = connected_components(mat, directed=False)
        n_components.append(n_components_curr - fake_nn.item())
        labels.append(labels_curr)

    return n_components, labels, adjmat

from torch_geometric.utils import to_networkx
from src.metrics.utils.synth import are_isomorphic

def check_same_graphs(graph_a: SparseGraph, graph_b: SparseGraph, edges_ba: SparseEdges, graph_whole: SparseGraph):

    edges_ab = edges_ba.clone().transpose()

    graph_merged = merge_subgraphs(graph_a, graph_b, edges_ab, edges_ba)

    graph_whole = graph_whole.clone()
    graph_whole.node_perm = None

    # transform all graphs to nx graphs
    graph_merged_list = graph_merged.to_data_list()
    graph_whole_list = graph_whole.to_data_list()
    graph_merged_nx = [to_networkx(g) for g in graph_merged_list]
    graph_whole_nx = [to_networkx(g) for g in graph_whole_list]

    # check if the graphs are isomorphic
    non_isomorphic = []
    for i, (g1, g2) in enumerate(zip(graph_merged_nx, graph_whole_nx)):
        iso = are_isomorphic(g1, g2)
        if not iso:
            non_isomorphic.append(i)

    return len(non_isomorphic) == 0, non_isomorphic