#!/usr/bin/env python3

import os
import argparse
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

def plot_deviation_scatter(df, output_dir):
    """
    Generates scatter plots of AMP prediction vs. property deviation
    with regression lines.

    Args:
        df (pd.DataFrame): DataFrame containing the results.
        output_dir (str): Directory to save the plot.
    """
    print("\n--- Generating Deviation Scatter Plots ---")
    fig, axes = plt.subplots(2, 1, figsize=(8, 10)) # Share Y axis (amp-prediction)

    fig.supylabel('AMP Prediction')

    # Charge Deviation vs AMP Prediction with Regression Line
    sns.regplot(data=df, x='charge_deviation', y='amp-prediction', ax=axes[0],
                scatter_kws={'alpha':0.5, 's':10}, line_kws={'color': 'red'}) # Use regplot and set line color
    axes[0].set_xlabel('MAE Charge')
    axes[0].set_ylabel('')
    
    # Hydrophobicity Deviation vs AMP Prediction with Regression Line
    sns.regplot(data=df, x='hydro_deviation', y='amp-prediction', ax=axes[1],
                scatter_kws={'alpha':0.5, 's':10}, line_kws={'color': 'red'}) # Use regplot and set line color
    axes[1].set_xlabel('MAE Hydrophobicity')
    axes[1].set_ylabel('')
    
    plt.tight_layout()

    # Save the scatter plot
    plot_filename = "deviation_scatter_plots.svg" # Changed format to SVG
    plot_path = os.path.join(output_dir, plot_filename)
    try:
        plt.savefig(plot_path, format='svg', bbox_inches='tight') # Specify format
        print(f"Deviation scatter plot saved to: {plot_path}")
    except Exception as e:
        print(f"Error saving deviation scatter plot: {e}")
    plt.close(fig)

def plot_density_grid(df, conditioning_points, charge_lim, hydro_lim, output_dir):
    """
    Generates a single 2D density plot combining results for specific conditioning
    points (indices 0, 2, 6, 8 if available), each with a different color.

    Args:
        df (pd.DataFrame): DataFrame containing the results.
        conditioning_points (np.ndarray): Array of unique conditioning points.
        charge_lim (tuple): Tuple of (min, max) for the charge axis.
        hydro_lim (tuple): Tuple of (min, max) for the hydrophobicity axis.
        output_dir (str): Directory to save the plot.
    """
    target_indices = [0, 2, 7]
    # Define a list of distinct colors for the plots
    # Using a common qualitative color palette
    plot_colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f']
    light_colors = ['#d5e8f5', '#fae1ca', '#c0f0c0', None, None, None, None, None, None] # TODO: Add more colors


    selected_pairs = []
    for i, original_idx in enumerate(target_indices):
        if original_idx < len(conditioning_points):
            selected_pairs.append({
                "point_data": conditioning_points[original_idx],
                "color": plot_colors[i % len(plot_colors)], # Cycle through colors
                "light_color": light_colors[i % len(light_colors)],
                "id": original_idx
            })
        else:
            print(f"Warning: Conditioning point index {original_idx} is out of bounds (total {len(conditioning_points)} points). Skipping.")

    if not selected_pairs:
        print(f"No conditioning points selected from target indices {target_indices}. Skipping combined density plot.")
        return

    print(f"\n--- Generating Combined Density Plot for {len(selected_pairs)} selected pairs ---")

    # Create a single subplot
    fig, ax = plt.subplots(1, 1, figsize=(20, 10))

    for i, pair_info in enumerate(selected_pairs):
        cond_charge, cond_hydro = pair_info["point_data"]
        current_color = pair_info["color"]
        light_color = pair_info["light_color"]
        original_idx = pair_info["id"]

        print(f"Processing conditioning point (original index {original_idx}): Charge={cond_charge:.2f}, Hydrophobicity={cond_hydro:.2f} with color {current_color}")

        # Filter data for the current conditioning point
        subset_df = df[(df['Conditioned-Charge'] == cond_charge) &
                       (df['Conditioned-Hydrophobicity'] == cond_hydro)]

        if subset_df.empty:
            print(f"  No data found for conditioning point (original index {original_idx}). Skipping this pair.")
            continue

        print(f"  Found {len(subset_df)} samples for this point.")

        # Plot KDE on the single axis (ax) with the assigned color
        try:
            sns.kdeplot(
                data=subset_df,
                x='Obtained-Hydrophobicity',
                y='Obtained-Charge',
                fill=True,  # Fill the KDE for a heatmap effect
                cmap=sns.blend_palette([light_color, current_color], n_colors=10, as_cmap=True),  # Blend from current color to desaturated version
                n_levels=4,  # Number of levels for the filled plot, adjust for smoothness
                alpha=0.9,  # Set transparency to ensure 'X' markers are visible
                linewidths=0,  # Remove contour lines
                ax=ax
            )
        except Exception as e:
            print(f"  Error generating kdeplot for point (original index {original_idx}): {e}. Skipping this pair.")
            continue

        # Overlay the target conditioning point with the same color
        ax.scatter(
            cond_hydro,
            cond_charge,
            color=current_color,
            s=1000, # Size of the marker
            edgecolors='black', # Edge color for better visibility
            marker='X', # Changed marker to 'X'
            zorder=5 # Ensure it's on top
        )

    # Set consistent axis limits for the plot
    ax.set_xlim(hydro_lim)
    ax.set_ylim(charge_lim)

    ax.set_xlabel('Hydrophobicity') # Updated x-axis label
    ax.set_ylabel('Charge') # Updated y-axis label
    
    # Create a custom legend for the target 'X' markers.
    # This will display a single entry: a black 'X' marker with the label "Target".
    legend_elements = [
        plt.Line2D([0], [0], marker='X', color='none', # 'color="none"' ensures no line is drawn, only the marker
                   label='Target',
                   markerfacecolor='black', # Fill color of the 'X' marker
                   markeredgecolor='black', # Edge color of the 'X' marker
                   markersize=36) # Adjust this value for desired legend marker size
    ]
    ax.legend(handles=legend_elements, loc="best", handlelength=1.0) # Display the custom legend, adjust handlelength
    
    sns.despine(ax=ax, left=True, bottom=True) # Add axis lines by keeping bottom and left spines

    plt.tight_layout()

    # Save the combined density plot
    plot_filename = "combined_density_plot.svg" # Updated filename
    plot_path = os.path.join(output_dir, plot_filename)
    try:
        plt.savefig(plot_path, format='svg', bbox_inches='tight')
        print(f"\nCombined density plot saved to: {plot_path}")
    except Exception as e:
        print(f"\nError saving combined density plot: {e}")

    plt.close(fig) # Close the figure to free memory


def plot_2d_conditioning(results_csv_path, output_dir):
    """
    Generates two separate plots:
    1. Scatter plots of AMP prediction vs. property deviation (Charge & Hydrophobicity).
    2. A combined 2D density plot for selected conditioning points (indices 0, 2, 6, 8 if available).
    Plots are saved as SVG files.

    Args:
        results_csv_path (str): Path to the CSV file containing experiment results.
        output_dir (str): Directory to save the generated plots.
    """
    # Set style at the beginning of the plotting function
    sns.set_style("whitegrid")
    plt.rcParams.update({
        'font.size': 50,
        'axes.labelsize': 50,
        'axes.titlesize': 50,
        'xtick.labelsize': 50,
        'ytick.labelsize': 50,
        'legend.fontsize': 50
    })

    print(f"Loading results from: {results_csv_path}")
    try:
        df = pd.read_csv(results_csv_path)
        # --- Check for required columns ---
        required_cols = [
            'Conditioned-Charge', 'Conditioned-Hydrophobicity',
            'Obtained-Charge', 'Obtained-Hydrophobicity',
            'amp-prediction'
        ]
        if not all(col in df.columns for col in required_cols):
            missing = [col for col in required_cols if col not in df.columns]
            print(f"Error: Missing required columns in CSV: {missing}")
            print("Ensure the experiment script was run with prediction enabled.")
            return
    except FileNotFoundError:
        print(f"Error: Results file not found at {results_csv_path}")
        return
    except Exception as e:
        print(f"Error loading CSV file: {e}")
        return

    os.makedirs(output_dir, exist_ok=True)
    print(f"Output directory for plots: {output_dir}")

    # --- Calculate Deviations ---
    df['charge_deviation'] = abs(df['Obtained-Charge'] - df['Conditioned-Charge'])
    df['hydro_deviation'] = abs(df['Obtained-Hydrophobicity'] - df['Conditioned-Hydrophobicity'])
    print("Calculated charge and hydrophobicity deviations.")

    # --- Plot 1: Deviation Scatter Plots ---
    plot_deviation_scatter(df, output_dir)

    # --- Prepare for Plot 2: Density Grid ---
    # Get unique conditioning points
    conditioning_points = df[['Conditioned-Charge', 'Conditioned-Hydrophobicity']].drop_duplicates().values
    print(f"Found {len(conditioning_points)} unique conditioning points for density plots.")

    # Determine overall data range for consistent axes in the density grid
    min_charge = df['Obtained-Charge'].min() + 3
    max_charge = df['Obtained-Charge'].max() - 4
    min_hydro = df['Obtained-Hydrophobicity'].min() + 0.18
    max_hydro = df['Obtained-Hydrophobicity'].max() - 0.14
    # Add a small buffer to the limits
    charge_buffer = (max_charge - min_charge) * 0.05
    hydro_buffer = (max_hydro - min_hydro) * 0.05
    charge_lim = (min_charge - charge_buffer, max_charge + charge_buffer)
    hydro_lim = (min_hydro - hydro_buffer, max_hydro + hydro_buffer)

    # --- Plot 2: Combined Density Plot ---
    plot_density_grid(df, conditioning_points, charge_lim, hydro_lim, output_dir)

    print("\nPlot generation complete.")

def main():
    parser = argparse.ArgumentParser(description='Generate separate plots for deviation analysis and 2D density grids.') # Updated description
    parser.add_argument(
        '--results_csv',
        type=str,
        default='experiments/results/files/2d-property-conditioning/charge_hydrophobicity_conditioned_samples.csv',
        help='Path to the input CSV file with experiment results (must include amp-prediction column).'
    )
    parser.add_argument(
        '--output_dir',
        type=str,
        default='experiments/results/plots/2d-property-conditioning',
        help='Directory to save the generated plots.'
    )
    args = parser.parse_args()

    plot_2d_conditioning(args.results_csv, args.output_dir)

if __name__ == "__main__":
    main()
