import cupy as cp
import matplotlib.pyplot as plt
import json
import numpy as np  # For plotting and saving results

 # Set parameters
k = 10  # Number of informative features
n_test = 1000  # Number of test samples
epsilon_variance = 0.1  # Variance of the noise
n_trial = 10 # Number of trials
max_block_size = 50  # Maximum block size for block-diagonal covariance matrix

# Define n values
n_values = [640, 320, 160, 80, 40, 20, 10]
T_norms = [1]  # Spectral norm for the first k dimensions
U_trace_factors = [5, 25, 100]

if __name__ == '__main__':
    
    # Fix random seed for reproducibility
    cp.random.seed(42)
    gpu_device = 2
    cp.cuda.Device(gpu_device).use() 
    
    # Initialize results dictionary
    results = {}

    # Run simulations
    for T_norm in T_norms:
        results[T_norm] = {}
        for U_trace_factor in U_trace_factors:
            excess_risks_mean = []
            excess_risks_std = []
            for n in n_values:
                # First k dimensions
                # Generate random covariance matrix with spectral norm T_norm
                A_T = cp.random.randn(k, k)
                Sigma_T = A_T @ A_T.T
                # Scale Sigma_T to have spectral norm T_norm
                eigvals_T = cp.linalg.eigvalsh(Sigma_T)
                lambda_max_T = eigvals_T.max()
                Sigma_T = (T_norm / lambda_max_T) * Sigma_T

                # Last n^2 dimensions
                U_trace = U_trace_factor * cp.sqrt(n)
                # Determine block size for block-diagonal covariance matrix
                block_size = min(max_block_size, n**2)  # Adjust block size as needed
                num_blocks = n**2 // block_size
                remainder = n**2 % block_size
                # Initialize list to store variances
                total_trace = 0
                blocks = []
                for i in range(num_blocks):
                    # Generate random covariance matrix for each block
                    A_U = cp.random.randn(block_size, block_size)
                    Sigma_U_block = A_U @ A_U.T
                    # Add to list
                    blocks.append(Sigma_U_block)
                    total_trace += cp.trace(Sigma_U_block)
                if remainder > 0:
                    # Handle the last block if n^2 is not divisible by block_size
                    A_U = cp.random.randn(remainder, remainder)
                    Sigma_U_block = A_U @ A_U.T
                    blocks.append(Sigma_U_block)
                    total_trace += cp.trace(Sigma_U_block)
                # Scale blocks to adjust total trace to U_trace
                scaling_factor = U_trace / total_trace
                for i in range(len(blocks)):
                    blocks[i] *= scaling_factor

                # Define the function to simulate the experiment inside the loop
                def simulate():
                    # Dimensions
                    d = k + n**2
                    # Generate beta
                    beta = cp.zeros(d)
                    beta[:k] = 1 / cp.sqrt(k)
                    # Generate training data
                    x_train = cp.zeros((n, d))
                    x_train[:, :k] = cp.random.normal(0, 1, (n, k))
                    x_train[:, k:] = cp.random.normal(0, n**(-0.75), (n, n**2))
                    epsilon_train = cp.random.normal(0, cp.sqrt(epsilon_variance), n)
                    y_train = x_train @ beta + epsilon_train
                    # Generate test data
                    x_test = cp.zeros((n_test, d))
                    # Generate test samples for first k dimensions using shared Sigma_T
                    x_test[:, :k] = cp.random.multivariate_normal(cp.zeros(k), Sigma_T, n_test)
                    # Generate test samples for last n^2 dimensions using shared blocks
                    x_test_U = []
                    for Sigma_U_block in blocks:
                        mean_block = cp.zeros(Sigma_U_block.shape[0])
                        samples_block = cp.random.multivariate_normal(mean_block, Sigma_U_block, n_test)
                        x_test_U.append(samples_block)
                    x_test[:, k:] = cp.hstack(x_test_U)
                    epsilon_test = cp.random.normal(0, cp.sqrt(epsilon_variance), n_test)
                    y_test = x_test @ beta + epsilon_test
                    # Minimum norm interpolator
                    X_train = x_train
                    XXT = X_train @ X_train.T  # n x n matrix
                    try:
                        inv_XXT = cp.linalg.inv(XXT)
                    except cp.linalg.LinAlgError:
                        inv_XXT = cp.linalg.pinv(XXT)
                    w = X_train.T @ inv_XXT @ y_train
                    # Compute MSE on test set
                    y_pred = x_test @ w
                    mse_model = cp.mean((y_test - y_pred) ** 2)
                    # Compute MSE of true model
                    y_true = x_test @ beta
                    mse_true = cp.mean((y_test - y_true) ** 2)
                    # Excess risk
                    excess_risk = mse_model - mse_true
                    # Convert excess_risk to numpy scalar
                    excess_risk = cp.asnumpy(excess_risk)
                    print(f"T_norm = {T_norm}, U_trace_factor = {U_trace_factor}, n = {n}, excess risk = {excess_risk}")
                    return excess_risk

                excess_risks = [simulate() for _ in range(n_trial)]
                excess_risks = np.array(excess_risks)
                mean_excess_risk = np.mean(excess_risks)
                std_excess_risk = np.std(excess_risks)
                excess_risks_mean.append(mean_excess_risk)
                excess_risks_std.append(std_excess_risk)

            results[T_norm][U_trace_factor] = {'n_values': n_values,
                                               'mean': excess_risks_mean,
                                               'std': excess_risks_std}
            # Write results to JSON file
            with open('results_minor_U.json', 'w') as f:
                json.dump(results, f)
    
    # Load results from JSON file
    with open('results_minor_U.json', 'r') as f:
        results = json.load(f)
 
    from scipy.stats import linregress
    epsilon = 0  # 1e-8  # Small value to avoid log(0)
    line_styles = ['dotted', '--', '-']
    T_norm = 1

    plt.figure(figsize=(4, 7.5))
    plt.rcParams["font.family"] = "Times New Roman"

    for idx, U_factor in enumerate(U_trace_factors):
        n_values = np.array(results[str(T_norm)][str(U_factor)]['n_values'])
        mean_excess_risk = np.array(results[str(T_norm)][str(U_factor)]['mean'])
        std_excess_risk = np.array(results[str(T_norm)][str(U_factor)]['std'])
        
        log_n = np.log(n_values)
        log_excess_risk = np.log(mean_excess_risk + epsilon)

        # Linear regression
        slope, intercept, r_value, p_value, std_err = linregress(log_n, log_excess_risk)

        # Plot the curve
        plt.plot(log_n, log_excess_risk, label=fr'$\mathrm{{tr}}[U]~/~\mathrm{{tr}}[V] = {U_factor}$',
                color='black', linestyle=line_styles[idx])
        plt.fill_between(log_n,
                        np.log(mean_excess_risk - std_excess_risk + epsilon),
                        np.log(mean_excess_risk + std_excess_risk + epsilon),
                        color='gray', alpha=0.2)
        
        # Annotate the slope on the plot
        midpoint_x = (log_n[0] + log_n[-1]) / 2
        midpoint_y = (log_excess_risk[0] + log_excess_risk[-1]) / 2
        plt.text(log_n[0]-0.6, log_excess_risk[0]-0.23, f'{slope:.2f}', fontsize=14, color='black')

    # Labels and title
    plt.xlabel('Log n', fontsize=16)
    plt.ylabel('Log Excess Risk', fontsize=16)
    plt.legend(fontsize=14)
    plt.tick_params(axis='both', which='major', labelsize=14)
    # Set aspect ratio to equal
    plt.gca().set_aspect('equal', adjustable='datalim')
    # Set x-axis range
    plt.xlim(2, 6.8)
    # Disable grid and adjust layout
    plt.grid(False)
    plt.tight_layout()

    # Save and show plot
    plt.savefig('simulation_minor_U.png')
    plt.savefig('simulation_minor_U.pdf')
    plt.show()
