import torch
import numpy as np
import matplotlib.pyplot as plt


def gradient(X, y, w):
    return X @ (X.T @ w - y)


def gradient_descent(X, y, w_0, w_gt, lr):
    g = []
    w = []
    max_iter = 0
    w.append(w_0) 
    while True:
        max_iter += 1
        g.append(gradient(X, y, w_0))
        w_0 = w_0 - lr * g[-1]
        w.append(w_0)
        
        if len(g) >= 2 and (g[-1]-g[-2]).norm() <= 1e-3:
            break
            
    return w, g, max_iter


def data_generation(d, n, lr):
    X = torch.randn(d, n) / np.sqrt(d)
    w_star = torch.randn(d, 1) / np.sqrt(d)  
    y = X.T @ w_star
    w, g, max_iter = gradient_descent(X, y, torch.zeros(d, 1), w_star, lr)
    return X, y, torch.stack(w, dim=-1).squeeze(), torch.stack(g, dim=-1).squeeze(), max_iter, w_star.squeeze()
