import numpy as np
from scipy.integrate import quad
import pandas as pd
from numba import njit
import torch
import os
from sys import argv



def solve_poly(z, sigma, kappa):
    alpha = 1 / kappa
    R_noise = sigma**2
    a3 = np.sqrt(alpha) * R_noise
    a2 = -(np.sqrt(alpha) * z + R_noise)
    a1 = (z + np.sqrt(alpha) - alpha**(-1 / 2))
    a0 = -1

    coefficients = [a3, a2, a1, a0]

    return np.roots(coefficients)




def edges_rho(sigma, kappa):
    alpha = 1/kappa
    R_noise = sigma**2

    a0 = -12 * R_noise + (4 * R_noise) / alpha + 12 * alpha * R_noise - 4 * alpha**2 * R_noise - 20 * R_noise**2 + R_noise**2 / alpha - 8 * alpha * R_noise**2 - 4 * R_noise**3
    a1 = -(10 * R_noise) / np.sqrt(alpha) + 2 * np.sqrt(alpha) * R_noise + 8 * alpha**(3/2) * R_noise - (2 * R_noise**2) / np.sqrt(alpha) + 8 * np.sqrt(alpha) * R_noise**2
    a2 = 1 - 2 * alpha + alpha**2 + 8 * R_noise - 2 * alpha * R_noise + R_noise**2
    a3 = -2 * np.sqrt(alpha) - 2 * alpha**(3/2) - 2 * np.sqrt(alpha) * R_noise
    a4 = alpha

    coefficients = [a4, a3, a2, a1, a0]

    roots_all = np.roots(coefficients)
    real_roots = np.real(roots_all[np.abs(np.imag(roots_all)) < 1e-10])

    return np.sort(real_roots)

def rho(x, sigma, kappa):        
    return np.max(np.imag(solve_poly(x-1e-12j, sigma, kappa))) / np.pi

def ERM_I_y0(delta, epsilon, kappa, kappa_stud, edges):

    if kappa_stud == 1:
        return epsilon
    
    raise ValueError("kappa_stud must be 1 for the current implementation.")

def ERM_I_precise(delta, epsilon, param_prior):

    Q0, lreg, kappa, kappa_stud, posdef = param_prior

    edges = edges_rho(delta, kappa)
    
    if posdef and 0 < kappa_stud < 1:
        y0 = ERM_I_y0(delta, epsilon, kappa, kappa_stud, edges)
    else:
        y0 = epsilon

    def integrand_plus(x):
        return rho(x, delta, kappa) * (x + epsilon)**2
    
    def integrand_minus(x):
        return rho(x, delta, kappa) * (x - epsilon)**2
    
    val1 = 0
        
    if (not posdef) and len(edges) == 2: 

        if -epsilon < edges[0]:
            val1 = 0
        
        elif -epsilon < edges[1]:
            val1 = quad(integrand_plus, edges[0], -epsilon)[0]
        
        else:
            val1 = quad(integrand_plus, edges[0], edges[1])[0]

    elif (not posdef) and len(edges) == 4: 
    
        if -epsilon < edges[0]:
            val1 = 0

        elif -epsilon < edges[1]:
            val1 = quad(integrand_plus, edges[0], -epsilon)[0]

        elif -epsilon < edges[2]:
            val1 = quad(integrand_plus, edges[0], edges[1])[0]

        elif -epsilon < edges[3]:
            val1 = quad(integrand_plus, edges[0], edges[1])[0] + quad(integrand_plus, edges[2], -epsilon)[0]

        else:
            val1 = quad(integrand_plus, edges[0], edges[1])[0] + quad(integrand_plus, edges[2], edges[3])[0]

    val2 = 0

    if len(edges) == 2: 

        if y0 > edges[1]:
            val2 = 0
        
        elif y0 > edges[0]:
            val2 = quad(integrand_minus, y0, edges[1])[0]
        
        else:
            val2 = quad(integrand_minus, edges[0], edges[1])[0]

    elif len(edges) == 4: 

        if y0 > edges[3]:
            val2 = 0

        elif y0 > edges[2]:
            val2 = quad(integrand_minus, y0, edges[3])[0]

        elif y0 > edges[1]:
            val2 = quad(integrand_minus, edges[2], edges[3])[0]

        elif y0 > edges[0]:
            val2 = quad(integrand_minus, edges[2], edges[3])[0] + quad(integrand_minus, y0, edges[1])[0]

        else:
            val2 = quad(integrand_minus, edges[2], edges[3])[0] + quad(integrand_minus, edges[0], edges[1])[0]

    return val1 + val2


@njit
def _flat_len_to_T(L):
    return int((np.sqrt(8.0 * L + 1.0) - 1.0) / 2.0)

@njit
def sym(A_flat):
    L = len(A_flat)
    T = _flat_len_to_T(L)
    R = np.zeros((T, T), dtype=A_flat.dtype)
    c = 0
    for i in range(T):
        for j in range(i, T):
            v = A_flat[c]
            R[i, j] = v
            R[j, i] = v
            c += 1
    return R

@njit
def desym(A):
    T = A.shape[0]
    L = T * (T + 1) // 2
    out = np.empty(L, dtype=A.dtype)
    c = 0
    for i in range(T):
        for j in range(i, T):
            out[c] = A[i, j]
            c += 1
    return out

@njit
def _rowwise_max(X):
    n, m = X.shape
    out = np.empty(n, dtype=X.dtype)
    for i in range(n):
        row_max = X[i, 0]
        for j in range(1, m):
            if X[i, j] > row_max:
                row_max = X[i, j]
        out[i] = row_max
    return out

@njit
def softmax(x_flat, beta=1.0):
    X = sym(x_flat)

    max_X = _rowwise_max(X)
    for i in range(X.shape[0]):
        for j in range(X.shape[1]):
            X[i, j] = X[i, j] - max_X[i]

    for i in range(X.shape[0]):
        for j in range(X.shape[1]):
            X[i, j] = np.exp(beta * X[i, j])

    for i in range(X.shape[0]):
        s = 0.0
        for j in range(X.shape[1]):
            s += X[i, j]
        invs = 1.0 / s
        for j in range(X.shape[1]):
            X[i, j] *= invs

    return X.ravel()  
@njit
def scale_upper_flat(x_flat):
    """
    Multiplies the symmetric matrix (from x_flat) elementwise by sqrt(1 + I),
    i.e. multiplies diagonal entries by sqrt(2) and off-diagonals by 1,
    then flattens back to upper-triangular.
    """
    L = len(x_flat)
    T = _flat_len_to_T(L)
    X = sym(x_flat)
    rt2 = np.sqrt(2.0)
    for i in range(T):
        for j in range(T):
            if i == j:
                X[i, j] *= rt2

    return desym(X)


def generate_teacher(cutoff,gamma,d):
    D=np.concatenate((np.arange(1,cutoff+1)**(-gamma),np.zeros(d-cutoff)))
    D=D/np.linalg.norm(D)*np.sqrt(d)  
    return np.diag(D)


def triu_indices(T, device):
    return torch.triu_indices(T, T, offset=0, device=device)

def upper_to_sym(x_upper, T):
    """
    x_upper: (B, L), L = T(T+1)/2
    returns: (B, T, T) symmetric
    """
    B, L = x_upper.shape
    iu = triu_indices(T, x_upper.device)
    X = torch.zeros((B, T, T), device=x_upper.device, dtype=x_upper.dtype)
    X[:, iu[0], iu[1]] = x_upper
    X[:, iu[1], iu[0]] = x_upper
    return X

def scale_diag_sqrt2(X):
    """
    Multiply diagonal by sqrt(2), off-diagonals unchanged.
    X: (B, T, T)
    """
    B, T, _ = X.shape
    rt2 = torch.sqrt(torch.tensor(2.0, device=X.device, dtype=X.dtype))
    diag = torch.arange(T, device=X.device)
    X = X.clone()
    X[:, diag, diag] *= rt2
    return X

def row_softmax(X, beta=1.0):
    """
    Row-wise softmax
    X: (B, T, T)
    """
    return torch.softmax(beta * X, dim=-1)

# =========================
#  Sampling (seed-free)
# =========================

def sample_z_zstar_upper(B, L, Q0, m, q, device="cpu", dtype=torch.float32):
    """
    Returns:
      zstar_upper, z_upper : (B, L)
    """
    Q0 = torch.as_tensor(Q0, device=device, dtype=dtype)
    m  = torch.as_tensor(m,  device=device, dtype=dtype)
    q  = torch.as_tensor(q,  device=device, dtype=dtype)

    u = torch.randn((B, L), device=device, dtype=dtype)
    v = torch.randn((B, L), device=device, dtype=dtype)

    eps = torch.as_tensor(1e-8, device=device, dtype=dtype)
    Q0c = torch.clamp(Q0, min=eps)

    zstar = torch.sqrt(Q0c) * u
    coeff = torch.sqrt(torch.clamp(q - (m*m)/Q0c, min=0.0))
    z = (m / torch.sqrt(Q0c)) * u + coeff * v
    return zstar, z

# =========================
#  Batched LBFGS proximal
# =========================

def proximal_batch_lbfgs(
    sigma, zstar_upper, z_upper,
    T, betastar=1.0, beta=1.0,
    noise_in=0.0, noise_out=0.0,
    max_iter=40, history_size=10,
    lr=1.0,
    line_search="strong_wolfe",
    device=None, dtype=torch.float32
):
    """
    Batched proximal operator using LBFGS.

    Inputs:
      zstar_upper, z_upper: (B, L)
    Returns:
      h_upper: (B, L)
    """
    if device is None:
        device = z_upper.device

    zstar_upper = zstar_upper.to(device=device, dtype=dtype)
    z_upper     = z_upper.to(device=device, dtype=dtype)

    B, L = z_upper.shape

    # --- fixed noises (drawn once, unseeded) ---
    add_in = torch.zeros_like(zstar_upper)
    if noise_in > 0:
        add_in = torch.sqrt(torch.as_tensor(noise_in / 2.0, device=device, dtype=dtype)) \
                 * torch.randn((B, L), device=device, dtype=dtype)

    add_out = torch.zeros((B, T, T), device=device, dtype=dtype)
    if noise_out > 0:
        add_out = torch.sqrt(torch.as_tensor(noise_out, device=device, dtype=dtype)) \
                  * torch.randn((B, T, T), device=device, dtype=dtype)

    # --- target labels ---
    Zs = upper_to_sym(zstar_upper + add_in, T)
    Zs = scale_diag_sqrt2(Zs)
    label1 = row_softmax(Zs, beta=betastar) + add_out

    # --- init ---
    h_upper = (zstar_upper + 0.01 * torch.randn((B, L), device=device, dtype=dtype)).clone()
    h_upper.requires_grad_(True)

    sigma_t = torch.as_tensor(float(sigma), device=device, dtype=dtype)

    opt = torch.optim.LBFGS(
        [h_upper],
        lr=lr,
        max_iter=max_iter,
        history_size=history_size,
        line_search_fn=line_search
    )

    def closure():
        opt.zero_grad(set_to_none=True)

        H = upper_to_sym(h_upper, T)
        Hf = scale_diag_sqrt2(H)
        label2 = row_softmax(Hf, beta=beta)

        diff_hz = h_upper - z_upper
        term1 = (diff_hz * diff_hz).sum(dim=1) / (2.0 * sigma_t)

        diff_lbl = label1 - label2
        term2 = (diff_lbl * diff_lbl).sum(dim=(1, 2))

        loss = (term1 + term2).mean()
        loss.backward()
        return loss

    opt.step(closure)
    return h_upper.detach()



def ERM_m_hat_eq(q, m, sigma, param_output, proximal_val, z_all, zstar_all):

    Q_0, T, beta, betastar, noise_in, noise_out = param_output
    
    samples = proximal_val.shape[0]

    res = 0
    for i in range(0,samples):
        prox = proximal_val[i]
        z = z_all[i]
        zstar = zstar_all[i]
        res += np.sum((prox - z) * (- m * z + q * zstar)) / sigma / (Q_0 * q - m**2)
    return res / samples


def ERM_q_hat_eq(q, m, sigma, param_output, proximal_val, z_all, zstar_all):

    Q_0, T, beta, betastar, noise_in, noise_out = param_output

    samples = proximal_val.shape[0]

    res = 0

    for i in range(0,samples):
        prox = proximal_val[i]
        z = z_all[i]
        zstar = zstar_all[i]
        res += np.sum((prox - z)**2) / sigma**2

    return res / samples

def ERM_sigma_hat_eq(q, m, sigma, param_output, proximal_val, z_all, zstar_all):

    Q_0, T, beta, betastar, noise_in, noise_out = param_output
    
    samples = proximal_val.shape[0]

    res = 0

    for i in range(0,samples):
        prox = proximal_val[i]
        z = z_all[i]
        zstar = zstar_all[i]
        res += np.sum((prox - z) * (Q_0 * z - m * zstar)) / sigma / (Q_0 * q - m**2)

    return res / samples


def relu(x):
    return np.maximum(0, x)


def denoiser(nu,lreg,sigma_hat):
    return relu(nu-2*lreg)/sigma_hat


def sample_GOE(d):
    Z=np.random.randn(d,d)
    Z=(Z+Z.T)/np.sqrt(2*d) 
    return Z


def ERM_sigma_eq_emp(nu,nu_denoised, sigma_hat, kappa, lreg,d):
    sig=2*np.sum(nu>2*lreg)/sigma_hat
    for i in range(d):
        mask=(nu!=nu[i])
        sig+=np.sum((nu_denoised[i]-nu_denoised[mask])/(nu[i]-nu[mask]))
    return sig/d**2


@njit
def ERM_sigma_eq_emp_fast(nu, nu_denoised, sigma_hat, kappa, lreg, d):
    # avoids Python + avoids mask arrays.
    sig = 2.0 * (nu > 2.0 * lreg).sum() / sigma_hat

    pair_sum = 0.0
    for i in range(d):
        for j in range(d):
            if j == i:
                continue
            denom = nu[i] - nu[j]
            if denom != 0.0:
                pair_sum += (nu_denoised[i] - nu_denoised[j]) / denom

    sig += pair_sum
    return sig / (d * d)



def ERM_state_evolution_equations(S_star, overlaps, alpha, param_prior, param_output, samples=1000, samples_prior=100):
    d = S_star.shape[0]
    Q_0, T, beta, betastar, noise_in, noise_out = param_output
    Q_0, lreg, kappa, kappa_stud, posdef = param_prior
    
    q, m, sigma, q_hat, m_hat, sigma_hat = overlaps




    def sample_and_stats(Sstar_diag, lreg, q_hat, m_hat, sigma_hat, d):
        # build M = sqrt(q_hat)*GOE + m_hat*diag(S*)
        M = np.sqrt(q_hat) * sample_GOE(d)
        M[np.diag_indices(d)] += m_hat * Sstar_diag  # add diagonal without forming diag matrix

        D, O = np.linalg.eigh(M)   # O(d^3) dominant cost
        fD = denoiser(D, lreg, sigma_hat)

        # q = (1/d) sum fD^2
        q_val = np.mean(fD * fD)

        # m = (1/d) sum_i fD_i * diag( O^T diag(S*) O )_i
        # diagproj_i = sum_k s_k * O_{k,i}^2  -> (O**2).T @ s
        diagproj = (O * O).T @ Sstar_diag          # O(d^2)
        m_val = np.mean(fD * diagproj)

        # sigma uses only eigenvalues + denoised eigenvalues (no O, no matrices)
        sigma_val = ERM_sigma_eq_emp_fast(D, fD, sigma_hat, kappa, lreg, d)

        return q_val, m_val, sigma_val


    device = "cuda" 
    dtype  = torch.float32

    L = int(T * (T + 1) / 2)
    B = samples

    zstar_upper, z_upper = sample_z_zstar_upper(
        B, L, Q_0, m, q, device=device, dtype=dtype
    )

    prox_upper = proximal_batch_lbfgs(
        sigma=sigma,
        zstar_upper=zstar_upper,
        z_upper=z_upper,
        T=T,
        betastar=betastar,
        beta=beta,
        noise_in=noise_in,
        noise_out=noise_out,
        max_iter=40,
        device=device,
        dtype=dtype
    )

    proximal_val_all = prox_upper.cpu().numpy()
    z_all = z_upper.cpu().numpy()
    zstar_all = zstar_upper.cpu().numpy()

    Sstar_diag = np.diag(S_star).copy()  # do once

    qs = np.empty(samples_prior)
    ms = np.empty(samples_prior)
    sigmas = np.empty(samples_prior)

    for t in range(samples_prior):
        qs[t], ms[t], sigmas[t] = sample_and_stats(
            Sstar_diag, lreg, q_hat, m_hat, sigma_hat, d
        )

    q_new = qs.mean()
    m_new = ms.mean()
    sigma_new = sigmas.mean()

    

    q_hat_new = 2 * alpha * ERM_q_hat_eq(q, m, sigma, param_output, proximal_val_all, z_all, zstar_all)
    m_hat_new = 2 * alpha *  ERM_m_hat_eq(q, m, sigma, param_output, proximal_val_all, z_all, zstar_all)
    sigma_hat_new = - 2 * alpha *  ERM_sigma_hat_eq(q, m, sigma, param_output, proximal_val_all, z_all, zstar_all)

    return np.array([q_new, m_new, sigma_new, q_hat_new, m_hat_new, sigma_hat_new])



def ERM_solution(alpha, Q_0, lreg, kappa, kappa_stud, noise_in, noise_out, posdef, T, beta, betastar, d, gamma, q = 0.1887961166803156, m = 0.2695953266947839, sigma = 21.213354211136686, damping=1.0, max_iter=50000, toll=1e-5, samples=1000, samples_prior=100, folder_name="data"):
    cutoff=d
    S_star = generate_teacher(cutoff,gamma,d)


    param_prior = [Q_0, lreg, kappa, kappa_stud, posdef]
    param_output = [Q_0, T, beta, betastar, noise_in, noise_out]


    device = "cuda"
    dtype  = torch.float32

    L = int(T * (T + 1) / 2)
    B = samples

    zstar_upper, z_upper = sample_z_zstar_upper(
        B, L, Q_0, m, q, device=device, dtype=dtype
    )

    prox_upper = proximal_batch_lbfgs(
        sigma=sigma,
        zstar_upper=zstar_upper,
        z_upper=z_upper,
        T=T,
        betastar=betastar,
        beta=beta,
        noise_in=noise_in,
        noise_out=noise_out,
        max_iter=40,
        device=device,
        dtype=dtype
    )

    proximal_val_all = prox_upper.cpu().numpy()
    z_all = z_upper.cpu().numpy()
    zstar_all = zstar_upper.cpu().numpy()


    q_hat =  2 * alpha * ERM_q_hat_eq(q, m, sigma, param_output, proximal_val_all, z_all, zstar_all)
    m_hat =  2 * alpha * ERM_m_hat_eq(q, m, sigma, param_output, proximal_val_all, z_all, zstar_all)
    sigma_hat = -  2 * alpha * ERM_sigma_hat_eq(q, m, sigma, param_output, proximal_val_all, z_all, zstar_all)

    overlaps = np.array([q, m, sigma, q_hat, m_hat, sigma_hat])


    df = pd.DataFrame({
        "q": [q],
        "m": [m],
        "sigma": [sigma],
        "q_hat": [q_hat],
        "m_hat": [m_hat],
        "sigma_hat": [sigma_hat],
        "MSE": [Q_0 + q - 2 * m],
    })
    df.to_csv(f"{folder_name}/alpha_{alpha}_lreg_{lreg}_noise_{noise_in}_T_{T}_d_{d}_gamma_{gamma}.csv", index=False)


    for i in range(max_iter):
        new_overlaps = ERM_state_evolution_equations(S_star, overlaps, alpha, param_prior, param_output, samples=samples, samples_prior=samples_prior)

        err_toll = np.linalg.norm(new_overlaps - overlaps)


        # append the NEW overlaps (not the old ones)
        df_new = pd.DataFrame({
            "q": [new_overlaps[0]],
            "m": [new_overlaps[1]],
            "sigma": [new_overlaps[2]],
            "q_hat": [new_overlaps[3]],
            "m_hat": [new_overlaps[4]],
            "sigma_hat": [new_overlaps[5]],
            "MSE": [Q_0 + new_overlaps[0] - 2 * new_overlaps[1]],
        })
        df = pd.concat([df, df_new], ignore_index=True)
        df.to_csv(f"{folder_name}/alpha_{alpha}_lreg_{lreg}_noise_{noise_in}_T_{T}_d_{d}_gamma_{gamma}.csv", index=False)

        if err_toll < toll:
            return new_overlaps, Q_0 + new_overlaps[0] - 2 * new_overlaps[1]

        overlaps = (1 - damping) * overlaps + damping * new_overlaps

    return overlaps, float('NaN'), float('NaN')



if __name__ == "__main__":
    d = 400
    lreg = 1/d
    gamma = 0.75

    folder_name = "data_powerlaw_spectrum"
    # if there is no folder, create it
    if not os.path.exists(folder_name):
        os.makedirs(folder_name)

    
    n_eff_list = np.logspace(-1, 5, 64)
    alpha_list = n_eff_list / d
    alpha_idx = int(argv[1])
    alpha = alpha_list[alpha_idx]

    kappa=2.0
    noise_in=0.5
    noise_out=0.0
    Q_0 = 1 
    beta = 1
    betastar = 1
    T = 2
    kappa_stud = 1
    posdef = True

    ERM_solution(alpha, Q_0, lreg, kappa, kappa_stud, noise_in, noise_out, posdef, T, beta, betastar, d, gamma, 
                q = 0.6, m = 0.2, sigma = 1., damping=0.1, max_iter=500, toll=1e-5, 
                samples=int(1e5), samples_prior=int(1e3), folder_name=folder_name)