import yaml
from pathlib import Path
import platform
import torch
import json
import os

CONFIG_PATH = Path(__file__).parent / "config.yaml"

def get_env_description():
    """Returns a JSON object describing the environment information."""
    env_info = {}

    # Python version
    env_info["Python Version"] = platform.python_version()

    # CUDA information
    if torch.cuda.is_available():
        env_info["CUDA Available"] = True
        env_info["CUDA Version"] = torch.version.cuda
        env_info["GPU"] = torch.cuda.get_device_name(0)
    else:
        env_info["CUDA Available"] = False

    # CPU information
    env_info["CPU"] = platform.processor()

    # OS information
    env_info["os"] = f"{platform.system()} {platform.release()} ({platform.version()})"

    return json.dumps(env_info)


with open(CONFIG_PATH, 'r') as f:
    config = yaml.safe_load(f)

project_root = os.path.dirname(os.path.abspath(__file__))
config['paths']['abcrown_run_path'] = os.path.join(project_root, "alpha-beta-CROWN/complete_verifier/abcrown.py")
config['paths']['customized_models_paths'] = os.path.join(project_root, "models/model.py")
config['paths']['mnist_state_dict_path'] = os.path.join(project_root, "models/weights/3_layered_mnist_10K_976.pth")
config['paths']['cifar10-big_state_dict_path'] = os.path.join(project_root, "models/weights/cifar10_model.pth")
config['paths']['cifar10-small_state_dict_path'] = os.path.join(project_root, "models/weights/resnet2b.pth")
config['paths']['gtsrb_state_dict_path'] = os.path.join(project_root, "models/weights/gtsrb_cnn_relu_89_45.pth")
config['paths']['gtsrb_data_path'] = os.path.join(project_root, "data/gtsrbVerix/gtsrb.pickle")
# config['paths']['taxinet_state_dict_path'] = os.path.join(project_root, "models/weights/taxinet_model_larger.pth")
config['paths']['taxinet_state_dict_path'] = os.path.join(project_root, "models/weights/taxinet_model.pth")
config['paths']['taxinet_data_path'] = os.path.join(project_root, "data/taxinet/taxinet.pickle")
config['paths']['data_root_path'] = os.path.join(project_root, "data")


env_info = get_env_description()