# Utility functions for sanity check on the inference results
import os
import glob
from utils._utils import inference_scenarios

# ------------------------- "Fixtures" ------------------------- #
def noise_distributions():
    x = ["gauss", "nonlin_weak", "nonlin_mid", "nonlin_strong"]
    return x

def graph_types():
    x = ["ER", "FC", "GRP", "SF"]
    return x

def scenarios():
    x = ["vanilla", "linear", "measure_err", "confounded", "timino", "unfaithful"]
    return x

def methods():
    x = ["cam", "das", "nogam", "score", "ges", "lingam", "pc", "resit", "random", "grandag", "diffan"]
    return x

def expected_predictions(method, graph_type):
    """The number of expected files for the current method in the prediction folder
    """
    if graph_type == "ER":
        expected_count = {
            "cam" : 2560,
            "das" : 2560,
            "nogam" : 2560,
            "score" : 2560,
            "ges" : 1920,
            "lingam" : 640,
            "pc" : 1920,
            "resit" : 2560,
            "random" : 1280,
            "grandag" : 2560,
            "diffan" : 2560,
        }
    elif graph_type == "SF":
        expected_count = {
            "cam" : 1920,
            "das" : 1920,
            "nogam" : 1920,
            "score" : 1920,
            "ges" : 1280,
            "lingam" : 480,
            "pc" : 1280,
            "resit" : 1920,
            "random" : 960,
            "grandag" : 1920,
            "diffan" : 1920,
        }
    
    elif graph_type == "GRP":
        expected_count = {
            "cam" : 960,
            "das" : 960,
            "nogam" : 960,
            "score" : 960,
            "ges" : 640, 
            "lingam" : 240,
            "pc" : 640,
            "resit" : 960,
            "random" : 480,
            "grandag" : 960,
            "diffan" : 960,
        }

    if graph_type == "FC":
        expected_count = {
            "cam" : 1280,
            "das" : 1280,
            "nogam" : 1280,
            "score" : 1280,
            "ges" : 960,
            "lingam" : 320,
            "pc" : 960,
            "resit" : 1280,
            "random" : 640,
            "grandag" : 1280,
            "diffan" : 1280,
        }

    return expected_count[method]


def tuning_expected_runs(method):
    if method == "diffan":
        return 12
    elif method == "grandag":
        return 9


def base_logs_dir(inference=True):
    if inference:
        return os.path.join(os.sep, "efs", "tmp", "causal-benchmark-logs", "inference")
    else:
        return os.path.join(os.sep, "efs", "tmp", "causal-benchmark-logs", "tuning")



# ------------------------- Processing functions ------------------------- #
def count_predictions():
    base_dir = base_logs_dir()
    for graph_type in graph_types():
        for scenario in scenarios():
            scenario_dir = os.path.join(base_dir, graph_type, scenario)
            for scenario_param in os.listdir(scenario_dir):
                for method in methods():
                    method_dir = os.path.join(scenario_dir, scenario_param, method)
                    if os.path.exists(method_dir):
                        preds_dir = os.path.join(method_dir, "predictions")
                        l = len(os.listdir(preds_dir))

                        # Compare with the correct expected number of predictions
                        expected_count = expected_predictions(method, graph_type)
                        if l != expected_count:
                            error_dir = os.path.join(graph_type, scenario, scenario_param, method)
                            print(f"ERROR: {l}/{expected_count} predictions in {error_dir}")
                    else:
                        print(f"DOES NOT EXIST: {method_dir}\n")


def tuning_fails():
    base_dir = base_logs_dir(inference=False)
    # for graph_type in graph_types():
    for graph_type in graph_types():
        for noise in noise_distributions():
            for scenario in scenarios():
                scenario_dir = os.path.join(base_dir, graph_type, noise,scenario)
                for scenario_param in os.listdir(scenario_dir):
                    for method in ["grandag", "diffan"]:
                        method_dir = os.path.join(scenario_dir, scenario_param, method)
                        for data_config in os.listdir(method_dir):
                            data_config_dir = os.path.join(method_dir, data_config)
                            for dataset in os.listdir(data_config_dir):
                                dataset_dir = os.path.join(data_config_dir, dataset)
                                l = len(os.listdir(dataset_dir))
                                if l < tuning_expected_runs(method):
                                    error_dir = os.path.join(graph_type, scenario, scenario_param, method)
                                    print(f"ERROR: {l}/{tuning_expected_runs(method)} predictions in {error_dir}")


def check_tuning_results():
    base_dir = "/efs/tmp/causal-benchmark-logs/tuning"

    max_files = {
        "diffan": 12,
        "grandag": 9
    }

    for graph_type in graph_types():
        for noise_distr in noise_distributions():
            for scenario in inference_scenarios():
                scenario_dir = os.path.join(base_dir, graph_type, noise_distr, scenario)
                for scenario_param in os.listdir(scenario_dir):
                    scenario_param_dir = os.path.join(scenario_dir, scenario_param)
                    for method in ["diffan", "grandag"]:
                        method_dir = os.path.join(scenario_param_dir, method)
                        for data_config in os.listdir(method_dir):
                            data_config_dir = os.path.join(method_dir, data_config)
                            for dataset in os.listdir(data_config_dir):
                                dataset_dir = os.path.join(data_config_dir, dataset)
                                if os.path.exists(dataset_dir):
                                    files = glob.glob(os.path.join(dataset_dir, "*"))
                                    files.sort(key=os.path.getmtime, reverse=True) # most recent first
                                    n_files_to_keep = max_files[method]
                                    if len(files) > n_files_to_keep:
                                        print(f'Excess of tuning logs at {dataset_dir}')
                                    elif len(files) < n_files_to_keep:
                                        print(f'Missing of tuning logs at {dataset_dir}')


def find_exit_code_137(err_dir):
    path2err_logs = []
    counter = 0
    for file in os.listdir(err_dir):
        with open(os.path.join(err_dir, file), "r") as f:
            lines = f.readlines()
            if len(lines) > 0 and "exit code 137" in lines[-1].lower():
                path2err_logs.append(file)
        counter  += 1
    
    path2err_logs.sort(key=lambda x: int(x.split("-")[0]))
    print(f"Other types of error: {path2err_logs}")
    print(f"Number of fails: {len(path2err_logs)}/{counter}")


def find_slurm_fails(err_dir):
    path2err_logs = []
    counter = 0
    for file in os.listdir(err_dir):
        with open(os.path.join(err_dir, file), "r") as f:
            lines = f.readlines()
            if len(lines) > 0 and\
                any(["lockfile" not in l for l in lines]) and\
                any(["GPU" not in l for l in lines]):
                path2err_logs.append(file)
        counter  += 1

    path2err_logs.sort(key=lambda x: int(x.split("-")[0]))
    print(f"Other types of error: {path2err_logs}")
    print(f"Number of fails: {len(path2err_logs)}/{counter}")