import networkx as nx
import numpy as np
from graph_utils import *
from itertools import combinations
from collections import defaultdict
import re
import pandas as pd
import math
import io

def generate_subgraph_list(ground_truth, n_subgraphs, p_overlap) -> list:
    '''
    Takes in a ground truth nx.DiGraph and a number of desired subgraphs.
    Allocates every node to one of N subgraphs, then each subgraph samples N-1 of the nodes it does not currently contain.
    Returns a list of nx.DiGraphs of the *induced* subgraphs, not yet causally accurate.

    Parameters: 
        nx.DiGraph ground_truth: full ground-truth graph.
        int n_subgraphs: number of subgraphs to generate

    Returns:
        list subgraphs: a list of nx.Digraphs of *induced subgraphs*
    '''
    # pick sets of nodes -- currently, ~equal-sized random subsets of the nodes
    shuffled = np.random.permutation(ground_truth.nodes) # shuffle...
    node_lists = np.array_split(shuffled, n_subgraphs)

    modified_lists = []
    # sample additional nodes for each subgraph from complement of its current contents
    for node_list in node_lists:
        complement = list(set(shuffled) - set(node_list))
        modified_lists.append(np.concatenate([node_list, 
                                              np.random.choice(complement, size = math.ceil(p_overlap * len(complement)), replace = False)], 
                                             axis = None))
    
    subgraph_nodes = list(map(set, modified_lists)) # lists of nodes in each subgraph
    subgraphs = list(map(ground_truth.subgraph, subgraph_nodes)) # the subgraphs as NetworkX graphs
    return subgraphs

def unobserved_path_DFS(source_graph, unobserved, current_node, original_node) -> list:
    '''
    Recursive DFS for finding paths between observed nodes that consist of unobserved nodes.
    Basically, normal DFS but if the node is observed, stop and record a new edge.

    Parameters: 
        nx.DiGraph source_graph: Complete ground truth graph.
        set unobserved: Set of nodes not observed in the current subgraph
        int (or other NetworkX label) current_node: Label of current DFS node.
        int (or other NetworkX label) original_node: Label of root node of DFS.

    Returns:
        list changes: list of ordered tuples representing new edges to be drawn to subgraph.
    '''
    successors = list(source_graph.successors(current_node))
    changes = []
    for succ in successors:
        if succ not in unobserved:
            if current_node != original_node: # unless there's already an edge...
                # ...draw edge from original node to successor
                changes.append((original_node, succ))
        else:
            # recursively search children, and get any changes
            changes = changes + unobserved_path_DFS(source_graph, unobserved, succ, original_node)
    return changes

def unobserved_ancestors_DFS(source_graph, unobserved, current_node) -> set:
    '''
    DFS for finding the set of unobserved ancestors of a node that link to that node through only other unobserved nodes.
    Basically, normal DFS through ancestors, but if the node is observed, then stop, and if not, then add it to the output set.
    We will later compare the output sets across nodes in the subgraph -- if they overlap, then a bidirected edge will be drawn.

    Parameters: 
        nx.DiGraph source_graph: Complete ground truth graph.
        Set unobserved: Set of nodes not observed in the current subgraph
        Int (or other NetworkX label) current_node: Label of current DFS node.

    Returns:
        Set unobserved_ancestors: Set of unobserved ancestors in ground truth of a node that link to that node 
            through only other unobserved nodes.
    '''
    ancestors = list(source_graph.predecessors(current_node))
    unobserved_ancestors = set()
    for anc in ancestors:
        if anc in unobserved: # if not, then stop search here (and do nothing)
            unobserved_ancestors.add(anc) # valid common ancestor
            # keep searching ancestors recursively
            unobserved_ancestors = unobserved_ancestors.union(unobserved_ancestors_DFS(source_graph, unobserved, anc)) 

    return unobserved_ancestors

def causally_accurate_subgraphs(subgraphs, ground_truth) -> list:
    '''
    Modifies induced subgraphs to be faithful to ground truth graph.
    There are two types of edge commissions that need to be made:
        1. Direct unobserved paths: when two observed nodes in the subgraph have a directed path between them 
            that consists only of unobserved nodes. In this case, draw a directed edge between the nodes in the 
            same direction as the directed path. We search for these paths using unobserved_path_DFS(), which
            returns a list of new edges to draw.
        2. Unobserved common ancestors: when two observed nodes have a common ancestor that is connected to both 
            observed nodes by only unobserved nodes. If so, draw a bi-directed edge between the nodes. To prevent
            duplication, we use unobserved_ancestors_DFS() to find, for each node, the set of ancestors that connect 
            to that node through only other unobserved nodes. Then, for each pair of observed nodes, we can check
            for unobserved common ancestors by checking whether the respective sets for both nodes overlap. If so,
            draw a directed edge.

    Parameters:
        List subgraphs: list of nx.DiGraphs of the induced subgraphs of ground_truth.
        nx.DiGraph ground_truth: the original, presumed oracle, ground_truth graph.

    Returns:
        List modified_subgraphs: list of nx.DiGraphs of the now causally accurate subgraphs of ground_truth.
    '''
    modified_subgraphs = []
    for subgraph in subgraphs:
        subgraph_nodes = list(subgraph.nodes)
        unobserved_nodes = set(ground_truth.nodes) - set(subgraph.nodes)
    
        valid_ancestors = dict()
        new_edges = []
        for node in subgraph_nodes:
            # # direct unobserved paths
            # DFS at this node, but stop if the node is observed (and draw an edge to there)
            changes = unobserved_path_DFS(ground_truth, unobserved_nodes, node, node)
            for change in changes:
                # print(f'Added edge from {change[0]} to {change[1]}')
                new_edges.append((change[0], change[1]))
    
            # # unobserved common ancestors
            # for each node, find set of ancestors that connect to this node only through other unobserved nodes
            valid_ancestors[node] = unobserved_ancestors_DFS(ground_truth, unobserved_nodes, node)
    
        # now, use valid ancestor sets to make new bi-directed edge commissions
        for node_1, node_2 in combinations(subgraph_nodes, 2):
            if valid_ancestors[node_1].intersection(valid_ancestors[node_2]) and not ground_truth.has_edge(node_1, node_2) and not ground_truth.has_edge(node_2, node_1):
                # print(f'Adding bi-directed edge from {node_1} to {node_2}')
                new_edges.append((node_1, node_2))
                new_edges.append((node_2, node_1))
    
        # commit edge changes to copy
        subgraph_new = nx.DiGraph(subgraph)
        subgraph_new.add_edges_from(new_edges)
        modified_subgraphs.append(subgraph_new)

    return modified_subgraphs

def graph_to_clingo_input(graph, idx) -> str:
    '''
    Takes in a graph and its index in the list of graphs, and returns (as a string) a statement of the graph's edges in desired clingo format.

    Parameters:
        nx.DiGraph graph: nx.DiGraph of one of the causally-accurate subgraphs
        int idx: index of graph in list, used in clingo specification.

    Returns:
        str output: string representing specification of this graph as part of clingo problem
    '''
    output = ''

    output += f'% Edges in subgraph {idx}\n'
    bidirected_added = []
    for a,b in graph.edges:
        if graph.has_edge(b,a):
            if (a,b) in bidirected_added: # already added this bidirected edge
                continue
            output += f'bidirected ({a}, {b}, {idx}) .\n'
            bidirected_added.append((a,b))
            bidirected_added.append((b,a))
        else:
            output += f'edge ({a}, {b}, {idx}) .\n'

    output += f'\n% Absent edges in subgraph {idx}\n'
    for a,b in combinations(graph.nodes, 2):
        if (a,b) not in graph.edges and (b,a) not in graph.edges:
            output += f'nedge ({a}, {b}, {idx}) .\n'
            output += f'nedge ({b}, {a}, {idx}) .\n'
    
    output += f'\n% Nodes in subgraph {idx}\n'
    for node in graph.nodes:
        output += f'varin ({idx}, {node}) .\n'

    output += '\n'
    return output

def problem_definition(ground_truth, subgraphs, clingo_file = './clingo/clingo_code.txt') -> str:
    '''
    Takes in our list of NetworkX subgraphs and returns the full problem definition,
    consisting of graph edges from graph_to_clingo_input, a couple parameters, and the clingo program.

    Parameters: 
        nx.DiGraph ground_truth: ground truth graph we've been using this whole time
        List subgraphs: list of causally modified subgraphs returned from causally_accurate_subgraphs
        String clingo_file: contains text of the clingo ION code

    Returns:
        String of full ION problem definition, including problem specification and clingo code.
    '''
    output_str = ''
    for idx in range(len(subgraphs)):
        output_str += graph_to_clingo_input(subgraphs[idx], idx)
        output_str += '\n'

    output_str += f'#const g = {len(subgraphs) - 1}.\n'
    output_str += f'#const n = {max(ground_truth.nodes)}.\n' 
    output_str += f'node(0..n).\n\n'

    with open(clingo_file, 'r') as f_clingo:
        output_str += f_clingo.read()
    
    return output_str

def parse_edge_counts(clingo_file):
    '''
    Parses Clingo output. Grabs output file, checks lines that contain an answer,
    uses regex to parse those answers, and returns defaultdict of edge frequencies and # solutions

    Parameters: 
        String clingo_file: path to file containing output from clingo
    
    Returns:
        defaultdict edge_frequencies: dict with tuples representing edges as keys, and frequency of edge as value
        int n_candidate_graphs: number of solutions returned by ION-C
    '''
    edge_frequencies = defaultdict(int)
    n_candidate_graphs = 0
    with open(clingo_file) as outputfile:
        for line in outputfile:
            if line.startswith('Answer'):
                ans = next(outputfile)
                if ans:
                    edge_list = [(int(a), int(b)) for a,b in re.findall(r'\(([0-9]+),([0-9]+)\)', ans)]
                    n_candidate_graphs += 1
                    for edge in edge_list:
                        assert type(edge) is tuple
                        edge_frequencies[edge] += 1
    return edge_frequencies, n_candidate_graphs

def parse_stdout(sol_string: str):
    '''
    For use instead of parse_edge_counts when capturing clingo output directly instead of writing to file in between:
    Read in the clingo stdout as a string, and parse it in the same way.

    Parameters: 
        String sol_string: string containing output from clingo
    
    Returns:
        defaultdict edge_frequencies: dict with tuples representing edges as keys, and frequency of edge as value
        int n_candidate_graphs: number of solutions returned by ION-C
    '''
    edge_frequencies = defaultdict(int)
    n_candidate_graphs = 0
    output_str = io.StringIO(sol_string)
    for line in output_str:
        if line.startswith('edge('):
            edge_list = [(int(a), int(b)) for a,b in re.findall(r'\(([0-9]+),([0-9]+)\)', line)]
            n_candidate_graphs += 1
            for edge in edge_list:
                assert type(edge) is tuple
                edge_frequencies[edge] += 1
    return edge_frequencies, n_candidate_graphs

def calculate_log_result(params, result_file, ground_truth, runtime, merge_directions = True):
    '''
    Calculates everything we need to output to the log file about the output from Clingo.
    
    Parameters: 
        dict params: parameters of this simulation (for use in output)
        str result_file: path to file containing output from clingo
        nx.DiGraph ground_truth: ground truth graph
        float runtime: actual runtime of clingo algorithm
        bool merge_directions: whether to merge edges in both directions in calculating result statistics
    
    Returns:
        dict log_result: dict containing params, statistics of this clingo run
    '''
    ground_truth_edges = set(ground_truth.edges)

    # get frequency of each edge
    edge_frequencies, n_candidate_graphs = parse_edge_counts(result_file) 
    if not n_candidate_graphs:
        print(f'NO CANDIDATE GRAPHS: {n_candidate_graphs}')
        raise RuntimeError('no candidate graphs!')
    freq_df = pd.Series(edge_frequencies).reset_index()
    freq_df['edge'] = pd.Series(zip(freq_df['level_0'], freq_df['level_1']))
    freq_df = freq_df.drop(columns = ['level_0', 'level_1'])
    freq_df = freq_df.rename(columns = {0 : 'count'})

    # merge opposite directions (can be removed)
    # removes opposite direction from dataframe and adds its count to current
    if merge_directions:
        to_remove = []
        for i in freq_df.index:
            edge = freq_df.loc[i]['edge']
            opposite = edge[::-1]
            opp_df = freq_df[freq_df['edge'] == opposite]
            if i not in to_remove and len(opp_df):
                opposite_row = opp_df.iloc[0]
                freq_df.at[i, 'count'] += opposite_row['count']
                to_remove.append(opp_df.index[0])
        freq_df = freq_df.drop(index = to_remove)

    # now, frequency maps edges to their frequency among candidate graphs
    if merge_directions:
        freq_df['present_in_ground_truth'] = [(edge in ground_truth_edges) or (edge[::-1] in ground_truth_edges) for edge in freq_df['edge']]
    else:
        freq_df['present_in_ground_truth'] = [(edge in ground_truth_edges) for edge in freq_df['edge']]
    freq_df = freq_df.reset_index().rename(columns = {'level_0' : 'node_1', 'level_1' : 'node_2'})
    freq_df['absence_count'] = n_candidate_graphs - freq_df['count']
    # determine whether edge or absence more common
    freq_df['predicted_edge'] = freq_df['count'] > freq_df['absence_count']
    freq_df['same_proportion'] = freq_df[['count', 'absence_count']].max(axis = 1) / (n_candidate_graphs)
    freq_df['accurate'] = freq_df['predicted_edge'] == freq_df['present_in_ground_truth']

    # proportion shared in X%
    prop_same_75 = (freq_df['same_proportion'] > 0.75).mean()
    prop_same_90 = (freq_df['same_proportion'] > 0.90).mean()
    prop_same_100 = (freq_df['same_proportion'] >= 1.0).mean()

    # among those shared in X%, how often are accurate?
    prop_accurate_75 = (freq_df[freq_df['same_proportion'] >= 0.75]['accurate']).mean()
    prop_accurate_90 = (freq_df[freq_df['same_proportion'] >= 0.90]['accurate']).mean()
    prop_accurate_100 = (freq_df[freq_df['same_proportion'] >= 1.0]['accurate']).mean()

    # output result
    log_result = {'vertices' : params['vertices'],
                  'p_degree' : params['p_degree'],
                  'p_overlap' : params['p_overlap'],
                  'n_subgraphs' : params['n_subgraphs'], 
                  'runtime' : round(runtime, 5), 
                  'n_graphs' : round(n_candidate_graphs),
                  'prop_same_75' : round(prop_same_75, 5),
                  'prop_same_90' : round(prop_same_90, 5),
                  'prop_same_100' : round(prop_same_100, 5),
                  'prop_accurate_75' : round(prop_accurate_75, 5),
                  'prop_accurate_90' : round(prop_accurate_90, 5),
                  'prop_accurate_100' : round(prop_accurate_100, 5)
                  }
    return log_result
