import pandas as pd
import json
import os
import math
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
import matplotlib.patches as patches # Import patches for drawing rectangles
import matplotlib.transforms as transforms
import matplotlib.ticker as ticker # Import ticker for FixedLocator
import numpy as np


# --- Data Processing Function (Unchanged) ---
def process_data(target,edge_prob,p_order,t,metric,noises_type_list,edge_ratio_list,lambda_e_list,
                 file_mapping_dict,base_filename,all_filename):
    print("\n--- Starting Data Processing ---")
    
    print("Step 1: Reading and cleaning CSV data...")
    all_dfs = {}
    for filepath, alg_name in file_mapping_dict.items(): # Iterate using the generated mapping
        df = pd.read_csv(filepath)
        df['name'] = df['name'].str.strip()
        df['porders'] = pd.to_numeric(df['porders'], errors='coerce')
        df['T'] = pd.to_numeric(df['T'], errors='coerce')
        df['lambda_e'] = pd.to_numeric(df['lambda_e'], errors='coerce')
        df[metric] = pd.to_numeric(df[metric], errors='coerce')

        filtered_df = df[(df['name'] == target) &
                        (df['porders'] == p_order) &
                        (df['T'] == t)].copy()

        all_dfs[filepath] = filtered_df
    
    results=[]
    print("\nStep 2: Pre-calculating baseline values...")
    baseline_values=[]
    for nt in noises_type_list:
        nt_values=[]
        for er in edge_ratio_list:
            er_values=[]
            baseline_filepaths_in_map = [fp for fp in file_mapping_dict.keys() if base_filename in os.path.basename(fp)]
            for base_filepath in baseline_filepaths_in_map:
                bf_values=[]
                df_base = all_dfs[base_filepath]
                for lambda_e in lambda_e_list:
                    value = df_base[(df_base['noise_type']==nt) &
                                  (df_base['edge']/df_base['node']==er) &
                                  (df_base['lambda_e']==lambda_e)]
                    mean_base = value[metric].mean()
                    std_base = value[metric].std()
                    bf_values.append([mean_base,std_base])
                er_values.append(bf_values)
            nt_values.append(er_values)
        baseline_values.append(nt_values)
    results.append(baseline_values)

    print("\nStep 3: Populating results structure and calculating algorithm values...")
    all_values=[]
    for nt in noises_type_list:
        nt_values=[]
        for er in edge_ratio_list:
            er_values=[]
            all_filepaths_in_map = [fp for fp in file_mapping_dict.keys() if all_filename in os.path.basename(fp)]
            for all_filepath in all_filepaths_in_map:
                bf_values=[]
                df_all = all_dfs[all_filepath]
                for lambda_e in lambda_e_list:
                    value = df_all[(df_all['noise_type']==nt) &
                                  (df_all['edge']/df_all['node']==er) &
                                  (df_all['lambda_e']==lambda_e)]
                    mean_base = value[metric].mean()
                    std_base = value[metric].std()
                    bf_values.append([mean_base,std_base])
                er_values.append(bf_values)
            nt_values.append(er_values)
        all_values.append(nt_values)
    results.append(all_values)
    
    return results


# --- Plotting Function (Uses pre-assigned colors via alg_color_map) ---
def plot_data(data_to_plot, nt_list, er_list, lambda_e_list, metric, file_mapping_dict,
            linestyles_dict, markers_map, alg_color_map, metric_titles, output_plot_filepath):

    print("\n--- Starting Plotting ---")
    num_nt = len(nt_list)
    num_er = len(er_list)

    fig, axes = plt.subplots(num_nt, num_er,
                             figsize=(5 * num_nt, 5 * num_er),
                            sharex=True,
                            sharey='row')

    lines_for_legend = {}
    row_y_ranges = [[] for _ in range(num_nt)]

    for i, nt in enumerate(nt_list):
        for j, er in enumerate(er_list):
            ax = axes[i, j]
            subplot_y_values = []
            subplot_std_values = []
            y_base = [item[0] for item in data_to_plot[0][i][j][0]] # Data index [0] for baseline
            std_base = [item[1] for item in data_to_plot[0][i][j][0]]
            line, caplines, errorlines = ax.errorbar(range(len(lambda_e_list)), y_base,
                                                yerr=std_base,  
                                                color='green',
                                                linestyle='-.', 
                                                marker='.', 
                                                linewidth=4,
                                                ecolor='green', 
                                                capsize=0,       
                                                elinewidth=3,
                                                markersize=12)
            lines_for_legend['baseline'] = line
            subplot_y_values.extend(y_base)
            subplot_std_values.extend(std_base)
            
            all_filepaths_in_map = [fn for fn in file_mapping_dict.values() if 'dynotears' in fn]
            for k, alg_name in enumerate(all_filepaths_in_map):
                y_values = [item[0] for item in data_to_plot[1][i][j][k]]
                for q in range(len(y_values)):
                    y_values[q] += y_base[q]
                std_values = [item[1] for item in data_to_plot[1][i][j][k]]
                # --- Plotting ---
                line, caplines, errorlines = ax.errorbar(range(len(lambda_e_list)), y_values,
                                                yerr=std_values,  
                                                color=alg_color_map[k],
                                                linestyle=linestyles_dict[k], 
                                                marker=markers_map[k], 
                                                linewidth=4,
                                                ecolor=alg_color_map[k], 
                                                capsize=0,       
                                                elinewidth=3,
                                                markersize=12)
                lines_for_legend[alg_name] = line
                subplot_y_values.extend(y_values)
                subplot_std_values.extend(std_values)
                
            if subplot_y_values: 
                y_vals_arr = np.array(subplot_y_values)
                std_vals_arr = np.array(subplot_std_values)
                subplot_min_y = np.min(y_vals_arr - std_vals_arr)
                subplot_max_y = np.max(y_vals_arr + std_vals_arr)
                row_y_ranges[i].append((subplot_min_y, subplot_max_y))
                
            for m in range(num_nt):
                if row_y_ranges[m]: # Check if this row has any data
                    # Find the overall min and max y for the entire row based on data
                    min_y_overall = min(item[0] for item in row_y_ranges[m])
                    max_y_overall = max(item[1] for item in row_y_ranges[m])

            # # --- Set subplot properties ---
            y_range = max_y_overall - min_y_overall
            padding = 0.01
            ax.set_ylim(min_y_overall - padding, max_y_overall + padding)
            # y_min, y_max = ax.get_ylim()
            # scaled_min = math.floor(y_min * 10)
            # scaled_max = math.ceil(y_max * 10)
            # candidate_yticks = [i / 10.0 for i in range(scaled_min, scaled_max + 1)]
            # # print(candidate_yticks)
            # # for yticks in candidate_yticks:
            # #     candidate_yticks.append(yticks+0.05)
            # # print(candidate_yticks)
            # # ax.set_yticks(candidate_yticks[1:-1])
            # ax.set_xscale('log')
            ax.set_xticks(range(len(lambda_e_list)))
            formatted_tick_labels = []
            for val in lambda_e_list:
                if val == 0.01:
                    formatted_tick_labels.append('.01')
                elif val == 0.1:
                    formatted_tick_labels.append('.1')
                elif val == 0.5:
                    formatted_tick_labels.append('.5')
                elif val in [1, 5, 10]: 
                    formatted_tick_labels.append(f'{int(val)}')
                else:
                    formatted_tick_labels.append(str(val))
            ax.set_xticklabels(formatted_tick_labels)
            ax.tick_params(axis='both', which='both', bottom=False, top=False, left=False,right=False)
            if j == 1:
                ax.tick_params(
                    axis='both',   
                    which='both',    
                    bottom=False,      
                    top=False,        
                    left=False,        
                    right=False,
                    labelleft=False) 
            if i == 0:
                ax.tick_params(
                    axis='both',   
                    which='both',    
                    bottom=False,      
                    top=False,        
                    left=False,        
                    right=False,
                    labelbottom=False) 
            ax.minorticks_on() # <-- Enable minor ticks for both axes
            ax.grid(True, which='major', color='white', linestyle='-', linewidth=4)
            ax.grid(True, which='minor', color='white', linestyle='-', linewidth=2) # <-- Add minor grid
            candidate_xticks = ax.xaxis.get_ticklocs()
            candidate_yticks = ax.yaxis.get_ticklocs()
            minor_x_locations = [(candidate_xticks[k] + candidate_xticks[k+1]) / 2 for k in range(len(candidate_xticks) - 1)]
            minor_y_locations = [(candidate_yticks[k] + candidate_yticks[k+1]) / 2 for k in range(len(candidate_yticks) - 1)]
            # minor_x_locations = []
            # if len(candidate_xticks) > 1:
            #     minor_x_locations = [math.sqrt(candidate_xticks[k] * candidate_xticks[k+1])
            #                         for k in range(len(candidate_xticks) - 1)]
            ax.xaxis.set_minor_locator(ticker.FixedLocator(minor_x_locations))
            ax.yaxis.set_minor_locator(ticker.FixedLocator(minor_y_locations))

            
            ax.spines['top'].set_visible(False)
            ax.spines['right'].set_visible(False)
            ax.spines['bottom'].set_visible(False)
            ax.spines['left'].set_visible(False)
            ax.set_facecolor((240/256, 240/256, 240/256))
            # ax.grid(True, color='white',linestyle='-', linewidth=3)
    
    # --- Adjust layout of the main figure to make space for external labels ---
    # These values need tuning based on figsize, font size, and number of rows/cols
    # Increase space for the external labels
    adjust_left = 0.2 # Increased space for common ylabel and left tick labels
    adjust_bottom = 0.2 # Increased space for common xlabel and bottom tick labels
    adjust_right = 0.9 # Reduced space on the right for row labels
    adjust_top = 0.9 # Reduced space on the top for column labels
    adjust_wspace = 0.03 # Horizontal space between subplots
    adjust_hspace = 0.03 # Vertical space between subplots

    fig.subplots_adjust(left=adjust_left, bottom=adjust_bottom,
                        right=adjust_right, top=adjust_top,
                        wspace=adjust_wspace, hspace=adjust_hspace)

    # --- Define the style for the facet label boxes ---
    label_box_style = dict(facecolor=(230/256, 230/256, 230/256), # Fill color (light grey)
                        #    edgecolor='gray', # Border color
                           linewidth=1.5, # Border line width
                           clip_on=False # Allow drawing outside axes if needed (though we are drawing on fig)
                          )

    # --- Add column labels (top) with gray boxes aligned to columns ---
    column_labels = ['ER2','ER4']
    for j, col_label in enumerate(column_labels):
        # Get the position of the first axes in this column (any axes in the column will have the same width and x-position)
        ax = axes[0, j]
        pos = ax.get_position() # Get Axes position in Figure coordinates (Bbox)

        # Calculate box position and size
        box_x = pos.x0 # Box starts at the left edge of the column's axes
        box_width = pos.width # Box width is the width of the column's axes
        box_height = (1 - adjust_top) * 0.6 # Example: make height 60% of the available space
        box_y = adjust_top + (1 - adjust_top) * 0.1 # Example: position box starting 20% into the available space

        # Draw the gray rectangle
        rect = patches.Rectangle((box_x, box_y), box_width, box_height, **label_box_style)
        fig.add_artist(rect)

        # Place the text label centered within the rectangle
        text_x = box_x + box_width / 2
        text_y = box_y + box_height / 2
        fig.text(text_x, text_y, col_label, ha='center', va='center',
                 fontsize=40) # Removed bbox from fig.text


    # --- Add row labels (right) with gray boxes aligned to rows ---
    row_labels = ['Gauss','Exp']
    for i, row_label in enumerate(row_labels):
        # Get the position of the last axes in this row (any axes in the row will have the same height and y-position)
        ax = axes[i, -1] # Get the rightmost axes in the row
        pos = ax.get_position() # Get Axes position in Figure coordinates (Bbox)

        # Calculate box position and size
        box_y = pos.y0 # Box starts at the bottom edge of the row's axes
        box_height = pos.height # Box height is the height of the row's axes
        box_width = (1 - adjust_right) * 0.6 # Example: make width 60% of the available space
        box_x = adjust_right + (1 - adjust_right) * 0.1 # Example: position box starting 20% into the available space

        # Draw the gray rectangle
        rect = patches.Rectangle((box_x, box_y), box_width, box_height, **label_box_style)
        fig.add_artist(rect)

        # Place the text label centered within the rectangle
        text_x = box_x + box_width / 2
        text_y = box_y + box_height / 2
        fig.text(text_x, text_y, row_label, ha='center', va='center', rotation=270,
                 fontsize=40) # Removed bbox from fig.text


    # --- Add a common X-axis label at the bottom ---
    common_xlabel_y_pos = adjust_bottom / 2 # Or adjust_bottom - (adjust_bottom - 0) * 0.5
    if num_nt > 0 and num_er > 0:
        fig.text(0.55, common_xlabel_y_pos, r'Prior Loss Weight $\lambda_p$', ha='center', va='center',
                 fontsize=40)

    # --- Common Y-axis label on the left ---
    # Position it centered vertically, to the left of the 'left' boundary of the axes grid
    common_ylabel_x_pos = adjust_left / 2 # Halfway between figure left and grid left
    common_ylabel_y_pos = adjust_bottom + (adjust_top - adjust_bottom) / 2 # Center of the grid height
    common_ylabel_text = metric_titles.get(metric, metric) # Get title based on metric, fallback to metric name
    # Add the common ylabel
    fig.text(common_ylabel_x_pos, common_ylabel_y_pos, common_ylabel_text,
             ha='center', va='center', rotation='vertical', fontsize=40)

    if lines_for_legend:
        alg_name_mapping={
            'baseline': "Baseline", 
            'dynotears_and_init0': "DYNOTEARS& (Init 0)",
            'dynotears_and_initdata': "DYNOTEARS& (Init Data)",
            'dynotears_and_initedge': "DYNOTEARS& (Init Edge)",
            'dynotears_multiply_init0': "DYNOTEARS* (Init 0)",
            'dynotears_multiply_initdata': "DYNOTEARS* (Init Data)",
            'dynotears_multiply_initedge': "DYNOTEARS* (Init Edge)"
        }

        # Sort items based on the original keys
        legend_items = sorted(lines_for_legend.items(), key=lambda item: (item[0] != 'baseline', item[0]))
        legend_lines = [lines_for_legend[item[0]] for item in legend_items]
        mapped_legend_labels = [alg_name_mapping.get(item[0], item[0]) for item in legend_items]

        num_legend_items = len(mapped_legend_labels)
        approx_width_per_item = 2.5 # Increased slightly, might need more tuning
        legend_fig_width = max(num_legend_items *3, approx_width_per_item)
        legend_fig_height = 1 # Minimum height for a single row legend, adjust if labels wrap

        fig_legend = plt.figure(figsize=(legend_fig_width, legend_fig_height))

        # --- Create the legend on the new Figure ---
        markerscale_factor = 2.0 # Make markers larger

        legend = fig_legend.legend(legend_lines, mapped_legend_labels, # Use mapped labels
                                loc='center',
                                title="Algorithms",
                                ncol=1, # Arrange in a single row
                                markerscale=markerscale_factor) # Make markers larger

        # Save the legend figure
        legend_filepath = os.path.join(os.path.dirname(output_plot_filepath), 'legend_only.pdf')
        print(f"Saving legend to {legend_filepath}")
        # bbox_extra_artists is often needed to ensure the legend title is not cut off
        fig_legend.savefig(legend_filepath, bbox_inches='tight', bbox_extra_artists=(legend,), dpi=300)

        plt.close(fig_legend) # Close the legend figure to avoid displaying it
        
    output_dir = os.path.dirname(output_plot_filepath)
    if output_dir and not os.path.exists(output_dir):
        os.makedirs(output_dir)
    print(f"Saving main plot to {output_plot_filepath}")
    plt.savefig(output_plot_filepath, format='pdf', dpi=300, bbox_inches='tight')
    plt.close(fig) 

# --- Main Execution ---
if __name__ == "__main__":
    # fm.fontManager.clear() 
    plt.rcParams['font.family'] = 'serif'
    plt.rcParams['font.serif'] = ['Times New Roman']
    plt.rcParams['font.size'] = 30
    plt.rcParams['mathtext.fontset'] = 'stix'
    
    target_name = ['timeseries']
    edge_prior_probs = [0.8]
    p_order_list = [3]
    t_values = [250, 1000]
    
    metrics = ['accuracy', 'recall', 'f1', 'shd', 'edge_recovery']
    noises_type_list=['noisegauss','noiseexp']
    edge_ratio_list = [2,4]
    lambda_e_list = [0.01,0.1,0.5,1,5,10]
    experiment_folders = ['and_init0', 'multiply_init0','and_initdata', 'multiply_initdata']
    
    base_result_dir = 'result/result_lambda_e'
    base_filename = 'merged_base_summary.csv'
    all_filename = 'merged_all_summary.csv' 
    
    colors = ['tab:orange','tab:cyan','tab:brown', 'tab:purple']

    file_mapping = {}
    for alg in experiment_folders:
        file_mapping[os.path.join(base_result_dir,alg,base_filename)]=f'baseline_{alg}'
        file_mapping[os.path.join(base_result_dir,alg,all_filename)]=f'dynotears_{alg}'
        
    linestyles = ['-','-','-','-','-','-']
    markers = ['.'] * 6
    metric_titles = {
        'accuracy': 'Accuracy',
        'recall': 'Recall',
        'f1': 'F1 Score',
        'shd': 'SHD',
        'edge_recovery': 'Edge Recovery Rate'
    }

    for target in target_name:
        for edge_prob in edge_prior_probs:
            for p_order in p_order_list:
                for t in t_values:
                    for metric in metrics:
                        output_json_file = f'figure/exp_lambda_e/{target}_{edge_prob}_{p_order}_{t}_{metric}.json'
                        output_plot_file = f'figure/exp_lambda_e/{target}_{edge_prob}_{p_order}_{t}_{metric}.pdf'

                        # 1. Process Data (No change needed here)
                        processed_results = process_data(
                            target,
                            edge_prob,
                            p_order,
                            t,
                            metric,
                            noises_type_list,
                            edge_ratio_list,
                            lambda_e_list,
                            file_mapping_dict=file_mapping,
                            base_filename=base_filename,
                            all_filename=all_filename
                        )

                        if processed_results:
                            # 2. Save JSON (No change needed here)
                            print(f"\n--- Saving Processed Data to JSON ({output_json_file}) ---")
                            output_dir_json = os.path.dirname(output_json_file)
                            if output_dir_json and not os.path.exists(output_dir_json):
                                os.makedirs(output_dir_json); print(f"Created output directory: {output_dir_json}")

                            with open(output_json_file, 'w', encoding='utf-8') as f:
                                json.dump(processed_results, f, indent=4, ensure_ascii=False)

                            # 3. Plot Data - Pass the pre-calculated color map
                            plot_data(
                                data_to_plot=processed_results, 
                                nt_list=noises_type_list,
                                er_list=edge_ratio_list, 
                                lambda_e_list=lambda_e_list, 
                                metric=metric,
                                file_mapping_dict=file_mapping,
                                linestyles_dict=linestyles, 
                                markers_map=markers, 
                                alg_color_map=colors, 
                                metric_titles=metric_titles,
                                output_plot_filepath=output_plot_file
                            )

