import argparse
import logging
import os,sys
import random
from pathlib import Path
from collections import defaultdict
import h5py
import numpy as np
import torch
from torch.nn import functional as F
from tqdm import tqdm
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
from adjustText import adjust_text  # pip install adjustText required

# Use relative import path
project_root = os.path.abspath(os.getcwd())
sys.path.append(project_root)
from others.prompt.visualization import visualize_class_probility

# Set font path for visualization
font_path = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "font", "times.ttf")
font_prop = fm.FontEntry(fname=font_path, name="Times New Roman")
fm.fontManager.ttflist.insert(0, font_prop)

# Configure global logger
logger = logging.getLogger(__name__)

# Minimum threshold for considering a category in visualization
min_thred = 10

def set_seed(seed):
    """Set global random seed for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

def get_parser_info():
    """Create and return the argument parser with command-line options."""
    parser = argparse.ArgumentParser(
        description="Compute prediction accuracy for CLIP model on Waterbirds dataset.",
        add_help=True
    )
    parser.add_argument("--model", default="ViT-B-32", type=str, help="Name of the CLIP model to use (default: ViT-B-32)")
    parser.add_argument("--batch_size", default=500, type=int, help="Batch size for processing")
    parser.add_argument("--input_dir", default="./results", help="Path where input data is saved")
    parser.add_argument("--dataset", type=str, default="cocogbv1", help="Dataset to process (default: cocogbv1)")
    parser.add_argument("--output_dir", default="./results/prs", help="Path where output data is saved")
    parser.add_argument("--text_mode", default="openai", help="Text mode: 'simple' or 'openai' (default: openai)")
    parser.add_argument("--embedding_method", default="clip_base", help="Method to load or generate embeddings (e.g., 'clip_base'). Default: clip_base")
    parser.add_argument("--cuda_id", type=str, default="0", help="cuda id")
    return parser

def get_image_embeddings(embedding_method, h5_file, start_idx, end_idx, device):
    """
    Load image embeddings from an HDF5 file for a specific batch.

    Args:
        embedding_method (str): Method to load or generate embeddings (e.g., 'clip_base').
        h5_file (h5py.File): Opened HDF5 file containing image embeddings.
        start_idx (int): Starting index of the batch.
        end_idx (int): Ending index of the batch.
        device (torch.device): Device to load tensors to.

    Returns:
        torch.Tensor: Image embeddings for the specified batch.
    """
    try:
        import extract_image_embedding
        embedding_func = getattr(extract_image_embedding, f"{embedding_method}_embedding")
        logger.info(f"Computing in {embedding_method} method")
        return embedding_func(h5_file, start_idx, end_idx, device, args)
    except ImportError as e:
        raise ImportError(f"Failed to import embedding function for method '{embedding_method}': {e}")

def compute_accuracy(dataset, input_dir, batch_size, device, text_embeddings, embedding_method):
    """
    Compute prediction accuracy for the specified dataset with detailed gender analysis.
    
    Args:
        dataset (str): Name of the dataset to process
        input_dir (str): Directory containing the dataset files
        batch_size (int): Number of samples to process in each batch
        device (torch.device): Device to load tensors to
        text_embeddings (torch.Tensor): Pre-computed text embeddings
        embedding_method (str): Method to extract image embeddings
        
    Returns:
        dict: Accuracy metrics and gender analysis results
    """
    subfolder_name = f"{args.model}_{dataset}"
    input_dir = os.path.join(input_dir, subfolder_name)

    if not os.path.exists(input_dir):
        raise FileNotFoundError(f"Input directory '{input_dir}' does not exist. Run 'extract_clip_info.py' first.")

    with h5py.File(os.path.join(input_dir, "data.h5"), 'r') as f:
        labels_info_dset = f['labels_info']
        total_samples = labels_info_dset.shape[0]
        logger.info(f"Total samples in {dataset}: {total_samples}")

        # Total correct count
        all_correct_counts = 0
        # Dictionary to track statistics for each cocolabel
        cocolabel_stats = defaultdict(lambda: {'female_total': 0, 'female_correct': 0, 'male_total': 0, 'male_correct': 0})
        # Overall gender statistics
        female_total = 0
        female_correct = 0
        male_total = 0
        male_correct = 0

        label_counts = defaultdict(int)
        cocolabel_counts = defaultdict(int)
        cooccurrence_counts = defaultdict(lambda: defaultdict(int))
        # Process data in batches
        for start_idx in tqdm(range(0, total_samples, batch_size), desc=f"Processing {dataset}"):
            end_idx = min(start_idx + batch_size, total_samples)
            try:
                # Load batch data
                if embedding_method == "text_based_decomposition":
                    with h5py.File(f"./results/Text_Based_Decomposition/{subfolder_name}/data.h5", 'r') as fm:
                        image_embeddings_batch = get_image_embeddings(embedding_method,fm, start_idx, end_idx, device)
                else:
                    image_embeddings_batch = get_image_embeddings(embedding_method,f, start_idx, end_idx, device)
                    
                if dataset in ["cocogbv1","cocogbv2"]:
                    labels_batch = labels_info_dset[start_idx:end_idx][:, 0]  # gender labels
                    cocolabel_batch = labels_info_dset[start_idx:end_idx][:, 1:]  # multi-label cocolabels

                    # Count occurrences of each label and co-occurrence counts
                    for label in labels_batch:
                        label_counts[int(label)] += 1
                    
                    for row in cocolabel_batch:
                        for cocolabel in row:
                            if cocolabel != -1:
                                cocolabel_counts[int(cocolabel)] += 1
                    
                    for label,row in zip(labels_batch, cocolabel_batch):
                        for cocolabel in row:
                            if cocolabel != -1:
                                cooccurrence_counts[int(label)][int(cocolabel)] += 1

                else:
                    raise ValueError(f"Dataset '{dataset}' not supported yet.")

                # Convert to PyTorch tensors and move to device
                image_embeddings_batch = image_embeddings_batch.to(device)
                labels_batch = torch.tensor(labels_batch, dtype=torch.long).to(device)

                # Compute predictions 
                predictions = (100 * image_embeddings_batch @ text_embeddings.t()).argmax(dim=1)

                # Calculate total correct predictions
                correct_counts = (predictions == labels_batch).sum().item()
                all_correct_counts += correct_counts

                # Process multiple cocolabels for each sample
                for i in range(end_idx - start_idx):
                    # Get all valid cocolabels for the sample (exclude -1)
                    sample_cocolabels = cocolabel_batch[i]
                    valid_cocolabels = sample_cocolabels[sample_cocolabels != -1]

                    # Get true label and prediction
                    label = labels_batch[i].item()  # 0: female, 1: male
                    prediction = predictions[i].item()
                    is_correct = (prediction == label)

                    # Update overall gender statistics
                    if label == 0:  # female
                        female_total += 1
                        if is_correct:
                            female_correct += 1
                    elif label == 1:  # male
                        male_total += 1
                        if is_correct:
                            male_correct += 1

                    # Update statistics for each valid cocolabel
                    for cocolabel in valid_cocolabels:
                        cocolabel = int(cocolabel)  # ensure integer
                        if label == 0:  # female
                            cocolabel_stats[cocolabel]['female_total'] += 1
                            if is_correct:
                                cocolabel_stats[cocolabel]['female_correct'] += 1
                        elif label == 1:  # male
                            cocolabel_stats[cocolabel]['male_total'] += 1
                            if is_correct:
                                cocolabel_stats[cocolabel]['male_correct'] += 1

                # Clean up memory
                del image_embeddings_batch, labels_batch, predictions
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()

            except Exception as e:
                logger.error(f"Error in batch {start_idx}-{end_idx}: {e}")
                continue

        # Print results - commented out code kept for reference
        # print("Count of 0 and 1 in labels_batch", dict(label_counts))
        # print("Count of each label (excluding -1) in cocolabel_batch", dict(cocolabel_counts))
        # print("Co-occurrence counts between labels_batch and cocolabel_batch")
        # for label in cooccurrence_counts:
        #     dic = cooccurrence_counts[label]
        #     dic = sorted(dic.keys())
        #     print(f"  label={label}:")
        #     for index in dic:
        #         print(f"{index}:{cooccurrence_counts[label][index]}, ",end="")
        #     print()
        # plot_label_cocolabel_bars(label_counts, cocolabel_counts, cooccurrence_counts)

        # Calculate overall accuracy
        overall_accuracy = (all_correct_counts / total_samples) * 100
        female_accuracy = (female_correct / female_total * 100) if female_total > 0 else 0
        male_accuracy = (male_correct / male_total * 100) if male_total > 0 else 0

        # Calculate accuracy for each cocolabel
        cocolabel_accuracies = {}
        # Track the lowest accuracy
        lowest_female_acc = 100.0
        lowest_male_acc = 100.0
        lowest_female_cocolabel = None
        lowest_male_cocolabel = None

        for cocolabel, stats in cocolabel_stats.items():
            female_total_coco = stats['female_total']
            female_correct_coco = stats['female_correct']
            male_total_coco = stats['male_total']
            male_correct_coco = stats['male_correct']

            # Calculate accuracies, return None if total is 0
            female_accuracy_coco = (female_correct_coco / female_total_coco * 100) if female_total_coco > 0 else None
            male_accuracy_coco = (male_correct_coco / male_total_coco * 100) if male_total_coco > 0 else None
            total_coco = female_total_coco + male_total_coco
            total_correct_coco = female_correct_coco + male_correct_coco
            cocolabel_overall_accuracy = (total_correct_coco / total_coco * 100) if total_coco > 0 else None

            # Update lowest accuracy record
            if female_accuracy_coco is not None and female_accuracy_coco < lowest_female_acc and female_total_coco >= min_thred:
                lowest_female_acc = female_accuracy_coco
                lowest_female_cocolabel = cocolabel
            if male_accuracy_coco is not None and male_accuracy_coco < lowest_male_acc and male_total_coco >= min_thred:
                lowest_male_acc = male_accuracy_coco
                lowest_male_cocolabel = cocolabel

            cocolabel_accuracies[cocolabel] = {
                'female_accuracy': female_accuracy_coco,
                'male_accuracy': male_accuracy_coco,
                'overall_accuracy': cocolabel_overall_accuracy,
                'female_total': female_total_coco,
                'male_total': male_total_coco
            }

        plot_acc(cocolabel_accuracies)
        # Return results
        return {
            'overall_accuracy': overall_accuracy,
            'female_accuracy': female_accuracy,
            'male_accuracy': male_accuracy,
            'cocolabel_accuracies': cocolabel_accuracies,
            'lowest_female': (lowest_female_cocolabel, lowest_female_acc),
            'lowest_male': (lowest_male_cocolabel, lowest_male_acc)
        }

def plot_acc(cocolabel_accuracies):
    """
    Plot accuracy comparison between female and male recognition across different object categories.
    
    Args:
        cocolabel_accuracies (dict): Dictionary mapping category IDs to accuracy statistics
    """
    # Filter categories with sufficient samples
    labels = list()
    for label in cocolabel_accuracies.keys():
        female_total = cocolabel_accuracies[label]["female_total"]
        male_total = cocolabel_accuracies[label]["male_total"]
        if female_total <= 10 or male_total <= 10:
            continue    
        labels.append(label)
    
    # COCO category ID to name mapping
    label_name = {
    1: "person",
    2: "bicycle",
    3: "car",
    4: "motorcycle",
    5: "airplane",
    6: "bus",
    7: "train",
    8: "truck",
    9: "boat",
    10: "traffic light",
    11: "fire hydrant",
    13: "stop sign",
    14: "parking meter",
    15: "bench",
    16: "bird",
    17: "cat",
    18: "dog",
    19: "horse",
    20: "sheep",
    21: "cow",
    22: "elephant",
    23: "bear",
    24: "zebra",
    25: "giraffe",
    27: "backpack",
    28: "umbrella",
    31: "handbag",
    32: "tie",
    33: "suitcase",
    34: "frisbee",
    35: "skis",
    36: "snowboard",
    37: "sports ball",
    38: "kite",
    39: "baseball bat",
    40: "baseball glove",
    41: "skateboard",
    42: "surfboard",
    43: "tennis racket",
    44: "bottle",
    46: "wine glass",
    47: "cup",
    48: "fork",
    49: "knife",
    50: "spoon",
    51: "bowl",
    52: "banana",
    53: "apple",
    54: "sandwich",
    55: "orange",
    56: "broccoli",
    57: "carrot",
    58: "hot dog",
    59: "pizza",
    60: "donut",
    61: "cake",
    62: "chair",
    63: "couch",
    64: "potted plant",
    65: "bed",
    67: "dining table",
    70: "toilet",
    72: "tv",
    73: "laptop",
    74: "mouse",
    75: "remote",
    76: "keyboard",
    77: "cell phone",
    78: "microwave",
    79: "oven",
    80: "toaster",
    81: "sink",
    82: "refrigerator",
    84: "book",
    85: "clock",
    86: "vase",
    87: "scissors",
    88: "teddy bear",
    89: "hair drier",
    90: "toothbrush"
}
    # Set global style
    plt.style.use('default')
    mpl.rcParams['font.family'] = 'Times New Roman'
    mpl.rcParams['font.size'] = 14
    mpl.rcParams['axes.labelsize'] = 16
    mpl.rcParams['xtick.labelsize'] = 14
    mpl.rcParams['ytick.labelsize'] = 14
    mpl.rcParams['legend.fontsize'] = 13
    mpl.rcParams['axes.grid'] = False
    mpl.rcParams['axes.spines.top'] = False
    mpl.rcParams['axes.spines.right'] = False

    # Use soft color scheme
    colors = ['#FFC3A0','#A6C8E8']  # Soft coral and blue colors
    filter_list=["person","wine glass","handbag", "suitcase", "umbrella", "horse", "bench","tennis racket", "dog", "bottle","fork","tv","car", "knife","bed","skis","potted plant"]
    # filter_list=["person"]
    # Filter categories
    labels = [label for label in cocolabel_accuracies.keys()
                if cocolabel_accuracies[label]["female_total"] > 10 
                and cocolabel_accuracies[label]["male_total"] > 10
                and label_name[int(label)] not in filter_list]  # Exclude specified categories

    # Prepare data
    label_name_coco = [label_name[int(label)] for label in labels]
    
    y0 = []  # Female accuracy
    y1 = []  # Male accuracy
    for label in labels:
        female_acc = cocolabel_accuracies[label]["female_accuracy"]
        male_acc = cocolabel_accuracies[label]["male_accuracy"]
        y0.append(female_acc if female_acc is not None else 0)
        y1.append(male_acc if male_acc is not None else 0)

    # Create chart
    x = np.arange(len(labels))
    width = 0.35
    
    fig, ax = plt.subplots(figsize=(max(12, len(labels) * 0.6), 6), dpi=800)
    
    # Draw bar chart
    ax.bar(x - width/2, y0, width, label='Female', color=colors[0])
    ax.bar(x + width/2, y1, width, label='Male', color=colors[1])

    # Set labels and title
    ax.set_xlabel('Context Categories', fontsize=27, fontweight='bold')
    ax.set_ylabel('Accuracy (%)', fontsize=27, fontweight='bold')

    # Set x-axis labels
    ax.set_xticks(x)
    ax.set_xticklabels(label_name_coco, rotation=45, ha='right', fontsize=22)

    # Set y-axis range and ticks
    ax.set_ylim(50, 100)
    ax.yaxis.set_major_locator(plt.MultipleLocator(10))
    ax.tick_params(axis='y', labelsize=18)

    # Set axis line thickness
    for spine in ax.spines.values():
        spine.set_linewidth(1.5)  # pixels

    # Add grid lines
    ax.grid(axis='y', linestyle='--', alpha=0.3)

    # Move legend to top center
    ax.legend(ncol=2, 
                bbox_to_anchor=(0.5, 1.02), 
                loc='lower center', 
                frameon=False, 
                columnspacing=1,
                fontsize=25,  # Larger legend font
                labelspacing=1,  # Increase vertical spacing between legend items
                handlelength=2,  # Increase legend marker length
                handletextpad=1  # Increase spacing between marker and text
                )

    # Adjust layout
    plt.subplots_adjust(top=0.85, bottom=0.2)

    # Save figure
    plt.savefig('acc_compare_new.pdf', bbox_inches='tight', dpi=800, transparent=True)
    plt.savefig('acc_compare_new.png', bbox_inches='tight', dpi=800)
    plt.close()

    logger.info("Recognition accuracy comparison plots saved as acc_compare.pdf and acc_compare.png")

def plot_label_cocolabel_bars(label_counts, cocolabel_counts, cooccurrence_counts):
    """
    Plot bar charts showing co-occurrence patterns between gender labels and object categories.
    
    Args:
        label_counts (dict): Counts of gender labels (0: female, 1: male)
        cocolabel_counts (dict): Counts of each object category
        cooccurrence_counts (dict): Co-occurrence counts between gender labels and object categories
    """
    n = 0
    for label in label_counts:
        n += int(label_counts[label])

    def compute_PMI(x_z, x, z):
        """Compute Pointwise Mutual Information.
        
        Args:
            x_z: Joint count of x and z
            x: Count of x
            z: Count of z
            
        Returns:
            float: PMI value
        """
        return float(x_z)*n/(x*z)

    # Get all unique cocolabels (sorted)
    cocolabels = sorted(cocolabel_counts.keys())

    # Prepare y values
    y0 = []  # PMI values for label=0 (female)
    y1 = []  # PMI values for label=1 (male)

    for cocolabel in cocolabels:
        # Check if cocolabel co-occurs with label=0
        if cocolabel in cooccurrence_counts[0]:
            y0.append(compute_PMI(x_z=cooccurrence_counts[0][cocolabel], x=label_counts[0], z=cocolabel_counts[cocolabel]))
        else:
            # No co-occurrence, fill with 0
            y0.append(0)

        # Check if cocolabel co-occurs with label=1
        if cocolabel in cooccurrence_counts[1]:
            y1.append(compute_PMI(x_z=cooccurrence_counts[1][cocolabel], x=label_counts[1], z=cocolabel_counts[cocolabel]))
        else:
            y1.append(0)

    x = np.arange(len(cocolabels))
    width = 0.35

    label_name = {
    1: "person",
    2: "bicycle",
    3: "car",
    4: "motorcycle",
    5: "airplane",
    6: "bus",
    7: "train",
    8: "truck",
    9: "boat",
    10: "traffic light",
    11: "fire hydrant",
    13: "stop sign",
    14: "parking meter",
    15: "bench",
    16: "bird",
    17: "cat",
    18: "dog",
    19: "horse",
    20: "sheep",
    21: "cow",
    22: "elephant",
    23: "bear",
    24: "zebra",
    25: "giraffe",
    27: "backpack",
    28: "umbrella",
    31: "handbag",
    32: "tie",
    33: "suitcase",
    34: "frisbee",
    35: "skis",
    36: "snowboard",
    37: "sports ball",
    38: "kite",
    39: "baseball bat",
    40: "baseball glove",
    41: "skateboard",
    42: "surfboard",
    43: "tennis racket",
    44: "bottle",
    46: "wine glass",
    47: "cup",
    48: "fork",
    49: "knife",
    50: "spoon",
    51: "bowl",
    52: "banana",
    53: "apple",
    54: "sandwich",
    55: "orange",
    56: "broccoli",
    57: "carrot",
    58: "hot dog",
    59: "pizza",
    60: "donut",
    61: "cake",
    62: "chair",
    63: "couch",
    64: "potted plant",
    65: "bed",
    67: "dining table",
    70: "toilet",
    72: "tv",
    73: "laptop",
    74: "mouse",
    75: "remote",
    76: "keyboard",
    77: "cell phone",
    78: "microwave",
    79: "oven",
    80: "toaster",
    81: "sink",
    82: "refrigerator",
    84: "book",
    85: "clock",
    86: "vase",
    87: "scissors",
    88: "teddy bear",
    89: "hair drier",
    90: "toothbrush"
}
    plt.style.use('default')
    mpl.rcParams['font.family'] = 'Times New Roman'  # Use Times New Roman font
    mpl.rcParams['font.size'] = 14
    mpl.rcParams['axes.labelsize'] = 16
    mpl.rcParams['xtick.labelsize'] = 14
    mpl.rcParams['ytick.labelsize'] = 14
    mpl.rcParams['legend.fontsize'] = 13
    mpl.rcParams['axes.grid'] = False
    mpl.rcParams['axes.spines.top'] = False    # Remove top border
    mpl.rcParams['axes.spines.right'] = False 
    # Use soft color scheme
    colors = ['#A6C8E8', '#FFC3A0']  # Soft blue and coral colors

    label_name_coco = [label_name.get(cocolabel, f"Unknown({cocolabel})") for cocolabel in cocolabels]

    # Adjust figure size and DPI
    fig, ax = plt.subplots(figsize=(max(12, len(cocolabels) * 0.6), 6), dpi=600)

    # Draw bar chart with new color scheme
    ax.bar(x - width/2, y0, width, label='Female', color=colors[0])
    ax.bar(x + width/2, y1, width, label='Male', color=colors[1])

    # Set labels and title
    ax.set_xlabel('Object Categories', fontsize=16, fontweight='bold')
    ax.set_ylabel('PMI', fontsize=16, fontweight='bold')

    # Set x-axis labels
    ax.set_xticks(x)
    ax.set_xticklabels(label_name_coco, rotation=45, ha='right', fontsize=14)

    # Move legend to top center
    ax.legend(ncol=2, bbox_to_anchor=(0.5, 1.02), loc='lower center', 
             frameon=False, columnspacing=1)

    # Set y-axis ticks
    ax.yaxis.set_major_locator(plt.MultipleLocator(0.2))  # Adjust based on actual data range

    # Add grid lines (optional)
    ax.grid(axis='y', linestyle='--', alpha=0.3)

    # Adjust layout
    plt.subplots_adjust(top=0.85, bottom=0.2)

    # Save figure
    plt.savefig('cooccurrence_bar.pdf', bbox_inches='tight', dpi=600, transparent=True)
    plt.savefig('cooccurrence_bar.png', bbox_inches='tight', dpi=600)

def plot_cocolabel_accuracies(cocolabel_accuracies, output_path, min_samples=min_thred):
    """
    Plot a bar chart of female and male accuracies for each cocolabel.

    Args:
        cocolabel_accuracies (dict): Dictionary with accuracy info for each cocolabel.
        output_path (Path): Path to save the chart.
        min_samples (int): Minimum sample count to include a cocolabel.
    """
    # Set global style
    plt.style.use('default')
    mpl.rcParams['font.family'] = 'Times New Roman'  # Use Times New Roman font
    mpl.rcParams['font.size'] = 14
    mpl.rcParams['axes.labelsize'] = 16
    mpl.rcParams['xtick.labelsize'] = 14
    mpl.rcParams['ytick.labelsize'] = 14
    mpl.rcParams['legend.fontsize'] = 13
    mpl.rcParams['axes.grid'] = False
    mpl.rcParams['axes.spines.top'] = False
    mpl.rcParams['axes.spines.right'] = False

    # Use soft color scheme
    colors = ['#A6C8E8', '#FFC3A0']  # Soft blue and coral colors

    # Filter out cocolabels with too few samples
    filtered_data = {
        cocolabel: data for cocolabel, data in cocolabel_accuracies.items()
        if (data['female_total'] >= min_samples or data['male_total'] >= min_samples)
    }

    if not filtered_data:
        logger.warning(f"No cocolabels with enough samples (>= {min_samples}) to plot.")
        return

    # Prepare data
    cocolabels = sorted(filtered_data.keys())
    female_accs = [data['female_accuracy'] if data['female_accuracy'] is not None else 0 for data in filtered_data.values()]
    male_accs = [data['male_accuracy'] if data['male_accuracy'] is not None else 0 for data in filtered_data.values()]

    # Set up the chart
    fig, ax = plt.subplots(figsize=(max(12, len(cocolabels) * 0.6), 6), dpi=600)
    x = np.arange(len(cocolabels))
    width = 0.35

    # Draw bar chart
    ax.bar(x - width/2, female_accs, width, label='Female', color=colors[0])
    ax.bar(x + width/2, male_accs, width, label='Male', color=colors[1])

    # Add labels and title
    ax.set_xlabel('Object Categories', fontsize=16, fontweight='bold')
    ax.set_ylabel('Accuracy (%)', fontsize=16, fontweight='bold')
    
    # Set x-axis labels
    ax.set_xticks(x)
    ax.set_xticklabels(cocolabels, rotation=45, ha='right', fontsize=14)

    # Move legend to top center
    ax.legend(ncol=2, bbox_to_anchor=(0.5, 1.02), loc='lower center', 
             frameon=False, columnspacing=1)

    # Set y-axis range and ticks
    ax.set_ylim(0, 100)  # Accuracy range is 0-100%
    ax.yaxis.set_major_locator(plt.MultipleLocator(20))  # One tick every 20%

    # Add grid lines
    ax.grid(axis='y', linestyle='--', alpha=0.3)

    # Adjust layout
    plt.subplots_adjust(top=0.85, bottom=0.2)

    # Save figure
    plt.savefig(str(output_path).replace('.png', '.pdf'), bbox_inches='tight', dpi=600, transparent=True)
    plt.savefig(output_path, bbox_inches='tight', dpi=600)
    plt.close()
    
    logger.info(f"Bar charts saved to {output_path} and {str(output_path).replace('.png', '.pdf')}")

def main(args):
    """
    Main function to compute and log accuracies for the dataset.
    
    Args:
        args: Command-line arguments containing configuration
    """
    try:
        subfolder_name = f"{args.model}_{args.dataset}"
        input_dir = Path(args.input_dir) / subfolder_name
        output_dir = Path(args.output_dir)
        output_dir.mkdir(parents=True, exist_ok=True)

        if not input_dir.exists():
            raise FileNotFoundError(f"Input directory '{input_dir}' does not exist. Run 'extract_clip_info.py' first.")

        # Configure logging to file and console
        logging.basicConfig(
            level=logging.INFO,
            format="%(asctime)s - %(levelname)s - %(message)s",
            handlers=[
                logging.FileHandler(output_dir / f"Console_Info_{subfolder_name}.log", mode='w'),
                logging.StreamHandler()
            ]
        )

        # Set device based on CUDA availability
        device = torch.device(f"cuda:{args.cuda_id}" if torch.cuda.is_available() else "cpu")

        # Load text embeddings
        if args.text_mode in ["simple", "openai"]:
            logger.info(f"Text mode: {args.text_mode}")
            text_file_path = input_dir / f"{args.text_mode}_text.h5"
            with h5py.File(text_file_path, 'r') as f:
                text_embeddings = torch.tensor(f['text_embeddings'][:], dtype=torch.float32).to(device)
        else:
            raise ValueError(f"Text mode '{args.text_mode}' not supported.")

        # Compute accuracy
        results = compute_accuracy(
            dataset=args.dataset,
            input_dir=args.input_dir,
            batch_size=args.batch_size,
            device=device,
            text_embeddings=text_embeddings,
            embedding_method=args.embedding_method
        )

        # Log results
        logger.info(f"Dataset: {args.dataset}")
        logger.info(f"Overall accuracy: {results['overall_accuracy']:.2f}%")
        logger.info(f"Female accuracy: {results['female_accuracy']:.2f}%")
        logger.info(f"Male accuracy: {results['male_accuracy']:.2f}%")

        # Log worst group accuracies
        lowest_female_cocolabel, lowest_female_acc = results['lowest_female']
        lowest_male_cocolabel, lowest_male_acc = results['lowest_male']
        logger.info(f"Lowest female accuracy: {lowest_female_acc:.2f}% (cocolabel: {lowest_female_cocolabel})")
        logger.info(f"Lowest male accuracy: {lowest_male_acc:.2f}% (cocolabel: {lowest_male_cocolabel})")

        # # Save results to CSV file for further analysis
        # csv_file = output_dir / f"cocolabel_accuracies_{subfolder_name}_{args.text_mode}.csv"

        # with open(csv_file, 'w') as f:
        #     f.write("Cocolabel,Female_Accuracy,Male_Accuracy,Overall_Accuracy,Female_Total,Male_Total\n")
        #     for cocolabel, acc in sorted(results['cocolabel_accuracies'].items()):
        #         female_acc = f"{acc['female_accuracy']}" if acc['female_accuracy'] is not None else ""
        #         male_acc = f"{acc['male_accuracy']}" if acc['male_accuracy'] is not None else ""
        #         overall_acc = f"{acc['overall_accuracy']}" if acc['overall_accuracy'] is not None else ""
        #         f.write(f"{cocolabel},{female_acc},{male_acc},{overall_acc},{acc['female_total']},{acc['male_total']}\n")
        # logger.info(f"Results saved to {csv_file}")

        # # Plot bar chart
        # chart_file = output_dir / f"cocolabel_chart_{subfolder_name}_{args.text_mode}.png"
        # plot_cocolabel_accuracies(results['cocolabel_accuracies'], chart_file)

    except Exception as e:
        logger.exception(f"Execution failed: {e}")

if __name__ == "__main__":
    parser = get_parser_info()
    args = parser.parse_args()
    set_seed(42)
    main(args)