import os
import glob
import pickle
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_style("ticks")
plt.rcParams['figure.dpi'] = 100

def load_exp2_FA_results(save_dir, weight=None, poset=None, block=None, exponents=None, base=None):
    pattern = os.path.join(save_dir, "result_exp2_*.pkl")
    files = sorted(glob.glob(pattern))
    
    results = []
    for fp in files:
        try:
            with open(fp, "rb") as f:
                payload = pickle.load(f)
        except Exception as e:
            print(f"Warning: Failed to load {fp}: {e}")
            continue
        
        config = payload.get("config", {})
        file_weight = config.get("weight")
        file_poset = config.get("poset")
        file_block = config.get("selected_block")
        file_exponents = config.get("exponents")
        file_base = config.get("base")
        
        if file_exponents is not None and isinstance(file_exponents, list):
            file_exponents = [float(x) for x in file_exponents]
        
        fname = os.path.basename(fp)
        stem = os.path.splitext(fname)[0]
        tokens = stem.split("_")
        
        if file_weight is None:
            if "uniform" in fname:
                file_weight = "uniform"
            elif any("exp" in t for t in tokens):
                file_weight = "sweep"
        
        if file_poset is None:
            if "op" in fname:
                file_poset = "op"
            elif "general" in fname:
                file_poset = "general"
            elif file_weight == "uniform":
                if "op" in fname:
                    file_poset = "op"
                elif "general" in fname:
                    file_poset = "general"
            else:
                file_poset = "general"
        
        if file_block is None:
            if len(tokens) >= 4 and tokens[3] in ["upstream", "middle", "downstream"]:
                file_block = tokens[3]
        
        if file_exponents is None:
            for token in tokens:
                if token.startswith("exp") and token != "exp2" and token != "exp1":
                    exp_str = token[3:]
                    file_exponents = [float(x) for x in list(exp_str)]
                    break
        
        if weight is not None and file_weight != weight:
            continue
        if poset is not None and file_poset != poset:
            continue
        if block is not None and file_block != block:
            continue
        if exponents is not None:
            if file_exponents is None:
                continue
            if len(file_exponents) != len(exponents):
                continue
            file_exp_list = [float(x) for x in file_exponents] if (isinstance(file_exponents, list) and len(file_exponents) > 0 and not isinstance(file_exponents[0], float)) else file_exponents
            exp_list = [float(x) for x in exponents] if (isinstance(exponents, list) and len(exponents) > 0 and not isinstance(exponents[0], float)) else exponents
            if not all(abs(f - e) < 1e-6 for f, e in zip(file_exp_list, exp_list)):
                continue
        
        result_dict = {
            "path": fp,
            "weight": file_weight,
            "poset": file_poset,
            "block": file_block,
            "exponents": file_exponents,
            "base": file_base,
            "config": config,
        }
        
        # Check for different data formats
        if "phi_reps_by_base" in payload:
            phi_reps_by_base = payload["phi_reps_by_base"]
            base_array = payload.get("base_array")
            if base_array is None:
                base_array = config.get("base_array")
            if base_array is None:
                base_array = list(phi_reps_by_base.keys())
            feat_names = payload.get("feat_names", [])
            
            if not feat_names:
                feat_names = config.get("feat_names", [])
            
            if base_array is not None and len(base_array) > 0:
                try:
                    base_array = [float(b) for b in base_array]
                except (ValueError, TypeError):
                    base_array = [float(k) if not isinstance(k, float) else k for k in phi_reps_by_base.keys()]
            else:
                base_array = [float(k) if not isinstance(k, float) else k for k in phi_reps_by_base.keys()]
            
            phi_reps_by_base_float = {}
            for k, v in phi_reps_by_base.items():
                try:
                    k_float = float(k) if not isinstance(k, float) else k
                    phi_reps_by_base_float[k_float] = v
                except (ValueError, TypeError):
                    continue
            
            rank_reps_by_base_float = {}
            if "rank_reps_by_base" in payload:
                rank_reps_by_base = payload["rank_reps_by_base"]
                for k, v in rank_reps_by_base.items():
                    try:
                        k_float = float(k) if not isinstance(k, float) else k
                        rank_reps_by_base_float[k_float] = v
                    except (ValueError, TypeError):
                        continue
            
            num_added = 0
            for b in base_array:
                try:
                    b_float = float(b) if not isinstance(b, float) else b
                except (ValueError, TypeError):
                    continue
                    
                matching_key = None
                for k in phi_reps_by_base_float.keys():
                    if abs(k - b_float) < 1e-6:
                        matching_key = k
                        break
                
                if matching_key is not None:
                    if base is not None and abs(b_float - base) > 1e-6:
                        continue
                    base_result = {
                        "path": result_dict["path"],
                        "weight": result_dict["weight"],
                        "poset": result_dict["poset"],
                        "block": result_dict["block"],
                        "exponents": result_dict["exponents"],
                        "base": b_float,
                        "config": result_dict["config"],
                        "phi_reps": phi_reps_by_base_float[matching_key],
                        "feat_names": feat_names,
                    }
                    if matching_key in rank_reps_by_base_float:
                        base_result["rank_reps"] = rank_reps_by_base_float[matching_key]
                    results.append(base_result)
                    num_added += 1
            
            if num_added == 0:
                continue
        elif "phi_reps" in payload:
            result_dict["phi_reps"] = payload["phi_reps"]
            result_dict["feat_names"] = payload.get("feat_names", [])
            results.append(result_dict)
        elif "phi" in payload:
            result_dict["phi"] = payload["phi"]
            result_dict["feat_names"] = payload.get("feat_names", [])
            results.append(result_dict)
        else:
            continue
    
    return results


def plot_barplot_heatmap(
    save_dir,
    save_path=None,
    figsize=(20, 6),
    base_min=None,
    base_max=None,
    horizontal=True,
    fontsize_xtick=None,
    fontsize_xtick_bar=None,
    fontsize_ytick=None,
    fontsize_xlabel=None,
    fontsize_ylabel=None,
    fontsize_legend=None,
    fontsize_colorbar_label=None,
    width_ratios=[1, 1],
):
    sns.set_style("ticks")
    
    if base_min is None:
        base_min = 2**(-8)
    if base_max is None:
        base_max = 2**8
    
    fig, axes = plt.subplots(1, 2, figsize=figsize, gridspec_kw={'width_ratios': width_ratios})
    ax_bar = axes[0]
    ax_heat = axes[1]
    
    feature_order = [
        "age", "sex", "race", "native-country",
        "marital-status", "education", "occupation", "workclass",
        "relationship", "capital-gain", "capital-loss", "hours-per-week"
    ]
    
    op_results = load_exp2_FA_results(save_dir, weight="uniform", poset="op")
    general_results = load_exp2_FA_results(save_dir, weight="uniform", poset="general")
    
    if len(op_results) > 0 and len(general_results) > 0:
        op_data = op_results[0]
        general_data = general_results[0]
        
        op_phi_reps = op_data.get("phi_reps")
        general_phi_reps = general_data.get("phi_reps")
        
        if op_phi_reps is not None and general_phi_reps is not None:
            op_feat_names = op_data.get("feat_names", [])
            general_feat_names = general_data.get("feat_names", [])
            
            op_mean = np.mean(op_phi_reps, axis=0)
            op_ci95 = 1.96 * np.std(op_phi_reps, axis=0, ddof=1) / np.sqrt(op_phi_reps.shape[0]) if op_phi_reps.shape[0] > 1 else np.zeros_like(op_mean)
            
            general_mean = np.mean(general_phi_reps, axis=0)
            general_ci95 = 1.96 * np.std(general_phi_reps, axis=0, ddof=1) / np.sqrt(general_phi_reps.shape[0]) if general_phi_reps.shape[0] > 1 else np.zeros_like(general_mean)
            
            op_vals_ordered = []
            op_ci95_ordered = []
            general_vals_ordered = []
            general_ci95_ordered = []
            
            for feat in feature_order:
                if feat in op_feat_names:
                    idx = op_feat_names.index(feat)
                    op_vals_ordered.append(op_mean[idx])
                    op_ci95_ordered.append(op_ci95[idx])
                else:
                    op_vals_ordered.append(0.0)
                    op_ci95_ordered.append(0.0)
                
                if feat in general_feat_names:
                    idx = general_feat_names.index(feat)
                    general_vals_ordered.append(general_mean[idx])
                    general_ci95_ordered.append(general_ci95[idx])
                else:
                    general_vals_ordered.append(0.0)
                    general_ci95_ordered.append(0.0)
            
            color_op = "#b8d4f0"
            color_general = "#ffc8a3"
            
            if horizontal:
                y = np.arange(len(feature_order))
                height = 0.35
                
                ax_bar.barh(y - height/2, op_vals_ordered, height, xerr=op_ci95_ordered,
                           color=color_op, edgecolor="black", linewidth=0.2, capsize=3,
                           label="Ordered Partition")
                ax_bar.barh(y + height/2, general_vals_ordered, height, xerr=general_ci95_ordered,
                           color=color_general, edgecolor="black", linewidth=0.2, capsize=3,
                           label="General DAG")
                
                ax_bar.axvline(0, color="black", linewidth=0.6, linestyle="--")
                ax_bar.set_yticks(y)
                ax_bar.set_yticklabels(feature_order, fontsize=fontsize_ytick)
                ax_bar.set_ylabel("Feature", fontsize=fontsize_ylabel)
                ax_bar.set_xlabel("Value", fontsize=fontsize_xlabel)
                ax_bar.set_xticks([0.00, 0.01, 0.02, 0.03, 0.04])
                ax_bar.set_xticklabels(['0.00', '0.01', '0.02', '0.03', '0.04'], fontsize=fontsize_xtick_bar if fontsize_xtick_bar is not None else fontsize_xtick)
                ax_bar.invert_yaxis()
            else:
                x = np.arange(len(feature_order))
                width = 0.35
                
                ax_bar.bar(x - width/2, op_vals_ordered, width, yerr=op_ci95_ordered,
                          color=color_op, edgecolor="black", linewidth=0.2, capsize=3,
                          label="Ordered Partition")
                ax_bar.bar(x + width/2, general_vals_ordered, width, yerr=general_ci95_ordered,
                          color=color_general, edgecolor="black", linewidth=0.2, capsize=3,
                          label="Proposed Poset")
                
                ax_bar.axhline(0, color="black", linewidth=0.6, linestyle="--")
                ax_bar.set_xticks(x)
                ax_bar.set_xticklabels(feature_order, rotation=45, ha='right', fontsize=fontsize_xtick_bar if fontsize_xtick_bar is not None else fontsize_xtick)
                ax_bar.set_xlabel("Feature", fontsize=fontsize_xlabel)
                ax_bar.set_ylabel("Value", fontsize=fontsize_ylabel)
            
            ax_bar.legend(fontsize=fontsize_legend)
            ax_bar.xaxis.grid(False)
            ax_bar.set_axisbelow(True)
    else:
        ax_bar.text(0.5, 0.5, "No data", ha="center", va="center", transform=ax_bar.transAxes)
    
    UPSTREAM_BLOCK = ["age", "sex", "race", "native-country"]
    MIDDLE_BLOCK = ["marital-status", "education", "occupation", "workclass"]
    DOWNSTREAM_BLOCK = ["relationship", "capital-gain", "capital-loss", "hours-per-week"]
    
    heatmap_data = []
    feature_list = []
    base_list = None
    
    for block_idx, block in enumerate(["upstream", "middle", "downstream"]):
        if block == "upstream":
            block_features = UPSTREAM_BLOCK
        elif block == "middle":
            block_features = MIDDLE_BLOCK
        else:
            block_features = DOWNSTREAM_BLOCK
        
        for col_idx in range(4):
            exponents = [0.0] * 4
            exponents[col_idx] = 1.0
            exponents = [float(e) for e in exponents]
            
            results = load_exp2_FA_results(save_dir, weight="sweep", block=block, exponents=exponents)
            
            if len(results) == 0:
                continue
            
            results = sorted(results, key=lambda x: x["base"] if x["base"] is not None else 0)
            
            bases = []
            feat_names = None
            phi_reps_list = []
            
            for r in results:
                if r["base"] is not None:
                    bases.append(r["base"])
                    if feat_names is None:
                        feat_names = r.get("feat_names", [])
                    if "phi_reps" in r:
                        phi_reps_list.append(r["phi_reps"])
            
            if len(bases) == 0 or feat_names is None:
                continue
            
            target_feat = block_features[col_idx]
            if target_feat not in feat_names:
                continue
            
            feat_idx = feat_names.index(target_feat)
            
            mean_vals = np.array([np.mean(phi_reps[:, feat_idx]) for phi_reps in phi_reps_list])
            
            valid_base_indices = [i for i, b in enumerate(bases) if base_min <= b <= base_max]
            if len(valid_base_indices) == 0:
                continue
            bases_filtered = [bases[i] for i in valid_base_indices]
            mean_vals_filtered = mean_vals[valid_base_indices]
            
            base_1_idx = None
            for i, b in enumerate(bases_filtered):
                if abs(b - 1.0) < 1e-6:
                    base_1_idx = i
                    break
            
            if base_1_idx is None:
                continue
            
            phi_1 = mean_vals_filtered[base_1_idx]
            if abs(phi_1) < 1e-10:
                percentage = np.ones_like(mean_vals_filtered) * 100.0
            else:
                percentage = (mean_vals_filtered / phi_1) * 100.0
            
            # Store data
            if base_list is None:
                base_list = bases_filtered
            heatmap_data.append(percentage)
            feature_list.append(target_feat)
    
    if len(heatmap_data) == 0:
        ax_heat.text(0.5, 0.5, "No data", ha="center", va="center", transform=ax_heat.transAxes)
    else:
        heatmap_array = np.array(heatmap_data)
        
        vmax = np.max(np.abs(heatmap_array - 100))
        vmin = 100 - vmax
        vmax = 100 + vmax
        
        n_features = len(feature_list)
        n_bases = len(base_list)
        im = ax_heat.imshow(heatmap_array, aspect='auto', cmap='RdBu_r', interpolation='nearest', 
                           vmin=vmin, vmax=vmax, extent=[-0.5, n_bases-0.5, n_features-0.5, -0.5])
        
        text_threshold = 100.0
        
        for i in range(n_features):
            j_left = 0
            value_left = heatmap_array[i, j_left]
            text_str_left = f'{value_left:.1f}'
            if value_left < text_threshold - 20 or value_left > text_threshold + 20:
                text_color_left = 'white'
            else:
                text_color_left = 'black'
            ax_heat.text(j_left - 0.4, i, text_str_left, ha='left', va='center', 
                       color=text_color_left, fontsize=10, fontweight='bold')
            
            j_right = n_bases - 1
            value_right = heatmap_array[i, j_right]
            text_str_right = f'{value_right:.1f}'
            if value_right < text_threshold - 20 or value_right > text_threshold + 20:
                text_color_right = 'white'
            else:
                text_color_right = 'black'
            ax_heat.text(j_right + 0.4, i, text_str_right, ha='right', va='center', 
                       color=text_color_right, fontsize=10, fontweight='bold')
            
            
            
        padding = 0.5
        ax_heat.set_ylim(n_features-0.5+padding, -0.5-padding)
        
        ax_heat.set_yticks([])
        ax_heat.set_yticklabels([])
        
        n_bases = len(base_list)
        if n_bases <= 21:
            ax_heat.set_xticks(range(n_bases))
            base_labels = []
            for b in base_list:
                exp = int(np.log2(b))
                base_labels.append(str(exp))
            ax_heat.set_xticklabels(base_labels, rotation=0, ha='center', fontsize=fontsize_xtick)
        else:
            step = max(1, n_bases // 10)
            ax_heat.set_xticks(range(0, n_bases, step))
            base_labels = []
            for i in range(0, n_bases, step):
                b = base_list[i]
                exp = int(np.log2(b))
                base_labels.append(str(exp))
            ax_heat.set_xticklabels(base_labels, rotation=0, ha='center', fontsize=fontsize_xtick)
        
        ax_heat.set_xlabel('$\log_2(\lambda)$', fontsize=fontsize_xlabel, labelpad=2)
        
        cbar = plt.colorbar(im, ax=ax_heat)
        cbar.set_label('Percentage relative to PSV ($\lambda=1$)', fontsize=fontsize_colorbar_label)
    
    plt.tight_layout()
    ax_heat.xaxis.labelpad = 0
    
    fig.align_xlabels()
    
    if save_path is not None:
        os.makedirs(os.path.dirname(save_path) if os.path.dirname(save_path) else '.', exist_ok=True)
        plt.savefig(save_path, format='pdf', bbox_inches='tight', pad_inches=0.1)
        print(f"Saved plot to {save_path}")
    
    plt.close(fig)


def plot_base_value(
    save_dir,
    save_path=None,
    figsize=(16, 8),
    base_min=None,
    base_max=None,
    marker_size=4,
    line_width=1.5,
    alpha=0.2
):
    sns.set_style("ticks")
    
    if base_min is None:
        base_min = 2**(-11)
    if base_max is None:
        base_max = 2**11
    
    upstream_block = ["age", "sex", "race", "native-country"]
    middle_block = ["marital-status", "education", "occupation", "workclass"]
    downstream_block = ["relationship", "capital-gain", "capital-loss", "hours-per-week"]
    
    fig, axes = plt.subplots(3, 4, figsize=figsize, sharex=True, sharey=False)
    
    for row_idx, block in enumerate(["upstream", "middle", "downstream"]):
        if block == "upstream":
            block_features = upstream_block
        elif block == "middle":
            block_features = middle_block
        else:
            block_features = downstream_block
        
        for col_idx in range(4):
            ax = axes[row_idx, col_idx]
            
            exponents = [0.0] * 4
            exponents[col_idx] = 1.0
            
            exponents = [float(e) for e in exponents]
            
            results = load_exp2_FA_results(save_dir, weight="sweep", block=block, exponents=exponents)
            
            if len(results) == 0:
                ax.text(0.5, 0.5, f"No data", ha="center", va="center", transform=ax.transAxes)
                ax.set_title(block_features[col_idx])
                continue
            
            results = sorted(results, key=lambda x: x["base"] if x["base"] is not None else 0)
            
            bases = []
            feat_names = None
            phi_reps_list = []
            
            for r in results:
                if r["base"] is not None:
                    bases.append(r["base"])
                    if feat_names is None:
                        feat_names = r.get("feat_names", [])
                    if "phi_reps" in r:
                        phi_reps_list.append(r["phi_reps"])
            
            if len(bases) == 0 or feat_names is None:
                ax.text(0.5, 0.5, f"No data", ha="center", va="center", transform=ax.transAxes)
                ax.set_title(block_features[col_idx], fontsize=14)
                continue
            
            target_feat = block_features[col_idx]
            if target_feat not in feat_names:
                ax.text(0.5, 0.5, f"Feature not found", ha="center", va="center", transform=ax.transAxes)
                ax.set_title(target_feat, fontsize=14)
                continue
            
            feat_idx = feat_names.index(target_feat)
            
            mean_vals = np.array([np.mean(phi_reps[:, feat_idx]) for phi_reps in phi_reps_list])
            std_vals = np.array([np.std(phi_reps[:, feat_idx], ddof=1) for phi_reps in phi_reps_list])
            
            valid_indices = [i for i, b in enumerate(bases) if base_min <= b <= base_max]
            if len(valid_indices) == 0:
                ax.text(0.5, 0.5, f"No data in range", ha="center", va="center", transform=ax.transAxes)
                ax.set_title(target_feat, fontsize=14)
                continue
            
            bases_filtered = [bases[i] for i in valid_indices]
            mean_vals_filtered = mean_vals[valid_indices]
            std_vals_filtered = std_vals[valid_indices]
            
            ax.plot(bases_filtered, mean_vals_filtered, marker='o', markersize=marker_size, linewidth=line_width, color='#2E86AB')
            ax.fill_between(bases_filtered, 
                           mean_vals_filtered - 2*std_vals_filtered,
                           mean_vals_filtered + 2*std_vals_filtered,
                           alpha=alpha, color='#2E86AB')
            
            ax.set_xscale('log', base=2)
            ax.set_xlim(base_min/2, base_max*2)
            
            xticks = [2**(-8), 2**(-4), 2**0, 2**4, 2**8]
            ax.set_xticks(xticks)
            ax.set_xticklabels(['$-8$', '$-4$', '$0$', '$4$', '$8$'])
            
            if row_idx == 2:
                ax.set_xlabel('$\log_2(\lambda)$', fontsize=12)
            
            if col_idx == 0:
                ax.set_ylabel('Value', fontsize=12)
            ax.set_title(target_feat, fontsize=14)
            ax.grid(True, which='both', alpha=0.3)
            
            
            base_1_idx = None
            for i, b in enumerate(bases_filtered):
                if abs(b - 1.0) < 1e-6:
                    base_1_idx = i
                    break
            y_at_x1 = mean_vals_filtered[base_1_idx]
            ax.axhline(y_at_x1, color="black", linewidth=0.6, linestyle="--")

    
    fig.align_ylabels()
    plt.tight_layout()
    
    if save_path is not None:
        os.makedirs(os.path.dirname(save_path) if os.path.dirname(save_path) else '.', exist_ok=True)
        plt.savefig(save_path, format='pdf', bbox_inches='tight', pad_inches=0.1)
        print(f"Saved plot to {save_path}")
    
    plt.close(fig)




def plot_rank_value(
    save_dir,
    save_path=None,
    figsize=(16, 8),
    base_min=None,
    base_max=None,
    marker_size=4,
    line_width=1.5,
    alpha=0.2
):
    sns.set_style("ticks")
    
    if base_min is None:
        base_min = 2**(-8)
    if base_max is None:
        base_max = 2**8
    
    upstream_block = ["age", "sex", "race", "native-country"]
    middle_block = ["marital-status", "education", "occupation", "workclass"]
    downstream_block = ["relationship", "capital-gain", "capital-loss", "hours-per-week"]
    
    fig, axes = plt.subplots(3, 4, figsize=figsize, sharex=False, sharey=False)
    
    for row_idx, block in enumerate(["upstream", "middle", "downstream"]):
        if block == "upstream":
            block_features = upstream_block
        elif block == "middle":
            block_features = middle_block
        else:
            block_features = downstream_block
        
        for col_idx in range(4):
            ax = axes[row_idx, col_idx]
            
            exponents = [0.0] * 4
            exponents[col_idx] = 1.0
            
            exponents = [float(e) for e in exponents]
            
            results = load_exp2_FA_results(save_dir, weight="sweep", block=block, exponents=exponents)
            
            if len(results) == 0:
                ax.text(0.5, 0.5, f"No data", ha="center", va="center", transform=ax.transAxes)
                ax.set_title(block_features[col_idx])
                continue
            
            results = sorted(results, key=lambda x: x["base"] if x["base"] is not None else 0)
            
            bases = []
            feat_names = None
            phi_reps_list = []
            rank_reps_list = []
            
            for r in results:
                if r["base"] is not None:
                    bases.append(r["base"])
                    if feat_names is None:
                        feat_names = r.get("feat_names", [])
                    if "phi_reps" in r:
                        phi_reps_list.append(r["phi_reps"])
                    if "rank_reps" in r:
                        rank_reps_list.append(r["rank_reps"])
            
            if len(bases) == 0 or feat_names is None:
                ax.text(0.5, 0.5, f"No data", ha="center", va="center", transform=ax.transAxes)
                ax.set_title(block_features[col_idx])
                continue
            
            if len(phi_reps_list) == 0 or len(rank_reps_list) == 0:
                ax.text(0.5, 0.5, f"Missing data", ha="center", va="center", transform=ax.transAxes)
                ax.set_title(block_features[col_idx])
                continue
            
            target_feat = block_features[col_idx]
            if target_feat not in feat_names:
                ax.text(0.5, 0.5, f"Feature not found", ha="center", va="center", transform=ax.transAxes)
                ax.set_title(target_feat)
                continue
            
            feat_idx = feat_names.index(target_feat)
            
            mean_vals = np.array([np.mean(phi_reps[:, feat_idx]) for phi_reps in phi_reps_list])
            std_vals = np.array([np.std(phi_reps[:, feat_idx], ddof=1) for phi_reps in phi_reps_list])
            mean_ranks = np.array([np.mean(rank_reps[:, feat_idx]) for rank_reps in rank_reps_list])
            
            valid_indices = [i for i, b in enumerate(bases) if base_min <= b <= base_max]
            if len(valid_indices) == 0:
                ax.text(0.5, 0.5, f"No data in range", ha="center", va="center", transform=ax.transAxes)
                ax.set_title(target_feat)
                continue
            
            bases_filtered = [bases[i] for i in valid_indices]
            mean_vals_filtered = mean_vals[valid_indices]
            std_vals_filtered = std_vals[valid_indices]
            mean_ranks_filtered = mean_ranks[valid_indices]
            
            ax.plot(mean_ranks_filtered, mean_vals_filtered, 
                   marker='o', markersize=marker_size, linewidth=line_width, color='#2E86AB')
            ax.fill_between(mean_ranks_filtered,
                           mean_vals_filtered - std_vals_filtered,
                           mean_vals_filtered + std_vals_filtered,
                           alpha=alpha, color='#2E86AB')
            
            if row_idx == 2:
                ax.set_xlabel('Rank', fontsize=12)
            
            if col_idx == 0:
                ax.set_ylabel('Value', fontsize=12)
            ax.set_title(target_feat, fontsize=14)
            ax.grid(True, alpha=0.3)
            
            y_min = np.min(mean_vals_filtered)
            y_max = np.max(mean_vals_filtered)
            if y_min <= 0 <= y_max:
                ax.axhline(0, color="black", linewidth=0.6, linestyle="--")
    
    fig.align_ylabels()
    plt.tight_layout()
    
    if save_path is not None:
        os.makedirs(os.path.dirname(save_path) if os.path.dirname(save_path) else '.', exist_ok=True)
        plt.savefig(save_path, format='pdf', bbox_inches='tight', pad_inches=0.1)
        print(f"Saved plot to {save_path}")
    
    plt.close(fig)




def main():
    import argparse
    
    parser = argparse.ArgumentParser(description="Generate plots for exp2_FA")
    parser.add_argument("--save_dir", type=str, default="save")   
    parser.add_argument("--out_dir", type=str, default="figure")
    args = parser.parse_args()
    
    os.makedirs(args.out_dir, exist_ok=True)
    
    print("Generating plots...")
    
    # Plot 1: Barplot + Heatmap
    print("1. Generating barplot + heatmap...")
    plot_barplot_heatmap(
        save_dir=args.save_dir,
        save_path=os.path.join(args.out_dir, "value_heatmap_income.pdf"),
        horizontal=True,
        figsize=(9, 4),
        width_ratios=[1, 1.7],
        fontsize_xtick=12,
        fontsize_xtick_bar=12,
        fontsize_ytick=13,
        fontsize_xlabel=15,
        fontsize_ylabel=15,
        fontsize_legend=9,
        fontsize_colorbar_label=14
    )
    
    # Plot 2: Base vs Value
    print("2. Generating base vs value plot...")
    plot_base_value(
        save_dir=args.save_dir,
        base_min=2**(-8),
        base_max=2**8,
        save_path=os.path.join(args.out_dir, "base_value_income.pdf"),
        figsize=(10, 5),
        marker_size=2,
        line_width=1.5,
        alpha=0.3
    )
    
    # Plot 3: Rank vs Value
    print("3. Generating rank vs value plot...")
    plot_rank_value(
        save_dir=args.save_dir,
        save_path=os.path.join(args.out_dir, "rank_value_income.pdf"),
        figsize=(10, 5),
        marker_size=2,
        line_width=1.5,
        alpha=0.3
    )
    
    print("All plots generated successfully!")


if __name__ == "__main__":
    main()
