import json
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
import argparse
import os
from statsmodels.stats.multicomp import pairwise_tukeyhsd
import pandas as pd


# Define the specified colors for consistency
ATOM_COLORS = {
    "U_Point":     "#F9E6B1",
    "U_Image":     "#FDFBDA",
    "Redundancy":  "#EBCAC2",
    "Synergy":     "#B8D6AA",
}

# Map internal names to bold, publication-ready legend names
LEGEND_MAP = {
    "U_Point":    r"$\bf{U_1}$ (Point)",
    "U_Image":    r"$\bf{U_2}$ (Image)",
    "Redundancy": r"$\bf{Redundancy}$",
    "Synergy":    r"$\bf{Synergy}$",
}


def find_significant_atom_rigorous(weights_data, alpha=0.05):
    """
    Performs a one-way ANOVA and Tukey's HSD post-hoc test to find if one
    atom's mean weight is significantly greater than all others.
    """
    if not weights_data or len(weights_data) < 2: return None
    atom_names = list(weights_data.keys())
    # ANOVA requires at least 2 samples per group
    weight_lists = [[item[1] for item in w] for w in weights_data.values() if len(w) > 1]
    if len(weight_lists) < 2: return None
    
    f_stat, p_value_anova = stats.f_oneway(*weight_lists)
    if p_value_anova >= alpha: return None

    all_weights = np.concatenate(weight_lists)
    group_labels = np.concatenate([[name] * len(weights) for name, weights in zip(atom_names, weight_lists)])
    tukey_result = pairwise_tukeyhsd(endog=all_weights, groups=group_labels, alpha=alpha)
    results_df = pd.DataFrame(data=tukey_result._results_table.data[1:], columns=tukey_result._results_table.data[0])
    results_df = results_df.astype({'meandiff': float, 'p-adj': float})
    means = {atom: np.mean([item[1] for item in weights]) for atom, weights in weights_data.items()}
    winner_atom = max(means, key=means.get)
    is_dominant = True
    for competitor_atom in atom_names:
        if winner_atom == competitor_atom: continue
        comp1 = results_df[(results_df.group1 == winner_atom) & (results_df.group2 == competitor_atom)]
        comp2 = results_df[(results_df.group1 == competitor_atom) & (results_df.group2 == winner_atom)]
        if not comp1.empty:
            if not (comp1['meandiff'].iloc[0] > 0 and comp1['p-adj'].iloc[0] < alpha): is_dominant = False; break
        elif not comp2.empty:
            if not (comp2['meandiff'].iloc[0] < 0 and comp2['p-adj'].iloc[0] < alpha): is_dominant = False; break
        else: is_dominant = False; break
            
    return winner_atom if is_dominant else None


def create_compact_bar_plot(target_directory: str):
    """
    Loads DAM weights, aggregates data, performs significance testing,
    and generates a compact, publication-quality bar plot.
    """
    report_path = os.path.join(target_directory, 'dam_weights_report.json')
    output_path = os.path.join(target_directory, 'dam_weights_analysis_bar.pdf')

    try:
        with open(report_path, 'r') as f:
            report = json.load(f)
    except FileNotFoundError:
        print(f"Error: The report file was not found at '{report_path}'."); return

    # --- 1. Data Aggregation ---
    atom_order = report['atom_names']
    major_question_types = ['what', 'is', 'how', 'can', 'which']
    processed_weights = {q_type: {atom: [] for atom in atom_order} for q_type in major_question_types + ['others']}
    for q_type, atoms_data in report['weights_by_question_type'].items():
        target_category = q_type if q_type in major_question_types else 'others'
        for atom in atom_order:
            if atom in atoms_data:
                processed_weights[target_category][atom].extend(atoms_data[atom])

    # --- 2. Calculate Stats and Significance ---
    question_types_to_plot = ['what', 'is', 'how', 'can', 'which', 'others']
    plot_stats = {}
    significant_atoms = {}
    for q_type in question_types_to_plot:
        if q_type in processed_weights and processed_weights[q_type][atom_order[0]]:
            plot_stats[q_type] = {}
            for atom in atom_order:
                # Create a temporary list of just the weight values
                weights_only = [item[1] for item in processed_weights[q_type][atom]]
                plot_stats[q_type][atom] = {
                    'mean': np.mean(weights_only) if weights_only else 0,
                    'std': np.std(weights_only) if weights_only else 0,
                }
            significant_atoms[q_type] = find_significant_atom_rigorous(processed_weights[q_type])

    # --- 3. Plotting ---
    plt.style.use('seaborn-v0_8-ticks')
    fig, ax = plt.subplots(figsize=(12, 5)) # Compact figure size

    x = np.arange(len(question_types_to_plot))
    width = 0.18
    
    for i, atom_name in enumerate(atom_order):
        offset = width * (i - 1.5)
        means = [plot_stats.get(q_type, {}).get(atom_name, {}).get('mean', 0) for q_type in question_types_to_plot]
        stds = [plot_stats.get(q_type, {}).get(atom_name, {}).get('std', 0) for q_type in question_types_to_plot]
        
        ax.bar(
            x + offset, means, width,
            label=LEGEND_MAP[atom_name],
            yerr=stds, capsize=3, color=ATOM_COLORS[atom_name],
            edgecolor='black', linewidth=0.7
        )

        # Add significance stars ('*') if an atom is dominant
        for j, q_type in enumerate(question_types_to_plot):
            if significant_atoms.get(q_type) == atom_name:
                star_y = means[j] + stds[j] + 0.01
                ax.text(x[j] + offset, star_y, '*', ha='center', va='bottom', fontsize=24, color='black')

    # --- 4. Final Touches and Formatting ---
    ax.set_ylabel('Mean DAM weight', fontsize=24, fontweight='bold')
    ax.set_xticks(x, [q.capitalize() for q in question_types_to_plot], fontsize=24, fontweight='bold')
    ax.tick_params(axis='y', labelsize=18)
    
    # ax.legend(
    #     loc='upper center', bbox_to_anchor=(0.5, 1.18),
    #     ncol=len(atom_order), frameon=False, fontsize=16
    # )
    
    ax.spines[['top', 'right']].set_visible(False)
    ax.yaxis.grid(True, linestyle='--', which='major', color='grey', alpha=0.3)
    
    # Close the gap between the leftmost bar and the y-axis
    left_limit = -0.45  # Adjust this value to control the space
    right_limit = len(question_types_to_plot) - 0.5
    ax.set_xlim(left_limit, right_limit)
    
    ax.set_ylim(0, 0.60) # Set a clean upper limit
    
    plt.tight_layout(pad=1.5)
    plt.savefig(output_path, format='pdf', bbox_inches='tight')
    print(f"Plot successfully saved to: {output_path}")


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Generates a compact DAM weights analysis bar plot.")
    parser.add_argument("input_dir", type=str, help="Directory containing the dam_weights_report.json file.")
    try:
        import pandas as pd
    except ImportError:
        print("Error: Pandas & Statsmodels are required. Please run: pip install pandas statsmodels"); exit()
    args = parser.parse_args()
    create_compact_bar_plot(args.input_dir)