import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import re
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib.colors as mcolors
import matplotlib.cm as cm
import numpy as np
import os


plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.serif'] = ['Times New Roman']
plt.rcParams['font.size'] = 10

def load_weights_from_rows(filepath, n_nodes, p_order, threshold):
    """Loads weights from the row-based CSV format and splits into W, A1, ..., Ap."""
    if not os.path.exists(filepath):
        print(f"Warning: File not found: {filepath}")
        return None

    try:
        # Use 'sep=,' explicitly and handle potential byte order mark (BOM)
        df = pd.read_csv(filepath, sep=',', encoding='utf-8')
        # Remove potential BOM from column names if present
        df.columns = df.columns.str.replace('\ufeff', '', regex=False)
    except Exception as e:
        print(f"Error reading {filepath}: {e}")
        return None

    # Check required columns
    required_cols = ['matrix_type', 'lag', 'source_node'] + [f'dest_{i}' for i in range(n_nodes)]
    if not all(col in df.columns for col in required_cols):
        print(f"Warning: Missing required columns in {filepath}. Expected: {required_cols}. Found: {df.columns.tolist()}. Skipping.")
        return None

    # Initialize matrices
    matrices = [np.zeros((n_nodes, n_nodes)) for _ in range(p_order + 1)]

    # Fill matrices from DataFrame rows
    for index, row in df.iterrows():
        matrix_type = row['matrix_type']
        lag = int(row['lag'])
        source_node = int(row['source_node'])

        # Determine the correct matrix index based on type and lag
        matrix_idx = None
        if matrix_type == 'W' and lag == 0:
            matrix_idx = 0
        elif matrix_type == 'A' and 1 <= lag <= p_order:
            matrix_idx = lag
        # else: unexpected type/lag, handled below

        if matrix_idx is None:
             # print(f"Warning: Skipping row with unexpected matrix_type '{matrix_type}' and lag '{lag}' in {filepath}")
             continue # Skip rows that don't match expected types/lags

        if source_node < 0 or source_node >= n_nodes or matrix_idx > p_order:
             print(f"Warning: Skipping row with invalid source_node ({source_node}) or lag ({lag}) in {filepath}")
             continue

        # Extract destination weights
        dest_cols = [f'dest_{i}' for i in range(n_nodes)]
        if not all(col in row.index for col in dest_cols):
             print(f"Warning: Missing destination columns in row {index} of {filepath}. Skipping row.")
             continue
        try:
            dest_weights = row[dest_cols].values.astype(float)
            for item in range(len(dest_weights)):
                if np.abs(dest_weights[item]) < threshold:
                    dest_weights[item] = 0
        except ValueError as ve:
             print(f"Warning: Could not convert destination weights to float in row {index} of {filepath}: {ve}. Skipping row.")
             continue


        # Place weights into the corresponding matrix
        matrices[matrix_idx][source_node, :] = dest_weights

    return matrices

def load_ground_truth_matrix(filepath, n_nodes, p_order):
    """Loads Ground Truth from the n*(p+1) x n CSV format and splits into W, A1, ..., Ap."""
    if not os.path.exists(filepath):
        print(f"Warning: File not found: {filepath}")
        return None

    try:
        # Ensure no header is assumed
        matrix = pd.read_csv(filepath, header=None).values.astype(float)
    except Exception as e:
        print(f"Error reading {filepath}: {e}")
        return None

    expected_rows = n_nodes * (p_order + 1)
    expected_cols = n_nodes

    if matrix.shape != (expected_rows, expected_cols):
        print(f"Warning: Ground truth matrix shape mismatch in {filepath}. Expected ({expected_rows}, {expected_cols}), got {matrix.shape}. Skipping.")
        return None

    matrices = []
    for i in range(p_order + 1):
        start_row = i * n_nodes
        end_row = start_row + n_nodes
        matrices.append(matrix[start_row:end_row, :])

    return matrices

def load_prior_mask_matrix(filepath, n_nodes):
    """Loads the n x n prior mask matrix."""
    if not os.path.exists(filepath):
        print(f"Warning: File not found: {filepath}")
        return None

    try:
        # Ensure no header is assumed
        matrix = pd.read_csv(filepath, header=None).values.astype(float)
    except Exception as e:
        print(f"Error reading {filepath}: {e}")
        return None

    if matrix.shape != (n_nodes, n_nodes):
        print(f"Warning: Prior mask matrix shape mismatch in {filepath}. Expected ({n_nodes}, {n_nodes}), got {matrix.shape}. Skipping.")
        return None

    return matrix


import matplotlib.colors as mcolors

# Helper function to calculate a simple SHD (element-wise difference after thresholding)
# This is a simplified metric for comparison, not the standard DAG SHD with reversal handling.
def calculate_shd_simple(gt_matrices, estimated_matrices, threshold=0.1):
    """
    Calculates a simplified SHD (element-wise binary difference)
    between corresponding matrices in two lists.
    Returns float('inf') if estimated_matrices is None or invalid.
    """
    if estimated_matrices is None or not estimated_matrices:
        return float('inf') # Cannot calculate SHD if estimated matrices are missing
    
    if len(gt_matrices) != len(estimated_matrices):
        print("Warning: Mismatched number of matrices for SHD calculation. Returning inf.")
        return float('inf')

    total_diff = 0
    for gt_mat, est_mat in zip(gt_matrices, estimated_matrices):
        if gt_mat is None or est_mat is None:
            print(f"Warning: None matrix found for a pair during SHD calculation. Skipping pair.")
            # If a matrix is None, it means the method failed or result is missing for this part.
            # We can count all potential edges in GT for this matrix as 'missing' for simplicity,
            # or just skip and return inf if any matrix is missing.
            # Let's return inf if any matrix is missing for simplicity in filtering.
            return float('inf') # Return inf if any required matrix is None

        # Ensure matrices are numpy arrays for consistent operations
        gt_mat = np.asarray(gt_mat)
        est_mat = np.asarray(est_mat)

        if gt_mat.shape != est_mat.shape:
            print(f"Warning: Mismatched matrix shapes during SHD calculation: {gt_mat.shape} vs {est_mat.shape}. Returning inf.")
            return float('inf')

        # Convert to binary adjacency using the threshold
        gt_adj = (np.abs(gt_mat) > threshold).astype(int)
        est_adj = (np.abs(est_mat) > threshold).astype(int)

        # Sum of element-wise differences
        total_diff += np.sum(gt_adj != est_adj)

    return total_diff

def calculate_shd_simple(gt_matrices, pred_matrices, threshold=0.01):
    if gt_matrices is None or pred_matrices is None:
        return float('inf')

    shd = 0
    min_len = min(len(gt_matrices), len(pred_matrices)) if gt_matrices is not None and pred_matrices is not None else 0

    if min_len == 0:
         return float('inf')

    for i in range(min_len):
        gt_mat = gt_matrices[i]
        pred_mat = pred_matrices[i]

        if gt_mat is not None and pred_mat is not None:
            gt_mat = np.asarray(gt_mat)
            pred_mat = np.asarray(pred_mat)

            if gt_mat.shape != pred_mat.shape:
                 shd += float('inf')
                 continue

            gt_binary = (np.abs(gt_mat) > threshold).astype(int)
            pred_binary = (np.abs(pred_mat) > threshold).astype(int)
            shd += np.sum(np.abs(gt_binary - pred_binary))
        else:
             shd += float('inf')

    return shd

def plot_causal_structures(gt_matrices, prior_mask_matrix, baseline_matrices, and_matrices, multiply_matrices, and_init_matrices, multiply_init_matrices,
                                          n_nodes, p_order, repeat_idx, prior_prob, data_info_dir, save_dir, index, annotation_threshold=0.01):

    grey_color = (230/255, 230/255, 230/255)

    shd_threshold = annotation_threshold

    # SHD Calculation (removed try-except)
    shd_baseline = calculate_shd_simple(gt_matrices, baseline_matrices, threshold=shd_threshold) if gt_matrices is not None and baseline_matrices is not None else float('inf')
    shd_and = calculate_shd_simple(gt_matrices, and_matrices, threshold=shd_threshold) if gt_matrices is not None and and_matrices is not None else float('inf')
    shd_multiply = calculate_shd_simple(gt_matrices, multiply_matrices, threshold=shd_threshold) if gt_matrices is not None and multiply_matrices is not None else float('inf')
    shd_and_init = calculate_shd_simple(gt_matrices, and_init_matrices, threshold=shd_threshold) if gt_matrices is not None and and_init_matrices is not None else float('inf')
    shd_multiply_init = calculate_shd_simple(gt_matrices, multiply_init_matrices, threshold=shd_threshold) if gt_matrices is not None and multiply_init_matrices is not None else float('inf')

    is_and_better = False if shd_and == float('inf') else (shd_and+3 < shd_baseline)
    is_multiply_better = False if shd_multiply == float('inf') else (shd_multiply+3 < shd_baseline)
    is_and_init_better = False if shd_and_init == float('inf') else (shd_and_init+3 < shd_baseline)
    is_multiply_init_better = False if shd_multiply_init == float('inf') else (shd_multiply_init+3 < shd_baseline)
    if not (is_and_better or is_multiply_better or is_and_init_better or is_multiply_init_better):
        print(f"[{data_info_dir} - Repeat {repeat_idx} - Prior Prob {prior_prob}] Condition not met: AND ({shd_and:.2f}) or Multiply ({shd_multiply:.2f}) not strictly better than Baseline ({shd_baseline:.2f}). Skipping plot.")
        return

    shd_list = [shd_and,shd_multiply,shd_and_init,shd_multiply_init]
    min_value = min(shd_list)
    min_index = shd_list.index(min_value)
    
    if min_index == 0:
        best_matrices = and_matrices
        # best_title = f'Dynotears& (Init 0)'
    elif min_index == 1:
        best_matrices = and_matrices
        # best_title = f'Dynotears* (Init 0)'
    elif min_index == 2:
        best_matrices = and_matrices
        # best_title = f'Dynotears& (Init Data)'
    elif min_index == 3:
        best_matrices = and_matrices
        # best_title = f'Dynotears* (Init Data)'
    best_title = f'DYNOTEARS*'
    

    matrices = []
    row_titles = []
    matrices.append(baseline_matrices)
    row_titles.append('Baseline')
    matrices.append(gt_matrices)
    row_titles.append('Ground Truth')
    matrices.append(best_matrices)
    row_titles.append(best_title)

    n_rows = len(row_titles)
    n_cols_weights = p_order + 1
    gs_width_ratios = [1, 0.2] + [1] * n_cols_weights

    fig_height = 2.5 * n_rows
    fig_width = 2.5 * (1.2 + n_cols_weights)
    fig = plt.figure(figsize=(fig_width, fig_height))

    gs = gridspec.GridSpec(n_rows, n_cols_weights + 2, width_ratios=gs_width_ratios, wspace=0.05, hspace=0.05)

    cbar_ax = fig.add_axes([0.91, 0.2, 0.02, 0.6])

    ax_prior = fig.add_subplot(gs[:, 0])
    matshow_prior = ax_prior.matshow(prior_mask_matrix, cmap='binary', vmin=0, vmax=1)
    ax_prior.set_title('Prior Mask', fontsize=20)
    ax_prior.set_xticks(np.arange(-.5, n_nodes-.5, 1), minor=True)
    ax_prior.set_yticks(np.arange(-.5, n_nodes-.5, 1), minor=True)
    ax_prior.grid(which="minor", color="lightgray", linestyle='-', linewidth=0.5)
    ax_prior.tick_params(which="minor", top=False, bottom=False, left=False, right=False)
    ax_prior.set_xticks([])
    ax_prior.set_yticks([])
    ax_prior.set_aspect('equal', adjustable='box')
    for spine in ax_prior.spines.values():
        spine.set_color(grey_color)
        spine.set_linewidth(1.0)
    ax_prior.text(-0.1, 0.5, 'Prior', rotation=90, ha='right', va='center', transform=ax_prior.transAxes, fontsize=20)

    all_weights = []
    for mat_list in matrices:
        if mat_list is not None:
             for m in mat_list:
                if isinstance(m, (np.ndarray, list)):
                    if m is not None:
                         all_weights.extend(np.asarray(m).flatten())

    all_weights = np.array(all_weights)

    vmax = np.abs(all_weights).max()+0.2
    vmin = -vmax
    nodes = sorted([0.0, 0.5, 1.0])
    colors = ["blue", (245/255,245/255,245/255), "red"]
    nodes = [0.0, 0.5, 1.0]

    cmap_weights = mcolors.LinearSegmentedColormap.from_list("RedBlueWhite", list(zip(nodes, colors)))

    col_titles = [r'W$_0$',r'W$_1$',r'W$_2$',r'W$_3$'] 

    for r in range(n_rows):
        current_matrices = matrices[r]
        row_title = row_titles[r]

        for c in range(n_cols_weights):
            ax = fig.add_subplot(gs[r, c + 2])  # 跳过 spacer
            mat = np.asarray(current_matrices[c])
            matshow_obj = ax.matshow(mat, cmap=cmap_weights, vmin=vmin, vmax=vmax)
            norm = mcolors.Normalize(vmin=vmin, vmax=vmax)
            for i in range(n_nodes):
                for j in range(n_nodes):
                    value = mat[i, j]
                    if abs(value) >= annotation_threshold:
                        text_to_add = f"{value:.2f}"
                        normalized_value = norm(value)
                        text_color = 'black' if normalized_value > 0.3 and normalized_value < 0.7 else 'white'
                        ax.text(j, i, text_to_add,
                                    va='center', ha='center',
                                    color=text_color, fontsize=10)

            ax.set_xticks(np.arange(-.5, n_nodes-.5, 1), minor=True)
            ax.set_yticks(np.arange(-.5, n_nodes-.5, 1), minor=True)
            ax.grid(which="minor", color="white", linestyle='-', linewidth=2)
            ax.tick_params(which="minor", top=False, bottom=False, left=False, right=False)
            ax.set_xticks([])
            ax.set_yticks([])
            ax.set_xlabel('')
            ax.set_ylabel('')
            ax.set_aspect('equal', adjustable='box')
            for spine in ax.spines.values():
                spine.set_color('white')
                spine.set_linewidth(1.0)

            if r == 0:
                if c < len(col_titles):
                     ax.set_title(col_titles[c], fontsize=20)

            if c == 0:
                 ax.text(-0.1, 0.5, row_title, rotation=90, ha='right', va='center', transform=ax.transAxes, fontsize=20)

    norm = plt.Normalize(vmin=vmin, vmax=vmax)
    sm = cm.ScalarMappable(cmap=cmap_weights, norm=norm)
    sm.set_array([])
    cbar = fig.colorbar(sm, cax=cbar_ax)
    cbar.ax.tick_params(labelsize=12)
    for spine in cbar_ax.spines.values():
        spine.set_visible(False)

    fig.tight_layout(rect=[0, 0.01, 0.9, 0.98])

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    filename = f'{data_info_dir}_repeat{repeat_idx}_prior{str(prior_prob).replace(".", "_")}_{str(index)}_best_visualization.svg'
    filepath_to_save = os.path.join(save_dir, filename)
    # Save the figure (removed try-except)
    plt.savefig(filepath_to_save, bbox_inches='tight', dpi=300)
    print(f"Saved plot to {filepath_to_save}")

    plt.close(fig)


# --- Main Execution ---
root_dir = 'simulated_data_visualization'
save_base_dir = 'result/result_visualization' # Directory to save the plots

# Define the algorithm subdirectories
alg_dir_and = 'simulated_data_and_init0'
alg_dir_multiply = 'simulated_data_multiply_init0'
alg_dir_and_init = 'simulated_data_and_initdata'
alg_dir_multiply_init = 'simulated_data_multiply_initdata'

# Use the 'and_init0' path as the primary structure reference
base_structure_path = os.path.join(root_dir, alg_dir_and)

# Iterate through data info directories
if not os.path.exists(root_dir):
     print(f"Error: Root data directory not found: {root_dir}")
elif not os.path.exists(base_structure_path):
    print(f"Error: Base structure path for alg_dir_and not found: {base_structure_path}")
else:
    # List directories and filter for the pattern
    data_info_dirs = [d for d in os.listdir(base_structure_path) if os.path.isdir(os.path.join(base_structure_path, d)) and re.match(r'node\d+_edge\d+_porders\d+_T\d+_noisegauss', d)]
    data_info_dirs.sort() # Process in a consistent order

    if not data_info_dirs:
        print(f"No data info directories found in {base_structure_path} matching the pattern.")

    for data_info_dir in data_info_dirs:
        # Extract n_nodes and p_order from the directory name
        match = re.search(r'node(\d+)_edge(\d+)_porders(\d+)_T(\d+)_noisegauss', data_info_dir)
        if not match:
            print(f"Warning: Could not parse n_nodes or p_order from directory name: {data_info_dir}. Skipping.")
            continue
        n_nodes = int(match.group(1))
        p_order = int(match.group(3))
        print(f"Processing {data_info_dir} (n={n_nodes}, p={p_order})")

        repeat_base_path = os.path.join(base_structure_path, data_info_dir)
        repeat_dirs = [r for r in os.listdir(repeat_base_path) if os.path.isdir(os.path.join(repeat_base_path, r)) and r.startswith('repeat')]
        repeat_dirs.sort() # Process in a consistent order

        if not repeat_dirs:
             print(f"No repeat directories found in {repeat_base_path}.")

        for repeat_dir in repeat_dirs:
            repeat_idx_match = re.search(r'repeat(\d+)', repeat_dir)
            if not repeat_idx_match:
                 print(f"Warning: Could not parse repeat index from directory name: {repeat_dir}. Skipping.")
                 continue
            repeat_idx = int(repeat_idx_match.group(1))

            repeat_path_and = os.path.join(base_structure_path, data_info_dir, repeat_dir)
            # Construct path for multiply directory - check if it exists
            repeat_path_multiply = os.path.join(root_dir, alg_dir_multiply, data_info_dir, repeat_dir)
            if not os.path.exists(repeat_path_multiply):
                 print(f"Warning: Corresponding repeat directory for multiply method not found: {repeat_path_multiply}. Dynotears* results will be missing.")
                 # Set multiply_filepath to a non-existent path so load_weights_from_rows returns None
                 multiply_base_path = None
            else:
                 multiply_base_path = repeat_path_multiply


            gt_filepath = os.path.join(repeat_path_and, 'ground_truth.csv') # GT is at repeat level

            prior_prob_dirs = [p for p in os.listdir(repeat_path_and) if os.path.isdir(os.path.join(repeat_path_and, p)) and p.startswith('exist_edges_prob_')]
            prior_prob_dirs.sort() # Process in a consistent order

            if not prior_prob_dirs:
                 print(f"No prior probability directories found in {repeat_path_and}.")

            for prior_prob_dir in prior_prob_dirs:
                # Extract prior probability from directory name (handle potential float conversion issues)
                try:
                    prior_prob_str = prior_prob_dir.replace('exist_edges_prob_', '')
                    prior_prob = float(prior_prob_str)
                except ValueError:
                    print(f"Warning: Could not parse prior probability from directory name: {prior_prob_dir}. Skipping.")
                    continue
                
                for i in range(6):
                    # Construct file paths
                    prior_mask_filepath = os.path.join(repeat_path_and, prior_prob_dir, f'exist_edges_mask_{i}.csv')
                    baseline_filepath = os.path.join(repeat_path_and, prior_prob_dir, f'baseline_weights_{i}.csv')
                    and_filepath = os.path.join(repeat_path_and, prior_prob_dir, f'constrained_multiply_weights_{i}.csv') # Assuming this is Dynotears&
                    and_init_filepath = os.path.join(os.path.join(root_dir, alg_dir_and_init, data_info_dir, repeat_dir), prior_prob_dir, f'constrained_multiply_weights_{i}.csv')
                    multiply_filepath = os.path.join(multiply_base_path, prior_prob_dir, f'constrained_multiply_weights_{i}.csv')
                    multiply_init_filepath = os.path.join(os.path.join(root_dir, alg_dir_multiply_init, data_info_dir, repeat_dir), prior_prob_dir, f'constrained_multiply_weights_{i}.csv') 

                    # Load data using the specific functions
                    prior_mask_matrix = load_prior_mask_matrix(prior_mask_filepath, n_nodes)
                    gt_matrices = load_ground_truth_matrix(gt_filepath, n_nodes, p_order)
                    baseline_matrices = load_weights_from_rows(baseline_filepath, n_nodes, p_order,threshold=0.1)
                    and_matrices = load_weights_from_rows(and_filepath, n_nodes, p_order,threshold=0.1)
                    multiply_matrices = load_weights_from_rows(multiply_filepath, n_nodes, p_order,threshold=0.1)
                    and_init_matrices = load_weights_from_rows(and_init_filepath, n_nodes, p_order,threshold=0.1)
                    multiply_init_matrices = load_weights_from_rows(multiply_init_filepath, n_nodes, p_order,threshold=0.1)

                    # Define save directory for this specific data info directory
                    current_save_dir = os.path.join(save_base_dir, data_info_dir)

                    # Plot
                    plot_causal_structures(gt_matrices, prior_mask_matrix, baseline_matrices, and_matrices, multiply_matrices,
                                        and_init_matrices, multiply_init_matrices, n_nodes, p_order, repeat_idx, prior_prob, data_info_dir, current_save_dir,i , annotation_threshold=0.1)

print("Visualization process finished.")
