import os
import json
import numpy as np
import pandas as pd
import cv2  # using OpenCV for image reading; pip install opencv-python
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

def compute_saturation(img_path, verbose=False):
    """
    Reads the image from img_path and computes the average saturation in [0, 1].
    Returns None if the image cannot be read.
    """
    
    bgr = cv2.imread(img_path)
    if bgr is None:
        if verbose:
            print("    WARNING: Unable to read the image file.")
        return None
    
    hsv = cv2.cvtColor(bgr, cv2.COLOR_BGR2HSV)
    sat_values = hsv[..., 1]
    mean_sat = np.mean(sat_values) / 255.0
    return mean_sat

def collect_scores_and_saturation(root_folder, subfolder_order, 
                                  image_exts=(".jpg", ".png", ".jpeg"), 
                                  verbose=True):
    """
    1. Loads aesthetic and corrupt scores for each subfolder.
    2. Tries to compute color saturation for each image matched by the JSON keys.
    3. Returns a Pandas DataFrame containing columns:
       [type, subfolder, filename, score, saturation].
    """
    records = []
    
    for sf in subfolder_order:
        sf_path = os.path.join(root_folder, sf)
        if verbose:
            print(f"\n--- Processing subfolder: {sf} ---")
        
        # Load aesthetic data
        aes_path = os.path.join(sf_path, "aesthetic.json")
        if os.path.exists(aes_path):
            if verbose:
                print(f"  Reading aesthetic scores from: {aes_path}")
            with open(aes_path, "r") as f:
                aes_data = json.load(f)
        else:
            aes_data = {}
            if verbose:
                print(f"  WARNING: {aes_path} not found. Skipping aesthetic data.")

        # Load corrupt data
        cor_path = os.path.join(sf_path, "corrupt.json")
        if os.path.exists(cor_path):
            if verbose:
                print(f"  Reading corrupt scores from: {cor_path}")
            with open(cor_path, "r") as f:
                cor_data = json.load(f)
        else:
            cor_data = {}
            if verbose:
                print(f"  WARNING: {cor_path} not found. Skipping corrupt data.")
        
        combined_scores = {}
        
        # Aesthetic
        for fn, score in aes_data.items():
            if fn not in combined_scores:
                combined_scores[fn] = []
            combined_scores[fn].append(("aesthetic", float(score)))
        
        # Corrupt
        for fn, score in cor_data.items():
            if fn not in combined_scores:
                combined_scores[fn] = []
            combined_scores[fn].append(("corrupt", float(score)))
        
        # Compute saturation
        for fn, type_score_list in tqdm(combined_scores.items(), desc="Processing images"):
            # Try an image path
            img_path = None
            full_path = os.path.join(sf_path, fn)
            if os.path.isfile(full_path):
                img_path = full_path
            else:
                base_name = os.path.splitext(fn)[0]
                for ext in image_exts:
                    candidate = os.path.join(sf_path, base_name + ext)
                    if os.path.isfile(candidate):
                        img_path = candidate
                        break
            
            sat_val = None
            if img_path:
                sat_val = compute_saturation(img_path, verbose=verbose)
            else:
                if verbose:
                    print(f"    WARNING: Could not find an image file for {fn} in {sf}.")
            
            # Add records
            for (score_type, score_val) in type_score_list:
                records.append({
                    "type": score_type,
                    "subfolder": sf,
                    "filename": fn,
                    "score": score_val,
                    "saturation": sat_val
                })
    
    df = pd.DataFrame(records)
    if verbose:
        print(f"\nCollected {len(df)} records in total.")
    return df

def analyze_correlations(df, verbose=True, figure_dir="figures"):
    """
    Computes correlations between saturation and score for aesthetic/corrupt data.
    Also provides a binned analysis + saves plots to files instead of showing them.
    """
    # Ensure figure directory exists
    os.makedirs(figure_dir, exist_ok=True)

    # ----- AESTHETIC -----
    aes_df = df[(df['type'] == 'aesthetic') & df['saturation'].notna() & df['score'].notna()]
    
    if not aes_df.empty:
        corr_aes = aes_df[['saturation', 'score']].corr(method='pearson')
        if verbose:
            print("\n--- Aesthetic Data Analysis ---")
            print("Correlation (Pearson) between saturation and aesthetic score:")
            print(corr_aes, "\n")
        
        # Binned analysis
        aes_df['sat_bin'] = pd.cut(aes_df['saturation'], bins=np.linspace(0,1,11))
        binned_means = aes_df.groupby('sat_bin')['score'].mean().reset_index()
        binned_count = aes_df.groupby('sat_bin')['score'].count().reset_index()
        
        if verbose:
            print("Average aesthetic score per saturation bin:\n")
            display_df = binned_means.merge(binned_count, on='sat_bin', suffixes=('_mean','_count'))
            print(display_df, "\n")

        # Scatter + regression
        plt.figure(figsize=(8, 6))
        sns.regplot(x='saturation', y='score', data=aes_df,
                    scatter_kws={'alpha':0.3}, line_kws={'color':'red'})
        plt.title('Scatter: Saturation vs. Aesthetic Score')
        plt.xlabel('Saturation (0-1)')
        plt.ylabel('Aesthetic Score')
        out_path = os.path.join(figure_dir, "aesthetic_scatter.png")
        plt.savefig(out_path, dpi=300, bbox_inches='tight')
        if verbose:
            print(f"Saved scatter plot to: {out_path}")
        plt.close()

        # Binned means plot
        binned_means['sat_bin_mid'] = binned_means['sat_bin'].apply(lambda x: x.mid)
        
        plt.figure(figsize=(8, 6))
        plt.plot(binned_means['sat_bin_mid'], binned_means['score'], marker='o', linestyle='-')
        plt.title('Binned Saturation vs. Mean Aesthetic Score')
        plt.xlabel('Saturation Bin Midpoint')
        plt.ylabel('Mean Aesthetic Score')
        plt.grid(True)
        out_path = os.path.join(figure_dir, "aesthetic_binned.png")
        plt.savefig(out_path, dpi=300, bbox_inches='tight')
        if verbose:
            print(f"Saved binned means plot to: {out_path}")
        plt.close()
    
    else:
        if verbose:
            print("\n--- Aesthetic Data Analysis ---")
            print("No aesthetic data available for correlation.\n")

    # ----- CORRUPT -----
    cor_df = df[(df['type'] == 'corrupt') & df['saturation'].notna() & df['score'].notna()]
    if not cor_df.empty:
        corr_cor = cor_df[['saturation', 'score']].corr(method='pearson')
        if verbose:
            print("--- Corrupt Data Analysis ---")
            print("Correlation (Pearson) between saturation and corrupt score:")
            print(corr_cor, "\n")
        
        # Scatter + regression
        plt.figure(figsize=(8, 6))
        sns.regplot(x='saturation', y='score', data=cor_df,
                    scatter_kws={'alpha':0.3}, line_kws={'color':'red'})
        plt.title('Scatter: Saturation vs. Corrupt Score')
        plt.xlabel('Saturation (0-1)')
        plt.ylabel('Corrupt Score')
        out_path = os.path.join(figure_dir, "corrupt_scatter.png")
        plt.savefig(out_path, dpi=300, bbox_inches='tight')
        if verbose:
            print(f"Saved corrupt scatter plot to: {out_path}")
        plt.close()
    else:
        if verbose:
            print("--- Corrupt Data Analysis ---")
            print("No corrupt data available for correlation.\n")

def main(verbose=True):
    root_folder = "/ROOT_DIR"
    subfolder_order = [
        "MODELS_NAME",
    ]
    
    df = collect_scores_and_saturation(root_folder, subfolder_order, verbose=verbose)
    analyze_correlations(df, verbose=verbose, figure_dir="figures")

if __name__ == "__main__":
    main(verbose=True)
