"""
Implementation of the Over-parametrized Gradient Flow experiment from Section 6.2.

This script investigate how the number of layers (parameter D) affect OP-GF
 and Effective Span Dimension (ESD) trajectories over time.

Key Components:
- Implements gradient descent with D-layer overparametrized models
- Uses parallel processing for multiple experimental runs
- Tracks squared error and ESD values at logarithmically spaced iterations
- Compares performance across different D values (0, 1, 3) representing the number of layers
- Provides visualization of results with error bars representing standard errors

Usage:
python Section_6_2_OPGF.py
"""

import numpy as np
from generate_sequence import generate_lambda, generate_theta3
from esd_modular_functions import compute_esd, compute_pc_error
import matplotlib.pyplot as plt
import multiprocessing as mp
from functools import partial
import os

def compute_loss(a, b, beta, y, D):
    """Calculate loss function (based on observation data y)"""
    theta = a * (b ** D) * beta  # Use D as exponent
    return 0.5 * np.sum((y - theta) ** 2)

def compute_loss_true_theta(a, b, beta, theta_true, D):
    """Calculate loss based on true parameter theta_true"""
    theta_est = a * (b ** D) * beta  # Use D as exponent
    return 0.5 * np.sum((theta_true - theta_est) ** 2)

def compute_gradients(a, b, beta, y, D):
    """Calculate gradients"""
    theta = a * (b ** D) * beta  # Use D as exponent
    error = y - theta
    
    grad_beta = -error * a * (b ** D)
    grad_a = -error * (b ** D) * beta
    grad_b = -error * a * D * (b ** (D-1)) * beta  # Derivative with respect to b

    return grad_a, grad_b, grad_beta

def gradient_descent_with_D(n, D, d, p, q, gamma, J, seed=None):
    # Set random seed for reproducible results
    if seed is not None:
        np.random.seed(seed)
    else:
        np.random.seed(77)


    sigma = 1/n ** 0.5
    theta_true = generate_theta3(q, p, J, d, sigma)
    y = np.zeros(d)
    noise = np.random.normal(0, sigma, d)
    y = theta_true + noise

    # Initialize learning parameters
    a_est = (generate_lambda(d, gamma)) ** 0.5  
    b_est = np.full(d, n**(-1/(2*(D+2))))  
    beta_est = np.zeros(d)   # Initial beta values are all 0

    # Gradient descent parameters
    learning_rate_a = 1e-2
    learning_rate_b = 1e-2
    learning_rate_beta = 1e-2
    num_iterations = int(n/(learning_rate_a))

    # Record initial loss
    initial_loss = compute_loss(a_est, b_est, beta_est, y, D)
    initial_loss_true_theta = compute_loss_true_theta(a_est, b_est, beta_est, theta_true, D)
    print(f"\nInitial Loss (y): {initial_loss}")
    print(f"Initial Loss (theta_true): {initial_loss_true_theta}")

    # For recording Sq error every 200 steps
    sq_error_history = []
    d_dagger_history = []
    time_points = []
    
    # Calculate iteration points on logarithmic scale
    log_start = np.log10(200)
    log_end = np.log10(num_iterations)
    log_iteration_points = np.linspace(log_start, log_end, int((log_end - log_start) / 0.05) + 1)
    iteration_points = np.unique(np.round(10 ** log_iteration_points).astype(int))
    iteration_points = [i for i in iteration_points if i <= num_iterations]
    
    # Gradient descent iterations
    for iteration in range(num_iterations):
        # Calculate gradients
        grad_a, grad_b, grad_beta = compute_gradients(a_est, b_est, beta_est, y, D)
        
        # Update parameters
        a_est -= learning_rate_a * grad_a
        b_est -= learning_rate_b * grad_b
        beta_est -= learning_rate_beta * grad_beta
        
        # Ensure b_i > 0
        b_est = np.abs(b_est)
        
        # Calculate current lambda_est
        lambda_est = (a_est * (b_est ** D))**2

        # Calculate avg_risk at specified iterations
        if iteration in iteration_points:
            sq_error, d_dagger = compute_pc_error(y, theta_true, lambda_est, sigma_sq=sigma**2)
            sq_error_history.append(sq_error)
            d_dagger_history.append(d_dagger)
            time_point = iteration * learning_rate_a
            time_points.append(time_point)
            print(f"Iteration {iteration}, Sq Error: {sq_error}, Optimal K: {d_dagger}")

    return time_points, sq_error_history, d_dagger_history

def run_single_experiment(exp, n, D, d, p, q, gamma, J):
    """Run a single experiment"""
    print(f"Experiment {exp+1}, D={D}")
    # Set random seed for reproducibility
    seed = 77 + exp
    return gradient_descent_with_D(n, D, d, p, q, gamma, J, seed)

def run_multiple_experiments_parallel(n, D, d, p, q, gamma, J, num_experiments=10):
    """Run multiple experiments in parallel and return average results and standard errors"""
    print(f"Running {num_experiments} experiments in parallel for D={D}...")
    
    # Create process pool
    num_cores = mp.cpu_count()
    pool = mp.Pool(processes=min(num_cores, num_experiments))
    
    # Prepare partial function, fixing all parameters except exp
    partial_func = partial(run_single_experiment, n=n, D=D, d=d, p=p, q=q, gamma=gamma, J=J)
    
    # Run experiments in parallel
    results = pool.map(partial_func, range(num_experiments))
    
    # Close process pool
    pool.close()
    pool.join()
    
    # Extract results
    all_time_points = results[0][0]  # Use time points from the first experiment
    all_sq_errors = []
    all_d_daggers = []
    
    for time_points, sq_error_history, d_dagger_history in results:
        # Ensure all experiments have the same number of time points
        if len(time_points) == len(all_time_points):
            all_sq_errors.append(sq_error_history)
            all_d_daggers.append(d_dagger_history)
        else:
            print(f"Warning: An experiment has a different number of time points than others, skipped")
    
    # Calculate averages
    avg_sq_errors = np.mean(all_sq_errors, axis=0)
    avg_d_daggers = np.mean(all_d_daggers, axis=0)
    
    # Calculate standard errors (standard deviation divided by square root of number of experiments)
    std_sq_errors = np.std(all_sq_errors, axis=0) / np.sqrt(len(all_sq_errors))
    std_d_daggers = np.std(all_d_daggers, axis=0) / np.sqrt(len(all_d_daggers))
    
    return all_time_points, avg_sq_errors, avg_d_daggers, std_sq_errors, std_d_daggers

# Main program
if __name__ == "__main__":  # This line is important for multiprocessing code
    n = 10000
    d = 5000
    p = 2.5
    q = 2.5
    gamma = 1
    J = 15
    num_experiments = 20  # Number of experiments
    save_data = True  # Whether to save data
    load_saved_data = False  # Whether to load saved data
    data_folder = "saved_data"  # Folder for saving data
    
    # Create folder for saving data (if it doesn't exist)
    if save_data and not os.path.exists(data_folder):
        os.makedirs(data_folder)

    # Define different D values
    D_values = [0, 1, 3]
    colors = ['blue', 'green', 'red', 'purple']
    labels = [f'D={d_val}' for d_val in D_values]
    
    # Dictionary for storing all experimental data
    all_data = {}

    # Run experiments or load saved data
    if load_saved_data and os.path.exists(f"{data_folder}/experimental_data.npz"):
        # Load saved data
        print("Loading saved experimental data...")
        loaded_data = np.load(f"{data_folder}/experimental_data.npz", allow_pickle=True)
        all_data = dict(loaded_data)
        # Convert arrays back to lists (if needed)
        for D in D_values:
            if str(D) in all_data:
                D_data = all_data[str(D)].item()  # Extract dictionary from 0-d array
    else:
        # Run experiments and collect data
        for i, D in enumerate(D_values):
            print(f"\nRunning experiments for D={D}...")
            time_points, avg_sq_error, avg_d_dagger, std_sq_error, std_d_dagger = run_multiple_experiments_parallel(n, D, d, p, q, gamma, J, num_experiments)
            
            # Store data in dictionary
            all_data[str(D)] = {
                'time_points': time_points,
                'avg_sq_error': avg_sq_error,
                'avg_d_dagger': avg_d_dagger,
                'std_sq_error': std_sq_error,
                'std_d_dagger': std_d_dagger
            }
        
        # Save data to file
        if save_data:
            print("\nSaving experimental data...")
            np.savez(f"{data_folder}/experimental_data.npz", **all_data)
            print(f"Data saved to {data_folder}/experimental_data.npz")

    # Plot charts
    fig, axs = plt.subplots(1, 2, figsize=(16, 6), sharex=True)

    for i, D in enumerate(D_values):
        if str(D) in all_data:
            D_data = all_data[str(D)]
            time_points = D_data['time_points']
            avg_sq_error = D_data['avg_sq_error']
            avg_d_dagger = D_data['avg_d_dagger']
            std_sq_error = D_data['std_sq_error']
            std_d_dagger = D_data['std_d_dagger']
            
            # Use errorbar instead of plot, add error bars
            axs[0].errorbar(time_points, avg_sq_error, yerr=std_sq_error, color=colors[i], label=labels[i], 
                           fmt='-', capsize=3, elinewidth=1, markeredgewidth=1)
            axs[1].errorbar(time_points, avg_d_dagger, yerr=std_d_dagger, color=colors[i], label=labels[i], 
                           fmt='-', capsize=3, elinewidth=1, markeredgewidth=1)

    # Set common x-axis properties
    axs[0].set_xscale('log')  # Set x-axis to logarithmic scale
    key_time_points = [10, 100, 1000, 10000]
    key_time_labels = ['$10$', r'$10^2$', r'$10^3$', r'$10^4$']
    axs[0].set_xticks(ticks=key_time_points, labels=key_time_labels, rotation=45)

    # Set properties for the first subplot
    axs[0].set_title(f'Squared Error of PC Estimator over Time (Log Scale, Avg of {num_experiments} runs)')
    axs[0].set_ylabel('Squared Error')
    axs[0].grid(True, which="both", ls="--")
    axs[0].legend()
    axs[0].grid(True)

    # Set properties for the second subplot
    axs[1].set_title(f'$d^\\dagger$ over Time Points (Log Scale, Avg of {num_experiments} runs)')
    axs[1].set_ylabel('$d^\\dagger$')
    axs[1].grid(True, which="both", ls="--")
    axs[1].legend()
    axs[1].grid(True)

    plt.tight_layout()
    plt.savefig(f'multiple_experiments_{num_experiments}_runs_with_error_bars.pdf')
    plt.show()