#!/usr/bin/env python3

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import os
import glob

sns.set_style("whitegrid")

classifier_names = {
    "broad-classifier": "General",
    "species-acinetobacterbaumannii-classifier": "A. Baumannii",
    "species-escherichiacoli-classifier": "E. Coli",
    "species-klebsiellapneumoniae-classifier": "K. Pneumoniae",
    "species-pseudomonasaeruginosa-classifier": "P. Aeruginosa",
    "species-staphylococcusaureus-classifier": "S. Aureus"
}

species_name = {
    "acinetobacterbaumannii": "A. Baumannii",
    "escherichiacoli": "E. Coli",
    "klebsiellapneumoniae": "K. Pneumoniae",
    "pseudomonasaeruginosa": "P. Aeruginosa",
    "staphylococcusaureus": "S. Aureus"
}

def drop_unwanted_columns(df):
    return df.drop(columns=[col for col in df.columns if col.startswith("strain") or col in ["Id", "Sequence"]], errors='ignore')

def load_prediction_data(results_dir="experiments/results/files/subset-conditioning"):
    # Dictionary to store all loaded dataframes
    all_results = {}
    
    # Load all-amps predictions
    all_amps_path = os.path.join(results_dir, "all-amps-predictions.csv")
    if os.path.exists(all_amps_path):
        all_results["All AMPs"] = pd.read_csv(all_amps_path)
    
    # Load high-quality AMPs predictions
    hq_amps_path = os.path.join(results_dir, "hq-amps-predictions.csv")
    if os.path.exists(hq_amps_path):
        all_results["HQ AMPs"] = pd.read_csv(hq_amps_path)
    
    # Load species-specific predictions
    for species_id in species_name.keys():
        species_path = os.path.join(results_dir, f"{species_id}-predictions.csv")
        if os.path.exists(species_path):
            all_results[species_id] = pd.read_csv(species_path)
    
    return all_results

def create_comparison_dataframe(all_results):
    # Initialize comparison DataFrame
    comparison_df = pd.DataFrame()
    
    # Add each result set to the comparison DataFrame
    if "All AMPs" in all_results:
        all_amps_means = drop_unwanted_columns(all_results["All AMPs"]).round().mean()
        comparison_df["Unconditional"] = all_amps_means
    
    if "HQ AMPs" in all_results:
        hq_amps_means = drop_unwanted_columns(all_results["HQ AMPs"]).round().mean()
        comparison_df["SC HQ AMPs"] = hq_amps_means
    
    # Add species-specific results
    for species_id, display_name in species_name.items():
        if species_id in all_results:
            species_means = drop_unwanted_columns(all_results[species_id]).round().mean()
            comparison_df[f"SC {display_name}"] = species_means
    
    # Map the row indices to classifier names
    comparison_df.index = [classifier_names.get(idx, idx) for idx in comparison_df.index]
    
    # Define the correct order of classifiers
    desired_order = [
        "General",
        "A. Baumannii",
        "E. Coli",
        "K. Pneumoniae",
        "P. Aeruginosa",
        "S. Aureus"
    ]
    
    # Reorder the index to match desired order
    comparison_df = comparison_df.loc[desired_order]
    
    return comparison_df

def plot_comparison(comparison_df, output_path="experiments/results/plots/subset-conditioning-experiment.svg"):
    # Plotting
    ax = comparison_df.plot(kind='bar', figsize=(15, 6), width=0.8)
    
    # Adjust font sizes
    plt.xlabel('Classifiers', fontsize=20)
    plt.ylabel('Predicted Positives', fontsize=20)
    plt.xticks(rotation=0, ha='center', fontsize=14)
    plt.yticks(fontsize=20)
    
    # Extend y-axis limit to fit legend
    y_max = comparison_df.values.max()
    plt.ylim(0, y_max * 1.1)
    
    # Adjust legend position
    plt.legend(
        title='Generative Modes',
        fontsize=12,
        title_fontsize=12,
        loc='upper right',
        ncol=2
    )
    
    # Create directory if it doesn't exist
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    
    plt.savefig(output_path, format="svg", bbox_inches="tight")
    
    plt.tight_layout()
    plt.show()

def main():
    # Load all prediction data
    all_results = load_prediction_data()
    
    # Create comparison DataFrame
    comparison_df = create_comparison_dataframe(all_results)
    
    # Plot the comparison
    plot_comparison(comparison_df)

if __name__ == "__main__":
    main()