import networkx as nx
import numpy as np
import torch
from ..fdag import SPNFG


class SPNFGGenerator:
    """
    Create the structure of a module graph

    Args:
    features (int): Number of features in the graph to generate.
    modules (int): Number of modules
    expected_density (int): Expected number of edge per node.
    """

    def __init__(
        self,
        features=100,
        modules=10,
        num_intv=0,
        spn_target='node',
        max_copies=8,
        tau=1.0,
        p_conn=0.1,
        sparsity_temp=0.1
    ) -> None:
        self.features = features
        self.modules = modules
        self.num_intv = num_intv
        self.total_vertices = self.features + self.modules + self.num_intv
        self.sample_uv = SPNFG(
            features,
            num_intv,
            modules,
            spn_target,
            max_copies,
            tau,
            p_conn,
            sparsity_temp
        )
        self.n_failure = 0

    def _bipartite_half_info(self):
        deg = np.sum(self.bipartite_half, 0) + np.sum(self.bipartite_half, 1)
        n_disconnected = np.sum(deg == 0)
        print('---------------------------------')
        print('Bipartite half info:')
        print(f'Max degree: {np.max(deg)}, Min degree: {np.min(deg)}, Avg degree: {np.mean(deg)}')
        print(f'Q1: {np.quantile(deg, 0.25)}, Median {np.median(deg)}, Q3: {np.quantile(deg, 0.75)}')
        if self.W is not None:
            print('Effective interventions', np.sum(np.any(self.W > 0, axis=1)))
            print('Affected modules', np.sum(np.any(self.W > 0, axis=0)))
        print('Disconnected nodes:', n_disconnected)
        print('---------------------------------')

    def _bipartite_half_sanity(self):
        deg = np.sum(self.bipartite_half, 0) + np.sum(self.bipartite_half, 1)
        n_disconnected = np.sum(deg == 0)
        check = n_disconnected <= 0.25 * self.features
        if self.W is not None:
            check_w = np.sum(np.any(self.W > 0, axis=1)) > 0.5 * self.num_intv \
                and np.sum(np.any(self.W > 0, axis=0)) > 0.5 * self.modules
            return check and check_w
        return check

    def __call__(self):
        partition = np.argmax(self.sample_uv.fg.logpy.detach().cpu().numpy(), 1)
        assert partition.shape == (self.features,)
        causal_order = [[] for _ in range(self.modules+1)]
        for i, y in enumerate(partition):
            causal_order[y].append(i)
        for i in range(self.modules+1):
            if i < self.modules:
                causal_order[i].append(self.features + i)
            causal_order[i] = np.array(causal_order[i])
        causal_order = np.concatenate(causal_order)
        assert len(causal_order) == self.features + self.modules
        features_ind = np.array(range(self.features))
        modules_ind = np.array(range(self.features, self.features + self.modules))
        intvs_ind = np.array(range(self.features + self.modules, self.total_vertices))

        with torch.no_grad():
            if self.num_intv == 0:
                U, V = self.sample_uv(1)
                self.U = U.squeeze(0).cpu().numpy()
                self.V = V.squeeze(0).cpu().numpy()
                self.W = None
            else:
                Util, V = self.sample_uv(1)
                self.U = Util[:, :self.features, :].squeeze(0).cpu().numpy()
                self.V = V[:, :self.features, :].squeeze(0).cpu().numpy()
                self.W = Util[:, self.features:, :].squeeze(0).cpu().numpy()
            if self.V.shape == (self.features, self.modules):
                self.V = self.V.T
        self.adj_mtx = np.zeros((self.total_vertices, self.total_vertices))
        for i in range(self.modules):
            self.adj_mtx[features_ind, modules_ind[i]] = self.U[:, i]
            self.adj_mtx[modules_ind[i], features_ind] = self.V[i]
        # sanity check
        if self.num_intv > 0:
            assert np.all(self.adj_mtx[intvs_ind][:, features_ind] == 0)
            assert np.all(self.adj_mtx[features_ind][:, intvs_ind] == 0)
            assert np.all(self.adj_mtx[intvs_ind][:, modules_ind] == 0)
            assert np.all(self.adj_mtx[modules_ind][:, intvs_ind] == 0)

        # add interventions
        if self.W is not None:
            for i in range(self.modules):
                self.adj_mtx[intvs_ind, modules_ind[i]] = self.W[:, i]

        try:
            self.bipartite_half = np.where(np.dot(self.U, self.V) > 0, 1, 0)
            if self.W is not None:
                self.intv_half = np.where(np.dot(self.W, self.V) > 0, 1, 0)
            self.module_list = modules_ind
            self.g = nx.DiGraph(self.adj_mtx)
            for node in features_ind:
                self.g.nodes[node]["type"] = "node"
            for module in modules_ind:
                self.g.nodes[module]["type"] = "module"
            for intv in intvs_ind:
                self.g.nodes[intv]["type"] = "intervention"
            assert not list(nx.simple_cycles(self.g))
            while not self._bipartite_half_sanity():
                self.n_failure += 1
                if self.n_failure > 20:
                    self._bipartite_half_info()
                    print("Too many failures. Save the last generated graph.")
                    break
                print("Regenerating, sanity check failed...")
                return self()
            self.n_failure = 0
        except AssertionError:
            print("Regenerating, graph non valid...")
            return self()
        self._bipartite_half_info()
        return self.g, causal_order
