
import matplotlib.pyplot as plt
import os

import sys
import os

# Get absolute path of the project root
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if project_root not in sys.path:
    sys.path.insert(0, project_root)
    

plotting_params = {
    'legend.fontsize': 16,
    'lines.linewidth': 2.5,
    'lines.markersize': 6,
    'xtick.labelsize': 14,
    'ytick.labelsize': 14,
    'axes.labelsize': 14,
    'font.size': 16,
    'title.size': 16,
    'figsize_regimes': (6, 8), 
    'figsize_union': (8, 6),
    'common_fig_width': 8,
    'common_fig_height': 6
}
# Define clean metric names and method names
metric_clean_names = {
    'union_tpr': 'TPR', 'union_fpr': 'FPR', 'adj_recall': 'adj-Recall', 'union_adj_prec': 'adj-Precision',
    'union_edge_prec': 'Union Edgemark Prec.', 'union_edge_rec': 'Union Edgemark Rec.',
    'tpr': 'TPR', 'fpr': 'FPR', 'adj_rec': 'Recall', 'adj_prec': 'Precision',
    'edge_prec': 'Edge Precision', 'edge_rec': 'Edge Recall',
    'equal_regimes': '# equal regimes', 'reg_fpr': 'Reg.ind. FPR', 'reg_tpr': 'Reg.ind. TPR',
    'f1': 'F1', 'union_f1': 'F1', 'intersection': 'Intersection'
}
method_names = {
    '_ymask': 'R-PCMCI',
    '_ymask_naive': 'M-PCMCI',
    '_persistent_regimes': 'PAC-PCMCI',
    '_sparse_regimes': 'SAC-PCMCI',
    '_pcmci': 'P-PCMCI', '_intersection': 'B-PCMCI',
    '_fci': 'FCI', '_nod': 'CD-NOD'
}


colors = {
    '_intersection':      '#1F78B4',
    '_ymask':             '#E31A1C',
    '_ymask_naive':       '#6A3D9A',
    '_persistent_regimes':'#2CA02C', 
    '_sparse_regimes':    '#91BF54', 
    '_pcmci':             '#FF7F00',

}

children_labels = {'and_parents': '$C$',
            'True': '$C$-ch.',
            'False': 'none'}

linestyles = {'False': '-', 'True': '--', 'and_parents': ':'}

nodestyles = {'_persistent_regimes': 'o', '_sparse_regimes': 'o', 
             '_intersection': '*', 
             '_ymask': 'D',
             '_ymask_naive': 'D',
             '_pcmci': '^'}


vary_labels = {'sample_size': 'Sample size',
               'regime_autocorr': 'Context auto-lag',
              'N_values': 'Graph size'}

def save_legend(fig, handles, labels, plotting_params, figure_path, legend_filename):
    fig_legend = plt.figure(figsize=(3, 3))
    fig_legend.legend(handles, labels, loc='center', fontsize=plotting_params['legend.fontsize'])
    fig_legend.savefig(os.path.join(figure_path, legend_filename), bbox_inches='tight')
    plt.close(fig_legend)

def convert_to_string_list(s):
    # Check if the input is None or empty and return an empty list
    if not s:
        return []
    # Split the string on commas and strip any surrounding whitespace
    return [item.strip() for item in s.split(',')]

def convert_to_children_known(s):
    list = convert_to_string_list(s)

    for i in range(len(list)):
        if list[i] == 'False':
            list[i] = False
        if list[i] == 'True':
            list[i] = True
            
    return list

def convert_to_int_list(input_string):
    # Initialize an empty list to store the integers
    result = []
    
    # Split the input string by commas first, then iterate through each segment
    for item in input_string.split(','):
        # Further split by spaces to catch cases where spaces are used
        for subitem in item.split():
            # Attempt to convert each subitem to an integer
            try:
                # Convert subitem to integer and add to the result list
                result.append(int(subitem))
            except ValueError:
                # If a ValueError occurs (non-integer string), skip the subitem
                continue
    
    return result

def convert_to_float_list(input_string):
    # Initialize an empty list to store the integers
    result = []
    
    # Split the input string by commas first, then iterate through each segment
    for item in input_string.split(','):
        # Further split by spaces to catch cases where spaces are used
        for subitem in item.split():
            # Attempt to convert each subitem to an integer
            try:
                # Convert subitem to integer and add to the result list
                result.append(float(subitem))
            except ValueError:
                # If a ValueError occurs (non-integer string), skip the subitem
                continue
    
    return result
