import datetime
import heapq
import networkx as nx
import numpy as np
import scipy as sp
import torch
import torch.nn as nn
import torch.nn.functional as F


def hash_time():
    return datetime.datetime.now().strftime("%Y-%m-%d-%H:%M:%S")

# Performance metrics
def pred_score(pred, truth):
    tp = np.sum((pred > 0) & (truth > 0))
    tn = np.sum((pred == 0) & (truth == 0))
    fp = np.sum((pred > 0) & (truth == 0))
    fn = np.sum((pred == 0) & (truth > 0))

    tpr = tp / (tp + fn)
    fpr = fp / (fp + tn)

    precision = tp / (tp + fp)
    recall = tpr
    f1 = 2 * precision * recall / (precision + recall)

    return precision, recall, f1


def shd_metric(pred, target):
    """
    Calculates the structural hamming distance

    Args:
        pred (:class:`numpy.ndarray`):
            The predicted adjacency matrix
        target (:class:`numpy.ndarray`):
            The true adjacency matrix

    Returns:
        shd
    """
    diff = target - pred
    rev = (((diff + diff.transpose()) == 0) & (diff != 0)).sum() / 2
    # Each reversed edge necessarily leads to one fp and one fn so we need to subtract those
    fn = (diff == 1).sum() - rev
    fp = (diff == -1).sum() - rev
    return sum([fp, fn, rev])


def eval_dag(pred, true):
    """
    Evaluate the DAG in dag_path against the true DAG in true_dag_path

    Args:
        pred (:class:`numpy.ndarray`):
            The predicted adjacency matrix
        true (:class:`numpy.ndarray`):
            The true adjacency matrix

    Returns:
        precision, recall, f1, shd

    """

    precision, recall, f1 = pred_score(pred, true)
    shd = shd_metric(pred, true)

    perf = {
        'Precision': precision,
        'Recall': recall,
        'F1': f1,
        'SHD': shd
    }
    return perf


# DAG metrics
def fas_lb(adj, thred=None):
    """Calculates the lower bound of the minimum feedback arc set

    Args:
        adj (np.ndarray): adjacency matrix
        thred (float, optional): threshold if adj is not binary. Defaults to None.

    Returns:
        int: lower bound of minimum feedback arc set
    """
    # strict but loose lower bound
    if isinstance(thred, float):
        adj = (adj > thred).astype(int) * (1 - np.eye(adj.shape[0]))
    else:
        adj = adj * (1 - np.eye(adj.shape[0]))
    g = nx.DiGraph(adj)
    # Find strongly connected components
    sccs = list(nx.strongly_connected_components(g))
    lower_bound = 0
    for scc in sccs:
        if len(scc) > 1:
            sub_g = g.subgraph(list(scc))
            
            # length-2 cycles must be removed
            count = 0
            edges = sub_g.edges
            visited = set()
            for src, tar in edges:
                if src in visited and tar in visited:
                    continue
                if tar > src and sub_g.has_edge(tar, src):
                    count += 1
                    if src not in visited:
                        visited.add(src)
                    if tar not in visited:
                        visited.add(tar)
            lower_bound += count + int(len(scc) > len(visited))

    return lower_bound


def fas_asym_lb(adj, thred=None):
    """Calculates an asymptomic lower bound of the minimum feedback arc set as in the paper
        Diamond, Harvey, Mark Kon, and Louise Raphael. "Asymptotic Lower Bounds for the Feedback Arc Set Problem in Random Graphs." 
        arXiv preprint arXiv:2409.16443 (2024).

    Args:
        adj (np.ndarray): adjacency matrix
        thred (float, optional): threshold if adj is not binary. Defaults to None.

    Returns:
        int: lower bound of minimum feedback arc set
    """
    # asymptotic lower bound
    if isinstance(thred, float):
        adj = (adj > thred).astype(int) * (1 - np.eye(adj.shape[0]))
    else:
        adj = adj * (1 - np.eye(adj.shape[0]))
    g = nx.DiGraph(adj)
    # Find strongly connected components
    sccs = list(nx.strongly_connected_components(g))
    lower_bound = 0
    for scc in sccs:
        if len(scc) > 1:
            sub_g = g.subgraph(list(scc))
            n_edge = sub_g.number_of_edges()
            n_vertex = sub_g.number_of_nodes()
            avg_degree = n_edge / n_vertex
            lower_bound += n_edge * (0.5 - np.sqrt(np.log(n_vertex)/2/avg_degree))
    return int(np.floor(lower_bound))


def fas_ub(adj, thred=None):
    """Calculates an upper bound of the minimum feedback arc set as in the paper
        Eades, Peter, Xuemin Lin, and William F. Smyth. "A fast and effective heuristic for the feedback arc set problem."
        Information processing letters 47.6 (1993): 319-323.


    Args:
        adj (np.ndarray): adjacency matrix
        thred (float, optional): threshold if adj is not binary. Defaults to None.

    Returns:
        int: upper bound of minimum feedback arc set
    """
    if isinstance(thred, float):
        adj = (adj > thred).astype(int) * (1 - np.eye(adj.shape[0]))
    else:
        adj = adj * (1 - np.eye(adj.shape[0]))
    g = nx.DiGraph(adj)
    # Find strongly connected components
    sccs = list(nx.strongly_connected_components(g))
    upper_bound = 0
    for scc in sccs:
        if len(scc) > 1:
            sub_g = g.subgraph(list(scc))
            n_edge = sub_g.number_of_edges()
            n_vertex = sub_g.number_of_nodes()
            upper_bound += n_edge / 2 - n_vertex / 6

    return int(np.ceil(upper_bound))


def els_alg(g):
    # remove self loops
    g.remove_edges_from(nx.selfloop_edges(g))

    s1 = []  # source
    s2 = []  # sink
    edges = list(g.edges)
    
    while g.number_of_nodes() > 0:
        # Find sink and source nodes
        sink = set()
        src = set()
        nodes = list(g.nodes)
        for v in nodes:  # sink
            if g.out_degree(v) == 0:
                s2.insert(0, v)
                sink.add(v)
                g.remove_node(v)
        nodes = list(g.nodes)
        for u in nodes:  # source
            if g.in_degree(u) == 0:
                s1.append(u)
                src.add(u)
                g.remove_node(u)

        if g.number_of_nodes() == 0:
            break
        # Select a node 
        next_src, max_delta_deg = -1, -np.inf
        for u in g.nodes:
            if g.out_degree(u) - g.in_degree(u) > max_delta_deg:
                max_delta_deg = g.out_degree(u) - g.in_degree(u)
                next_src = u
        s1.append(next_src)
        g.remove_node(next_src)

    # Count reversed edges
    s = s1 + s2
    rank = {}
    for idx, u in enumerate(s):
        rank[u] = idx
    del s, s1, s2
    n_rev = 0
    for src, tar in edges:
        if rank[src] > rank[tar]:
            n_rev += 1
    return n_rev
    
    
def fas_greedy(adj, thred=None):
    if isinstance(thred, float):
        adj = (adj > thred).astype(int) * (1 - np.eye(adj.shape[0]))
    g = nx.DiGraph(adj)
    # Find strongly connected components
    sccs = list(nx.strongly_connected_components(g))
    edit_dist = 0
    for scc in sccs:
        if len(scc) > 1:
            sub_g = g.subgraph(list(scc)).copy()
            edit_dist += els_alg(sub_g)
    return edit_dist

