"""
In this module, we report the necessary
methods to sample a pair of abstract
and concrete models.
"""
from typing import Tuple

import igraph as ig
import numpy as np

from .utils import check_faithfulness, compute_mechanism, linear_anm

AdjacencyMatrix = np.ndarray


def is_dag(weight_adj: AdjacencyMatrix) -> bool:
    """Checks if a given adjacency matrix is a DAG."""
    graph = ig.Graph.Weighted_Adjacency(weight_adj.tolist())
    return graph.is_dag()


def sample_random_dag(
    num_nodes: int,
    num_edges: int,
    graph_type: str,
    ordered: bool = False,
) -> ig.Graph:
    """Samples a random DAG with the expected number of edges."""

    def _random_permutation(matrix: AdjacencyMatrix) -> AdjacencyMatrix:
        # np.random.permutation permutes first axis only
        p_matrix = np.random.permutation(np.eye(matrix.shape[0]))
        return p_matrix.T @ matrix @ p_matrix

    def _random_acyclic_orientation(b_und):
        return np.tril(_random_permutation(b_und), k=-1)

    def _graph_to_adjmat(graph: ig.Graph) -> AdjacencyMatrix:
        return np.array(graph.get_adjacency().data)

    if graph_type == "ER":
        # Erdos-Renyi
        graph = ig.Graph.Erdos_Renyi(n=num_nodes, m=num_edges)
        binary = _graph_to_adjmat(graph)
        binary = _random_acyclic_orientation(binary)
    elif graph_type == "SF":
        # Scale-free, Barabasi-Albert
        graph = ig.Graph.Barabasi(
            n=num_nodes, m=int(round(num_edges / num_nodes)), directed=True
        )
        binary = _graph_to_adjmat(graph)
    elif graph_type == "BP":
        # Bipartite, Sec 4.1 of (Gu, Fu, Zhou, 2018)
        top = int(0.2 * num_nodes)
        graph = ig.Graph.Random_Bipartite(
            top, num_nodes - top, m=num_edges, directed=True, neimode=ig.OUT
        )
        binary = _graph_to_adjmat(graph)
    else:
        raise ValueError("unknown graph type")

    # Random permutation of the node labels
    binary = _random_permutation(binary)
    graph = ig.Graph.Weighted_Adjacency(binary.tolist())

    if ordered:
        # Get topological ordering
        order = graph.topological_sorting()
        # Reorder adjacency matrix
        binary = binary[order, :][:, order]
        graph = ig.Graph.Weighted_Adjacency(binary.tolist())

    return graph


def sample_linear_anm(
    nodes: int,
    edges: int,
    graph_type: str,
    w_ranges=((-2.0, -0.5), (0.5, 2.0)),
) -> AdjacencyMatrix:
    """
    Samples ANM parameters for a DAG.
    """
    # sample a DAG binary adjacency matrix
    graph = sample_random_dag(nodes, edges, graph_type, ordered=True)
    binary = np.array(graph.get_adjacency().data)

    # sample the weights
    weight_adj = np.zeros(binary.shape)
    sign_mat = np.random.randint(len(w_ranges), size=binary.shape)
    for i, (low, high) in enumerate(w_ranges):
        unif_sample = np.random.uniform(low=low, high=high, size=binary.shape)
        weight_adj += binary * (sign_mat == i) * unif_sample

    return weight_adj


def sample_linear_concretization(
    abs_weights: AdjacencyMatrix,
    tau_adj: np.ndarray,
    readout_size: list[int],
    internal: bool = True,
    alpha: float = 1.0,
    style: str = "dirichlet",
) -> [AdjacencyMatrix, np.ndarray]:
    """
    Given an abstract model and an abstraction function,
    the function samples a concrete model and the
    corresponding gamma function.
    """
    # Abstract model
    abs_nodes = abs_weights.shape[0]
    readout_start = np.cumsum([0] + readout_size[:-1])

    # Helper to get block-indices
    def get_block(y: int) -> slice:
        return slice(readout_start[y], readout_start[y] + readout_size[y])

    # Paper notation
    gamma_adj = np.copy(tau_adj)

    # Handle block internal
    f_blocks = []
    cnc_nodes = tau_adj.shape[0]
    cnc_weights = np.zeros((cnc_nodes, cnc_nodes))
    for y in range(abs_nodes):
        block_y = get_block(y)
        size_y = readout_size[y]
        if internal:
            # NOTE: As of now, all the mediators are valid since
            #       we are generating a fully-connected block
            #       where the last variable in the topological
            #       order is relevant. There are scenarios where
            #       the block is not fully-connected but still
            #       the mediators are valid. We leave this as
            #       future work.

            # sample the internal connections
            cnc_weights[block_y, block_y] = np.random.normal(
                size=(size_y, size_y)
            ) / np.sqrt(size_y)
            # remove the lower-triangular part
            cnc_weights[block_y, block_y] = np.triu(
                cnc_weights[block_y, block_y], k=1
            )
            # compute internal mechanism
            f_yy = compute_mechanism(cnc_weights[block_y, block_y])
        else:
            f_yy = np.eye(size_y)
        f_blocks.append(f_yy)
        # compute gamma
        gamma_adj[block_y, y] = f_blocks[y] @ tau_adj[block_y, y]

    # Abstract mechanism
    abs_model = compute_mechanism(abs_weights)

    # Remainder matrix (cnc_nodes, cnc_nodes)
    remainder = np.zeros_like(cnc_weights)

    # Concrete weights
    for y_b in range(abs_nodes):
        # target block
        block_b = slice(
            readout_start[y_b], readout_start[y_b] + readout_size[y_b]
        )

        # target abstraction
        s_b = gamma_adj[block_b, y_b]

        # allowed concrete targets
        all_targets = list(np.where(np.abs(s_b) > 0.0)[0])

        for y_a in list(reversed(range(y_b))):
            # source block
            block_a = slice(
                readout_start[y_a], readout_start[y_a] + readout_size[y_a]
            )

            # abstraction function
            t_a = tau_adj[block_a, y_a]

            # abstract mechanism
            g_ab = abs_model[y_a, y_b]
            m_ab = abs_weights[y_a, y_b]

            # Update the remainder
            i, j = y_a, y_b
            start = min(i, j) + 1
            end = max(i, j)
            for k in range(start, end):
                # compute k-th step
                block_i = get_block(i)
                block_j = get_block(j)
                block_k = get_block(k)
                w_ik = cnc_weights[block_i, block_k]
                w_kj = cnc_weights[block_k, block_j]
                r_kj = remainder[block_k, block_j]
                f_kk = f_blocks[k]

                # update the remainder
                remainder[block_i, block_j] += w_ik @ f_kk @ (w_kj + r_kj)

            # extract the remainder
            r_ab = remainder[block_a, block_b]

            # compute the weights
            for k in range(readout_size[y_a]):
                # random choice half of the possible targets variables
                targets = np.random.choice(
                    all_targets,
                    size=max(len(all_targets) // 2, 1),
                    replace=False,
                )

                # build the assignment vector
                v = np.zeros_like(s_b)

                if style == "uniform":
                    dist = 1 / len(targets)
                elif style == "dirichlet":
                    dist = np.random.dirichlet(np.ones(len(targets)) * alpha)
                else:
                    raise ValueError(f"Unknown style {style}.")

                c_b = np.zeros_like(s_b)
                v[targets] = dist
                c_b[all_targets] = v[all_targets] / s_b[all_targets]

                # compute block
                cnc_weights[block_a, block_b][k, :] = m_ab * t_a[k] * c_b

            # check weights closed form
            w_ab = cnc_weights[block_a, block_b]
            s_b = s_b.reshape((len(s_b), 1))
            t_a = t_a.reshape((len(t_a), 1))
            assert np.allclose(w_ab @ s_b, g_ab * t_a - r_ab @ s_b)
            s_b = s_b.reshape((len(s_b),))

    # set connections towards ignored variables
    block_ignored = get_block(abs_nodes)
    for y in range(abs_nodes):
        block_y = get_block(y)
        # random weights from block_y to block_ignored
        cnc_weights[block_y, block_ignored] = np.random.normal(
            size=(readout_size[y], readout_size[-1])
        ) / np.sqrt(readout_size[y])
        # randomly mask half of the weights
        mask = np.random.randint(
            2, size=cnc_weights[block_y, block_ignored].shape
        )
        # set the masked weights to zero
        cnc_weights[block_y, block_ignored] *= mask
    # set connections from block_ignored to block_ignored
    cnc_weights[block_ignored, block_ignored] = np.random.normal(
        size=(readout_size[-1], readout_size[-1])
    ) / np.sqrt(readout_size[-1])
    # remove the lower-triangular part
    cnc_weights[block_ignored, block_ignored] = np.triu(
        cnc_weights[block_ignored, block_ignored], k=1
    )

    # consistency when FT = SG
    cnc_model = compute_mechanism(cnc_weights)
    assert np.allclose(cnc_model @ tau_adj, gamma_adj @ abs_model, atol=1e-4)

    # check faithfulness
    if not check_faithfulness(cnc_weights):
        # NOTE: It should raise an exception, so that
        #       the sampling is repeated.
        # raise ValueError("The concrete model is not faithful.")
        print("WARNING: The concrete model is not faithful.")

    # return cnc_weights, gamma_adj, remainder
    return cnc_weights, gamma_adj


def sample_linear_abstraction(
    abs_nodes: int,
    min_readout: int,
    max_readout: int,
    marginalize_ratio: float = 0.2,
    tau_ranges: tuple = (0.5, 2.0),
) -> np.ndarray:
    """Samples a linear abstraction function."""
    # Sample readouts
    readout_size = [
        np.random.randint(min_readout, max_readout + 1)
        for _ in range(abs_nodes + 1)
    ]

    # Concrete model pointers
    concrete_nodes = sum(readout_size)
    readout_start = np.cumsum([0] + readout_size[:-1])

    # Sample abstraction function
    tau_adj = np.zeros((concrete_nodes, abs_nodes))
    for y in range(abs_nodes):
        interval = slice(readout_start[y], readout_start[y] + readout_size[y])
        sign_mat = np.random.randint(2, size=readout_size[y])
        tau_adj[interval, y] = (
            np.random.uniform(
                low=tau_ranges[0], high=tau_ranges[1], size=readout_size[y]
            )
            * (-1) ** sign_mat
        )

        # Max number of variables to marginalize
        max_marginalized = readout_size[y] - 1

        # Sample number of variables to marginalize
        n_marginalized = np.random.randint(
            0, int(max_marginalized * marginalize_ratio) + 1
        )

        # Sample the variables to marginalize (excluding the last one)
        marginalized = np.random.choice(
            max_marginalized, n_marginalized, replace=False
        )
        # Set the abstraction to zero
        tau_adj[interval, y][marginalized] = 0.0

    return tau_adj


def sample_linear_abstracted_models(
    abs_nodes: int,
    abs_edges: int,
    abs_type: str,
    min_readout: int = 2,
    max_readout: int = 5,
    alpha: float = 1.0,
    marginalize_ratio: float = 0.2,
    internal: bool = True,
    ignored: bool = True,
    style: str = "dirichlet",
) -> Tuple[AdjacencyMatrix, AdjacencyMatrix, np.ndarray, np.ndarray, list]:
    """
    The function samples a pair of ANMs that are
    abstracted by a randomly sampled abstraction
    linear function. It returns the adjacencies
    of the abstract and concrete models, the
    endogenous abstraction function and the
    exogenous abstraction function.

    Parameters:
    -----------
    abs_nodes: int
        The number of abstract variables.
    abs_edges: int
        The number of edges in the abstract model.
    abs_type: str
        The type of abstract graph to sample (ER, SF, BP).
    min_readout: int
        The minimum number of concrete variables
        per abstract variable.
    max_readout: int
        The maximum number of concrete variables
        per abstract variable.
    marginalize_ratio: float
        The ratio of concrete variables to marginalize
        within each readout. The number of marginalized
        variables is uniformly sampled in the range
        [0, floor(readout_size * marginalize_ratio)]
    internal: bool
        If True, the internal connections are sampled
        as well. Otherwise, the blocks are considered
        to be internally disconnected.

    Returns:
    --------
    concrete_weighted: AdjacencyMatrix
        The adjacency matrix of the concrete model.
    abstract_weighted: AdjacencyMatrix
        The adjacency matrix of the abstract model.
    tau_adj: np.ndarray
        The endogenous abstraction function.
    gamma_adj: np.ndarray
        The exogenous abstraction function.
    partitions: list
        The list of partitions of the abstract variables.
    """

    # Sample the abstract model
    abs_weights = sample_linear_anm(abs_nodes, abs_edges, abs_type)

    # Sample the abstraction
    tau_adj = sample_linear_abstraction(
        abs_nodes, min_readout, max_readout, marginalize_ratio
    )
    cnc_nodes = tau_adj.shape[0]

    # Measure the readouts
    readout_size = []
    for y in range(abs_nodes):
        start = sum(readout_size)
        last = None
        for x in range(cnc_nodes):
            if tau_adj[x, y] != 0.0:
                last = x
        assert last is not None
        readout_size.append(last + 1 - start)

    n_ignored = cnc_nodes - sum(readout_size)
    readout_size.append(n_ignored)

    # Sample the concretization
    cnc_weights, gamma_adj = sample_linear_concretization(
        abs_weights,
        tau_adj,
        readout_size,
        internal=internal,
        alpha=alpha,
        style=style,
    )

    # eventually remove ignored variables
    if not ignored:
        tau_adj = tau_adj[:-n_ignored, :]
        gamma_adj = gamma_adj[:-n_ignored, :]
        readout_size = readout_size[:-1]
        cnc_weights = cnc_weights[:-n_ignored, :-n_ignored]

    return cnc_weights, abs_weights, tau_adj, gamma_adj, readout_size


def sample_linear_realizations(
    cnc_weights: AdjacencyMatrix,
    abs_weights: AdjacencyMatrix,
    tau: np.ndarray,
    gamma: np.ndarray,
    n_samples: int = 1000,
    noise_term: str = "gaussian",
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Given two abstracted linear ANMs and their abstraction
    function, it returns a dataset of samples from the
    concrete and abstract models in the observational distribution
    and the random interventional distributions.
    """

    # get the number of nodes
    cnc_nodes = cnc_weights.shape[0]

    # Concrete exogenous distribution
    def sample_exogenous() -> np.array:
        if noise_term == "gaussian":
            samples_e = np.random.normal(size=(n_samples, cnc_nodes))
        elif noise_term == "exponential":
            samples_e = np.random.exponential(size=(n_samples, cnc_nodes))
        elif noise_term == "gumbel":
            samples_e = np.random.gumbel(size=(n_samples, cnc_nodes))
        elif noise_term == "uniform":
            samples_e = np.random.uniform(size=(n_samples, cnc_nodes))
        elif noise_term == "logistic":
            samples_e = np.random.logistic(size=(n_samples, cnc_nodes))
        else:
            raise ValueError(f"Unknown noise_term type {noise_term}")

        return samples_e

    # Sample exogenous
    samples_e = sample_exogenous()

    # Compute endogenous
    samples_x = linear_anm(cnc_weights, samples_e)

    # Abstract
    samples_y = samples_x @ tau

    # Check consistency
    samples_y_bis = linear_anm(abs_weights, samples_e @ gamma)
    if not np.allclose(samples_y, samples_y_bis):
        norm2error = np.linalg.norm(samples_y - samples_y_bis, ord=2)
        print(f"WARNING: Consistency error {norm2error}.")

    dataset = (samples_x, samples_y)

    return dataset
