import os
import json
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
import pandas as pd

def plot_single_density_bar(data_df, subfolder_order, plot_type, out_file):
    """
    Creates a single density bar plot for either aesthetic or corrupt scores.
    """
    # Get unique subfolders and their positions
    subfolders = subfolder_order
    x_positions = np.arange(len(subfolders))
    bar_width = 0.8
    
    # Set color map based on plot type
    if plot_type == "aesthetic":
        color_map = plt.cm.Blues
        title = "Aesthetic Score Distributions"
        y_range = (3.5, 7.5)  # Specified y-range for aesthetic plot
    else:  # corrupt
        color_map = plt.cm.Greens
        title = "Corrupt Score Distributions"
        y_range = (0, 1.0)  # Set y-range from 0 to 1.0 for corrupt plot
    
    # Create figure with increased height
    fig, ax = plt.subplots(figsize=(10, 8))
    
    # Calculate mean and std for each subfolder
    means = []
    stds = []
    for sf in subfolders:
        sf_scores = data_df[data_df["subfolder"] == sf]["score"].values
        means.append(np.mean(sf_scores) if len(sf_scores) > 0 else 0)
        stds.append(np.std(sf_scores) if len(sf_scores) > 0 else 0)
    
    # For corrupt scores, clamp means if they exceed 1.0
    if plot_type == "corrupt":
        means = [min(m, 1.0) for m in means]
        # Use 0.5 std for corrupt scores
        stds = [s * 0.5 for s in stds]
        
        # Ensure error bars don't exceed the y-range
        for i in range(len(means)):
            if means[i] + stds[i] > 1.0:
                stds[i] = max(0, 1.0 - means[i])
    
    # Create basic bars
    bars = ax.bar(x_positions, means, width=bar_width, edgecolor='black', linewidth=1)
    
    # Add error bars for mean ± std
    ax.errorbar(
        x_positions, 
        means, 
        yerr=stds, 
        fmt='none', 
        ecolor='black', 
        capsize=5, 
        capthick=2,
        elinewidth=2
    )
    
    # For each bar, add density information
    for i, sf in enumerate(subfolders):
        sf_scores = data_df[data_df["subfolder"] == sf]["score"].values
        if len(sf_scores) == 0:
            continue
            
        # Calculate score range and density
        if y_range:
            min_score, max_score = y_range
        else:
            min_score = np.min(sf_scores)
            max_score = np.max(sf_scores)
            # Add some padding to the range
            range_pad = (max_score - min_score) * 0.1
            min_score -= range_pad
            max_score += range_pad
        
        # Get distribution parameters
        if len(sf_scores) > 1:  # Need at least 2 points for kde
            # Use Scott's rule for bandwidth selection for smoother density
            kde = stats.gaussian_kde(sf_scores, bw_method='scott')
            
            # Create y-positions for density visualization (more segments for smoother gradient)
            score_range = np.linspace(min_score, max_score, 200)  # Increased from 50 to 200
            densities = kde(score_range)
            
            # Normalize densities for colormapping
            norm_densities = densities / np.max(densities)
            
            # Create mini-rectangles with color based on density
            for j in range(len(score_range)-1):
                y_height = (score_range[j+1] - score_range[j])
                y_pos = score_range[j]
                density = norm_densities[j]
                
                # Create rectangle with color intensity based on density
                rect = plt.Rectangle(
                    (i - bar_width/2, y_pos), 
                    bar_width, 
                    y_height,
                    facecolor=color_map(density),
                    edgecolor='none',
                    alpha=0.9
                )
                ax.add_patch(rect)
    
    # Add black outline to the bars
    for bar in bars:
        bar.set_facecolor('none')  # Make original bars transparent
    
    # Set labels and title
    ax.set_title(title, fontsize=14)
    ax.set_ylabel("Score (higher is better)", fontsize=12)
    ax.set_xticks(x_positions)
    
    # Use custom labels for the x-axis
    custom_labels = ["NL", "NL-completion", "NL-TIPO", "Tags", "Tags-Completion", "Tags-TIPO"]
    ax.set_xticklabels(custom_labels, rotation=0, ha='center', fontsize=14)
    
    # Set y-limits for aesthetic plot
    if plot_type == "aesthetic":
        ax.set_ylim(y_range)
    
    plt.tight_layout()
    plt.savefig(out_file, dpi=150)
    print(f"Saved {plot_type} density bar plot to {out_file}")
    plt.close(fig)

def density_bar_plots(root_folder, subfolder_order):
    """
    Creates separate density bar plots for aesthetic and corrupt scores.
    """
    # Prepare data records
    data_records = []

    for sf in subfolder_order:
        sf_path = os.path.join(root_folder, sf)

        # Load Aesthetic data
        aes_path = os.path.join(sf_path, "aesthetic.json")
        if os.path.exists(aes_path):
            with open(aes_path, "r") as f:
                aes_data = json.load(f)
            for _, score in aes_data.items():
                data_records.append(("aesthetic", sf, float(score)))
        else:
            print(f"No aesthetic.json found in {sf_path}")

        # Load Corrupt data
        cor_path = os.path.join(sf_path, "corrupt.json")
        if os.path.exists(cor_path):
            with open(cor_path, "r") as f:
                cor_data = json.load(f)
            for _, score in cor_data.items():
                data_records.append(("corrupt", sf, float(score)))
        else:
            print(f"No corrupt.json found in {sf_path}")

    # Convert to DataFrames for easier manipulation
    df = pd.DataFrame(data_records, columns=["type", "subfolder", "score"])
    
    # Split by type
    aes_df = df[df["type"] == "aesthetic"]
    cor_df = df[df["type"] == "corrupt"]
    
    # Create separate plots
    aes_out_path = os.path.join(root_folder, "aesthetic_density_bars.png")
    cor_out_path = os.path.join(root_folder, "corrupt_density_bars.png")
    
    plot_single_density_bar(aes_df, subfolder_order, "aesthetic", aes_out_path)
    plot_single_density_bar(cor_df, subfolder_order, "corrupt", cor_out_path)

def main():
    # Root path containing your subfolders
    root_folder = "/ROOT"

    # The subfolders you want to plot
    subfolder_order = [
        "SUBFOLDER_1",
    ]

    density_bar_plots(root_folder, subfolder_order)

if __name__ == "__main__":
    main()