"""
Merged experiment script for comparing gradient estimators across dimensions
for both quadratic and logistic functions, with side-by-side visualization.
"""

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import sys
import os

# Add the parent directory to the path so we can import the package
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from unbiased_zoo.estimators import UnbiasedEstimator, UnbiasedEstimatorV2, ZipfDistribution
from unbiased_zoo.estimators.uniform import CentralizedUniformEstimator
from unbiased_zoo.estimators.gaussian import CentralizedGaussianEstimator
from unbiased_zoo.functions import QuadraticFunction, LogisticLossFunction
from unbiased_zoo import create_estimator

# Set random seed for reproducibility
np.random.seed(42)

# Font size parameters for easy tuning
TITLE_FONTSIZE = 22
SUBTITLE_FONTSIZE = 22
AXIS_LABEL_FONTSIZE = 18
TICK_LABEL_FONTSIZE = 18
LEGEND_FONTSIZE = 18
ANNOTATION_FONTSIZE = 18

SimpleCustomEstimator = UnbiasedEstimator

def compare_multiple_estimators(custom_estimator, reference_names, function, x, batch_size=64, num_trials=10):
    """Compare a custom estimator against multiple reference estimators.
    
    Args:
        custom_estimator: The custom estimator to test
        reference_names: List of names of reference estimators to compare with
        function: The function to estimate gradients for
        x: The point at which to estimate gradients
        batch_size: Batch size for all estimators
        num_trials: Number of trials to run
        
    Returns:
        dict: Dictionary mapping estimator names to their error lists
    """
    # Set batch size
    custom_estimator.zoo_batch_size = batch_size
    
    # Create reference estimators
    reference_estimators = {}
    for name in reference_names:
        if name == "unbiased_p3":
            reference_estimators[name] = UnbiasedEstimator(P=3, zoo_batch_size=1, mu=custom_estimator.mu, a=2.0)
        elif name == "unbiased_p4":
            reference_estimators[name] = UnbiasedEstimator(P=4, zoo_batch_size=1, mu=custom_estimator.mu, a=2.0)
        elif name == "centralized_uniform":
            reference_estimators[name] = CentralizedGaussianEstimator(zoo_batch_size=batch_size, mu=custom_estimator.mu)
        else:
            reference_estimators[name] = create_estimator(name, zoo_batch_size=batch_size, mu=custom_estimator.mu)
    
    # Get true gradient
    true_grad = function.grad(x)
    
    # Store all errors
    all_errors = {}
    for name in reference_names:
        all_errors[name.capitalize()] = []
    
    # Run trials
    for i in range(num_trials):
        # print(f"Trial {i+1}/{num_trials}:")
        
        # Estimate with reference estimators
        for name, estimator in reference_estimators.items():
            reference_grad = estimator.estimate(function, x)
            reference_error = np.linalg.norm(reference_grad - true_grad)
            all_errors[name.capitalize()].append(reference_error)
            # print(f"  {name.capitalize()} error = {reference_error:.6f}")
    
    # Print summary
    print("\nSummary:")
    for name in reference_names:
        print(f"{name.capitalize()} estimator - Mean error: {np.mean(all_errors[name.capitalize()]):.6f}, Std: {np.std(all_errors[name.capitalize()]):.6f}")
    
    return all_errors


def run_experiment(function_type, dimensions, zoo_batch_size, num_trials, reference_names):
    """Run experiment for a specific function type across multiple dimensions.
    
    Args:
        function_type (str): Type of function to test ('quadratic' or 'logistic')
        dimensions (list): List of dimensions to test
        zoo_batch_size (int): Batch size for estimators
        num_trials (int): Number of trials per dimension
        reference_names (list): List of reference estimator names
        
    Returns:
        pd.DataFrame: DataFrame with experiment results
    """
    # Create a DataFrame to store all results
    all_results = []
    
    for n in dimensions:
        print(f"\n\n===== Testing {function_type} function with dimension n = {n} =====")
        
        # Create function and test point based on function type
        if function_type == 'quadratic':
            f = QuadraticFunction(n)
            x = np.random.normal(scale=5.0, size=n)
        elif function_type == 'logistic':
            f = LogisticLossFunction(input_dim=n, n_samples=500, random_state=42)
            x = np.random.normal(scale=0.1, size=n)  # Smaller scale for logistic regression
        
        # Create a dummy estimator for the compare_multiple_estimators function
        mu = 1e-5
        dummy_estimator = SimpleCustomEstimator(zoo_batch_size=zoo_batch_size, mu=mu, a=2.0)
        
        # Test against multiple reference estimators
        print(f"Testing estimators...")
        all_errors = compare_multiple_estimators(
            dummy_estimator, 
            reference_names, 
            f, 
            x, 
            batch_size=zoo_batch_size, 
            num_trials=num_trials
        )
        
        # Add results to DataFrame
        for method, errors in all_errors.items():
            for error in errors:
                all_results.append({
                    'Function': function_type.capitalize(),
                    'Dimension': n,
                    'Method': method,
                    'Error': error
                })
    
    # Convert to DataFrame
    return pd.DataFrame(all_results)


if __name__ == "__main__":
    # Common parameters
    zoo_batch_size = 3
    reference_names = ["unbiased_p3", "gaussian", "uniform", "centralized_uniform"]
    # reference_names = ["gaussian", "uniform", "bernoulli"]
    
    # Dimensions to test for both functions
    dimensions = [16, 64, 256, 1024, 4096]
    num_trials = 100  # Reduced for faster execution
    
    # Run experiments
    print("Running quadratic function experiments...")
    quadratic_results = run_experiment('quadratic', dimensions, zoo_batch_size, num_trials, reference_names)
    
    print("\nRunning logistic function experiments...")
    logistic_results = run_experiment('logistic', dimensions, zoo_batch_size, num_trials, reference_names)
    
    # Combine results
    combined_results = pd.concat([quadratic_results, logistic_results])
    
    # Define a consistent color palette for all methods
    method_colors = {
        'Unbiased_p3': '#e41a1c',    # red
        'Gaussian': '#377eb8',    # blue
        'Uniform': '#4daf4a',      # green
        'Centralized_uniform': '#984ea3'     # purple
    }
    
    # Set the style
    sns.set_style("whitegrid")
    plt.rcParams.update({
        'font.size': TICK_LABEL_FONTSIZE,
        'axes.titlesize': SUBTITLE_FONTSIZE,
        'axes.labelsize': AXIS_LABEL_FONTSIZE,
        'xtick.labelsize': TICK_LABEL_FONTSIZE,
        'ytick.labelsize': TICK_LABEL_FONTSIZE,
        'legend.fontsize': LEGEND_FONTSIZE,
    })
    
    # Create a subplot grid
    fig, axes = plt.subplots(1, 2, figsize=(20, 8), sharey=False)
    
    # Plot quadratic function results
    quadratic_data = combined_results[combined_results['Function'] == 'Quadratic']
    b_quadratic = sns.boxplot(x='Dimension', y='Error', hue='Method', data=quadratic_data, 
                palette=method_colors, boxprops=dict(alpha=0.75), 
                medianprops=dict(color="black", linewidth=2.5), ax=axes[0])
    
    # Add individual points for better visualization
    sns.stripplot(x='Dimension', y='Error', hue='Method', data=quadratic_data, 
                 size=4, alpha=0.3, dodge=True, palette=method_colors, ax=axes[0])
    
    # Improve plot appearance
    axes[0].set_title('Quadratic Function', fontsize=SUBTITLE_FONTSIZE)
    axes[0].set_xlabel('Dimension (d)', fontsize=AXIS_LABEL_FONTSIZE)
    axes[0].set_ylabel('MSE Error', fontsize=AXIS_LABEL_FONTSIZE)
    axes[0].set_yscale('log')
    axes[0].legend([],[], frameon=False)  # Remove legend from first plot
    
    # Plot logistic function results
    logistic_data = combined_results[combined_results['Function'] == 'Logistic']
    b_logistic = sns.boxplot(x='Dimension', y='Error', hue='Method', data=logistic_data, 
                palette=method_colors, boxprops=dict(alpha=0.75),
                medianprops=dict(color="black", linewidth=2.5), ax=axes[1])
    
    # Add individual points for better visualization
    sns.stripplot(x='Dimension', y='Error', hue='Method', data=logistic_data, 
                 size=4, alpha=0.3, dodge=True, palette=method_colors, ax=axes[1])
    
    # Improve plot appearance
    axes[1].set_title('Logistic Function', fontsize=SUBTITLE_FONTSIZE)
    axes[1].set_xlabel('Dimension (d)', fontsize=AXIS_LABEL_FONTSIZE)
    axes[1].set_ylabel('MSE Error', fontsize=AXIS_LABEL_FONTSIZE)
    axes[1].set_yscale('log')
    axes[1].legend([],[], frameon=False)  # Remove legend from second plot
    
    # Add a main title
    # fig.suptitle('Gradient Estimation Error Comparison Across Dimensions', fontsize=TITLE_FONTSIZE, y=0.98)
    
    # Create a common legend at the bottom
    handles, labels = axes[1].get_legend_handles_labels()
    
    # Customize the labels with mathematical notation
    custom_labels = []
    for label in labels[:4]:
        if label == 'Unbiased_p3':
            custom_labels.append('Zipf\'s $P_3$-Estimator')
        elif label == 'Gaussian':
            custom_labels.append('One-Side Two-Point Estimator (Normal)')
        elif label == 'Uniform':
            custom_labels.append('One-Side Two-Point Estimator (Uniform)')
        elif label == 'Centralized_uniform':
            custom_labels.append('Two-Side Two-Point Estimator (Normal)')
        else:
            custom_labels.append(label)
    
    fig.legend(handles[:4], custom_labels, loc='lower center', ncol=4, 
               fontsize=14, bbox_to_anchor=(0.5, 0.005))
    
    plt.tight_layout()
    # Adjust the bottom margin to make room for the legend
    plt.subplots_adjust(bottom=0.15)
    
    plt.savefig('combined_dimension_comparison.png', dpi=300, bbox_inches='tight')
    
    print("Experiment completed. Results saved to 'combined_dimension_comparison.png'") 
    plt.show()