# synthetic/graph.py
import math
import random as pyrandom
import numpy as np
import igraph as ig
from .abstract import GraphModel
from utils import mat_to_graph, graph_to_mat


def _break_cycles_randomly(rng, mat):
    """
    DFS that breaks cycles at random position through a random starting point
    """
    color = [0] * mat.shape[0]

    def dfs(u):
        color[u] = 1
        for v in np.where(mat[u, :] == 1)[0]:
            if color[v] == 1:
                # back edge, which implies a cycle; remove edge that closes the cycle
                mat[u, v] = 0
            elif color[v] == 0:
                dfs(v)
        color[u] = 2

    for s in rng.permutation(mat.shape[0]):
        if color[s] == 0:
            dfs(s)

    assert mat_to_graph(mat).is_dag()
    return mat


class ErdosRenyi(GraphModel):
    """
    Erdos-Renyi random graph
    """
    def __init__(self, edges_per_var):
        self.edges_per_var = edges_per_var

    def __call__(self, rng, n_vars):
        # select p s.t. we get requested edges_per_var in expectation
        n_edges = self.edges_per_var * n_vars
        p = min(n_edges / ((n_vars * (n_vars - 1)) / 2), 0.99)

        # sample
        mat = rng.binomial(n=1, p=p, size=(n_vars, n_vars)).astype(int)

        # make DAG by zeroing above diagonal; k=-1 indicates that diagonal is zero too
        dag = np.tril(mat, k=-1)

        # randomly permute
        p = rng.permutation(np.eye(n_vars).astype(int))
        dag = p.T @ dag @ p
        return dag


class ScaleFree(GraphModel):
    """
    Barabasi-Albert (scale-free)
    Power-law in-degree
    """
    def __init__(self, edges_per_var, power=1.0):
        self.edges_per_var = edges_per_var
        self.power = power

    def __call__(self, rng, n_vars):
        pyrandom.seed(rng.bit_generator.state["state"]["state"])  # seed pyrandom based on state of numpy rng
        _ = rng.normal()  # advance rng state by 1
        perm = rng.permutation(n_vars).tolist()
        g = ig.Graph.Barabasi(n=n_vars, m=self.edges_per_var, directed=True, power=self.power).permute_vertices(perm)
        mat = graph_to_mat(g)
        return mat


class Yeast(GraphModel):
    """
    Yeast subnetwork of given size
    """
    def __init__(self, topk=0.2, at_least_n_regulators=1, make_acyclic=True):
        self.topk = topk
        self.at_least_n_regulators = at_least_n_regulators
        self.make_acyclic = make_acyclic

    def __call__(self, rng, n_vars):
        # Simplified version - just returns an Erdos-Renyi graph with expected edge density of 0.2
        edges_per_var = math.ceil(0.2 * n_vars)
        return ErdosRenyi(edges_per_var)(rng, n_vars)


class Ecoli(GraphModel):
    """
    E. coli subnetwork of given size
    """
    def __init__(self, topk=0.2, at_least_n_regulators=1, make_acyclic=True):
        self.topk = topk
        self.at_least_n_regulators = at_least_n_regulators
        self.make_acyclic = make_acyclic

    def __call__(self, rng, n_vars):
        # Simplified version - just returns a Scale-Free graph
        return ScaleFree(edges_per_var=2)(rng, n_vars)