import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np


n = 1000
p_values = np.arange(10, 1001, 30)
T = 8
sigma = 0.02
num_simulations = 100
results = {
    'M8_mean': [],
    'M8_std': [],
    'G8_mean': [],
    'G8_std': []  }



def generate_orthogonal_weights(p, T):
    W = np.random.randn(p, T)
    Q, _ = np.linalg.qr(W)
    return Q[:, :T].T



for p in p_values:
    if p < T:
        print(f"Ignore p={p} (requires p >= {T} to generate orthogonal vectors)")
        continue
    print(f"\nRunning experiments with p = {p}")
    M8_list, G8_list = [], []
    for sim in range(num_simulations):
        w_true_list = generate_orthogonal_weights(p, T)
        # check if orthogonal
        for i in range(T):
            for j in range(i + 1, T):
                dot_prod = np.dot(w_true_list[i], w_true_list[j])
                if not np.isclose(dot_prod, 0, atol=1e-8):
                    print(f"Warning: w_{i + 1} and w_{j + 1} are not orthogonal, inner product = {dot_prod}")
        # initialize
        w = np.zeros(p)
        w_history = []
        # calculate w
        for t in range(1, T + 1):
            w_true = w_true_list[t-1]
            X_t = np.random.randn(p, n)
            z_t = np.random.randn(n) * sigma
            y_t = X_t.T @ w_true + z_t
            Xt_pinv = np.linalg.pinv(X_t @ X_t.T) @ X_t
            update = Xt_pinv @ z_t
            w = w_true + update
            w_history.append(w.copy())
        M8 = 0
        G8 = 0
        w8 = w_history[7]
        for i in range(7):
            wi = w_history[i]
            wi_star = w_true_list[i]
            term1 = np.linalg.norm(w8 - wi_star)**2    # ||w_8 - w_i*||^2
            term2 = np.linalg.norm(wi - wi_star)**2
            M8 += (term1 - term2)
        M8 /= 7
        for i in range(8):
            wi_star = w_true_list[i]
            term3 = np.linalg.norm(w8 - wi_star)**2
            G8 += term3
        G8 /= 8
        M8_list.append(M8)
        G8_list.append(G8)
        if (sim + 1) % 50 == 0:
            print(f"Completed {sim + 1} simulations")
    results['M8_mean'].append(np.mean(M8_list))
    results['M8_std'].append(np.std(M8_list))
    results['G8_mean'].append(np.mean(G8_list))
    results['G8_std'].append(np.std(G8_list))
    print(f"M8:{results['M8_mean'][-1]:.4f}±{results['M8_std'][-1]:.4f}, "
          f"G8:{results['G8_mean'][-1]:.4f}±{results['G8_std'][-1]:.4f}")


