import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import glob
import re
import matplotlib

plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.serif'] = ['Times New Roman']
plt.rcParams['font.size'] = 20
plt.rcParams['mathtext.fontset'] = 'stix'

# --- Configuration ---
ROOT_DIR = 'simulated_data_autoorder'
ALGORITHMS = [
    'simulated_data_and_init0',
    'simulated_data_and_initdata',
    'simulated_data_multiply_init0',
    'simulated_data_multiply_initdata'
]
alg_mapping = {
    'baseline':'Baseline',    
    'simulated_data_and_init0':'DYNOTEARS& (Init 0)',
    'simulated_data_and_initdata':'DYNOTEARS& (Init Data)',
    'simulated_data_multiply_init0':'DYNOTEARS* (Init 0)',
    'simulated_data_multiply_initdata':'DYNOTEARS* (Init Data)'}
SELECTED_REPEAT = 5 # Process only repeat0
PRIOR_DIR_NAME = 'exist_edges_prob_0.8'
RESULT_LOSS_FILE = f'result_constrained_0.txt' # Corrected file for loss
PREDICTED_FILE = f'constrained_multiply_weights_0.csv' # Still needed for max_abs

N_NODES = 30 # Assuming node count is fixed at 30 based on 'node030'

# Output directory for saving figures
OUTPUT_DIR = 'figure/exp_autoorder'
# Ensure the output directory exists
os.makedirs(OUTPUT_DIR, exist_ok=True)
print(f"Output directory '{OUTPUT_DIR}' ensured.")

# Background colors for the subplots
BG_COLORS = ['#FFFFE0', '#e0f2f7', '#fce4ec', '#fffde7', '#e8f5e9'] # Light blue, green, yellow, pink
alg_colors = ['green','tab:orange','tab:brown','tab:cyan', 'tab:purple']

# --- Data Collection ---
all_results = []

calculated_loss = None
calculated_max_abs = None

# --- Get Loss from .txt file ---
baseline_path = os.path.join(ROOT_DIR,'simulated_data_and_init0')
experiment_dirs = glob.glob(os.path.join(baseline_path, 'node*_edge*_porders*_T*_noisegauss'))
porder_exp_map = {} # Map porder (int) to directory path
for exp_dir in experiment_dirs:
    match = re.search(r'porders(\d+)', os.path.basename(exp_dir))
    if match:
        porder = int(match.group(1))
        porder_exp_map[porder] = exp_dir # Assuming only one dir per porder

    # Sort porders found for plotting order
    found_porders = sorted(porder_exp_map.keys())
    for porder in found_porders:
        exp_dir_path = porder_exp_map[porder]

        repeat_path = os.path.join(exp_dir_path, f'repeat{SELECTED_REPEAT}')
        result_loss_file_path = os.path.join(repeat_path, PRIOR_DIR_NAME, f'result_baseline_0.txt') # Path to the loss file
        predicted_path = os.path.join(repeat_path, PRIOR_DIR_NAME, f'baseline_weights_0.csv')

        if os.path.exists(result_loss_file_path):
            with open(result_loss_file_path, 'r') as f:
                loss_str = f.read().strip() # Read content and remove whitespace
            calculated_loss = float(loss_str) # Convert to float
        else:
            calculated_loss = 0

        # --- Calculate Max Absolute Value from .csv file ---
        pred_df = pd.read_csv(predicted_path)

        # Filter for lagged connections (matrix_type 'A') and relevant lags (1 to porder)
        lagged_pred_df = pred_df[(pred_df['matrix_type'] == 'A') & (pred_df['lag'] >= 1) & (pred_df['lag'] <= porder)].copy()

        # Check if all required dest columns exist in the dataframe
        dest_cols = [f'dest_{i}' for i in range(N_NODES)]
        # Collect all lagged weights for max_abs calculation
        lag_k_df = lagged_pred_df[lagged_pred_df['lag'] == porder]
        calculated_max_abs = np.max(np.abs(lag_k_df[dest_cols].values))

        # Store results if successfully collected for BOTH loss and max_abs
        all_results.append({
            'algorithm': 'baseline',
            'porder': porder,
            'loss': calculated_loss,
            'max_abs': calculated_max_abs
        })

print(f"Starting data collection from {ROOT_DIR}, specifically repeat{SELECTED_REPEAT}...")

for algo in ALGORITHMS:
    algo_path = os.path.join(ROOT_DIR, algo)
    
    print(f"Processing algorithm: {algo}")

    # Find experiment directories matching the pattern and extract porder
    experiment_dirs = glob.glob(os.path.join(algo_path, 'node*_edge*_porders*_T*_noisegauss'))

    porder_exp_map = {} # Map porder (int) to directory path
    for exp_dir in experiment_dirs:
        match = re.search(r'porders(\d+)', os.path.basename(exp_dir))
        if match:
            porder = int(match.group(1))
            porder_exp_map[porder] = exp_dir # Assuming only one dir per porder

    # Sort porders found for plotting order
    found_porders = sorted(porder_exp_map.keys())
    for porder in found_porders:
        exp_dir_path = porder_exp_map[porder]

        repeat_path = os.path.join(exp_dir_path, f'repeat{SELECTED_REPEAT}')
        result_loss_file_path = os.path.join(repeat_path, PRIOR_DIR_NAME, RESULT_LOSS_FILE) # Path to the loss file
        predicted_path = os.path.join(repeat_path, PRIOR_DIR_NAME, PREDICTED_FILE)

        calculated_loss = None
        calculated_max_abs = None

        # --- Get Loss from .txt file ---
        if os.path.exists(result_loss_file_path):
            with open(result_loss_file_path, 'r') as f:
                loss_str = f.read().strip() # Read content and remove whitespace
            calculated_loss = float(loss_str) # Convert to float
        else:
            calculated_loss = 0

        # --- Calculate Max Absolute Value from .csv file ---
        pred_df = pd.read_csv(predicted_path)

        # Filter for lagged connections (matrix_type 'A') and relevant lags (1 to porder)
        lagged_pred_df = pred_df[(pred_df['matrix_type'] == 'A') & (pred_df['lag'] >= 1) & (pred_df['lag'] <= porder)].copy()

        # Check if all required dest columns exist in the dataframe
        dest_cols = [f'dest_{i}' for i in range(N_NODES)]
        # Collect all lagged weights for max_abs calculation
        lag_k_df = lagged_pred_df[lagged_pred_df['lag'] == porder]
        calculated_max_abs = np.max(np.abs(lag_k_df[dest_cols].values))

        # Store results if successfully collected for BOTH loss and max_abs
        all_results.append({
            'algorithm': algo,
            'porder': porder,
            'loss': calculated_loss,
            'max_abs': calculated_max_abs
        })
print("Data collection finished.")

# Convert results to DataFrame for easier plotting
results_df = pd.DataFrame(all_results)

# --- Plotting ---

# Sort results by algorithm and porder for correct plotting order
results_df = results_df.sort_values(by=['algorithm', 'porder'])

# Get the list of porders actually found in the data
actual_porders = sorted(results_df['porder'].unique())
# Get algorithms that actually have data
present_algos = results_df['algorithm'].unique()
ALGORITHMS = ['baseline'] + ALGORITHMS
present_algo_colors = {algo: (alg_colors[i],BG_COLORS[i]) for i, algo in enumerate(ALGORITHMS) if algo in present_algos}
present_algo_map = {algo: i for i, algo in enumerate(ALGORITHMS) if algo in present_algos}


# --- Plot 1: Loss vs. Porder (Stacked Line Plots) ---
fig1, axes1 = plt.subplots(len(present_algos), 1, figsize=(10, len(present_algos)*2), sharex=True)

for algo in present_algos:
    algorithm = alg_mapping[algo]
    ax = axes1[present_algo_map[algo]]
    ax.set_xlim(0.5, 10.5)
    algo_data = results_df[results_df['algorithm'] == algo]
    
    ax.plot(algo_data['porder'], algo_data['loss'], marker='o', linestyle='-', label=algorithm, color=present_algo_colors[algo][0], linewidth=3, markersize=10 )
    ax.axvline(x=5.5, color='gray', linestyle='--', linewidth=1.5)

    # Add shaded region to the right of the line
    ax.axvspan(5.5, ax.get_xlim()[1], color='gray', alpha=0.2) # Shade from 5.5 to the right edge
    
    ax.text(0.7, 0.5, r'$L > L_{true}$', # 使用轴坐标的x和y
        transform=ax.transAxes, # X and Y in axis coordinates (0-1)
        ha='center', va='center', fontsize=25, color='black',fontname='Times New Roman')
    
    ax.patch.set_facecolor(present_algo_colors[algo][1])
    ax.patch.set_alpha(0.3)
    ax.grid(True, linestyle='--', alpha=0.6)
    ax.legend(loc='upper right', fontsize=20)

axes1[-1].set_xticks(actual_porders)
axes1[-1].set_xticklabels(actual_porders)

for i in range(len(present_algos)):
    if i < len(present_algos) - 1:
        axes1[i].tick_params(axis='x', labelbottom=False, bottom=False, which='both')

fig1.subplots_adjust(
    left=0.1,   
    right=1,  
    bottom=0.1, 
    top=1,    
    hspace=0.05 
)

fig1.text(
    0.55,                    
    0.01,                    
    r'$L$ (Maximum Lag)',               
    ha='center',            
    va='bottom',             
    fontsize=40
)

fig1.text(
    0.01,                    
    0.5,                     
    r'$F(W,A)$',                 
    ha='left',              
    va='center',            
    rotation='vertical',     
    fontsize=40,
    fontname='Times New Roman'
)

fig1_path = os.path.join(OUTPUT_DIR, f'loss_plot_repeat{SELECTED_REPEAT}.pdf')
fig1.savefig(fig1_path, format='pdf', bbox_inches='tight')
print(f"Saved Loss plot to {fig1_path}")
plt.close(fig1)

# --- Plot 2: Max Absolute Weight vs. Porder (Stacked Bar Plots) ---
fig2, axes2 = plt.subplots(len(present_algos), 1, figsize=(10, len(present_algos)*2), sharex=True)

bar_width = 0.6 # Width of the bars

for algo in present_algos:
    algorithm = alg_mapping[algo]
    ax = axes2[present_algo_map[algo]]
    ax.set_xlim(0.5, 10.5)
    algo_data = results_df[results_df['algorithm'] == algo]

    ax.bar(algo_data['porder'], algo_data['max_abs'], width=bar_width, color=present_algo_colors[algo][0], label=algorithm)
        # Add vertical dashed line
    ax.axvline(x=5.5, color='gray', linestyle='--', linewidth=1.5)

    # Add shaded region to the right of the line
    ax.axvspan(5.5, ax.get_xlim()[1], color='gray', alpha=0.2) # Shade from 5.5 to the right edge
    
    ax.text(0.7, 0.5, r'$L > L_{true}$', 
        transform=ax.transAxes, 
        ha='center', va='center', fontsize=25, color='black',fontname='Times New Roman')

    ax.patch.set_facecolor(present_algo_colors[algo][1])
    ax.patch.set_alpha(0.3)
    ax.grid(axis='y', linestyle='--', alpha=0.6)
    ax.legend(loc='upper right', fontsize=20)
axes2[-1].set_xticks(actual_porders)
axes2[-1].set_xticklabels(actual_porders)

for i in range(len(present_algos)):
    if i < len(present_algos) - 1:
        axes2[i].tick_params(axis='x', labelbottom=False, bottom=False, which='both')

fig2.subplots_adjust(
    left=0.11,   
    right=1,  
    bottom=0.1, 
    top=1,    
    hspace=0.05 
)

fig2.text(
    0.55,                    
    0.01,                    
    r'$L$ (Maximum Lag)',               
    ha='center',            
    va='bottom',             
    fontsize=40
)

fig2.text(
    0.01,                    
    0.5,                     
    r'$max(W_{\tau})$',                 
    ha='left',              
    va='center',            
    rotation='vertical',     
    fontsize=40
)

fig2_path = os.path.join(OUTPUT_DIR, f'max_abs_plot_repeat{SELECTED_REPEAT}.pdf')
fig2.savefig(fig2_path, format='pdf', bbox_inches='tight')
print(f"Saved Max Absolute Weight plot to {fig2_path}")
plt.close(fig2)

