import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

def plot_bar_chart(data_dict, x_ticks, x_labels=None, x_label='', y_label=''):
    """
    Bar chart plotting function
    :param data_dict: Dictionary with series names as keys and data lists as values
    :param x_ticks: X-axis tick positions (indices)
    :param x_labels: X-axis tick labels (optional)
    :param x_label: X-axis title
    :param y_label: Y-axis title
    :param title: Chart title
    """
    # Set chart style (matching line chart color scheme)
    plt.figure(figsize=(18, 12))  # Keep same image size as line_chart.py
    
    # Set bar width and spacing
    bar_width = 0.15
    n_series = len(data_dict)
    index = np.arange(len(x_ticks))
    
    # Plot multi-series bar chart
    # Handle gradient colors for non-50% series
    non_50_series = [name for name in data_dict.keys() if name != "50%"]
    # Sort by percentage value (0% < 33% < 67% < 100%)
    sorted_series = sorted(non_50_series, key=lambda x: float(x.strip('%')))
    # Generate light to deep teal gradient (ensure lighter than red)
    n_teal = len(sorted_series)
    teal_gradient = sns.light_palette("royalblue", n_colors=n_teal)
    # Border color 20% darker than main color
    teal_edge_gradient = [sns.dark_palette(color, n_colors=2)[1] for color in teal_gradient]
      
    # Plot multi-series bar chart
    for i, (series_name, values) in enumerate(data_dict.items()):
        # Generate labels with prefix
        label_with_prefix = f'hight layer ratio α={series_name}'
        if series_name == "50%":
            # Red body, dark red border, white stripes
            plt.bar(index + i * bar_width, values, bar_width, label=label_with_prefix, 
                    facecolor='#FF5252', edgecolor='#C62828', linewidth=3.0, 
                    hatch='//')
            # Add numerical labels
            for j, v in enumerate(values):
                plt.text(index[j] + i * bar_width + bar_width/2, v + max(values)*0.02, f'{v}', 
                         fontsize=22, rotation=45, ha='center', va='bottom', color='red', fontweight='bold')
        else:
            # Get gradient color based on percentage value
            color_idx = sorted_series.index(series_name)
            plt.bar(index + i * bar_width, values, bar_width, label=label_with_prefix, 
                    facecolor=teal_gradient[color_idx], edgecolor=teal_edge_gradient[color_idx], linewidth=1.5)
            # Add numerical labels
            for j, v in enumerate(values):
                if v == 68.4 or v == 18.9 or v == 42.2:
                    plt.text(index[j] + i * bar_width + bar_width/2, v -1, f'{v}', 
                         fontsize=22, rotation=45, ha='center', va='bottom', fontweight='bold')
                else:
                    plt.text(index[j] + i * bar_width + bar_width/2, v + max(values)*0.02, f'{v}', 
                            fontsize=22, rotation=45, ha='center', va='bottom', fontweight='bold')
    
    # Set axis labels and ticks
    if x_labels:
        plt.xticks(index + bar_width * (n_series - 1) / 2, x_labels, fontsize=24)
    else:
        plt.xticks(index + bar_width * (n_series - 1) / 2, x_ticks, fontsize=24)
    plt.xlabel(x_label, fontsize=30)
    plt.ylabel(y_label, fontsize=30)
    plt.yticks(fontsize=24)
    plt.ylim(15, max(max(values) for values in data_dict.values()) * 1.1)  # Set y-axis range
    
    # Set legend
    plt.legend(fontsize=24, frameon=True, edgecolor='black')
    
    # Add grid lines (common in academic charts)
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    
    # Remove top and right borders
    sns.despine()
    
    # Save image
    plt.savefig("./plots/bar_chart.png", dpi=300, bbox_inches="tight", pad_inches=0.1)
    # Set font to avoid Type 3 fonts
    plt.rcParams['pdf.fonttype'] = 42
    plt.rcParams['font.family'] = 'DejaVu Sans'
    plt.savefig("./plots/ratio.pdf", bbox_inches="tight", pad_inches=0.1)
    # plt.show()

if __name__ == '__main__':
    # Sample data
    sample_data = {
        '0%': [35.8, 19.5, 24.2, 43.0, 67.7 ],
        '33%': [32.3, 18.9, 22.4, 41.1, 68.0 ],
        '50%': [38.9, 20.4, 31.0, 44.9, 68.6 ],
        '67%': [34.5, 20.0, 23.8, 41.3, 66.3 ],
        '100%': [30.8, 18.8, 23.9, 42.2, 68.4 ],   
    }
    plot_bar_chart(
        data_dict=sample_data,
        x_ticks=[0, 1, 2, 3, 4],
        x_labels=['Attraction', 'Hotel', 'Restaurant', 'Train', 'Taxi'],
        x_label='Domain',
        y_label='Joint Goal Accuracy (%)'
    )