import networkx as nx
import numpy as np
import torch
from causal_graphs.variable_distributions import _random_categ
from torch.distributions.categorical import Categorical

from causal_discovery.multivariable_mlp import EmbedLayer, OneHotEmbedding
from causal_discovery.utils import evaluate_likelihoods_for_model


def inversePermutation(perm):
    """Create a inverse mapping of a given permutation"""
    inverse = [0] * len(perm)
    for i, p in enumerate(perm):
        inverse[p] = i

    return inverse


def sample_DAG(gamma, theta, only_gamma=False, t=0.2):
    """
    Generates DAGs using a two-stage mechanism:
        (A) Sample (root) node orderings using iterative refined score from soft-adjacency
        (B) Sample DAG based on the sampled node ordering
    """

    with torch.no_grad():
        while True:

            # Set starting soft-adjacency over all variables
            # Remember that GAMMA and THETA are transposed due their formulation
            if only_gamma:
                softAdj = torch.sigmoid(gamma)
            else:
                softAdj = torch.sigmoid(gamma) * torch.sigmoid(theta)
            softAdj_t = softAdj.detach().cpu()

            # ------------------------------------------
            # (A) Sample node ordering
            # ------------------------------------------
            ordering = []
            possibleNodes = np.arange(softAdj_t.shape[0])

            while len(possibleNodes) >= 1:
                # (1) Compute distribution over nodes being root nodes in current state
                prob = torch.ones(softAdj_t.shape[0]) - softAdj_t.max(dim=0)[0]
                softmax = torch.nn.Softmax()
                prob = softmax(prob / t)

                # (2) Sample root in current state
                index = torch.multinomial(prob, 1).numpy().reshape([1])[0]
                r_i = possibleNodes[index]
                ordering.append(r_i)

                # (3) Remove sampled node from adjacency
                possibleNodes = np.delete(possibleNodes, index)
                keepIndex = np.delete(np.arange(softAdj_t.shape[0]), index)
                softAdj_t = softAdj_t[keepIndex, :][:, keepIndex]

            # ------------------------------------------
            # (B) Sample DAG based on node ordering
            # ------------------------------------------

            # Get inverse permutation of node ordering
            invOrdering = inversePermutation(ordering)

            # Bernoulli Sampling from constrained gamma

            softAdjPerm = (torch.sigmoid(gamma))[ordering, :][:, ordering].triu()
            softAdjPerm.diagonal().zero_()
            dag = torch.empty_like(softAdjPerm).uniform_().lt_(softAdjPerm)
            dag = dag[invOrdering, :][:, invOrdering]

            yield dag


def sample_interventionalSamples(
    config,
    target_node,
    model,
    device,
    nb_categs,
    batch_size=128,
):
    """
    Generates samples using ancestral sampling from an interventional distribution
    on a specific single-target node (target_node) under a given DAG (config)
    """
    # Get dimension and topoOrder of current 'hypothesis' graph
    dim = config.shape[0]
    topoOrder = list(nx.topological_sort(nx.DiGraph(config.detach().cpu().numpy())))

    # Apply intervention to DAG mask and init mask vector
    adj = config.clone()
    adj[:, target_node] = 0
    adj_intervention_matrices = (
        adj.unsqueeze(dim=0).expand(batch_size, -1, -1).transpose(1, 2)
    )

    # Init sample array
    samples = torch.ones(batch_size, dim, dtype=torch.int64).to(device)

    # Iterate through topoOrder and sample variables
    for node in topoOrder:

        # Check if intervention node
        if node == target_node:

            # Interventional Distribution => uniform distribution over outcomes (most randmomness)
            prob = torch.from_numpy(_random_categ(size=nb_categs, scale=0.0, axis=-1))
            x = torch.multinomial(prob, num_samples=batch_size, replacement=True)

        else:
            # Apply masking and embedding layer
            p, logits = compute_dist_for_node(
                adj_intervention_matrices, model, node, samples
            )
            x = torch.multinomial(p, num_samples=1)

        # Set samples of corresponding variables
        samples[:, node] = x.squeeze(-1)

    return samples


def compute_dist_for_node(adj_intervention_matrices, model, node, samples):
    assert isinstance(model.layers[0], EmbedLayer) or isinstance(
        model.layers[0], OneHotEmbedding
    )
    first_linear_idx = 2 if isinstance(model.layers[0], EmbedLayer) else 1

    embedding = model.layers[0](samples, mask=adj_intervention_matrices)
    # Use only embedding of current variable and apply LeakyReLu
    embedding_node = embedding[:, node, :]
    if isinstance(model.layers[0], EmbedLayer):
        embedding_node = model.layers[1](embedding_node)

    # Apply first Linear Layer
    weight_1 = model.layers[first_linear_idx].weight[node].unsqueeze(dim=0)
    bias_1 = model.layers[first_linear_idx].bias[node].unsqueeze(dim=0)
    embedding_node = embedding_node.unsqueeze(dim=-1)
    x = torch.matmul(weight_1, embedding_node).squeeze(dim=-1)
    x = x + bias_1
    # Apply LeakyReLu
    x = model.layers[first_linear_idx + 1](x)
    # Apply second Linear Layer
    weight_2 = model.layers[first_linear_idx + 2].weight[node].unsqueeze(dim=0)
    bias_2 = model.layers[first_linear_idx + 2].bias[node].unsqueeze(dim=0)
    x = x.unsqueeze(dim=-1)
    x = torch.matmul(weight_2, x).squeeze(dim=-1)
    x = x + bias_2
    # Draw samples using softmax
    softmax = torch.nn.Softmax(dim=-1)
    p = softmax(x)
    return p, x


def logpdf_interventionalSamples(
    config,
    target_node,
    model,
    device,
    samples,
):
    """
    Computes the logliklihood of samples for graph passed in config, assuming intervention on target node

    Parameters
    ----------
    config
    target_node
    model
    device
    samples

    Returns
    -------

    """
    # Get dimension and topoOrder of current 'hypothesis' graph
    nr_graphs, batch_size, dim = samples.shape
    full_batch_size = nr_graphs * batch_size

    samples = samples.reshape(full_batch_size, -1)

    # Apply intervention to DAG mask and init mask vector
    adj = config.clone()
    # do not transpose here since it will be transposed in evaluate_likelihoods
    adj_intervention_matrices = adj.unsqueeze(dim=0).expand(full_batch_size, -1, -1)
    nll = evaluate_likelihoods_for_model(
        model, device, samples, adj_intervention_matrices, target_node
    )

    logpdf = -nll.sum(dim=1)

    return logpdf.reshape(nr_graphs, batch_size)
