import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import random

# --- SEED ---
seed = 2025
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

# --- Parameters ---
n_samples = 1000
n_features = 900
n_layers = 4
hidden_dim = n_features
noise_std = 0.1
n_epochs = 500
learning_rate = 1e-4
log_interval = 20


# Teacher coefficients (power-law decay)
q = 1.1
coeff = 1 / (np.arange(1, n_features+1)**q)
beta_star = 1 * coeff 

n_nonzero = 200
beta_star[n_nonzero:] = 0


beta_star_torch = torch.tensor(beta_star, dtype=torch.float32)

# Effective noise variance (RKHS view, isotropic design)
sigma_eff2 = (noise_std**2 + np.linalg.norm(beta_star, 1)**2) / n_samples


# Generate random-design data
X_tr = (np.random.randint(0,  2,  size=(n_samples,n_features))*2 - 1).astype(np.float32)
y_tr = X_tr @ beta_star + noise_std * np.random.randn(n_samples)
X_tr_torch = torch.FloatTensor(X_tr)
y_tr_torch = torch.FloatTensor(y_tr)

# --- Model ---
class DeepLinearNetwork(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers):
        super().__init__()
        layers = [nn.Linear(input_dim, hidden_dim, bias=False)]
        for _ in range(num_layers - 1):
            layers.append(nn.Linear(hidden_dim, hidden_dim, bias=False))
        self.layers = nn.Sequential(*layers)
        self.final_layer = nn.Linear(hidden_dim, 1, bias=False)
        self.initialize_weights()
    def initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                eye = torch.eye(m.out_features, m.in_features)
                with torch.no_grad():
                    m.weight.copy_(eye + 0.01*torch.randn_like(m.weight))
    def forward_representation(self, x):
        return self.layers(x)
    def forward(self, x):
        return self.final_layer(self.forward_representation(x))

model = DeepLinearNetwork(n_features, hidden_dim, n_layers)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# --- Training loop ---
history = {'epoch': [], 'loss': [], 'esd': [], 'risk': []}
print("Starting training...")

for epoch in range(n_epochs+1):
    y_hat = model(X_tr_torch).squeeze()
    loss = criterion(y_hat, y_tr_torch)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if epoch % log_interval == 0:
        with torch.no_grad():
            # Representation A = product of hidden weights
            A = torch.eye(n_features)
            for L in model.layers:
                A = L.weight @ A
            G = A.T @ A

            # Eigendecomposition
            lam, U = torch.linalg.eigh(G)
            lam = lam.flip(0)
            U = U.flip(1)

            theta = U.T @ beta_star_torch
            tail_sq = torch.cumsum(theta.flip(0)**2, dim=0).flip(0)
            H = tail_sq[1:] / torch.arange(1, len(theta))
            where = torch.where(H <= sigma_eff2)[0]
            d_dagger = int(where[0].item() + 1) if len(where) > 0 else len(theta)

            # Compute effective predictor v_t
            v = A.T @ model.final_layer.weight.data.flatten()
            risk = torch.norm(v - beta_star_torch).item()**2 + noise_std**2

            history['epoch'].append(epoch)
            history['loss'].append(loss.item())
            history['esd'].append(d_dagger)
            history['risk'].append(risk)

            print(f"Epoch {epoch}, Loss {loss.item():.4f}, ESD {d_dagger}, Risk {risk:.4f}")

# --- Plotting ---
fig, ax1 = plt.subplots(figsize=(12,7))
color = 'tab:red'
ax1.set_xlabel('Epoch')
ax1.set_ylabel('ESD', color=color)
ax1.plot(history['epoch'], history['esd'], color=color, marker='o', label='ESD')
ax1.tick_params(axis='y', labelcolor=color)
ax1.grid(True, linestyle='--')

ax2 = ax1.twinx()
color = 'tab:blue'
ax2.set_ylabel('Risk', color=color)
ax2.plot(history['epoch'], history['risk'], color='green', linestyle='-', marker='s', label='True Risk')
ax2.tick_params(axis='y', labelcolor=color)

lines1, labels1 = ax1.get_legend_handles_labels()
lines2, labels2 = ax2.get_legend_handles_labels()
ax1.legend(lines1 + lines2, labels1 + labels2, loc='upper right')


fig.suptitle('Pathwise RKHS-ESD with True Risk')


# Save the plot
output_filename = "pathwise_esd_network.pdf"
plt.savefig(output_filename)
print(f"Plot saved to {output_filename}")


plt.show()
