'''
SVD introduces randomness that is not determined by the seed.
'''

import cupy as cp
import matplotlib.pyplot as plt
import json
import numpy as np  # For plotting and saving results

 # Set parameters
k = 10  # Number of informative features
n_test = 1000  # Number of test samples
epsilon_variance = 0.1  # Variance of the noise
n_trial = 10 # Number of trials

# Define n values
n_values = [5120, 2560, 1280, 640, 320, 160, 80]

if __name__ == '__main__':
    # Fix random seed for reproducibility
    cp.random.seed(42)
    gpu_device = 2
    cp.cuda.Device(gpu_device).use() 
    
    # Initialize results dictionary
    results = {}

    # Run simulations
    excess_risks_mean_ridge_75, excess_risks_mean_ridge_0, \
        excess_risks_mean_ridge_50, excess_risks_mean_ridge_100, \
        excess_risks_mean_pcr = [], [], [], [], []
    excess_risks_std_ridge_75, excess_risks_std_ridge_0, \
        excess_risks_std_ridge_50, excess_risks_std_ridge_100, \
        excess_risks_std_pcr = [], [], [], [], []
    for n in n_values:
        # Define the function to simulate the experiment inside the loop
        def simulate():
            # Dimensions
            d_minor = int(cp.sqrt(n))
            # d = k + n**2
            d = k + d_minor
            # Generate beta
            beta = cp.zeros(d)
            beta[:k] = 1 / cp.sqrt(k)
            # Generate training data
            x_train = cp.zeros((n, d))
            x_train[:, :k] = cp.random.normal(0, 1, (n, k))
            x_train[:, k:k+d_minor] = cp.random.normal(0, n**(-0.25), (n, d_minor))
            epsilon_train = cp.random.normal(0, cp.sqrt(epsilon_variance), n)
            y_train = x_train @ beta + epsilon_train
            # Generate test data
            x_test = cp.zeros((n_test, d))
            # Generate test samples for first k dimensions using shared Sigma_T
            x_test[:, :] = cp.random.normal(0, 1, (n_test, d))
            epsilon_test = cp.random.normal(0, cp.sqrt(epsilon_variance), n_test)
            y_test = x_test @ beta + epsilon_test
            
            # Ridge
            X_train = x_train
            XXT = X_train @ X_train.T  # n x n matrix
            try:
                inv_XXT = cp.linalg.inv(XXT + (n**0.75)*cp.eye(n))
            except cp.linalg.LinAlgError:
                inv_XXT = cp.linalg.pinv(XXT + (n**0.75)*cp.eye(n))
            w = X_train.T @ (inv_XXT) @ y_train
            # Compute MSE on test set
            y_pred = x_test @ w
            mse_model = cp.mean((y_test - y_pred) ** 2)
            # Compute MSE of true model
            y_true = x_test @ beta
            mse_true = cp.mean((y_test - y_true) ** 2)
            # Excess risk
            excess_risk = mse_model - mse_true
            # Convert excess_risk to numpy scalar
            excess_risk_ridge_75 = cp.asnumpy(excess_risk)
            print(f"Ridge: n = {n}, excess risk = {excess_risk}, lambda=0.75")

            X_train = x_train
            XXT = X_train @ X_train.T  # n x n matrix
            try:
                inv_XXT = cp.linalg.inv(XXT + 1e-8 * cp.eye(n))
            except cp.linalg.LinAlgError:
                inv_XXT = cp.linalg.pinv(XXT + 1e-8 * cp.eye(n))
            w = X_train.T @ (inv_XXT) @ y_train
            # Compute MSE on test set
            y_pred = x_test @ w
            mse_model = cp.mean((y_test - y_pred) ** 2)
            # Compute MSE of true model
            y_true = x_test @ beta
            mse_true = cp.mean((y_test - y_true) ** 2)
            # Excess risk
            excess_risk = mse_model - mse_true
            # Convert excess_risk to numpy scalar
            excess_risk_ridge_0 = cp.asnumpy(excess_risk)
            print(f"Ridge: n = {n}, excess risk = {excess_risk}, lambda=0")

            X_train = x_train
            XXT = X_train @ X_train.T  # n x n matrix
            try:
                inv_XXT = cp.linalg.inv(XXT + (n**0.5)*cp.eye(n))
            except cp.linalg.LinAlgError:
                inv_XXT = cp.linalg.pinv(XXT + (n**0.5)*cp.eye(n))
            w = X_train.T @ (inv_XXT) @ y_train
            # Compute MSE on test set
            y_pred = x_test @ w
            mse_model = cp.mean((y_test - y_pred) ** 2)
            # Compute MSE of true model
            y_true = x_test @ beta
            mse_true = cp.mean((y_test - y_true) ** 2)
            # Excess risk
            excess_risk = mse_model - mse_true
            # Convert excess_risk to numpy scalar
            excess_risk_ridge_50 = cp.asnumpy(excess_risk)
            print(f"Ridge: n = {n}, excess risk = {excess_risk}, lambda=0.5")

            X_train = x_train
            XXT = X_train @ X_train.T  # n x n matrix
            try:
                inv_XXT = cp.linalg.inv(XXT + (n)*cp.eye(n))
            except cp.linalg.LinAlgError:
                inv_XXT = cp.linalg.pinv(XXT + (n)*cp.eye(n))
            w = X_train.T @ (inv_XXT) @ y_train
            # Compute MSE on test set
            y_pred = x_test @ w
            mse_model = cp.mean((y_test - y_pred) ** 2)
            # Compute MSE of true model
            y_true = x_test @ beta
            mse_true = cp.mean((y_test - y_true) ** 2)
            # Excess risk
            excess_risk = mse_model - mse_true
            # Convert excess_risk to numpy scalar
            excess_risk_ridge_100 = cp.asnumpy(excess_risk)
            print(f"Ridge: n = {n}, excess risk = {excess_risk}, lambda=1")
            
            # PCR
            # Center X_train
            X_train_mean = cp.mean(X_train, axis=0)
            X_train_centered = X_train - X_train_mean
            # Compute SVD
            U, S, VT = cp.linalg.svd(X_train_centered, full_matrices=False)
            # Take first k components
            V_k = VT[:k, :]  # k x d
            # Project X_train onto first k components
            Z_train = X_train_centered @ V_k.T  # n x k
            # Fit linear regression on Z_train
            coefficients = cp.linalg.lstsq(Z_train, y_train, rcond=None)[0]  # k x 1
            # For test data
            X_test_centered = x_test - X_train_mean  # Use same mean as training data
            Z_test = X_test_centered @ V_k.T  # n_test x k
            # Predict
            y_pred = Z_test @ coefficients  # n_test x 1
            # Compute MSE on test set
            mse_model = cp.mean((y_test - y_pred) ** 2)
            # Excess risk
            excess_risk = mse_model - mse_true
            excess_risk_pcr = cp.asnumpy(excess_risk)
            print(f"PCR: n = {n}, excess risk = {excess_risk}")

            return excess_risk_ridge_75, excess_risk_ridge_0, excess_risk_ridge_50, excess_risk_ridge_100,excess_risk_pcr 

        excess_risks_ridge_75, excess_risks_ridge_0, \
            excess_risks_ridge_50, excess_risks_ridge_100, \
            excess_risks_pcr = [], [], [], [], []
        for _ in range(n_trial):
            excess_risk_ridge_75, excess_risk_ridge_0, \
                excess_risk_ridge_50, excess_risk_ridge_100, \
                excess_risk_pcr= simulate()
            excess_risks_ridge_75.append(excess_risk_ridge_75)
            excess_risks_ridge_0.append(excess_risk_ridge_0)
            excess_risks_ridge_50.append(excess_risk_ridge_50)
            excess_risks_ridge_100.append(excess_risk_ridge_100)
            excess_risks_pcr.append(excess_risk_pcr)
        excess_risks_mean_ridge_75.append(np.mean(excess_risks_ridge_75))
        excess_risks_std_ridge_75.append(np.std(excess_risks_ridge_75))
        excess_risks_mean_ridge_0.append(np.mean(excess_risks_ridge_0))
        excess_risks_std_ridge_0.append(np.std(excess_risks_ridge_0))
        excess_risks_mean_ridge_50.append(np.mean(excess_risks_ridge_50))
        excess_risks_std_ridge_50.append(np.std(excess_risks_ridge_50))
        excess_risks_mean_ridge_100.append(np.mean(excess_risks_ridge_100))
        excess_risks_std_ridge_100.append(np.std(excess_risks_ridge_100))
        excess_risks_mean_pcr.append(np.mean(excess_risks_pcr))
        excess_risks_std_pcr.append(np.std(excess_risks_pcr))

    results['Ridge_75'] = {'n_values': n_values,
                        'mean': excess_risks_mean_ridge_75,
                        'std': excess_risks_std_ridge_75}
    results['Ridge_0'] = {'n_values': n_values,
                        'mean': excess_risks_mean_ridge_0,
                        'std': excess_risks_std_ridge_0}
    results['Ridge_50'] = {'n_values': n_values,
                        'mean': excess_risks_mean_ridge_50,
                        'std': excess_risks_std_ridge_50}
    results['Ridge_100'] = {'n_values': n_values,
                        'mean': excess_risks_mean_ridge_100,
                        'std': excess_risks_std_ridge_100}
    results['PCR'] = {'n_values': n_values,
                        'mean': excess_risks_mean_pcr,
                        'std': excess_risks_std_pcr}
    
    # Write results to JSON file
    with open('results_major.json', 'w') as f:
        json.dump(results, f)
    
    
    # Load results from JSON file
    with open('results_major.json', 'r') as f:
        results = json.load(f)
 
    from scipy.stats import linregress
    epsilon = 0  # Small value to avoid log(0)

    plt.figure(figsize=(4, 7.5))
    plt.rcParams["font.family"] = "Times New Roman"

    n_values = np.array(results['Ridge_0']['n_values'])
    mean_excess_risk = np.array(results['Ridge_0']['mean'])
    std_excess_risk = np.array(results['Ridge_0']['std'])
    log_n = np.log(n_values)
    log_excess_risk = np.log(mean_excess_risk + epsilon)
    # Linear regression
    slope, intercept, r_value, p_value, std_err = linregress(log_n, log_excess_risk)
    # Plot the curve
    plt.plot(log_n, log_excess_risk, label=fr'Ridgeless',
            color='black', linestyle='-.')
    plt.fill_between(log_n,
                    np.log(mean_excess_risk - std_excess_risk + epsilon),
                    np.log(mean_excess_risk + std_excess_risk + epsilon),
                    color='gray', alpha=0.2)
    # Annotate the slope on the plot
    midpoint_x = (log_n[0] + log_n[-1]) / 2
    midpoint_y = (log_excess_risk[0] + log_excess_risk[-1]) / 2
    plt.text(log_n[0]-0.6, log_excess_risk[0]-0.28, f'{slope:.2f}', fontsize=14, color='black')


    n_values = np.array(results['Ridge_50']['n_values'])
    mean_excess_risk = np.array(results['Ridge_50']['mean'])
    std_excess_risk = np.array(results['Ridge_50']['std'])
    log_n = np.log(n_values)
    log_excess_risk = np.log(mean_excess_risk + epsilon)
    # Linear regression
    slope, intercept, r_value, p_value, std_err = linregress(log_n, log_excess_risk)
    # Plot the curve
    plt.plot(log_n, log_excess_risk, label=fr'Ridge, $\lambda=n^{{0.5}}$',
            color='black', linestyle='--')
    plt.fill_between(log_n,
                    np.log(mean_excess_risk - std_excess_risk + epsilon),
                    np.log(mean_excess_risk + std_excess_risk + epsilon),
                    color='gray', alpha=0.2)
    # Annotate the slope on the plot
    midpoint_x = (log_n[0] + log_n[-1]) / 2
    midpoint_y = (log_excess_risk[0] + log_excess_risk[-1]) / 2
    plt.text(log_n[0]-0.6, log_excess_risk[0]-0.28, f'{slope:.2f}', fontsize=14, color='black')

    n_values = np.array(results['Ridge_75']['n_values'])
    mean_excess_risk = np.array(results['Ridge_75']['mean'])
    std_excess_risk = np.array(results['Ridge_75']['std'])
    log_n = np.log(n_values)
    log_excess_risk = np.log(mean_excess_risk + epsilon)
    # Linear regression
    slope, intercept, r_value, p_value, std_err = linregress(log_n, log_excess_risk)
    # Plot the curve
    plt.plot(log_n, log_excess_risk, label=fr'Ridge, $\lambda=n^{{0.75}}$',
            color='black', linestyle='-')
    plt.fill_between(log_n,
                    np.log(mean_excess_risk - std_excess_risk + epsilon),
                    np.log(mean_excess_risk + std_excess_risk + epsilon),
                    color='gray', alpha=0.2)
    # Annotate the slope on the plot
    midpoint_x = (log_n[0] + log_n[-1]) / 2
    midpoint_y = (log_excess_risk[0] + log_excess_risk[-1]) / 2
    plt.text(log_n[0]-0.6, log_excess_risk[0]-0.28, f'{slope:.2f}', fontsize=14, color='black')

    n_values = np.array(results['Ridge_100']['n_values'])
    mean_excess_risk = np.array(results['Ridge_100']['mean'])
    std_excess_risk = np.array(results['Ridge_100']['std'])
    log_n = np.log(n_values)
    log_excess_risk = np.log(mean_excess_risk + epsilon)
    # Linear regression
    slope, intercept, r_value, p_value, std_err = linregress(log_n, log_excess_risk)
    # Plot the curve
    plt.plot(log_n, log_excess_risk, label=fr'Ridge, $\lambda=n$',
            color='black', linestyle='dotted')
    plt.fill_between(log_n,
                    np.log(mean_excess_risk - std_excess_risk + epsilon),
                    np.log(mean_excess_risk + std_excess_risk + epsilon),
                    color='gray', alpha=0.2)
    # Annotate the slope on the plot
    midpoint_x = (log_n[0] + log_n[-1]) / 2
    midpoint_y = (log_excess_risk[0] + log_excess_risk[-1]) / 2
    plt.text(log_n[0]-0.6, log_excess_risk[0]-0.28, f'{slope:.2f}', fontsize=14, color='black')

    n_values = np.array(results['PCR']['n_values'])
    mean_excess_risk = np.array(results['PCR']['mean'])
    std_excess_risk = np.array(results['PCR']['std'])
    log_n = np.log(n_values)
    log_excess_risk = np.log(mean_excess_risk + epsilon)
    # Linear regression
    slope, intercept, r_value, p_value, std_err = linregress(log_n, log_excess_risk)
    # Plot the curve
    plt.plot(log_n, log_excess_risk, label=fr'PCR',
            color='#0A1172', linestyle='-')
    plt.fill_between(log_n,
                    np.log(mean_excess_risk - std_excess_risk + epsilon),
                    np.log(mean_excess_risk + std_excess_risk + epsilon),
                    color='#0A1172', alpha=0.2)
    # Annotate the slope on the plot
    midpoint_x = (log_n[0] + log_n[-1]) / 2
    midpoint_y = (log_excess_risk[0] + log_excess_risk[-1]) / 2
    plt.text(log_n[0]-0.6, log_excess_risk[0]-0.28, f'{slope:.2f}', fontsize=14, color='black')


    # Labels and title
    plt.xlabel('Log n', fontsize=16)
    plt.ylabel('Log Excess Risk', fontsize=16)
    plt.legend(fontsize=14, loc='upper right')
    plt.tick_params(axis='both', which='major', labelsize=14)
    # Set aspect ratio to equal
    plt.gca().set_aspect('equal', adjustable='datalim')
    # Set axis range
    plt.ylim(-8, 1.5)
    # Disable grid and adjust layout
    plt.grid(False)
    plt.tight_layout()

    # Save and show plot
    plt.savefig('simulation_major.png')
    plt.savefig('simulation_major.pdf')
    plt.show()

    