#!/usr/bin/env python3

import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import glob
import numpy as np

def main():
    # Set style
    sns.set_style("whitegrid")
    plt.rcParams.update({
        'font.size': 14,
        'axes.labelsize': 14,
        'axes.titlesize': 16,
        'xtick.labelsize': 14,
        'ytick.labelsize': 14,
        'legend.fontsize': 14
    })
    
    # Define input and output directories
    input_dir = "experiments/results/files/amp-data-analysis"
    output_dir = "experiments/results/plots/amp-data-analysis"
    os.makedirs(output_dir, exist_ok=True)
    
    # Create summary table
    create_summary_table(input_dir, output_dir)
    
    # Create boxplots for each property
    create_property_boxplots(input_dir, output_dir)
    
    print(f"Summary table and boxplots saved to {output_dir}")

def create_summary_table(input_dir, output_dir):
    """Create a summary table of mean statistics for all datasets, excluding length."""
    # Get all summary CSV files
    summary_files = glob.glob(os.path.join(input_dir, "*-summary.csv"))
    
    if not summary_files:
        print(f"No summary files found in {input_dir}")
        return
    
    # Load all summary data
    all_summaries = []
    for file_path in summary_files:
        dataset_name = os.path.basename(file_path).replace("-summary.csv", "")
        df = pd.read_csv(file_path)
        # Drop length property
        df = df[df['property'] != 'length']
        df['dataset'] = dataset_name
        all_summaries.append(df)
    
    # Combine all summaries
    combined_summary = pd.concat(all_summaries, ignore_index=True)
    
    # Dataset name mapping
    dataset_display_names = {
        "shuffled-sequences": "shuffled",
        "mutated-sequences": "mutated",
        "random-sequences": "random",
        "curated-AMPs": "AMPs",
        "curated-Non-AMPs": "non-AMPs"
    }
    
    # Pivot the table to have datasets as rows and properties as columns
    pivot_df = combined_summary.pivot(index='dataset', columns='property')
    
    # Create a figure for the table
    fig, ax = plt.subplots(figsize=(14, len(pivot_df) + 2))
    
    # Hide axes
    ax.axis('off')
    ax.axis('tight')
    
    # Format the data for the table - focusing on mean values
    table_data = []
    for dataset in pivot_df.index:
        # Use the mapped dataset name for display
        display_name = dataset_display_names.get(dataset, dataset)
        row = [display_name]
        for prop in ['charge', 'hydrophobicity', 'fitness_score', 'pseudo_perplexity']:
            mean = pivot_df.loc[dataset, ('mean', prop)]
            std = pivot_df.loc[dataset, ('std', prop)]
            row.append(f"{mean:.2f} ± {std:.2f}")
        table_data.append(row)
    
    # Column headers
    columns = ['Dataset', 'Charge', 'Hydrophobicity', 'Fitness Score', 'Pseudo Perplexity']
    
    # Create the table
    table = ax.table(
        cellText=table_data,
        colLabels=columns,
        loc='center',
        cellLoc='center',
        colColours=['#f2f2f2'] * len(columns)
    )
    
    # Style the table
    table.auto_set_font_size(False)
    table.set_fontsize(14)
    table.scale(1.5, 1.5)
    
    # Adjust column widths
    for (i, j), cell in table.get_celld().items():
        if j == 0:  # Dataset column
            cell.set_width(0.3)
        else:
            cell.set_width(0.2)
    
    # Adjust layout
    plt.tight_layout()
    
    # Save the table
    output_path = os.path.join(output_dir, "amp_data_summary_table.svg")
    plt.savefig(output_path, format="svg", bbox_inches="tight")
    plt.close()

def create_property_boxplots(input_dir, output_dir):
    """Create boxplots for each property across all datasets."""
    # Get all data CSV files (not summary files)
    data_files = glob.glob(os.path.join(input_dir, "*-properties.csv"))
    
    if not data_files:
        print(f"No data files found in {input_dir}")
        return
    
    # Dataset name mapping
    dataset_display_names = {
        "shuffled-AMP-sequences": "shuffled",
        "mutated-AMP-sequences": "mutated",
        "random-sequences": "random",
        "curated-AMPs": "AMPs",
        "curated-Non-AMPs": "non-AMPs"
    }
    
    # Load all data
    all_data = []
    for file_path in data_files:
        dataset_name = os.path.basename(file_path).replace("-properties.csv", "")
        df = pd.read_csv(file_path)
        df['dataset'] = dataset_name
        df['display_dataset'] = dataset_display_names.get(dataset_name, dataset_name)
        all_data.append(df)
    
    # Combine all data
    combined_data = pd.concat(all_data, ignore_index=True)
    
    # Print unique dataset values to debug
    print("Available datasets:", combined_data['display_dataset'].unique())
    
    # Define custom order based on the actual values in the data
    # Get unique dataset names from the data
    available_datasets = combined_data['display_dataset'].unique()
    
    # Define preferred order - will use only those that actually exist in the data
    preferred_order = ["AMPs", "non-AMPs", "shuffled", "mutated", "random"]
    dataset_order = [d for d in preferred_order if d in available_datasets]
    
    # Add any additional datasets not in the preferred order
    dataset_order.extend([d for d in available_datasets if d not in dataset_order])
    
    print("Using dataset order:", dataset_order)
    
    # Properties to plot
    properties = ['charge', 'hydrophobicity', 'fitness_score', 'pseudo_perplexity']
    
    # Create a single figure with all properties as subplots
    fig, axes = plt.subplots(len(properties), 1, figsize=(12, 4*len(properties)), sharex=True)
    
    # Create a boxplot for each property
    for i, prop in enumerate(properties):
        ax = axes[i]
        
        # Create the boxplot using the display dataset names with custom order
        sns.boxplot(x='display_dataset', y=prop, data=combined_data, 
                   palette='Set3', fliersize=0, order=dataset_order, ax=ax)
        
        # Calculate reasonable y-axis limits based on percentiles
        lower_bound = combined_data[prop].quantile(0.001)  # 0.1st percentile
        upper_bound = combined_data[prop].quantile(0.999)  # 99.9th percentile
        
        # Add some padding
        y_range = upper_bound - lower_bound
        ax.set_ylim(lower_bound - 0.1 * y_range, upper_bound + 0.1 * y_range)
        
        # Set ylabel but hide xlabel for all but the last subplot
        ax.set_ylabel(prop.replace('_', ' ').title(), fontsize=16)
        
        # Hide x-axis labels for all but the last subplot
        if i < len(properties) - 1:
            ax.set_xlabel('')
            plt.setp(ax.get_xticklabels(), visible=False)
        else:
            ax.set_xlabel('Dataset', fontsize=16)
            plt.setp(ax.get_xticklabels(), rotation=45, ha='right')
    
    # Adjust layout
    plt.tight_layout()
    
    # Save the combined plot
    output_path = os.path.join(output_dir, "amp_data_all_properties_boxplot.svg")
    plt.savefig(output_path, format="svg", bbox_inches="tight")
    plt.close()

if __name__ == "__main__":
    main()
