import pandas as pd
import json
import os
import math
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
import matplotlib.patches as patches # Import patches for drawing rectangles
import matplotlib.transforms as transforms
import matplotlib.ticker as ticker # Import ticker for FixedLocator
import numpy as np


# --- Data Processing Function (Unchanged) ---
def process_data(target,edge_prob,p_order,t,metric,noises_type_list,edge_ratio_list,node_list,
                 file_mapping_dict,base_filename,all_filename):
    print("\n--- Starting Data Processing ---")
    
    print("Step 1: Reading and cleaning CSV data...")
    all_dfs = {}
    for filepath, alg_name in file_mapping_dict.items(): # Iterate using the generated mapping
        df = pd.read_csv(filepath)
        df['name'] = df['name'].str.strip()
        df['edge_prior_prob'] = pd.to_numeric(df['edge_prior_prob'], errors='coerce')
        df['porders'] = pd.to_numeric(df['porders'], errors='coerce')
        df['T'] = pd.to_numeric(df['T'], errors='coerce')
        df['node'] = pd.to_numeric(df['node'], errors='coerce')
        df[metric] = pd.to_numeric(df[metric], errors='coerce')

        filtered_df = df[(df['name'] == target) &
                        (df['edge_prior_prob']==edge_prob) &
                        (df['porders'] == p_order) &
                        (df['T'] == t)].copy()

        all_dfs[filepath] = filtered_df
    
    results=[]
    print("\nStep 2: Pre-calculating baseline values...")
    baseline_values=[]
    for nt in noises_type_list:
        nt_values=[]
        for er in edge_ratio_list:
            er_values=[]
            baseline_filepaths_in_map = [fp for fp in file_mapping_dict.keys() if base_filename in os.path.basename(fp)]
            for base_filepath in baseline_filepaths_in_map:
                bf_values=[]
                df_base = all_dfs[base_filepath]
                for node in node_list:
                    value = df_base[(df_base['noise_type']==nt) &
                                  (df_base['edge']/df_base['node']==er) &
                                  (df_base['node']==node)]
                    mean_base = value[metric].mean()
                    std_base = value[metric].std()
                    bf_values.append([mean_base,std_base])
                er_values.append(bf_values)
            nt_values.append(er_values)
        baseline_values.append(nt_values)
    results.append(baseline_values)

    print("\nStep 3: Populating results structure and calculating algorithm values...")
    all_values=[]
    for nt in noises_type_list:
        nt_values=[]
        for er in edge_ratio_list:
            er_values=[]
            all_filepaths_in_map = [fp for fp in file_mapping_dict.keys() if all_filename in os.path.basename(fp)]
            for all_filepath in all_filepaths_in_map:
                bf_values=[]
                df_all = all_dfs[all_filepath]
                for node in node_list:
                    value = df_all[(df_all['noise_type']==nt) &
                                  (df_all['edge']/df_all['node']==er) &
                                  (df_all['node']==node)]
                    mean_base = value[metric].mean()
                    std_base = value[metric].std()
                    bf_values.append([mean_base,std_base])
                er_values.append(bf_values)
            nt_values.append(er_values)
        all_values.append(nt_values)
    results.append(all_values)
    
    return results


# --- Plotting Function (Uses pre-assigned colors via alg_color_map) ---
def plot_data(data_to_plot, nt_list, er_list, node_list, metric, file_mapping_dict,
            linestyles_dict, markers_map, alg_color_map, metric_titles, output_plot_filepath):

    print("\n--- Starting Plotting ---")
    num_nt = len(nt_list)
    num_er = len(er_list)

    fig, axes = plt.subplots(num_nt, num_er,
                             figsize=(5 * num_nt, 5 * num_er),
                            sharex=True,
                            sharey='row')

    lines_for_legend = {}
    row_y_ranges = [[] for _ in range(num_nt)]
        # --- Define bar parameters ---
    group_size = 3
    gap_between_groups = 1.5 # Gap between the right edge of the last bar of a group and the left edge of the first bar of the next group
    bar_width = 1.5 # Width of each bar
    # If you want bars to touch *exactly*, bar_width should ideally be 1 (when centers are 1 unit apart)
    # OR the centers should be 'bar_width' apart when using align='center'
    # Let's adjust the x_positions calculation to make them touch

    all_filepaths_in_map = [fn for fn in file_mapping_dict.values() if 'dynotears' in fn]
    num_algorithms = len(all_filepaths_in_map) # Should be 8 based on the description
    num_groups = num_algorithms // group_size # Should be 2

    # --- Calculate x_positions for bars ---
    x_positions = []
    current_group_start_center = 0 # Start the first group at position 0
    for i in range(num_groups):
        # Calculate centers for the current group, spaced by bar_width
        group_centers = [current_group_start_center + k * bar_width for k in range(group_size)]
        x_positions.extend(group_centers)
        # Calculate the center of the first bar in the *next* group
        if i < num_groups - 1: # Only calculate for groups before the last one
            # Right edge of the last bar in the current group
            right_edge_current_group = group_centers[-1] + bar_width / 2
            # Left edge of the first bar in the next group
            left_edge_next_group = right_edge_current_group + gap_between_groups
            # Center of the first bar in the next group
            current_group_start_center = left_edge_next_group + bar_width / 2
    x_positions = np.array(x_positions)
    print(f"Calculated x_positions: {x_positions}") # Debug print
    hatch_patterns = ['/','-', '\\','/','-', '\\'] # Ensure enough patterns for all bars
    bar_linewidth = 2

    for i, nt in enumerate(nt_list):
        for j, er in enumerate(er_list):
            ax = axes[i, j]
            subplot_y_values = []
            subplot_std_values = []
            y_base = [item[0] for item in data_to_plot[0][i][j][0]] # Data index [0] for baseline
            std_base = [item[1] for item in data_to_plot[0][i][j][0]]
            line = ax.axhline(y=y_base[0], color='green', linestyle='--', linewidth=2)
            lines_for_legend['baseline'] = line
            subplot_y_values.extend(y_base)
            subplot_std_values.extend(std_base)
            
            y_values=[]     
            std_values=[]     
            for k, alg_name in enumerate(all_filepaths_in_map):
                y_values.append(data_to_plot[1][i][j][k][0][0] + y_base[0])
                std_values.append(data_to_plot[1][i][j][k][0][1])
            bars = ax.bar(
                x_positions,
                y_values,
                width=bar_width,
                color='white',
                edgecolor=alg_color_map, 
                linewidth=bar_linewidth, 
                hatch=hatch_patterns
            )
            if len(bars.patches) == len(all_filepaths_in_map):
                for k, alg_name in enumerate(all_filepaths_in_map):
                    if alg_name not in lines_for_legend:
                        lines_for_legend[alg_name] = bars.patches[k]
            subplot_y_values.extend(y_values)
            subplot_std_values.extend(std_values)
                
            if subplot_y_values: 
                y_vals_arr = np.array(subplot_y_values)
                std_vals_arr = np.array(subplot_std_values)
                subplot_min_y = np.min(y_vals_arr - 0.05 * y_vals_arr)
                subplot_max_y = np.max(y_vals_arr + 0.05 * y_vals_arr)
                row_y_ranges[i].append((subplot_min_y, subplot_max_y))
                
            for m in range(num_nt):
                if row_y_ranges[m]: # Check if this row has any data
                    # Find the overall min and max y for the entire row based on data
                    min_y_overall = min(item[0] for item in row_y_ranges[m])
                    max_y_overall = max(item[1] for item in row_y_ranges[m])

            # # --- Set subplot properties ---
            padding = 0.01
            ax.set_ylim(min_y_overall - padding, max_y_overall + padding)
            ax.set_xticks([np.mean(x_positions[:3]),np.mean(x_positions[3:6])])
            ax.set_xticklabels(['Init 0','Init Data'], ha='center')
            ax.tick_params(axis='both', which='both', bottom=False, top=False, left=False,right=False)
            if j == 1:
                ax.tick_params(
                    axis='both',   
                    which='both',    
                    bottom=False,      
                    top=False,        
                    left=False,        
                    right=False,
                    labelleft=False) 
            if i == 0:
                ax.tick_params(
                    axis='both',   
                    which='both',    
                    bottom=False,      
                    top=False,        
                    left=False,        
                    right=False,
                    labelbottom=False) 
            # ax.minorticks_on() # <-- Enable minor ticks for both axes
            # ax.grid(True, which='major', color='white', linestyle='-', linewidth=4)
            # ax.grid(True, which='minor', color='white', linestyle='-', linewidth=2) # <-- Add minor grid
            # node_list_ = [10,20,30,40,50]
            # minor_x_locations = [(node_list_[k] + node_list_[k+1]) / 2 for k in range(len(node_list_) - 1)]
            # candidate_yticks = ax.yaxis.get_ticklocs()
            # minor_y_locations = [(candidate_yticks[k] + candidate_yticks[k+1]) / 2 for k in range(len(candidate_yticks) - 1)]
            # ax.xaxis.set_minor_locator(ticker.FixedLocator(minor_x_locations))
            # ax.yaxis.set_minor_locator(ticker.FixedLocator(minor_y_locations))
            ax.yaxis.grid(True, linestyle='-', color='gray', linewidth=2, alpha=0.1)
            for spine_name, spine in ax.spines.items():
                spine.set_linewidth(1.5)
                spine.set_edgecolor('black')
            ax.set_facecolor('white')
    
    # --- Adjust layout of the main figure to make space for external labels ---
    # These values need tuning based on figsize, font size, and number of rows/cols
    # Increase space for the external labels
    adjust_left = 0.2 # Increased space for common ylabel and left tick labels
    adjust_bottom = 0.2 # Increased space for common xlabel and bottom tick labels
    adjust_right = 0.9 # Reduced space on the right for row labels
    adjust_top = 0.9 # Reduced space on the top for column labels
    adjust_wspace = 0.03 # Horizontal space between subplots
    adjust_hspace = 0.03 # Vertical space between subplots

    fig.subplots_adjust(left=adjust_left, bottom=adjust_bottom,
                        right=adjust_right, top=adjust_top,
                        wspace=adjust_wspace, hspace=adjust_hspace)

    # --- Define the style for the facet label boxes ---
    label_box_style = dict(facecolor='white', # Fill color (light grey)
                           edgecolor='black', # Border color
                           linewidth=1.5, # Border line width
                           clip_on=False # Allow drawing outside axes if needed (though we are drawing on fig)
                          )

    # --- Add column labels (top) with gray boxes aligned to columns ---
    column_labels = ['ER2','ER4']
    for j, col_label in enumerate(column_labels):
        # Get the position of the first axes in this column (any axes in the column will have the same width and x-position)
        ax = axes[0, j]
        pos = ax.get_position() # Get Axes position in Figure coordinates (Bbox)

        # Calculate box position and size
        box_x = pos.x0 # Box starts at the left edge of the column's axes
        box_width = pos.width # Box width is the width of the column's axes
        box_height = (1 - adjust_top) * 0.6 # Example: make height 60% of the available space
        box_y = adjust_top + (1 - adjust_top) * 0.1 # Example: position box starting 20% into the available space

        # Draw the gray rectangle
        rect = patches.Rectangle((box_x, box_y), box_width, box_height, **label_box_style)
        fig.add_artist(rect)

        # Place the text label centered within the rectangle
        text_x = box_x + box_width / 2
        text_y = box_y + box_height / 2
        fig.text(text_x, text_y, col_label, ha='center', va='center',
                 fontsize=40) # Removed bbox from fig.text


    # --- Add row labels (right) with gray boxes aligned to rows ---
    row_labels = ['Gauss','Exp']
    for i, row_label in enumerate(row_labels):
        # Get the position of the last axes in this row (any axes in the row will have the same height and y-position)
        ax = axes[i, -1] # Get the rightmost axes in the row
        pos = ax.get_position() # Get Axes position in Figure coordinates (Bbox)

        # Calculate box position and size
        box_y = pos.y0 # Box starts at the bottom edge of the row's axes
        box_height = pos.height # Box height is the height of the row's axes
        box_width = (1 - adjust_right) * 0.6 # Example: make width 60% of the available space
        box_x = adjust_right + (1 - adjust_right) * 0.1 # Example: position box starting 20% into the available space

        # Draw the gray rectangle
        rect = patches.Rectangle((box_x, box_y), box_width, box_height, **label_box_style)
        fig.add_artist(rect)

        # Place the text label centered within the rectangle
        text_x = box_x + box_width / 2
        text_y = box_y + box_height / 2
        fig.text(text_x, text_y, row_label, ha='center', va='center', rotation=270,
                 fontsize=40) # Removed bbox from fig.text


    # --- Add a common X-axis label at the bottom ---
    common_xlabel_y_pos = adjust_bottom / 2 # Or adjust_bottom - (adjust_bottom - 0) * 0.5
    if num_nt > 0 and num_er > 0:
        fig.text(0.55, common_xlabel_y_pos, 'Algorithm', ha='center', va='center',
                 fontsize=40)

    # --- Common Y-axis label on the left ---
    # Position it centered vertically, to the left of the 'left' boundary of the axes grid
    common_ylabel_x_pos = adjust_left / 2 # Halfway between figure left and grid left
    common_ylabel_y_pos = adjust_bottom + (adjust_top - adjust_bottom) / 2 # Center of the grid height
    common_ylabel_text = metric_titles.get(metric, metric) # Get title based on metric, fallback to metric name
    # Add the common ylabel
    fig.text(common_ylabel_x_pos, common_ylabel_y_pos, common_ylabel_text,
             ha='center', va='center', rotation='vertical', fontsize=40)

    if lines_for_legend:
        alg_name_mapping={
            'baseline': "Baseline", 
            'dynotears_and_init0': "DYNOTEARS& (Init 0)",
            'dynotears_and_initrandom': "DYNOTEARS& (Init Random)",
            'dynotears_and_initedge': "DYNOTEARS& (Init Edge)",
            'dynotears_and_initdata': "DYNOTEARS& (Init Data)",
            'dynotears_multiply_init0': "DYNOTEARS* (Init 0)",
            'dynotears_multiply_initrandom': "DYNOTEARS* (Init Random)",
            'dynotears_multiply_initedge': "DYNOTEARS* (Init Edge)",
            'dynotears_multiply_initdata': "DYNOTEARS* (Init Data)",
            'dynotears_max_init0': "DYNOTEARS^ (Init 0)",
            'dynotears_max_initdata': "DYNOTEARS^ (Init Data)"
        }

        # Sort items based on the original keys
        legend_items = lines_for_legend.items()
        legend_lines = [lines_for_legend[item[0]] for item in legend_items]
        mapped_legend_labels = [alg_name_mapping.get(item[0], item[0]) for item in legend_items]
        num_legend_items = len(mapped_legend_labels)
        approx_width_per_item = 2.5 # Increased slightly, might need more tuning
        legend_fig_width = max(num_legend_items *3, approx_width_per_item)
        legend_fig_height = 1 # Minimum height for a single row legend, adjust if labels wrap

        fig_legend = plt.figure(figsize=(legend_fig_width, legend_fig_height))

        # --- Create the legend on the new Figure ---
        markerscale_factor = 2.0 # Make markers larger

        legend = fig_legend.legend(legend_lines, mapped_legend_labels, # Use mapped labels
                                loc='center',
                                title="Algorithms",
                                ncol=1, # Arrange in a single row
                                markerscale=markerscale_factor) # Make markers larger

        # Save the legend figure
        legend_filepath = os.path.join(os.path.dirname(output_plot_filepath), 'legend_only.pdf')
        print(f"Saving legend to {legend_filepath}")
        # bbox_extra_artists is often needed to ensure the legend title is not cut off
        fig_legend.savefig(legend_filepath, bbox_inches='tight', bbox_extra_artists=(legend,), dpi=300)

        plt.close(fig_legend) # Close the legend figure to avoid displaying it
        
    output_dir = os.path.dirname(output_plot_filepath)
    if output_dir and not os.path.exists(output_dir):
        os.makedirs(output_dir)
    print(f"Saving main plot to {output_plot_filepath}")
    plt.savefig(output_plot_filepath, format='pdf', dpi=300, bbox_inches='tight')
    plt.close(fig) 

# --- Main Execution ---
if __name__ == "__main__":
    # fm.fontManager.clear() 
    plt.rcParams['font.family'] = ['Times New Roman']
    plt.rcParams['font.size'] = 30
    
    target_name = ['timeseries']
    edge_prior_probs = [1.0]
    p_order_list = [3]
    t_values = [100]
    
    metrics = ['accuracy', 'recall', 'f1', 'shd', 'edge_recovery']
    noises_type_list=['noisegauss','noiseexp']
    edge_ratio_list = [2,4]
    node_list = [30]
    experiment_folders = ['max_init0','and_init0','multiply_init0' ,'max_initdata', 'and_initdata', 'multiply_initdata']
    # experiment_folders = ['max_init0', 'multiply_init0' ,'max_initdata', 'multiply_initdata']
    
    base_result_dir = 'result/result_max'
    base_filename = 'merged_base_summary.csv'
    all_filename = 'merged_all_summary.csv'
 
    
    colors = ['tab:pink','tab:orange','tab:cyan','tab:red', 'tab:brown','tab:purple']
    # colors = ['tab:pink','tab:cyan','tab:red', 'tab:purple']

    file_mapping = {}
    for alg in experiment_folders:
        file_mapping[os.path.join(base_result_dir,alg,base_filename)]=f'baseline_{alg}'
        file_mapping[os.path.join(base_result_dir,alg,all_filename)]=f'dynotears_{alg}'
        
    linestyles = ['-','-','-','-','-','-']
    markers = ['.'] * 6
    metric_titles = {
        'accuracy': 'Accuracy',
        'recall': 'Recall',
        'f1': 'F1 Score',
        'shd': 'SHD',
        'edge_recovery': 'Edge Recovery Rate'
    }

    for target in target_name:
        for edge_prob in edge_prior_probs:
            for p_order in p_order_list:
                for t in t_values:
                    for metric in metrics:
                        output_json_file = f'figure/exp_max/{target}_{edge_prob}_{p_order}_{t}_{metric}.json'
                        output_plot_file = f'figure/exp_max/{target}_{edge_prob}_{p_order}_{t}_{metric}.pdf'

                        # 1. Process Data (No change needed here)
                        processed_results = process_data(
                            target,
                            edge_prob,
                            p_order,
                            t,
                            metric,
                            noises_type_list,
                            edge_ratio_list,
                            node_list,
                            file_mapping_dict=file_mapping,
                            base_filename=base_filename,
                            all_filename=all_filename
                        )

                        if processed_results:
                            # 2. Save JSON (No change needed here)
                            print(f"\n--- Saving Processed Data to JSON ({output_json_file}) ---")
                            output_dir_json = os.path.dirname(output_json_file)
                            if output_dir_json and not os.path.exists(output_dir_json):
                                os.makedirs(output_dir_json); print(f"Created output directory: {output_dir_json}")

                            with open(output_json_file, 'w', encoding='utf-8') as f:
                                json.dump(processed_results, f, indent=4, ensure_ascii=False)

                            # 3. Plot Data - Pass the pre-calculated color map
                            plot_data(
                                data_to_plot=processed_results, 
                                nt_list=noises_type_list,
                                er_list=edge_ratio_list, 
                                node_list=node_list, 
                                metric=metric,
                                file_mapping_dict=file_mapping,
                                linestyles_dict=linestyles, 
                                markers_map=markers, 
                                alg_color_map=colors, 
                                metric_titles=metric_titles,
                                output_plot_filepath=output_plot_file
                            )

