"""
Implementation of the Over-parametrized Gradient Flow experiment from Section 6.1.

This script implements over-parametrized gradient descent optimization for parameter estimation with 
different signal structures. It tests how Effective Span Dimension (ESD) varies 
with noise level across different q-misalignment values.

Key Components:
- Implements gradient descent with loss functions compute_loss and compute_gradients
- Tests ESD variation against noise variance for different q-misalignment values.
- Generates visualization showing ESD vs noise variance plots for different q-values
- Outputs results as 'nips_study_1.pdf' with subplots for each q value tested

Usage:
python Section_6_1_OPGF.py
"""


import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import linregress
from generate_sequence import  generate_lambda, generate_theta3
from esd_modular_functions import compute_esd

# Set random seed for reproducible results
np.random.seed(333)

n = 10000
D = 0
p = 2.5
gamma = 1
J = 15
d = 5000

# Define loss function and its gradients
def compute_loss(a, b, beta, y, D):
    """Calculate loss function (based on observation data y)"""
    theta = a * (b ** D) * beta  # Use D as exponent
    return 0.5 * np.sum((y - theta) ** 2)

def compute_loss_true_theta(a, b, beta, theta_true, D):
    """Calculate loss based on true parameter theta_true"""
    theta_est = a * (b ** D) * beta  # Use D as exponent
    return 0.5 * np.sum((theta_true - theta_est) ** 2)

def compute_gradients(a, b, beta, y, D):
    """Calculate gradients"""
    theta = a * (b ** D) * beta  # Use D as exponent
    error = y - theta
    
    grad_beta = -error * a * (b ** D)
    grad_a = -error * (b ** D) * beta
    grad_b = -error * a * D * (b ** (D-1)) * beta  # Derivative with respect to b

    
    return grad_a, grad_b, grad_beta

# Range of noise levels (sigma^2)
sigma_sq_values = np.logspace(-7, -2, 20) # Sigma^2 from 1e-6 to 1e-2

# Create a new figure window with four subplots arranged in a 2x2 layout
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Loop through different q values
for idx, q in enumerate([1, 1.5, 2, 3]):
    row = idx // 2
    col = idx % 2
    ax = axes[row, col]
    
    sigma = 1/n ** 0.5
    
    theta_true = generate_theta3(q , p, J, d, sigma)
    y = np.zeros(d)

    noise = np.random.normal(0, sigma, d)
    y = theta_true + noise
     
     
    # Initialize learning parameters
    a_est = (generate_lambda(d, gamma)) ** 0.5  
    b_est = np.full(d,  n**(-1/(2*(D+2))))  
    beta_est = np.zeros(d)   # Initial beta values are all 0



    # Gradient descent parameters
    learning_rate_a = 1e-2
    learning_rate_b =  1e-2
    learning_rate_beta = 1e-2
    num_iterations = int(n ** ((D+1)/(D+2))/(learning_rate_a))

    # Record initial loss
    initial_loss = compute_loss(a_est, b_est, beta_est, y, D)
    initial_loss_true_theta = compute_loss_true_theta(a_est, b_est, beta_est, theta_true, D)
    print(f"\nInitial Loss (y): {initial_loss}")
    print(f"Initial Loss (theta_true): {initial_loss_true_theta}")

    # For recording loss at each iteration
    loss_history = []
    loss_history_true_theta = []
    lambda_his = []

    # Gradient descent iterations
    for iteration in range(num_iterations):
        # Calculate gradients
        grad_a, grad_b, grad_beta = compute_gradients(a_est, b_est, beta_est, y, D)
        
        # Update parameters
        a_est -= learning_rate_a * grad_a
        b_est -= learning_rate_b * grad_b
        beta_est -= learning_rate_beta * grad_beta
        
        # Ensure b_i > 0
        b_est = np.abs(b_est)
        
        # Calculate current lambda_est
        lambda_est = (a_est * (b_est ** D))**2


        # Calculate loss based on true parameters and record
        current_loss_true_theta = compute_loss_true_theta(a_est, b_est, beta_est, theta_true, D)
      #  loss_history_true_theta.append(current_loss_true_theta)

        # Print loss every 1000 iterations
        if iteration % 2000 == 0:
            lambda_his.append(lambda_est)
            print(f"Iteration {iteration},  Loss (theta_true): {current_loss_true_theta}")
            
    # Plot charts
    for i, lambda_val in enumerate(lambda_his):
        esd_values = np.array([compute_esd(theta_true**2, lambda_val, s_sq) for s_sq in sigma_sq_values])
        ax.loglog(sigma_sq_values, esd_values, label=f't = {i * 20}')
        # ax.semilogx(sigma_sq_values, esd_values, label=f't = {i * 20}')
    
    #ax.set_xlabel(r'Noise Variance $\sigma^2$')
    ax.set_xlabel(r'$\tau$')
    #ax.set_ylabel(r'Effective Span Dimension $d^\dagger$ at Noise Variance $\tau$')  # Still use LaTeX here for plot
    ax.set_ylabel(r'ESD at $\tau$')  # Still use LaTeX here for plot
    ax.set_title(rf'q={q}')  # Still use LaTeX here for plot
   # ax.set_title(rf'Span Profile ($d^\dagger$ vs $\sigma^2$) for q={q} (p={p}, $\gamma=${gamma}, J = {J}, d = {d})')  # Still use LaTeX here for plot
   # ax.grid(True, which="both", ls="--")
    ax.legend()

plt.tight_layout()
plt.savefig('study_1.pdf')
plt.show()



