import numpy as np
from scipy.integrate import quad,trapezoid
from scipy.optimize import root
from scipy.optimize import bisect
import pandas as pd
import sys
import os



def relu(x):
    return np.maximum(0, x)

def denoiser(nu,lreg,sigma_hat):
    return relu(nu-2*lreg)/sigma_hat

def generate_teacher(cutoff,gamma,d):
    D=np.concatenate((np.arange(1,cutoff+1)**(-gamma),np.zeros(d-cutoff)))
    D=D/np.linalg.norm(D)*np.sqrt(d)  
    return np.diag(D)

def sample_GOE(d):
    Z=np.random.randn(d,d)
    Z=(Z+Z.T)/np.sqrt(2*d) 
    return Z

def ERM_sigma_eq_emp(nu,nu_denoised, sigma_hat, kappa, lreg,d):
    sig=2*np.sum(nu>2*lreg)/sigma_hat
    for i in range(d):
        mask=(nu!=nu[i])
        sig+=np.sum((nu_denoised[i]-nu_denoised[mask])/(nu[i]-nu[mask]))
    return sig/d**2


def ERM_SE_hybrid(Sstar,overlaps, alpha, Q_0, kappa,gamma, lloss, lreg, posdef, kappa_stud, noise,cutoff,d,nsamples=1): 
    def sample_and_denoise(Sstar,lreg,q_hat,m_hat,sigma_hat,d):
        M=np.sqrt(q_hat)*sample_GOE(d)+Sstar*m_hat
        (D,O)=np.linalg.eigh(M)
        assert np.all(M-O @ np.diag(D) @ O.T < 1e-7)  , "Error in eigendecomposition of Sstar: "
        del M
        Denoised_D=denoiser(D,lreg,sigma_hat)
        Denoised_M=O @ np.diag(Denoised_D) @ O.T
        del O
        return Denoised_M,(D,Denoised_D)
    

    q, m, sigma, q_hat, m_hat, sigma_hat = overlaps

    q_hat_new = ERM_q_hat_eq(q, m, sigma, alpha, Q_0, lloss, noise)
    m_hat_new = ERM_m_hat_eq(q, m, sigma, alpha, Q_0, lloss, noise)
    sigma_hat_new = ERM_sigma_hat_eq(q, m, sigma, alpha, Q_0, lloss, noise)
    assert q_hat_new >= 0, "q_hat_new is negative"


    qs=[]
    ms=[]
    sigmas=[]
    for _ in range(nsamples):
        Denoised_M,_=sample_and_denoise(Sstar,lreg,q_hat_new,m_hat_new,sigma_hat_new,d)
        qs.append(np.trace(Denoised_M.T @ Denoised_M)/d)
        Denoised_M,_=sample_and_denoise(Sstar,lreg,q_hat_new,m_hat_new,sigma_hat_new,d)
        ms.append(np.trace(Sstar.T@Denoised_M)/d)
        _,(D,Denoised_D)=sample_and_denoise(Sstar,lreg,q_hat_new,m_hat_new,sigma_hat_new,d)
        sigmas.append(ERM_sigma_eq_emp(D,Denoised_D, sigma_hat_new, kappa, lreg,d))

    q_new=np.mean(qs)
    m_new=np.mean(ms)
    sigma_new=np.mean(sigmas)

    return np.array([q_new, m_new, sigma_new, q_hat_new, m_hat_new, sigma_hat_new]), np.array([np.std(qs), np.std(ms), np.std(sigmas),max(qs)-min(qs), max(ms)-min(ms), max(sigmas)-min(sigmas)])



def ERM_q_hat_eq(q, m, sigma, alpha, Q_0, lloss, noise):
    return 2*alpha*( Q_0 - 2*m + q + noise/2) / (sigma + lloss/4)**2

def ERM_m_hat_eq(q, m, sigma, alpha, Q_0, lloss, noise):
    return 2*alpha / (sigma + lloss/4)

def ERM_sigma_hat_eq(q, m, sigma, alpha, Q_0, lloss, noise):
    return 2*alpha / (sigma + lloss/4)

def ERM_q_eq(q_hat, m_hat, sigma_hat, kappa, lreg, integral, partials):
    return m_hat ** 2 / sigma_hat**2 * integral

def ERM_m_eq(q_hat, m_hat, sigma_hat, kappa, lreg, integral, partials):
    return (m_hat  * integral - np.sqrt(q_hat) / 2 * partials[0] - lreg  * partials[1])/sigma_hat

def ERM_sigma_eq(q_hat, m_hat, sigma_hat, kappa, lreg, integral, partials):
    return m_hat / 2 / np.sqrt(q_hat) / sigma_hat * partials[0]


def Wishart(d, kappa):
    m=int(d*kappa)
    X=np.random.randn(d,m)
    W=X@X.T/np.sqrt(m*d)
    return W

def save_results(filename,results,cols): 
    df=pd.DataFrame(results,columns=cols)
    df.to_csv(filename,columns=cols,header=True,lineterminator='\n',index=False)
    print(f"Results saved to {filename}")

def ERM_solution(type,alpha, kappa, gamma,beta,d, noise=0., lloss = 0, lreg = 1, posdef = False, kappa_stud = 1, damping=1.0,min_iter=50, max_iter=50000, toll=1e-5, q = 0.3, m = 0.2, sigma = 0.1, damp_schedule = False, verbose=True,saveits=False,nsamples=10,avgover=1):
    cutoff=int(d**beta)
    if type=='Wishart':
        Sstar=Wishart(d,kappa)
        Q_0=1+kappa
    elif type=='pow' or type=='logverif' or type=='powsingle':
        Sstar=generate_teacher(cutoff,gamma,d)
        Q_0=1

    if type=='Wishart':
        filename=f"data_SEit_Wishart/ERM_SEit_Wishart_alpha{alpha}_d{d}_kappa{round(kappa,3)}_lreg{round(lreg,3)}_noise{round(noise,3)}.csv"
    elif type=='pow':
        filename=f"data_SEit_pow/ERM_SEit_pow_alpha{alpha}_gamma{gamma}_beta{beta}_d{d}_kappa{round(kappa,3)}_lreg{round(lreg,3)}_noise{round(noise,3)}_nsamples{nsamples}.csv"
    elif type=='logverif':
        filename=f"data_SEit/ERM_SEit_logverif_alpha{alpha}_d{d}_kappa{round(kappa,3)}_lreg{round(lreg,3)}_noise{round(noise,3)}.csv"
    elif type=='powsingle':
        filename=f"data_SEit_pow/ERM_SEit_powsingle_alpha{alpha}_gamma{gamma}_beta{beta}_d{d}_lreg{round(lreg,3)}_noise{round(noise,3)}_nsamples{nsamples}.csv"


    cols=['alpha','gen_error','iters','kappa','lreg','d','noise','q','m','sigma','q_hat','m_hat','sigma_hat','err_toll','std_q','std_m','std_sigma','spread_q','spread_m','spread_sigma']
    
    df=pd.DataFrame(columns=cols) 
    if not os.path.exists(filename) and saveits:
        df.to_csv(filename,header=True,lineterminator='\n',index=False)
    

    q_hat = ERM_q_hat_eq(q, m, sigma, alpha, Q_0, lloss, noise)
    m_hat = ERM_m_hat_eq(q, m, sigma, alpha, Q_0, lloss, noise)
    sigma_hat = ERM_sigma_hat_eq(q, m, sigma, alpha, Q_0, lloss, noise)
    overlaps = np.array([q, m, sigma, q_hat, m_hat, sigma_hat])

    errtolls=[]
    results=[]
    for i in range(max_iter):
        if damp_schedule:
            damping_t = damping * (0.999 / np.sqrt(i+1) + 0.001)
        else:
            damping_t = damping


        new_overlaps,std_devs = ERM_SE_hybrid(Sstar,overlaps, alpha, Q_0, kappa,gamma, lloss, lreg, posdef, kappa_stud, noise,cutoff,d,nsamples)
        err_toll = np.linalg.norm(new_overlaps - overlaps) 
        errtolls.append(err_toll)

    
        overlaps = (1-damping_t) * overlaps + damping_t * new_overlaps
        mse=Q_0 + overlaps[0] - 2 * overlaps[1]
        results.append([alpha,mse,i,kappa,lreg,d,noise]+list(overlaps)+[err_toll]+list(std_devs))

        if err_toll < toll and i>min_iter: 
            if saveits:
                save_results(filename,results,cols)
 
            epsilon = 2 / new_overlaps[4]
            delta = np.sqrt(new_overlaps[3]) / new_overlaps[4]


            avg_ovps=np.mean([it[7:13] for it in results[-avgover:]],axis=0)
     
            mse=Q_0 + avg_ovps[0] - 2 * avg_ovps[1]

            loss =0 #placeholder for loss computation if needed in future
            return avg_ovps, mse,loss,i,True 
    

        if verbose:
            print(f"{i:6d} | Alpha: {alpha:.8f}, Gen error: {(Q_0 + new_overlaps[0] - 2 * new_overlaps[1]):.2e} | Toll : {err_toll:.2e}")
            print(f"       | Posdef: {posdef:.0f}  | gamma: {gamma:.6f} | Damping : {damping_t:.6f}")
            print(f"       | lreg: {lreg:.6f}")


        if np.isnan(toll):
            print("Divergence detected, returning NaN")
            return overlaps, float('NaN'), float('NaN'),i,False
        

    if saveits:
        save_results(filename,results,cols)

    if verbose:
        print(f"Convergence not reached for alpha = {alpha}")
    
    avg_ovps=np.mean([it[7:13] for it in results[-avgover:]],axis=0) 
    mse=Q_0 + avg_ovps[0] - 2 * avg_ovps[1]
    return overlaps, mse,0,i,False 
