import numpy as np
import pandas as pd
from sys import argv
import torch
import os

def min_norm_loss_lbfgs(x, y, lreg, S_true, rank=None, max_iter=500, tol=1e-9, verbose=False, print_every=10):
    device = 'cpu'
    dtype = torch.float64

    x_t = torch.as_tensor(x, dtype=dtype, device=device)
    y_t = torch.as_tensor(y, dtype=dtype, device=device)
    S_true_t = torch.as_tensor(S_true, dtype=dtype, device=device)

    N, D = x_t.shape
    if rank is None:
        rank = D

    B = torch.randn(D, rank, dtype=dtype, device=device) / np.sqrt(D)
    B.requires_grad_(True)

    optimizer = torch.optim.LBFGS([B], lr=1.0, max_iter=max_iter,
                                  line_search_fn='strong_wolfe',
                                  tolerance_grad=tol, tolerance_change=tol)

    sqrtD_t = torch.sqrt(torch.tensor(float(D), dtype=dtype, device=device))
    D_t = torch.tensor(float(D), dtype=dtype, device=device)

    it = {'k': 0}

    def closure():
        optimizer.zero_grad()

        Z = x_t @ B                 
        quad = (Z * Z).sum(dim=1)   
        trS = (B * B).sum()        
        yhat = (quad - trS) / sqrtD_t

        data = yhat - y_t
        loss = (data @ data) / (D_t * D_t) + (lreg / D_t) * trS
        loss.backward()

        if verbose:
            it['k'] += 1
            if it['k'] % print_every == 0:
                with torch.no_grad():
                    S_hat_t = B @ B.T
                    overlap_t = ((S_hat_t - S_true_t) ** 2).sum() / D_t

                    print(f"iter {it['k']:4d}:  loss={loss.item():.6e}  overlap={overlap_t.item():.6e}")
        return loss

    optimizer.step(closure)

    with torch.no_grad():
        S_hat = (B @ B.T).cpu().numpy()
        overlap = float(((B @ B.T - S_true_t) ** 2).sum().cpu().item() / D)

    return overlap, S_hat



if __name__ == "__main__":
    D = int(argv[1])
    gamma = float(argv[2])
    noise = float(argv[3])
    alpha = int(argv[4])


    samples = 1
    overlap_list = []

    for _ in range(samples):
        N = int(alpha * D)

        lreg = 1
        typ = "const"

        x = np.random.randn(N, D)


        Q, _ = np.linalg.qr(np.random.randn(D, D))
        w = np.arange(1, D + 1, dtype=float)**(-gamma)
        w *= np.sqrt(D) / np.linalg.norm(w)
        trS = w.sum()

        z = x @ Q
        quad = (z * z) @ w
        y = (quad - trS) / np.sqrt(D) + np.sqrt(noise) * np.random.randn(N)

        S_true = (Q * w) @ Q.T   

        overlap, S_hat = min_norm_loss_lbfgs(x, y, lreg=lreg, S_true=S_true, rank=D, max_iter=10000, tol=1e-10, verbose=False)

        overlap_list.append(overlap) 
        folder_name = f"data/typ_{typ}_D_{D}_noise_{float(noise)}_gamma_{float(gamma)}"

        if not os.path.exists(folder_name):
            os.makedirs(folder_name)

        pd.DataFrame(overlap_list, columns=['overlap']).to_csv(f'{folder_name}/alpha={alpha}_lreg={lreg}.csv', index=False)

