import argparse
import sys

import yaml


def load_yaml_config(path):
    """Load the config file in yaml format.

    Args:
        path (str): Path to load the config file.

    Returns:
        dict: config.
    """
    with open(path, 'r') as infile:
        return yaml.safe_load(infile)


def save_yaml_config(config, path):
    """Load the config file in yaml format.

    Args:
        config (dict object): Config.
        path (str): Path to save the config.
    """
    with open(path, 'w') as outfile:
        yaml.dump(config, outfile, default_flow_style=False)


def get_args():
    """Add arguments for parser.

    Returns:
        argparse.Namespace: Parsed arguments.
    """
    parser = argparse.ArgumentParser()

    add_dataset_args(parser)
    add_method_args(parser)
    add_other_args(parser)

    return parser.parse_args(args=sys.argv[1:])


def add_dataset_args(parser):
    """Add dataset arguments for parser.

    Args:
        parser (argparse.ArgumentParser): Parser.
    """
    parser.add_argument('--samplenum',
                    type=int,
                    default=500,
                    help="Number of samples.")

    parser.add_argument('--nodenum',
                        type=int,
                        default=10,
                        help="Number of nodes.")

    parser.add_argument('--degree',
                        type=int,
                        default=2,
                        help="Degree of graph.")

    parser.add_argument('--data_type',
                        type=str,
                        default='linear',
                        help="Type of SEM.")

    parser.add_argument('--noise_type',
                        type=str,
                        default='gauss',
                        help="Type of noise.")

    parser.add_argument('--max_selection_num',
                        type=int,
                        default=4,
                        help="Maximum number of selections combinations.")

    parser.add_argument('--max_single_selection_num',
                        type=int,
                        default=3,
                        help="Maximum number of variables in a single selection.")

    parser.add_argument('--num_of_interv_configs',
                        type=int,
                        default=10,
                        help="Number of interventions.")

    parser.add_argument('--max_single_intervention_num',
                        type=int,
                        default=3,
                        help="Maximum number of variables in a single intervention.")


def add_method_args(parser):
    """Add DAG arguments for parser.

    Args:
        parser (argparse.ArgumentParser): Parser.
    """
    parser.add_argument('--method_type',
                        type=str,
                        default='gies',
                        help="Type of causal discovery method.")
    
    parser.add_argument('--perturbation_type',
                        type=str,
                        default='hard',
                        help="Type of perturbation method.")

    parser.add_argument('--CI_type',
                        type=str,
                        default='oracle',
                        help="Type of CI test.")

    parser.add_argument('--alpha',
                        type=float,
                        default=0.05,
                        help="Significance level for CI test.")


def add_other_args(parser):
    """Add other arguments for parser.

    Args:
        parser (argparse.ArgumentParser): Parser.
    """
    parser.add_argument('--seed',
                        type=int,
                        default=0,
                        help="Random seed.")
    
    # python main --perturbaion_type hard --method_type gies --nodenum 10 --samplenum 500 --seed 0 