#!/usr/bin/env python
# coding: utf-8

import networkx as nx
from graph_utils import *
from utils import *
import json
import time
import subprocess
import os
from argparse import ArgumentParser

# get simulation arguments from command line
parser = ArgumentParser()
parser.add_argument('N_VERTICES', help = 'Number of vertices in ground-truth graph', type = int)
parser.add_argument('P_DEGREE', help = 'Probability of degree commission in ground-truth graph', type = float)
parser.add_argument('P_OVERLAP', help = 'How much the subgraphs of ground-truth overlap with each other, bounded (0,1)', type = float)
parser.add_argument('N_SUBGRAPHS', help = 'Number of subgraphs (of ~equal size) created from ground truth', type = int)
parser.add_argument('DENSITY_THRESHOLD', help = 'If set, any runs with average subgraph density > SKIP_DENSE will be skipped and re-run; to not set, set to 1', type = float)
parser.add_argument('TIME_LIMIT', help = 'Time limit (in seconds) for Clingo to solve problem, script will take about twice as long', type = int)
parser.add_argument('SLURM_ARRAY_ID', help = 'Provides index of this run in Slurm job array, only used for output filename', type = int)
parser.add_argument('RESULT_DIRECTORY', help = 'RELATIVE path within project directory to output files into', type = str)
args = parser.parse_args()

# # CONSTANTS
TIME_LIMIT = args.TIME_LIMIT

USER_ID = 'USER'
USER_DIR = os.environ.get('USER_DIR', f'/directory/for/{USER_ID}') # latter value is a default if environment variable not set (set in Slurm jobsubmit)
PROJECT_DIR = os.path.join(USER_DIR, 'PROJECT_NAME')
RESULT_DIRECTORY = os.path.join(PROJECT_DIR, args.RESULT_DIRECTORY)

# ground truth settings -- set from arguments
N_VERTICES = args.N_VERTICES
P_DEGREE = args.P_DEGREE
P_OVERLAP = args.P_OVERLAP
SLURM_ARRAY_ID = args.SLURM_ARRAY_ID

# graph split settings
N_SUBGRAPHS = args.N_SUBGRAPHS

# skip dense subgraphs? NOT USED IN PAPER
if args.DENSITY_THRESHOLD < 1.0 and args.DENSITY_THRESHOLD > 0.0:
    SKIP_DENSE = True
    DENSITY_THRESHOLD = args.DENSITY_THRESHOLD
else:
    SKIP_DENSE = False
    DENSITY_THRESHOLD = 1.0

# path to clingo ION code and output path for clingo problem
clingo_code = os.path.join(PROJECT_DIR, 'clingo/clingo_code.txt')  # code where core clingo code is stored (w/o problem definition), to be appended in problem_definition
full_clingo_program = os.path.join(PROJECT_DIR, './clingo/ION_problem.lp') # .lp file of full clingo program, to be run by clingo

# create scratch directory and designate temp file for clingo solutions
scratch_dir = os.path.join('/cache', USER_ID, f'{SLURM_ARRAY_ID}')
tmp_file = os.path.join('/cache', USER_ID, f'{SLURM_ARRAY_ID}', f'solutions_{SLURM_ARRAY_ID}.txt')

def single_ion(VERTICES: int, P_DEGREE: float, P_OVERLAP: float, N_SUBGRAPHS: int , skip_dense: bool) -> dict:
    '''
    Run from run_ion().
    Performs a single iteration of testing the ION algorithm.
    Creates a random DAG as specified by VERTICES, MAX_DEGREE, P_DEGREE.
    Splits graph according to N_SUBGRAPHS.
    Uses subprocess.run() to run ION code from Clingo on graphs.
    Parses Clingo's output, then analyzes solution and logs results.
    '''
    # Create random ground-truth DAG
    ground_truth = randomDAG(VERTICES, max_degree = VERTICES, p_degree = P_DEGREE, connected = True)
    assert nx.is_directed_acyclic_graph(ground_truth)

    # Display ground-truth
    # nx.draw(ground_truth, with_labels = True)
    # plt.title('Ground Truth Graph')

    # Pick overlapping nodes for subgraphs
    subgraphs = generate_subgraph_list(ground_truth, n_subgraphs=N_SUBGRAPHS, p_overlap=P_OVERLAP)

    # Modify subgraphs of ground truth to be *causally faithful*
    accurate_subgraphs = causally_accurate_subgraphs(subgraphs, ground_truth)

    # NOT USED IN PAPER
    # if skip_dense, then check density and if not valid, try again.
    if skip_dense:
        too_dense = np.mean(list(map(nx.density, accurate_subgraphs))) > DENSITY_THRESHOLD
        while too_dense:
            # graph too dense, so generate again and check
            ground_truth = randomDAG(VERTICES, max_degree = VERTICES, p_degree = P_DEGREE, connected = True)
            assert nx.is_directed_acyclic_graph(ground_truth)
            subgraphs = generate_subgraph_list(ground_truth, n_subgraphs=N_SUBGRAPHS, p_overlap=P_OVERLAP)
            accurate_subgraphs = causally_accurate_subgraphs(subgraphs, ground_truth)
            too_dense = np.mean(list(map(nx.density, accurate_subgraphs))) > DENSITY_THRESHOLD

    # generates clingo program and writes to file
    ION_problem = problem_definition(ground_truth, accurate_subgraphs, clingo_file = clingo_code)

    # Run Clingo algorithm on problem and output to solution folder
    pre = time.time()

    # runs clingo on ION_problem string (input keyword passes string as STDIN), writes solution to tmp_file
    with open(tmp_file, 'w') as fp:
        _ = subprocess.run(['clingo', '-W', 'no-atom-undefined', '-configuration=tweety', '-n','0', '-t', '16', 
                             f'--time-limit={TIME_LIMIT}'], input = ION_problem, text = True, stdout=fp)
    post = time.time()

    # calculate quantities of interest
    params = {'vertices' : VERTICES, 'p_degree' : P_DEGREE, 'p_overlap' : P_OVERLAP, 'n_subgraphs' : N_SUBGRAPHS}
    log_result = calculate_log_result(params, tmp_file, ground_truth, post - pre, merge_directions=True)

    # log out entire graphs into a file per parameterization (that's why mode = 'a')
    graph_file = os.path.join(RESULT_DIRECTORY, f"graph_file_{N_VERTICES}_{str(P_DEGREE).replace('.', 'p')}_{str(P_OVERLAP).replace('.', 'p')}_{N_SUBGRAPHS}_{str(DENSITY_THRESHOLD).replace('.', 'p')}_{TIME_LIMIT}_{SLURM_ARRAY_ID}.txt")
    with open(graph_file, 'a') as fp:
        fp.write(f'ground truth: {list(ground_truth.edges)}\tsubgraphs: {[list(graph.edges) for graph in accurate_subgraphs]}')

    # remove (possibly very large) ION-C output file
    if os.path.isfile(tmp_file):
        os.remove(tmp_file)
    return log_result


def run_ion() -> None:
    '''
    Runs single iteration of ION with settings given from arguments, outputs results to JSON file in data directory.
    '''
    log_result = single_ion(N_VERTICES, P_DEGREE, P_OVERLAP, N_SUBGRAPHS, SKIP_DENSE)
    fname = os.path.join(RESULT_DIRECTORY, f"ion_result_{N_VERTICES}_{str(P_DEGREE).replace('.', 'p')}_{str(P_OVERLAP).replace('.', 'p')}_{N_SUBGRAPHS}_{str(DENSITY_THRESHOLD).replace('.', 'p')}_{TIME_LIMIT}_{SLURM_ARRAY_ID}.json")
    with open(fname, 'w') as fp:
        json.dump(log_result, fp, indent = 3)

run_ion()
