import numpy as np
import networkx as nx
from causallearn.utils.cit import CIT
from scipy.stats import chi2
from itertools import combinations, chain, permutations, product
import copy
from tqdm import tqdm
import sys
import pandas as pd

N_8 = int(sys.argv[1])
N_9 = int(sys.argv[2])
N_int = int(sys.argv[3])
alpha = float(sys.argv[4])
PROBLEM_FILE = sys.argv[5]
SLURM_ID = sys.argv[6]

USER_ID = 'USERID'
USER_DIR = f'/cache/{USER_ID}/'
OUTPUTFILE = f'/path/to/log/' # log file

def powerset(iterable):
    "powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)"
    s = list(iterable)
    return chain.from_iterable(combinations(s, r) for r in range(len(s)+1))

def graph_to_clingo_input(graph, idx):
    '''
    Takes in a graph and its index in the list of graphs, and returns (as a string) a statement of the graph's edges in desired clingo format.

    Parameters:
        nx.DiGraph graph: nx.DiGraph of one of the causally-accurate subgraphs
        int idx: index of graph in list, used in clingo specification.

    Returns:
        str output: string representing specification of this graph as part of clingo problem
    '''
    output = ''

    output += f'% Edges in subgraph {idx}\n'
    bidirected_added = []
    for a,b in graph.edges:
        if graph.has_edge(b,a):
            if (a,b) in bidirected_added: # already added this bidirected edge
                continue
            output += f'bidirected ({a}, {b}, {idx}) .\n'
            bidirected_added.append((a,b))
            bidirected_added.append((b,a))
        else:
            output += f'edge ({a}, {b}, {idx}) .\n'

    output += f'\n% Absent edges in subgraph {idx}\n'
    for a,b in combinations(graph.nodes, 2):
        if (a,b) not in graph.edges and (b,a) not in graph.edges:
            output += f'nedge ({a}, {b}, {idx}) .\n'
            output += f'nedge ({b}, {a}, {idx}) .\n'
    
    output += f'\n% Nodes in subgraph {idx}\n'
    for node in graph.nodes:
        output += f'varin ({idx}, {node}) .\n'

    output += '\n'
    return output

def problem_definition(ground_truth, subgraphs, col_names, clingo_file = './clingo/clingo_code.txt'):
    '''
    Takes in our list of NetworkX subgraphs and returns the full problem definition,
    consisting of graph edges from graph_to_clingo_input, a couple parameters, and the clingo program.

    Parameters: 
        nx.DiGraph ground_truth: ground truth graph we've been using this whole time
        List subgraphs: list of causally modified subgraphs returned from causally_accurate_subgraphs
        String clingo_file: contains text of the clingo ION code
        String output_file: file to write problem to. If None, return string.

    Returns:
        If output_file is None, then returns string of problem definition.
        Else, returns None.
    '''
    output_str = ''

    for col_type, col_list in col_names:
        output_str += f'% {col_type}: {col_list}\n'
    output_str += '\n'

    for idx in range(len(subgraphs)):
        output_str += graph_to_clingo_input(subgraphs[idx], idx)
        output_str += '\n'

    output_str += f'#const g = {len(subgraphs) - 1}.\n'
    # output_str += f'#const n = {max(ground_truth.nodes)}.\n' 
    output_str += f'#const n = {N_int + N_8 + N_9 - 1}.\n' 
    output_str += f'node(0..n).\n\n'

    with open(clingo_file, 'r') as f_clingo:
        output_str += f_clingo.read()
    return output_str

# import data (download from https://www.europeansocialsurvey.org/data-portal)
# the file we use contains all variables for all countries for ESS8 and 9
ess = pd.read_csv('ESS_cols.csv', low_memory=False)

# ess8 only: 
welfare_all = ['basinc', 'bnlwinc', 'gvslvue', 'sbeqsoc', 'sblazy', 'sbprvpv',
           'sbstrec', 'smdfslv', 'bennent', 'lbenent', 'dfincac', 'slvuemp',
           'gvslvol', 'gvcldcr', 'sbbsntx', 'sblwcoa', 'imsclbn', 'uentrjb',
           'eduunmp', 'wrkprbf', 'eusclbf', 'eudcnbf', 'lknemny'
          ]
#ess9 only:
justice_all = ['gvintcz', 'frprtpl', 'evfrjob', 'sofrdst', 'sofrwrk', 
           'sofrpr', 'ppldsrv', 'wltdffr', 'jstprev', 'pcmpinj', 'poltran',
           'ifrjob', 'evfredu', 'grspfr', 'netifr', 'occinfr', 'topinfr', 
           'btminfr', 'sofrprv'
          ]
#present in both
intersection_all = ['gndr', 'agea', 'eduyrs', 'lrscale', 'imptrad', 'ipbhprp',
                    'ipfrule', 'ipstrgv', 'iphlppl', 'ipcrtiv', 'polintr', 'psppsgva',
                    'actrolga', 'psppipla', 'cptppola', 'trstlgl', 'trstplc', 'trstplt',
                    'stflife', 'stfeco', 'stfgov', 'gincdif', 'imbgeco', 'imueclt',
                    'imwbcnt', 'happy', 'sclmeet', 'atchctr', 'atcherp', 'rlgblg',
                    'rlgdgr', 'dscrgrp', 'ctzcntr', 'brncntr', 'blgetmg', 'hinctnta',
                    'ipeqopt', 'impsafe', 'ipfrule', 'impfree', 'imptrad'
               ]

# removing response-missing codes for welfare
for col in ['basinc', 'bnlwinc', 'sbeqsoc', 'sblazy', 'sbprvpv', 'sbstrec',
            'smdfslv', 'bennent', 'lbenent', 'dfincac', 'sbbsntx', 'sblwcoa',
            'imsclbn', 'uentrjb', 'eduunmp', 'wrkprbf', 'eusclbf', 'eudcnbf',
            'lknemny'
           ]:
    ess[col] = ess[col].replace([7,8,9], np.nan)
for col in ['gvslvue', 'slvuemp', 'gvslvol', 'gvcldcr']:
    ess[col] = ess[col].replace([77,88,99], np.nan)

# removing response-missing codes for justice
for col in ['gvintcz', 'frprtpl', 'sofrdst', 'sofrwrk', 'sofrpr', 'ppldsrv', 'wltdffr',
            'jstprev', 'pcmpinj', 'poltran', 'ifrjob', 'topinfr', 'btminfr', 'sofrprv']:
    ess[col] = ess[col].replace([7,8,9], np.nan)
for col in ['evfrjob', 'evfredu']:
    ess[col] = ess[col].replace([77,88,99], np.nan)
for col in ['grspfr', 'netifr', 'occinfr']:
    ess[col] = ess[col].replace([6,7,8,9], np.nan)

# removing response-missing codes for intersection
for col in ['gndr', 'imptrad', 'ipbhprp', 'ipfrule', 'ipstrgv', 'iphlppl', 'ipcrtiv',
            'polintr', 'psppsgva', 'actrolga', 'psppipla', 'cptppola', 'gincdif',
            'rlgblg', 'dscrgrp', 'ctzcntr', 'brncntr', 'blgetmg', 'ipeqopt', 'impsafe',
            'ipfrule', 'impfree', 'imptrad'
           ]:
    ess[col] = ess[col].replace([7,8,9], np.nan)
for col in ['eduyrs', 'lrscale', 'trstlgl', 'trstplc', 'trstplt', 'stflife', 'stfeco',
            'stfgov', 'imbgeco', 'imueclt', 'imwbcnt', 'happy', 'sclmeet', 'atchctr',
            'atcherp', 'rlgdgr', 'hinctnta'
           ]:
    ess[col] = ess[col].replace([77,88,99], np.nan)
for col in ['agea']:
    ess[col] = ess[col].replace([999], np.nan)

total_nodes = N_8 + N_9 + N_int

# select columns at random
welfare_selected = list(np.random.choice(welfare_all, size = N_8, replace = False))
justice_selected = list(np.random.choice(justice_all, size = N_9, replace = False))
intersection_selected = list(np.random.choice(intersection_all, size = N_int, replace = False))
all_selected = intersection_selected + welfare_selected + justice_selected

# create some dicts for ease of use in following algorithm
name_to_int = {val:all_selected.index(val) for val in all_selected}
int_to_name = {i:all_selected[i] for i in range(len(all_selected))}
name_to_int_ess9 = {name: name_to_int[name] - N_8 for name in justice_selected} # different column index!
# add intersection variables to name_to_int_ess9
for i in range(N_int):
    name_to_int_ess9[int_to_name[i]] = i

with open(OUTPUTFILE, 'a') as fp:
    fp.write(f'{all_selected}\n')
    fp.write(str(name_to_int) + '\n')
    fp.write(str(name_to_int_ess9) + '\n\n')

# get subsets of data
ess8 = ess[ess['name'] == 'ESS8e02_3'][intersection_selected + welfare_selected]
ess9 = ess[ess['name'] == 'ESS9e03_2'][intersection_selected + justice_selected]


# construct a complete graph over both surveys, but with no edges between non-co-measured variables
pc_graph = nx.DiGraph(nx.complete_graph(N_int+N_8+N_9, nx.DiGraph()))
for i,j in combinations(all_selected, 2):
    if (i in welfare_selected and j in justice_selected) or (i in justice_selected and j in welfare_selected):
        pc_graph.remove_edge(name_to_int[i],name_to_int[j])
        pc_graph.remove_edge(name_to_int[j],name_to_int[i])

# for use in conditional independence testing
fisher_8 = CIT(ess8.to_numpy(), "mv_fisherz")
fisher_9 = CIT(ess9.to_numpy(), "mv_fisherz")

# a relatively inefficient implementation of PC

indep_results = []
# PC for intersection nodes
for X,Y in tqdm(combinations(all_selected, 2)):
    # if edge already gone, continue
    if not pc_graph.has_edge(name_to_int[X], name_to_int[Y]):
        continue
    if X not in intersection_selected and Y not in intersection_selected:
        # both variables not in intersection
        if X in welfare_selected and Y in welfare_selected:
            # both in ESS8
            for subset in powerset(range(N_int+N_8)):
                if name_to_int[X] in subset or name_to_int[Y] in subset:
                    continue
                p_val = fisher_8(name_to_int[X], name_to_int[Y], list(subset))
                if p_val == 0: p_val = 1e-16
                if -2 * np.log(p_val) < chi2.ppf(1-alpha,df=1):
                    # independent
                    pc_graph.remove_edge(name_to_int[X], name_to_int[Y])
                    pc_graph.remove_edge(name_to_int[Y], name_to_int[X])
                    with open(OUTPUTFILE, 'a') as fp:
                        fp.write(f'\nremoved: {(X,Y)}')
                    break
        elif X in justice_selected and Y in justice_selected:
            # both in ESS9
            for subset in powerset(range(N_int + N_9)):
                if name_to_int_ess9[X] in subset or name_to_int_ess9[Y] in subset:
                    continue
                p_val = fisher_9(name_to_int_ess9[X], name_to_int_ess9[Y], list(subset))
                if p_val == 0: p_val = 1e-16
                if -2 * np.log(p_val) < chi2.ppf(1-alpha,df=1):
                    # independent
                    pc_graph.remove_edge(name_to_int[X], name_to_int[Y])
                    pc_graph.remove_edge(name_to_int[Y], name_to_int[X])
                    with open(OUTPUTFILE, 'a') as fp:
                        fp.write(f'\nremoved: {(X,Y)}')
                    break
        else:
            # not co-measured, have already removed edge
            continue
    elif X in intersection_selected or Y in intersection_selected:
        # at least one node in intersection
        if X in intersection_selected and Y in intersection_selected:
            # both in intersection
            for subset in powerset(range(N_8 + N_int)):
                if name_to_int[X] in subset or name_to_int[Y] in subset:
                    continue
                p_val_8 = fisher_8(name_to_int[X], name_to_int[Y], list(subset))
                if p_val_8 == 0: p_val_8 = 1e-16
                p_val_9 = fisher_9(name_to_int[X], name_to_int[Y], list(subset))
                if p_val_9 == 0: p_val_9 = 1e-16
                if (-2 * np.log([p_val_8, p_val_9]).sum()) < chi2.ppf(1-alpha, df=2):
                    # independent 
                    pc_graph.remove_edge(name_to_int[X], name_to_int[Y])
                    pc_graph.remove_edge(name_to_int[Y], name_to_int[X])
                    with open(OUTPUTFILE, 'a') as fp:
                        fp.write(f'\nremoved: {(X,Y)}')
                    break
        elif X in welfare_selected or Y in welfare_selected:
            # both are comeasured in ESS8
            for subset in powerset(range(N_int+N_8)):
                if name_to_int[X] in subset or name_to_int[Y] in subset:
                    continue
                p_val = fisher_8(name_to_int[X], name_to_int[Y], list(subset))
                if p_val == 0: p_val = 1e-16
                if -2 * np.log(p_val) < chi2.ppf(1-alpha,df=1):
                    # independent
                    pc_graph.remove_edge(name_to_int[X], name_to_int[Y])
                    pc_graph.remove_edge(name_to_int[Y], name_to_int[X])
                    with open(OUTPUTFILE, 'a') as fp:
                        fp.write(f'\nremoved: {(X,Y)}')
                    break
        elif X in justice_selected or Y in justice_selected:
            # both are comeasured in ESS9
            for subset in powerset(range(N_int+N_9)):
                if name_to_int_ess9[X] in subset or name_to_int_ess9[Y] in subset:
                    continue
                p_val = fisher_9(name_to_int_ess9[X], name_to_int_ess9[Y], list(subset))
                if p_val == 0: p_val = 1e-16
                if -2 * np.log(p_val) < chi2.ppf(1-alpha,df=1):
                    # independent
                    pc_graph.remove_edge(name_to_int[X], name_to_int[Y])
                    pc_graph.remove_edge(name_to_int[Y], name_to_int[X])
                    with open(OUTPUTFILE, 'a') as fp:
                        fp.write(f'\nremoved: {(X,Y)}')
                    break

with open(OUTPUTFILE, 'a') as fp:
    fp.write('\n\nBEFORE ORIENTATION\n')
    fp.write(f'{pc_graph.edges()}\n\n')

# code adapted from https://www.stat.cmu.edu/~cshalizi/402/lectures/24-causal-discovery/lecture-24.pdf

# conduct independence tests
def indep(A,C,S):
    range_9 = chain(range(N_int), range(N_int+N_8, N_int+N_8+N_9))
    if A in range(N_int) and C in range(N_int) and all([s in range(N_int) for s in S]):
        # all comeasured in both graphs
        p_val_8 = fisher_8(A, C, S)
        if p_val_8 == 0: p_val_8 = 1e-16
        p_val_9 = fisher_9(A, C, S)
        if p_val_9 == 0: p_val_9 = 1e-16
        if (-2 * np.log([p_val_8, p_val_9]).sum()) < chi2.ppf(1-alpha, df=2):
            # independent 
            return True
    elif (A in range_9) and (C in range_9) and all([s in range_9 for s in S]):
        subset_int_9 = list(map(lambda x: name_to_int_ess9[int_to_name[x]], S))
        # all comeasured in ESS9
        A_9 = name_to_int_ess9[int_to_name[A]]
        C_9 = name_to_int_ess9[int_to_name[C]]
        p_val = fisher_9(A_9, C_9, subset_int_9)
        if p_val == 0: p_val = 1e-16
        if -2 * np.log(p_val) < chi2.ppf(1-alpha,df=1):
            return True
    elif (A in range(N_int+N_8)) and (C in range(N_int+N_8)) and all([s in range(N_int+N_8) for s in S]):
        # all comeasured in ESS8
        p_val = fisher_8(A, C, S)
        if p_val == 0: p_val = 1e-16
        if -2 * np.log(p_val) < chi2.ppf(1-alpha,df=1):
            return True
    return False

# identify colliders
def colliders(graph):
    graph = copy.deepcopy(graph)
    for A, B in graph.edges():
        B_neighbors = copy.deepcopy(nx.neighbors(graph, B))
        for C in B_neighbors:
            if A is not C and not graph.has_edge(A,C):
                collision = True
                for subset in powerset(graph.nodes):
                    if A not in subset and C not in subset and B in subset:
                        if indep(A,C,subset):
                            collision = False
                            break
                if collision:
                    if graph.has_edge(B,A):
                        graph.remove_edge(B, A)
                    graph.remove_edge(B, C)
    return graph

# orient edges
def orient(graph):
    graph = copy.deepcopy(graph)
    for A,B,C in permutations(graph.nodes, 3):
        if graph.has_edge(A, B) and graph.has_edge(B, C) and graph.has_edge(C, B) and not graph.has_edge(A, C) and not graph.has_edge(C,A):
            graph.remove_edge(C, B)
        if graph.has_edge(A,B) and graph.has_edge(B,A):
            graph_copy = copy.deepcopy(graph)
            graph_copy.remove_edge(A,B); graph_copy.remove_edge(B,A)
            if nx.has_path(graph_copy, A, B):
                graph.remove_edge(B,A)
    return graph

# full algorithm
def SGS(indep_graph):
    G_hat = colliders(indep_graph)
    while True: 
        G_prime = orient(G_hat)
        if set(G_hat.edges()) == set(G_prime.edges()):
            return G_hat
        else:
            G_hat = G_prime

SGS_sol = SGS(pc_graph)
with open(OUTPUTFILE, 'a') as fp:
    fp.write('AFTER ORIENTATION\n')
    fp.write(f'{SGS_sol.edges()}\n\n')

# split into two induced subgraphs
graph_1 = SGS_sol.subgraph([i for i in range(N_int+N_8)])
graph_2 = SGS_sol.subgraph([i for i in chain(range(N_int), range(N_int+N_8, N_int+N_8+N_9))])

def unobserved_path_DFS(source_graph, unobserved, current_node, original_node) -> list:
    '''
    Recursive DFS for finding paths between observed nodes that consist of unobserved nodes.
    Basically, normal DFS but if the node is observed, stop and record a new edge.

    Parameters: 
        nx.DiGraph source_graph: Complete ground truth graph.
        set unobserved: Set of nodes not observed in the current subgraph
        int (or other NetworkX label) current_node: Label of current DFS node.
        int (or other NetworkX label) original_node: Label of root node of DFS.

    Returns:
        list changes: list of ordered tuples representing new edges to be drawn to subgraph.
    '''
    successors = list(source_graph.successors(current_node))
    changes = []
    for succ in successors:
        if succ not in unobserved:
            if current_node != original_node: # unless there's already an edge...
                # ...draw edge from original node to successor
                changes.append((original_node, succ))
        else:
            # recursively search children, and get any changes
            changes = changes + unobserved_path_DFS(source_graph, unobserved, succ, original_node)
    return changes

def unobserved_ancestors_DFS(source_graph, unobserved, current_node) -> set:
    '''
    DFS for finding the set of unobserved ancestors of a node that link to that node through only other unobserved nodes.
    Basically, normal DFS through ancestors, but if the node is observed, then stop, and if not, then add it to the output set.
    We will later compare the output sets across nodes in the subgraph -- if they overlap, then a bidirected edge will be drawn.

    Parameters: 
        nx.DiGraph source_graph: Complete ground truth graph.
        Set unobserved: Set of nodes not observed in the current subgraph
        Int (or other NetworkX label) current_node: Label of current DFS node.

    Returns:
        Set unobserved_ancestors: Set of unobserved ancestors in ground truth of a node that link to that node 
            through only other unobserved nodes.
    '''
    ancestors = list(source_graph.predecessors(current_node))
    unobserved_ancestors = set()
    for anc in ancestors:
        if anc in unobserved: # if not, then stop search here (and do nothing)
            unobserved_ancestors.add(anc) # valid common ancestor
            # keep searching ancestors recursively
            unobserved_ancestors = unobserved_ancestors.union(unobserved_ancestors_DFS(source_graph, unobserved, anc)) 

    return unobserved_ancestors

def causally_accurate_subgraphs(subgraphs, ground_truth) -> list:
    '''
    Modifies induced subgraphs to be faithful to ground truth graph.
    There are two types of edge commissions that need to be made:
        1. Direct unobserved paths: when two observed nodes in the subgraph have a directed path between them 
            that consists only of unobserved nodes. In this case, draw a directed edge between the nodes in the 
            same direction as the directed path. We search for these paths using unobserved_path_DFS(), which
            returns a list of new edges to draw.
        2. Unobserved common ancestors: when two observed nodes have a common ancestor that is connected to both 
            observed nodes by only unobserved nodes. If so, draw a bi-directed edge between the nodes. To prevent
            duplication, we use unobserved_ancestors_DFS() to find, for each node, the set of ancestors that connect 
            to that node through only other unobserved nodes. Then, for each pair of observed nodes, we can check
            for unobserved common ancestors by checking whether the respective sets for both nodes overlap. If so,
            draw a directed edge.

    Parameters:
        List subgraphs: list of nx.DiGraphs of the induced subgraphs of ground_truth.
        nx.DiGraph ground_truth: the original, presumed oracle, ground_truth graph.

    Returns:
        List modified_subgraphs: list of nx.DiGraphs of the now causally accurate subgraphs of ground_truth.
    '''
    modified_subgraphs = []
    for subgraph in subgraphs:
        subgraph_nodes = list(subgraph.nodes)
        unobserved_nodes = set(ground_truth.nodes) - set(subgraph.nodes)
    
        valid_ancestors = dict()
        new_edges = []
        for node in subgraph_nodes:
            # # direct unobserved paths
            # DFS at this node, but stop if the node is observed (and draw an edge to there)
            changes = unobserved_path_DFS(ground_truth, unobserved_nodes, node, node)
            for change in changes:
                # print(f'Added edge from {change[0]} to {change[1]}')
                new_edges.append((change[0], change[1]))
    
            # # unobserved common ancestors
            # for each node, find set of ancestors that connect to this node only through other unobserved nodes
            valid_ancestors[node] = unobserved_ancestors_DFS(ground_truth, unobserved_nodes, node)
    
        # now, use valid ancestor sets to make new bi-directed edge commissions
        for node_1, node_2 in combinations(subgraph_nodes, 2):
            if valid_ancestors[node_1].intersection(valid_ancestors[node_2]) and not ground_truth.has_edge(node_1, node_2) and not ground_truth.has_edge(node_2, node_1):
                # print(f'Adding bi-directed edge from {node_1} to {node_2}')
                new_edges.append((node_1, node_2))
                new_edges.append((node_2, node_1))
    
        # commit edge changes to copy
        subgraph_new = nx.DiGraph(subgraph)
        subgraph_new.add_edges_from(new_edges)
        modified_subgraphs.append(subgraph_new)

    return modified_subgraphs

# check that these are DAGs, then apply same procedure as with synthetic data for 
# adjusting graphs once they are split up
assert nx.is_directed_acyclic_graph(graph_1) and nx.is_directed_acyclic_graph(graph_2):
graph_1, graph_2 = tuple(causally_accurate_subgraphs(subgraphs=[graph_1, graph_2], ground_truth=SGS_sol))

# get ION-C formulation of this problem in clingo
prob_pool = problem_definition(None, [graph_1, graph_2], col_names=[('welfare', welfare_selected), ('justice', justice_selected), ('intersection', intersection_selected)], 
                               clingo_file='./clingo_code.txt')

# write out to log, and file to run for problem
with open(OUTPUTFILE, 'a') as fp:
    fp.write(f'{prob_pool}')

with open(PROBLEM_FILE, 'w') as fp:
    fp.write(f'{prob_pool}')
