import json
import subprocess
import os
import argparse

# ---  ---
# 
MODEL_CONFIG_PATH = "configs/rhinoedge/model_config_rhino_synthetic.json"
DATASET_CONFIG_PATH = "configs/dataset_config_temporal_causal_dataset.json"

def update_json_config(file_path, key, value):
    """、JSON"""
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            config_data = json.load(f)
        
        print(f"  - : '{file_path}'")
        # print(f"  - : '{key}': {config_data.get(key)}")
        config_data['model_hyperparams'][key] = value
        # print(f"  - : '{key}': {value}")

        with open(file_path, 'w', encoding='utf-8') as f:
            json.dump(config_data, f, indent=4)
        # print(f"  - 。")
        return True
    except FileNotFoundError:
        print(f"[] : {file_path}")
        return False
    except json.JSONDecodeError:
        print(f"[] JSON: {file_path}")
        return False

def run_command(command):
    """"""
    print(f"  - : {' '.join(command)}")
    try:
        #  check=True，（）， CalledProcessError
        subprocess.run(command, check=True, text=True, capture_output=False)
        print("  - 。")
        return True
    except FileNotFoundError:
        print(f"[]  '{command[0]}'.  Python 。")
        return False
    except subprocess.CalledProcessError as e:
        print(f"[] ，: {e.returncode}")
        print(f"  - : {' '.join(command)}")
        # print(f"  - :\n{e.stdout}\n{e.stderr}") # 
        return False

def main():
    """，"""
    parser = argparse.ArgumentParser(description="Causica。")
    parser.add_argument("dataset_name", type=str, help=" ( 'ER_ER_lag_2_dim_5_...')")
    parser.add_argument("mask_path", type=str, help=" ( '.../exist_mask_0.1.npy')")
    parser.add_argument("--device", type=str, default="gpu", choices=["gpu", "cpu"], help=" (gpu/cpu)")

    args = parser.parse_args()

    print("=" * 60)
    print(f":")
    print(f"  - : {args.dataset_name}")
    print(f"  - : {args.mask_path}")
    print("=" * 60)

    # 1. JSON
    # ：mask_path"None"，NoneJSON
    mask_value = args.mask_path if args.mask_path.lower() not in ["none", ""] else None
    if not update_json_config(MODEL_CONFIG_PATH, "exist_edges_mask_path", mask_value):
        print("[] ，。")
        return

    # 2. 
    command = [
        "python",
        "-m", "causica.run_experiment",
        args.dataset_name,  # 
        "--model_type", "rhino",
        "-dc", DATASET_CONFIG_PATH,
        "--model_config", MODEL_CONFIG_PATH,
        "-dv", args.device,
        "-c"
    ]

    # 3. 
    run_command(command)

    print("-" * 60)
    print("。")
    print("-" * 60)


if __name__ == "__main__":
    main()