#!/usr/bin/env python3

import os
import argparse
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

def setup_plot_style():
    """Set up the plotting style for all figures"""
    sns.set_style("whitegrid")
    plt.rcParams.update({
        'font.size': 14,
        'axes.labelsize': 16,
        'axes.titlesize': 16,
        'xtick.labelsize': 14,
        'ytick.labelsize': 14,
        'legend.fontsize': 14
    })

def load_data(input_dir):
    """Load and preprocess the data for plotting"""
    # Load the data
    charge_df = pd.read_csv(os.path.join(input_dir, 'charge_df.csv'))
    hydrophobicity_df = pd.read_csv(os.path.join(input_dir, 'hydrophobicity_df.csv'))
    length_df = pd.read_csv(os.path.join(input_dir, 'length_df.csv'))
    
    # Rename columns to standardize for plotting function
    charge_df.columns = ['sequence', 'charge', 'obtained-charge']
    hydrophobicity_df.columns = ['sequence', 'hydrophobicity', 'obtained-hydrophobicity']
    length_df.columns = ['sequence', 'length', 'obtained-length']
    
    return charge_df, hydrophobicity_df, length_df

def select_condition_values(df_list):
    """Select values closer to the middle of distribution for plotting"""
    charge_df, hydrophobicity_df, length_df = df_list
    
    # Get unique conditioning values for each property
    unique_charge_values = sorted(charge_df['charge'].unique())
    unique_hydro_values = sorted(hydrophobicity_df['hydrophobicity'].unique())
    unique_length_values = sorted(length_df['length'].unique())
    
    # Find middle index
    charge_middle_idx = len(unique_charge_values) // 2
    hydro_middle_idx = len(unique_hydro_values) // 2
    length_middle_idx = len(unique_length_values) // 2
    
    # Select middle value and adjacent values for charge and hydrophobicity
    selected_charge_values = [
        unique_charge_values[charge_middle_idx - 1],
        unique_charge_values[charge_middle_idx],
        unique_charge_values[charge_middle_idx + 1]
    ]
    selected_hydro_values = [
        unique_hydro_values[hydro_middle_idx - 1],
        unique_hydro_values[hydro_middle_idx],
        unique_hydro_values[hydro_middle_idx + 1]
    ]
    
    # Select middle value and adjacent values for length (no longer avoiding 18)
    selected_length_values = [
        unique_length_values[length_middle_idx - 1],
        unique_length_values[length_middle_idx],
        unique_length_values[length_middle_idx + 1]
    ]
    
    # Rounded values for display in labels
    rounded_charge_values = [round(val, 1) for val in selected_charge_values]
    rounded_hydro_values = [round(val, 2) for val in selected_hydro_values]
    rounded_length_values = [round(val) for val in selected_length_values]
    
    return {
        'charge': (selected_charge_values, rounded_charge_values),
        'hydrophobicity': (selected_hydro_values, rounded_hydro_values),
        'length': (selected_length_values, rounded_length_values)
    }

def create_kde_plot(df, property_name, selected_values, rounded_values, colors, output_dir):
    """Create and save a KDE plot for a specific property"""
    plt.figure(figsize=(10, 6))
    
    property_column = property_name
    obtained_column = f"obtained-{property_name}"
    
    for i, (cond_value, rounded_value) in enumerate(zip(selected_values, rounded_values)):
        subset = df[df[property_column] == cond_value]
        
        # For length, remove outliers over 55
        if property_name == 'length':
            filtered_subset = subset[subset[obtained_column] <= 55]
            print(f"{property_name.capitalize()} {rounded_value}: {len(subset)} samples (using {len(filtered_subset)} after removing outliers > 55)")
            plot_data = filtered_subset
        else:
            print(f"{property_name.capitalize()} {rounded_value}: {len(subset)} samples")
            plot_data = subset
        
        sns.kdeplot(
            plot_data[obtained_column],
            color=colors[i],
            alpha=0.7,
            fill=True,
            bw_adjust=0.5,
            label=f"Conditioned {property_name.capitalize()} = {rounded_value}" + 
                  (f".0" if property_name == "length" else "")
        )
    
    plt.xlabel(f"Obtained {property_name.capitalize()}", fontsize=16)
    plt.ylabel("Density", fontsize=16)
    plt.tick_params(axis='both', labelsize=14)
    
    # Set x-axis limits for length plots
    if property_name == 'length':
        min_val = min(selected_values) - 2
        max_val = max(selected_values) + 2
        plt.xlim(min_val, max_val)
    
    # Move legend to the left for hydrophobicity
    legend_loc = 'upper left' if property_name == 'hydrophobicity' else 'best'
    plt.legend(fontsize=14, loc=legend_loc)
    
    plt.tight_layout()
    
    # Save as SVG
    base_filename = f"{property_name}_distribution"
    plt.savefig(os.path.join(output_dir, f"{base_filename}.svg"), format='svg')
    plt.close()

def create_all_plots(df_list, selected_values, output_dir):
    """Create all plots for the different properties"""
    charge_df, hydrophobicity_df, length_df = df_list
    
    # Define colors for different conditions
    num_conditions = len(selected_values['charge'][0])
    colors = sns.color_palette("Set1", num_conditions)
    
    # Create individual property plots
    create_kde_plot(
        charge_df, 
        'charge', 
        selected_values['charge'][0], 
        selected_values['charge'][1], 
        colors, 
        output_dir
    )
    
    create_kde_plot(
        hydrophobicity_df, 
        'hydrophobicity', 
        selected_values['hydrophobicity'][0], 
        selected_values['hydrophobicity'][1], 
        colors, 
        output_dir
    )
    
    create_kde_plot(
        length_df, 
        'length', 
        selected_values['length'][0], 
        selected_values['length'][1], 
        colors, 
        output_dir
    )

def main():
    # Parse command-line arguments
    parser = argparse.ArgumentParser(description='Generate plots for property conditioning experiment results')
    parser.add_argument('--input_dir', type=str, default='experiments/results/files/property-conditioning', 
                        help='Directory with the experiment result files')
    parser.add_argument('--output_dir', type=str, default='experiments/results/plots/property-conditioning', 
                        help='Directory to save output plots')
    args = parser.parse_args()
    
    # Setup plotting style
    setup_plot_style()
    
    # Create output directory
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Load the data
    df_list = load_data(args.input_dir)
    
    # Select conditioning values
    selected_values = select_condition_values(df_list)
    
    # Create and save plots
    create_all_plots(df_list, selected_values, args.output_dir)
    
    print(f"Plots saved to {args.output_dir}")

if __name__ == "__main__":
    main()