import os
import json
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import argparse
import numpy as np
from statsmodels.nonparametric.smoothers_lowess import lowess

# Define a function to parse language_name
def parse_language_name(name):
    parts = name.split('_')
    return {
        'automaton_type': parts[0],
        'num_states': int(parts[1][1:]),
        'alphabet_size': int(parts[2][1:]),
        'random_seed': int(parts[3][1:])
    }

# Define a function to load and parse the JSON files
def load_model_results(base_dir):
    results = []
    
    for language_name in os.listdir(base_dir):
        language_dir = os.path.join(base_dir, language_name)
        
        if os.path.isdir(language_dir):
            metadata = parse_language_name(language_name)
            for architecture in ['rnn', 'lstm', 'transformer']:
                arch_dir = os.path.join(language_dir, architecture)
                
                if os.path.isdir(arch_dir):
                    for loss_type in os.listdir(arch_dir):
                        loss_type_dir = os.path.join(arch_dir, loss_type)
                        
                        if os.path.isdir(loss_type_dir):
                            for val_set in ['validation-long', 'validation-short']:
                                val_set_dir = os.path.join(loss_type_dir, val_set)
                                
                                if os.path.isdir(val_set_dir):
                                    for trial_no in os.listdir(val_set_dir):
                                        trial_dir = os.path.join(val_set_dir, trial_no, 'eval')
                                        
                                        if os.path.isdir(trial_dir):
                                            for dataset in ['test', 'test-short-held-out', 'training']:
                                                eval_file = os.path.join(trial_dir, f'{dataset}.json')
                                                
                                                if os.path.exists(eval_file):
                                                    with open(eval_file, 'r') as f:
                                                        data = json.load(f)
                                                    
                                                    result = {
                                                        'language_name': language_name,
                                                        'architecture': architecture,
                                                        'loss_type': loss_type,
                                                        'validation_set': val_set,
                                                        'dataset': dataset,
                                                        'trial_no': trial_no,
                                                        'recognition_cross_entropy': data['scores']['recognition_cross_entropy'],
                                                        'recognition_accuracy': data['scores']['recognition_accuracy'],
                                                        **metadata
                                                    }
                                                    results.append(result)
    
    return results

# Convert the list of results into a Pandas DataFrame
def aggregate_results(results):
    df = pd.DataFrame(results)
    return df
    
def create_trend_plots(df, output_dir, smoothed=False):
    """
    Create trend plots for accuracy vs |Q| and |Σ|, with rows for loss types
    and columns for validation sets.
    
    Args:
        df: Pandas DataFrame with the parsed results
        output_dir: Directory to save the output plots
        smoothed: Boolean, whether to create smoothed trend lines
    """
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # Set font to Times New Roman
    plt.rcParams['font.family'] = 'serif'
    plt.rcParams['font.family'] = 'STIXGeneral'
    plt.rcParams['mathtext.fontset'] = 'dejavuserif'  # For math text if needed
    
    # Set a color palette with distinct colors for architectures
    architecture_colors = {
        'rnn': '#1f77b4',       # Blue
        'lstm': '#ff7f0e',      # Orange
        'transformer': '#2ca02c' # Green
    }
    
    # Map the parameters to their text representations (avoiding LaTeX)
    param_labels = {
        'num_states': '|Q|',
        'alphabet_size': '|Σ|'
    }
    
    # Define datasets, parameters, and get unique loss types
    datasets = ['training', 'test', 'test-short-held-out']
    params = ['num_states', 'alphabet_size']
    val_sets = ['validation-short', 'validation-long']
    
    if df is None:
        print("Warning: DataFrame is None. Cannot create plots.")
        return
        
    if 'loss_type' not in df.columns:
        print("Warning: 'loss_type' column not found in DataFrame. Using a default value.")
        loss_types = ['rec']  # Use a default value
    else:
        loss_types = sorted(df['loss_type'].unique())
        
    # Also check if loss_types is empty:
    if len(loss_types) == 0:
        print("Warning: No loss types found in DataFrame. Using a default value.")
        loss_types = ['default_loss']  # Use a default value
    
    # Function to create a single plot (original or smoothed)
    def plot_trend(ax, data, param, smoothed=False):
        # Get unique architectures in this data subset
        architectures = data['architecture'].unique()
        
        # To collect legend handles
        legend_handles = []
        legend_labels = []
        
        for arch in architectures:
            arch_data = data[data['architecture'] == arch]
            
            # Skip if not enough data
            if len(arch_data) < 2:
                continue
                
            # Group by the parameter and calculate mean and std for accuracy
            grouped = arch_data.groupby(param)['recognition_accuracy'].agg(['mean', 'std', 'count']).reset_index()
            
            # Calculate standard error
            grouped['se'] = grouped['std'] / np.sqrt(grouped['count'])
            
            # For confidence interval (95%), use 1.96 * standard error
            grouped['ci_lower'] = grouped['mean'] - 1.96 * grouped['se']
            grouped['ci_upper'] = grouped['mean'] + 1.96 * grouped['se']
            
            # Ensure confidence bounds are within valid range
            grouped['ci_lower'] = grouped['ci_lower'].clip(0, 1)
            grouped['ci_upper'] = grouped['ci_upper'].clip(0, 1)
            
            x = grouped[param].values
            y = grouped['mean'].values
            
            # Sort by x values to ensure proper plotting
            sort_idx = np.argsort(x)
            x_sorted = x[sort_idx]
            y_sorted = y[sort_idx]
            ci_lower_sorted = grouped['ci_lower'].values[sort_idx]
            ci_upper_sorted = grouped['ci_upper'].values[sort_idx]
            
            color = architecture_colors.get(arch, None)
            
            # Plot original data points
            scatter = ax.scatter(x_sorted, y_sorted, s=30, color=color, alpha=0.7, 
                               label=f"{arch} (data points)")
            
            if smoothed and len(x_sorted) >= 3:  # Need at least 3 points for smoothing
                # Apply LOWESS smoothing with appropriate fraction
                smooth_frac = min(0.8, max(0.3, 2.0/len(x_sorted)))
                
                # Apply LOWESS smoothing to mean
                smoothed_mean = lowess(y_sorted, x_sorted, frac=smooth_frac, it=2, return_sorted=True)
                
                # Apply same smoothing to confidence bounds
                smoothed_lower = lowess(ci_lower_sorted, x_sorted, frac=smooth_frac, it=2, return_sorted=True)
                smoothed_upper = lowess(ci_upper_sorted, x_sorted, frac=smooth_frac, it=2, return_sorted=True)
                
                # Plot smoothed line with a distinct label
                line, = ax.plot(smoothed_mean[:, 0], smoothed_mean[:, 1], '-', linewidth=2.5, 
                              color=color, label=f"{arch} (trend)")
                
                # Add shaded confidence interval
                ci = ax.fill_between(smoothed_mean[:, 0], smoothed_lower[:, 1], smoothed_upper[:, 1], 
                                  color=color, alpha=0.2)
                
                # Only add the line to legend handles (not the CI fill)
                legend_handles.append(line)
                legend_labels.append(f"{arch}")
            else:
                # For non-smoothed plots, just connect the points
                line, = ax.plot(x_sorted, y_sorted, '-', linewidth=2, color=color, label=arch)
                # Add shaded confidence interval
                ci = ax.fill_between(x_sorted, ci_lower_sorted, ci_upper_sorted, 
                                  color=color, alpha=0.2)
                
                # Add to legend
                legend_handles.append(line)
                legend_labels.append(f"{arch}")
            
        # Add legend to this subplot
        if legend_handles:
            ax.legend(handles=legend_handles, labels=legend_labels, title="Architecture", loc='best')
    
    # Create plots for each parameter (|Q|, |Σ|)
    for param in params:
        # Create plots for each dataset (training, test, test-short-held-out)
        for dataset in datasets:
            # Create a figure with loss types as rows and validation sets as columns
            n_rows = len(loss_types)
            fig, axes = plt.subplots(n_rows, 2, figsize=(16, 6 * n_rows), sharey=True)
            
            # Set overall title
            plot_type = "Smoothed " if smoothed else ""
            fig.suptitle(f'{plot_type}Accuracy vs {param_labels[param]} ({dataset})', fontsize=18)
            
            # Add column headers for validation sets
            for i, val_set in enumerate(val_sets):
                if n_rows > 1:  # Only needed for multi-row plots
                    axes[0, i].set_title(f'{val_set}', fontsize=14)
                else:
                    axes[i].set_title(f'{val_set}', fontsize=14)
            
            # Plot for each loss type (rows) and validation set (columns)
            for row_idx, loss_type in enumerate(loss_types):
                for col_idx, val_set in enumerate(val_sets):
                    # Filter data for this dataset, validation set, and loss type
                    plot_df = df[(df['dataset'] == dataset) & 
                                 (df['validation_set'] == val_set) & 
                                 (df['loss_type'] == loss_type)]
                    
                    # Skip if no data
                    if plot_df.empty:
                        continue
                    
                    # Get the correct axis
                    if n_rows > 1:
                        ax = axes[row_idx, col_idx]
                    else:
                        ax = axes[col_idx]
                    
                    # Create the plot
                    plot_trend(ax, plot_df, param, smoothed)
                    
                    # Set labels
                    ax.set_xlabel(param_labels[param], fontsize=12)
                    
                    # Only set y-label for the first column
                    if col_idx == 0:
                        ax.set_ylabel('Recognition Accuracy', fontsize=12)
                        
                        # Add loss type as row label (left side)
                        ax.text(-0.15, 0.5, f'Loss: {loss_type}', 
                                transform=ax.transAxes, 
                                fontsize=14, 
                                rotation=90, 
                                ha='right', 
                                va='center')
                    
                    ax.grid(True, linestyle='--', alpha=0.7)
            
            # Adjust layout and save
            plt.tight_layout(rect=[0, 0.05, 1, 0.95])  # Make room for suptitle and legend
            
            # Save the plot
            smooth_suffix = "_smoothed" if smoothed else ""
            plt.savefig(os.path.join(output_dir, f'accuracy_vs_{param}_{dataset}{smooth_suffix}.png'), 
                       dpi=300, bbox_inches='tight')
            plt.close()
    
    print(f"{'Smoothed' if smoothed else 'Original'} trend plots saved in {output_dir}")

# Function to create both original and smoothed plots
def create_all_trend_plots(df, output_dir):
    """
    Create both original and smoothed trend plots.
    
    Args:
        df: Pandas DataFrame with the parsed results
        output_dir: Directory to save the output plots
    """
    # Create original plots
    create_trend_plots(df, output_dir, smoothed=False)
    
    # Create smoothed plots
    create_trend_plots(df, output_dir, smoothed=True)
    
    print(f"All plots saved in {output_dir}")

# Main function to handle the script execution
def main(base_dir, output_dir):
    # After your existing code to load data
    results = load_model_results(base_dir)

    df = aggregate_results(results)

    create_all_trend_plots(df, output_dir)

if __name__ == '__main__':
    # Set up argument parsing
    parser = argparse.ArgumentParser(description="Load, process, and plot model results.")
    parser.add_argument('--base_dir', type=str, required=True, help="The base directory where the model results are stored.")
    parser.add_argument('--output_dir', type=str, required=True, help="Directory to save the output plots.")

    args = parser.parse_args()

    # Run the main function with the provided arguments
    main(args.base_dir, args.output_dir)