import pandas as pd
import json
import os
import math
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
import matplotlib.patches as patches # Import patches for drawing rectangles
import matplotlib.transforms as transforms
import matplotlib.ticker as ticker # Import ticker for FixedLocator
import numpy as np
import matplotlib.colors as mcolors

# --- Data Processing Function (Unchanged) ---
def process_data(target,edge_prob,p_order,t,metric,noises_type_list,edge_ratio_list,node_list,
                 file_mapping_dict,base_filename,all_filename):
    print("\n--- Starting Data Processing ---")
    
    print("Step 1: Reading and cleaning CSV data...")
    all_dfs = {}
    for filepath, alg_name in file_mapping_dict.items(): # Iterate using the generated mapping
        df = pd.read_csv(filepath)
        df['name'] = df['name'].str.strip()
        df['edge_prior_prob'] = pd.to_numeric(df['edge_prior_prob'], errors='coerce')
        df['porders'] = pd.to_numeric(df['porders'], errors='coerce')
        df['T'] = pd.to_numeric(df['T'], errors='coerce')
        df['node'] = pd.to_numeric(df['node'], errors='coerce')
        df[metric] = pd.to_numeric(df[metric], errors='coerce')

        filtered_df = df[(df['name'] == target) &
                        (df['edge_prior_prob']==edge_prob) &
                        (df['porders'] == p_order) &
                        (df['T'] == t)].copy()

        all_dfs[filepath] = filtered_df
    
    results=[]
    print("\nStep 2: Pre-calculating baseline values...")
    baseline_values=[]
    for nt in noises_type_list:
        nt_values=[]
        for er in edge_ratio_list:
            er_values=[]
            baseline_filepaths_in_map = [fp for fp in file_mapping_dict.keys() if base_filename in os.path.basename(fp)]
            for base_filepath in baseline_filepaths_in_map:
                bf_values=[]
                df_base = all_dfs[base_filepath]
                for node in node_list:
                    value = df_base[(df_base['noise_type']==nt) &
                                  (df_base['edge']/df_base['node']==er) &
                                  (df_base['node']==node)]
                    # mean_base = value[metric].mean()
                    # std_base = value[metric].std()
                    bf_values.append(list(value[metric].values))
                er_values.append(bf_values)
            nt_values.append(er_values)
        baseline_values.append(nt_values)
    results.append(baseline_values)

    print("\nStep 3: Populating results structure and calculating algorithm values...")
    all_values=[]
    for nt in noises_type_list:
        nt_values=[]
        for er in edge_ratio_list:
            er_values=[]
            all_filepaths_in_map = [fp for fp in file_mapping_dict.keys() if all_filename in os.path.basename(fp)]
            for all_filepath in all_filepaths_in_map:
                bf_values=[]
                df_all = all_dfs[all_filepath]
                for node in node_list:
                    value = df_all[(df_all['noise_type']==nt) &
                                  (df_all['edge']/df_all['node']==er) &
                                  (df_all['node']==node)]
                    # mean_base = value[metric].mean()
                    # std_base = value[metric].std()
                    bf_values.append(list(value[metric].values))
                er_values.append(bf_values)
            nt_values.append(er_values)
        all_values.append(nt_values)
    results.append(all_values)
    
    return results


# --- Plotting Function (Uses pre-assigned colors via alg_color_map) ---
def plot_data(data_to_plot, nt_list, er_list, node_list, metric, file_mapping_dict,
            linestyles_dict, markers_map, alg_color_map, metric_titles, output_plot_filepath):

    print("\n--- Starting Plotting ---")
    num_nt = len(nt_list)
    num_er = len(er_list)

    fig, axes = plt.subplots(num_nt, num_er,
                             figsize=(5 * num_nt, 5 * num_er),
                            sharey=True)

    lines_for_legend = {}
    all_filepaths_in_map = [fn for fn in file_mapping_dict.values() if 'dynotears' in fn]
    for i, nt in enumerate(nt_list):
        for j, er in enumerate(er_list):
            ax = axes[i, j]
            y_base = data_to_plot[0][i][j][0][0]
            y_values=[]  
            y_values.append(y_base)       
            for k, alg_name in enumerate(all_filepaths_in_map):
                y_values.append(list(data_to_plot[1][i][j][k][0] + np.array(y_base)))
            
            # y positions for violin and shifted box
            y_positions_violin = np.arange(len(y_values))
            y_positions_box = y_positions_violin - 0.15 
            
            violin_parts = ax.violinplot(
                y_values,
                positions=y_positions_violin,
                widths=0.8,
                showmeans=False,
                showmedians=False,
                showextrema=False,
                vert=False
            )

            for l, body in enumerate(violin_parts['bodies']):
                path = body.get_paths()[0]
                vertices = path.vertices
                mean_y = np.mean(vertices[:, 1])
                vertices[:, 1] = np.where(vertices[:, 1] < mean_y, mean_y, vertices[:, 1])
                
                if l == 0:
                    color = 'green'
                else:
                    color = alg_color_map[l-1]
                body.set_facecolor(color)
                body.set_edgecolor(None)
                body.set_alpha(0.7)
                
                if l == 0:
                        lines_for_legend['baseline'] = body
                else:
                    lines_for_legend[f'dynotears_{experiment_folders[l-1]}'] = body
            
            boxplot_parts = ax.boxplot(
                y_values,
                positions=y_positions_box,
                widths=0.15,        
                showfliers=False, 
                showmeans=True,
                patch_artist=True, 
                # boxprops=dict(edgecolor='black', linewidth=1),
                # whiskerprops=dict(color='black', linewidth=1),
                # capprops=dict(color='black', linewidth=2),
                # medianprops=dict(color='black', linewidth=2), 
                meanprops=dict(marker='.', markerfacecolor='white',markeredgecolor='white', markersize=0),
                vert=False
            )
            
            colors = ['green'] +alg_color_map
            
            for p in range(len(boxplot_parts['boxes'])):
                box = boxplot_parts['boxes'][p]
                rgba_color = mcolors.to_rgba(colors[p], alpha=0.7)
                box.set_facecolor(rgba_color)
                box.set_edgecolor(colors[p])
                box.set_linewidth(1)

                for q in range(2):
                    whisker = boxplot_parts['whiskers'][p * 2 + q]
                    whisker.set_color(colors[p])
                    whisker.set_linewidth(1)

                for q in range(2):
                    cap = boxplot_parts['caps'][p * 2 + q]
                    cap.set_color(colors[p])
                    cap.set_linewidth(2)

                median = boxplot_parts['medians'][p]
                median.set_color(colors[p])
                median.set_linewidth(2)

            for idx, (data, y_pos) in enumerate(zip(y_values, y_positions_box)):
                x_jitter = np.random.normal(0, 0.05, size=len(data))
                ax.scatter(data, 
                        np.full_like(data, y_pos - 0.2) + x_jitter,
                        color=colors[idx], 
                        alpha=1, 
                        s=10, 
                        zorder=3)
            
            ax.tick_params(axis='both', which='both', bottom=False, top=False, left=False,right=False,labelleft=False)

            ax.grid(True, linestyle='--', color='gray', linewidth=2, alpha=0.1)
            if metric != 'shd':
                ax.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f'{x:.2f}'))
            else:
                ax.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f'{x:.0f}'))
                
            for spine_name, spine in ax.spines.items():
                spine.set_linewidth(1.5)
                spine.set_edgecolor('black')
            ax.set_facecolor('white')
    
    # --- Adjust layout of the main figure to make space for external labels ---
    # These values need tuning based on figsize, font size, and number of rows/cols
    # Increase space for the external labels
    adjust_left = 0.12 # Increased space for common ylabel and left tick labels
    adjust_bottom = 0.13 # Increased space for common xlabel and bottom tick labels
    adjust_right = 0.9 # Reduced space on the right for row labels
    adjust_top = 0.9 # Reduced space on the top for column labels
    adjust_wspace = 0.03 # Horizontal space between subplots
    adjust_hspace = 0.13 # Vertical space between subplots

    fig.subplots_adjust(left=adjust_left, bottom=adjust_bottom,
                        right=adjust_right, top=adjust_top,
                        wspace=adjust_wspace, hspace=adjust_hspace)

    # --- Define the style for the facet label boxes ---
    label_box_style = dict(facecolor='white', # Fill color (light grey)
                           edgecolor='black', # Border color
                           linewidth=1.5, # Border line width
                           clip_on=False # Allow drawing outside axes if needed (though we are drawing on fig)
                          )

    # --- Add column labels (top) with gray boxes aligned to columns ---
    column_labels = ['ER2','ER4']
    for j, col_label in enumerate(column_labels):
        # Get the position of the first axes in this column (any axes in the column will have the same width and x-position)
        ax = axes[0, j]
        pos = ax.get_position() # Get Axes position in Figure coordinates (Bbox)

        # Calculate box position and size
        box_x = pos.x0 # Box starts at the left edge of the column's axes
        box_width = pos.width # Box width is the width of the column's axes
        box_height = (1 - adjust_top) * 0.6 # Example: make height 60% of the available space
        box_y = adjust_top + (1 - adjust_top) * 0.1 # Example: position box starting 20% into the available space

        # Draw the gray rectangle
        rect = patches.Rectangle((box_x, box_y), box_width, box_height, **label_box_style)
        fig.add_artist(rect)

        # Place the text label centered within the rectangle
        text_x = box_x + box_width / 2
        text_y = box_y + box_height / 2
        fig.text(text_x, text_y, col_label, ha='center', va='center',
                 fontsize=40) # Removed bbox from fig.text

    # --- Add row labels (right) with gray boxes aligned to rows ---
    row_labels = ['Gauss','Exp']
    for i, row_label in enumerate(row_labels):
        # Get the position of the last axes in this row (any axes in the row will have the same height and y-position)
        ax = axes[i, -1] # Get the rightmost axes in the row
        pos = ax.get_position() # Get Axes position in Figure coordinates (Bbox)

        # Calculate box position and size
        box_y = pos.y0 # Box starts at the bottom edge of the row's axes
        box_height = pos.height # Box height is the height of the row's axes
        box_width = (1 - adjust_right) * 0.6 # Example: make width 60% of the available space
        box_x = adjust_right + (1 - adjust_right) * 0.1 # Example: position box starting 20% into the available space

        # Draw the gray rectangle
        rect = patches.Rectangle((box_x, box_y), box_width, box_height, **label_box_style)
        fig.add_artist(rect)

        # Place the text label centered within the rectangle
        text_x = box_x + box_width / 2
        text_y = box_y + box_height / 2
        fig.text(text_x, text_y, row_label, ha='center', va='center', rotation=270,
                 fontsize=40) # Removed bbox from fig.text

    # --- Add a common X-axis label at the bottom ---
    common_xlabel_y_pos = adjust_bottom / 2 # Or adjust_bottom - (adjust_bottom - 0) * 0.5
    common_ylabel_text = metric_titles.get(metric, metric)
    if num_nt > 0 and num_er > 0:
        fig.text(0.55, common_xlabel_y_pos, common_ylabel_text, ha='center', va='center',
                 fontsize=40)

    # --- Common Y-axis label on the left ---
    # Position it centered vertically, to the left of the 'left' boundary of the axes grid
    # common_ylabel_x_pos = adjust_left / 2 # Halfway between figure left and grid left
    # common_ylabel_y_pos = adjust_bottom + (adjust_top - adjust_bottom) / 2 # Center of the grid height
    # common_ylabel_text = metric_titles.get(metric, metric) # Get title based on metric, fallback to metric name
    # # Add the common ylabel
    # fig.text(common_ylabel_x_pos, common_ylabel_y_pos, common_ylabel_text,
    #          ha='center', va='center', rotation='vertical', fontsize=40)

    if lines_for_legend:
        import matplotlib.patches as mpatches

        legend_keys = lines_for_legend.keys()

        proxy_colors = []
        for key in legend_keys:
            orig = lines_for_legend[key]

            rgba = orig.get_facecolor()[0]  
            proxy_colors.append(rgba)

        alg_name_mapping = {
            'baseline': "Baseline", 
            'dynotears_and_init0': "DYNOTEARS& (Init 0)",
            'dynotears_and_initrandom': "DYNOTEARS& (Init Random)",
            'dynotears_and_initedge': "DYNOTEARS& (Init Edge)",
            'dynotears_and_initdata': "DYNOTEARS& (Init Data)",
            'dynotears_multiply_init0': "DYNOTEARS* (Init 0)",
            'dynotears_multiply_initrandom': "DYNOTEARS* (Init Random)",
            'dynotears_multiply_initedge': "DYNOTEARS* (Init Edge)",
            'dynotears_multiply_initdata': "DYNOTEARS* (Init Data)"
        }

        proxy_handles = []
        proxy_labels = []
        for key, color in zip(legend_keys, proxy_colors):
            proxy_handles.append(
                mpatches.Patch(facecolor=color, edgecolor=color, alpha=0.7)
            )
            proxy_labels.append(alg_name_mapping.get(key, key))

        num_columns = 1
        font_size = 12
        markerscale_factor = 2.0

        legend_fig_width = max(6, len(proxy_labels) / num_columns * 2.5)
        legend_fig_height = 1.5
        fig_legend = plt.figure(figsize=(legend_fig_width, legend_fig_height))

        fig_legend.legend(
            handles=proxy_handles,
            labels=proxy_labels,
            loc='center',
            title="Algorithms",
            title_fontsize=font_size + 1,
            ncol=num_columns,
            fontsize=font_size,
            markerscale=markerscale_factor,
            handlelength=2.5,
            borderaxespad=0.5,
            frameon=True,
            fancybox=True,
            framealpha=0.8
        )

        legend_filepath = os.path.join(os.path.dirname(output_plot_filepath), 'legend_only.pdf')
        print(f"Saving legend to {legend_filepath}")
        fig_legend.savefig(legend_filepath, bbox_inches='tight', dpi=300)
        plt.close(fig_legend)
        
    output_dir = os.path.dirname(output_plot_filepath)
    if output_dir and not os.path.exists(output_dir):
        os.makedirs(output_dir)
    print(f"Saving main plot to {output_plot_filepath}")
    plt.savefig(output_plot_filepath, format='pdf', dpi=300, bbox_inches='tight')
    plt.close(fig) 

# --- Main Execution ---
if __name__ == "__main__":
    # fm.fontManager.clear() 
    plt.rcParams['font.family'] = ['Times New Roman']
    plt.rcParams['font.size'] = 30
    
    target_name = ['timeseries']
    edge_prior_probs = [0.8]
    p_order_list = [3]
    t_values = [250, 1000]
    
    metrics = ['accuracy', 'recall', 'f1', 'shd', 'edge_recovery']
    noises_type_list=['noisegauss','noiseexp']
    edge_ratio_list = [2,4]
    node_list = [30]
    experiment_folders = ['and_initrandom','and_initdata', 'multiply_initrandom', 'multiply_initdata']
    
    base_result_dir = 'result/result_init'
    base_filename = 'merged_base_summary.csv'
    all_filename = 'merged_all_summary.csv'
 
    
    colors = ['tab:red','tab:brown','tab:blue', 'tab:purple']

    file_mapping = {}
    for alg in experiment_folders:
        file_mapping[os.path.join(base_result_dir,alg,base_filename)]=f'baseline_{alg}'
        file_mapping[os.path.join(base_result_dir,alg,all_filename)]=f'dynotears_{alg}'
        
    linestyles = ['-','-','-','-','-','-']
    markers = ['.'] * 6
    metric_titles = {
        'accuracy': 'Accuracy',
        'recall': 'Recall',
        'f1': 'F1 Score',
        'shd': 'SHD',
        'edge_recovery': 'Edge Recovery Rate'
    }

    for target in target_name:
        for edge_prob in edge_prior_probs:
            for p_order in p_order_list:
                for t in t_values:
                    for metric in metrics:
                        output_json_file = f'figure/exp_init/{target}_{edge_prob}_{p_order}_{t}_{metric}.json'
                        output_plot_file = f'figure/exp_init/{target}_{edge_prob}_{p_order}_{t}_{metric}.pdf'

                        # 1. Process Data (No change needed here)
                        processed_results = process_data(
                            target,
                            edge_prob,
                            p_order,
                            t,
                            metric,
                            noises_type_list,
                            edge_ratio_list,
                            node_list,
                            file_mapping_dict=file_mapping,
                            base_filename=base_filename,
                            all_filename=all_filename
                        )

                        if processed_results:
                            # 2. Save JSON (No change needed here)
                            print(f"\n--- Saving Processed Data to JSON ({output_json_file}) ---")
                            output_dir_json = os.path.dirname(output_json_file)
                            if output_dir_json and not os.path.exists(output_dir_json):
                                os.makedirs(output_dir_json); print(f"Created output directory: {output_dir_json}")

                            with open(output_json_file, 'w', encoding='utf-8') as f:
                                json.dump(processed_results, f, indent=4, ensure_ascii=False)

                            # 3. Plot Data - Pass the pre-calculated color map
                            plot_data(
                                data_to_plot=processed_results, 
                                nt_list=noises_type_list,
                                er_list=edge_ratio_list, 
                                node_list=node_list, 
                                metric=metric,
                                file_mapping_dict=file_mapping,
                                linestyles_dict=linestyles, 
                                markers_map=markers, 
                                alg_color_map=colors, 
                                metric_titles=metric_titles,
                                output_plot_filepath=output_plot_file
                            )

