import os
import glob
import json
from utils._utils import graph_types, noise_distributions, inference_scenarios
from utils._sanity_checks import expected_predictions, base_logs_dir, scenarios


def delete_tuning():

    base_dir = "/efs/tmp/causal-benchmark-logs/tuning"

    max_files = {
        "diffan": 12,
        "grandag": 9
    }

    for graph_type in graph_types():
        for noise_distr in noise_distributions():
            for scenario in inference_scenarios():
                scenario_dir = os.path.join(base_dir, graph_type, noise_distr, scenario)
                for scenario_param in os.listdir(scenario_dir):
                    scenario_param_dir = os.path.join(scenario_dir, scenario_param)
                    for method in ["diffan", "grandag"]:
                        method_dir = os.path.join(scenario_param_dir, method)
                        for data_config in os.listdir(method_dir):
                            data_config_dir = os.path.join(method_dir, data_config)
                            for dataset in os.listdir(data_config_dir):
                                dataset_dir = os.path.join(data_config_dir, dataset)
                                if os.path.exists(dataset_dir):
                                    files = glob.glob(os.path.join(dataset_dir, "*"))
                                    files.sort(key=os.path.getmtime, reverse=True) # most recent first
                                    n_files_to_keep = max_files[method]
                                    if len(files) > n_files_to_keep:
                                        for file in files[n_files_to_keep:]:
                                            os.remove(file)
                                            print(f"Removed: {file}")
                                    elif len(files) < n_files_to_keep: 
                                        print(f'missing tuning logs at {dataset_dir}')


def delete_grandag_inference():
    base_dir = base_logs_dir()
    method="grandag"
    for graph_type in ["FC", "GRP", "SF"]:
        for scenario in scenarios():
            scenario_dir = os.path.join(base_dir, graph_type, scenario)
            for scenario_param in os.listdir(scenario_dir):
                method_dir = os.path.join(scenario_dir, scenario_param, method)
                if os.path.exists(method_dir):
                    tmp_dir = os.path.join(method_dir, "tmp")
                    l = len(os.listdir(tmp_dir))
                    files = glob.glob(os.path.join(tmp_dir, "*"))
                    files.sort(key=os.path.getmtime, reverse=True) # most recent first
                    # Compare with the correct expected number of predictions
                    expected_count = expected_predictions(method, graph_type)
                    if l > expected_count:
                        delta = l - expected_count
                        for i in range(delta):
                            with open(files[i], "r") as f:
                                pred_path = json.load(f)["pred_location"]
                            print("Removing file")
                            os.remove(pred_path)
                            os.remove(files[i])
                else:
                    print(f"DOES NOT EXIST: {method_dir}\n")


if __name__ == '__main__':
    pass
    # delete_grandag_inference()