import numpy as np
from scipy.optimize import linear_sum_assignment
import torch
from tqdm import tqdm
from scipy.stats import binom

def get_tau( pFN, out_size ):
    samples = binom.ppf(pFN, out_size, .5)
    tau = samples / out_size
    return tau

def get_layer_decomp(neptune, net, layer, B=None, device='cuda'):
    W = net.get_parameter( layer + '.weight' ).data.to(device)
    b = net.get_parameter( layer + '.bias' ).data.to(device)
    grad_W = net.get_parameter( layer + '.weight' ).grad.to(device)
    grad_b = net.get_parameter( layer + '.bias' ).grad.to(device)
    
    if B is None:
        B = torch.linalg.matrix_rank( grad_W, tol=1e-6 )
    # Do fast SVD
    U, S, V = torch.svd_lowrank(grad_W, q=B, niter=10)
    R = torch.diag( torch.sqrt( S ) ) @ V.T
    L = U @ torch.diag( torch.sqrt( S ) )
    error_SVD = (L @ R - grad_W).abs().max().item() 
    
    if error_SVD > 1e-6:
        # Do slow SVD might help
        U, S, Vh = torch.linalg.svd( grad_W )
        if B is None:
            B = torch.linalg.matrix_rank( grad_W, tol=1e-6 )
        R_new = torch.diag( torch.sqrt( S[:B] ) ) @ Vh[:B,:]
        L_new = U[:,:B] @ torch.diag( torch.sqrt( S[:B] ) )
        error_SVD_new = (L @ R - grad_W).abs().max().item()
        if error_SVD_new < error_SVD:
            R = R_new
            L = L_new
            error_SVD = error_SVD_new
        
    Linv = torch.linalg.lstsq(L.T, torch.eye(B, device=device), driver='gels')[0].T
    Rinv = torch.linalg.lstsq(R, torch.eye(B, device=device), driver='gels')[0]
    
    error_L =  (Linv @ L - torch.eye(B, device=device)).abs().max()
    if error_L.isnan().item():
        error_L = -1
    else:
        error_L = error_L.item()
    error_R =  (R @ Rinv - torch.eye(B, device=device)).abs().max() 
    if error_R.isnan().item():
        error_R = -1
    else:
        error_R = error_R.item()
    error_SVD =  (L @ R - grad_W).abs().max().item() 
    print(f'Num Errors: L: {error_L}, R: {error_R}, SVD: {error_SVD}')
    if neptune:
        neptune['result/error_L'].log(error_L)
        neptune['result/error_R'].log(error_R)
        neptune['result/error_SVD'].log(error_SVD)

    return B, (W, b), (grad_W, grad_b), (L, R), (Linv, Rinv)

def makeUnique( L, results, counts, sparsity_tol=1e-6 ):
    num_proposals = results.shape[0]
    if num_proposals < 2:
        return results, counts
    results[ results[ : , 0 ] < 0, : ] = -results[ results[ : , 0 ] < 0, : ] # Fix sign of first element
    
    cos_sim = ( results[None,:,:] * results[:,None,:] ).sum(2)
    sim_thresh = 1 - 1e-2
    output, idx = torch.unique( cos_sim > sim_thresh, dim=0, return_inverse=True )
    num_uniq = output.shape[0]

    final = torch.zeros( num_uniq, results.shape[1], device='cuda' )
    final_counts = torch.zeros( num_uniq, dtype=torch.int64, device='cuda' )
    for i in range(num_uniq):
        final_counts[i] += counts[ idx == i ].sum()
        avg = (counts[ idx == i, None ] * results[ idx == i, : ]).sum(0) / final_counts[i]
        sparsity_idx = ((L @ avg).abs() < sparsity_tol)
        sparsity_measure = ( L @ results[ idx == i, : ].T )[ sparsity_idx ].abs().sum(axis=0)
        best_idx = sparsity_measure.argmin().item()
        final[ i, : ] = results[ idx == i, : ][best_idx]
    
    return final, final_counts

def getQBarColParallel( L, par_attempts, treshold=0.4, tol=1e-6 ):
    B = L.shape[ 1 ]
    out_size = L.shape[ 0 ]
    # Generate a few random permutations
    idxs1 = torch.topk( torch.rand( par_attempts // out_size, out_size ), B-1, sorted=False )[1].repeat( out_size, 1 )
    # Offset them - allows not to call topk for every permuttation
    idxs2 = torch.arange( out_size ).repeat_interleave( par_attempts // out_size ).reshape( -1, 1 )
    idxs = ( idxs1 + idxs2 ) % out_size
    U,S,V = torch.linalg.svd( L[ idxs, : ] )
    propolsals = V[:,-1,:]
    sparsity = ( ( L @ propolsals.T ).abs() < tol ).sum(axis=0)
    idx = sparsity > treshold * out_size
    results = propolsals[idx]
    return results

def getQBarUniqueCol( L, R, L_inv, W, b, grad_b, Q_opt, N=int(1e10), par_attempts=int(5e+5), treshold=0.4, cond='early', sigma_tol=1e-7, sigma_treshold=0.99, sparsity_tol=1e-6, count_hack=False ):
    Q_opt = Q_opt.cuda()
    L = L.cuda()

    if L.shape[1] > L.shape[0]:
        L = L.T
    B = L.shape[1] 
    output_size = L.shape[0]
    Q_cols = torch.zeros( 0, B, device='cuda' )
    counts = torch.zeros( 0, dtype=torch.int64, device='cuda' )
    pbar = tqdm( range( 0, N, par_attempts ) )
    pbar.set_description(f"[]/{0}/{0}/{B}")
    try:
        has_changes = False
        for i in pbar:
            Q_cols_new = getQBarColParallel( L, par_attempts, treshold=treshold, tol=sparsity_tol )
            if Q_cols_new.shape[0] == 0:
                continue
            counts_new = torch.ones( Q_cols_new.shape[0], dtype=torch.int64, device='cuda' ) 
            has_changes = True
            if not count_hack:
                counts = torch.concat( ( counts, counts_new ) )
                Q_cols = torch.concat( ( Q_cols, Q_cols_new ) )
                Q_cols, counts = makeUnique( L, Q_cols, counts, sparsity_tol=sparsity_tol )
            else:
                Q_cols_new, counts_new = makeUnique( L, Q_cols_new, counts_new, sparsity_tol=sparsity_tol )
                Q_cols = torch.concat( ( Q_cols, Q_cols_new ) )
                counts = torch.concat( ( counts, counts_new ) )

            num_result = Q_cols.shape[0]
            st_idx = 0 if num_result < B else -min(B + 3, num_result)
            en_idx = 3 if num_result < B else -B + 3
            
            sparsity = ( ( L @ Q_cols.T ).abs() < sparsity_tol ).sum(0)
            sparsity, idx = torch.sort( sparsity )
            Q_cols =  Q_cols[ idx, : ]
            sparsity = sparsity.cpu().numpy().tolist()
            
            cost = torch.nn.functional.cosine_similarity(Q_opt.T[None,:,:], Q_cols[:,None,:], dim=-1).abs()
            cos_cost, _ = cost.max(axis=0)
            n_cos_correct = ((cos_cost - 1).abs() < 1e-2).sum()
            
            pbar.set_description(f"{sparsity[st_idx:en_idx]}/{n_cos_correct}/{Q_cols.shape[0]}/{B}")
            
            if cond == 'gt':
                if n_cos_correct == B:
                    break
            elif cond == 'early':
                if (i // par_attempts) % 10 == 0 and has_changes:
                    has_changes = False
                    C, Q_best, Q_init = filterQ( L, R, L_inv, W, b, grad_b, Q_cols.T, sigma_tol=sigma_tol, prints=False )
                    if not( C is None ) and C >= sigma_treshold:
                        print('Early break')
                        break 
            if Q_cols.shape[0] > 7500:
                break
    except KeyboardInterrupt:
        pass
    i += par_attempts
    Q_cols = Q_cols.T
    return Q_cols, i, counts 

def fixQBarColScale( L_inv, grad_b, Q_bar ):
    Q_bar_inv = torch.linalg.lstsq( Q_bar, torch.eye( Q_bar.shape[1], device=Q_bar.get_device() )[None,:,:].cuda(), driver='gels')[0]
    Q_errors = ( Q_bar_inv @ Q_bar - torch.eye(Q_bar.shape[1], device=Q_bar.get_device())[None,:,:] ).abs().max(1)[0].max(1)[0].cpu().numpy()
    lens = Q_bar_inv @ L_inv @ grad_b
    Q = Q_bar @ torch.diag_embed( lens )
    Q_inv =  torch.diag_embed( 1/lens ) @ Q_bar_inv
    return Q, Q_inv

def checkQ( L, R, Q, Q_inv, W, b, tol=1e-7 ):
    Z = ( W @ R.T @ Q_inv.permute( 1, 2, 0 ) ) + b.reshape( 1, -1, 1 )
    M = Z <= 0
    grad_Z = L @ Q.permute( 2, 1, 0 )
    C = ( ( grad_Z.abs() * M ) < tol ).sum( (0,1) ) - ( ~M ).sum( (0,1) ) + ( ( grad_Z.abs() * ~M ) > tol ).sum( (0,1) ) # How many gradient entries have the sparsity induced by Z
    total =  float( torch.prod( torch.tensor( grad_Z.shape[ : -1] ) ) )
    C = C.float() / total
    return C

def summarizeQ( neptune, label, L, R, L_inv, W, b, grad_b, Q_bar, Q_opt, sigma_tol=1e-7 ):
    if len(Q_bar.shape) == 2:
        Q_bar = Q_bar[ None, : , : ]
    elif len(Q_bar.shape) == 3:
        assert Q_bar.shape[0] == 1
    else:
        assert False
    
    B = Q_bar.shape[1]
    dirs = Q_bar.shape[2]
    r = torch.linalg.matrix_rank( Q_bar[0] )

    cost = torch.nn.functional.cosine_similarity(Q_opt.T[None,:,:], Q_bar[0].T[:,None,:], dim=-1).abs()
    cos_cost, _ = cost.max(axis=0)
    n_cos_correct = ((cos_cost - 1).abs() < 1e-2).sum()
    if dirs > B:
        cost = -cost.detach().cpu().numpy().T
        row_ind, col_ind = linear_sum_assignment(cost)
        Q_bar = Q_bar[:,:,col_ind]
    
    if dirs < B:
        C, Q, Q_inv, n_l2_correct = -1, None, None, 0
    else:
        Q, Q_inv = fixQBarColScale( L_inv, grad_b, Q_bar )
        C = checkQ( L, R, Q, Q_inv, W, b, tol=sigma_tol ).item()
    
        l2 = ( Q_opt.T[None,:,:] - Q[0].T[:,None,:] ).pow(2).sum(2)
        l2_cost, _ = l2.min(0)
        n_l2_correct = (l2_cost < 1e-6).sum()

    print( f'{label}:\tl2/cos/r/B: {n_l2_correct}/{n_cos_correct}/{r}/{B}\tC: {C}' )
    
    if neptune:
        neptune[f'result/{label}/C'].log(C)
        neptune[f'result/{label}/B'].log(B)
        neptune[f'result/{label}/l2'].log(n_l2_correct)
        neptune[f'result/{label}/cos'].log(n_cos_correct)
        neptune[f'result/{label}/r'].log(r)

    return r, C, Q, Q_inv

def initFilterQ( Q_bar ):
    rank = torch.linalg.matrix_rank( Q_bar )
    dirs = Q_bar.shape[1]
    B = Q_bar.shape[0]
    
    if rank < B:
        return None, None

    Q_bar = Q_bar.cpu().numpy()
    
    guess = []
    thresh = 1e-4
    while len(guess) < B:
        guess = []
        for i in range(dirs-1,-1, -1):
            guess.append( i )
            q,r = np.linalg.qr( Q_bar[:,guess] )
            el_s = np.abs(np.diag(r))[-1]
            if el_s < thresh:
                guess = guess[:-1]
            if len(guess) == B:
                break
        thresh /= 10
    non_guess = list( set(range(dirs)) - set(guess) )
    return guess, non_guess

def filterQ( L, R, L_inv, W, b, grad_b, Q_bar, sigma_tol=1e-7, prints=True ):
    dirs = Q_bar.shape[1]
    B = Q_bar.shape[0]
 
    device = Q_bar.get_device()
    guess, non_guess = initFilterQ( Q_bar )
    if guess is None:
        return None, None, None
    Q_bar_init_guess = Q_bar[None, :, guess ]
    Q_guess, Q_guess_inv = fixQBarColScale( L_inv, grad_b, Q_bar_init_guess )
    sigma_guess = checkQ( L, R, Q_guess, Q_guess_inv, W, b, tol=sigma_tol )
 
    if len(non_guess) == 0:
        return sigma_guess, Q_bar_init_guess[0], Q_bar_init_guess[0]
 
    changed = True
    while changed:
        changed = False
        i_best, j_best = None, None
        # Swap every pair of non_guess and guess
        X, Y = torch.meshgrid( torch.tensor(non_guess,device=device), torch.tensor(guess,device=device) )
        n_switch = X.shape[0] * X.shape[1]
        idx_binary = torch.zeros(n_switch, dirs, dtype=torch.bool, device=device)
        idx_binary[:,guess] = True
        idx_binary[range(n_switch),X.reshape(-1)] = True
        idx_binary[range(n_switch),Y.reshape(-1)] = False

        idx_int_single_c = torch.argwhere( idx_binary )
        idx_int = idx_int_single_c.repeat(1,Q_bar.shape[0]).reshape(-1,Q_bar.shape[0],2)
        idx_int[:,:,0] = torch.arange(Q_bar.shape[0]).repeat(idx_int.shape[0],1)
        idx_int = idx_int.reshape(-1,2)

        Q_bar_guesses_new = Q_bar[idx_int[:,0], idx_int[:,1]].reshape(-1,Q_bar.shape[0],Q_bar.shape[0]).transpose(1,2)
        r = torch.linalg.matrix_rank( Q_bar_guesses_new )
        Q_guesses_new, Q_guesses_new_inv = fixQBarColScale( L_inv, grad_b, Q_bar_guesses_new )
        sigma_guesses_new = checkQ( L, R, Q_guesses_new, Q_guesses_new_inv, W, b, tol=sigma_tol )
        sigma_guesses_new[ r != B ] = 0
        
        # Check if better
        sigma_best_guess, best_idx = sigma_guesses_new.max(0)
        if sigma_best_guess > sigma_guess:
            sigma_guess = sigma_best_guess
            changed = True
            j_best, i_best = X.reshape(-1)[best_idx].item(), Y.reshape(-1)[best_idx].item()
            guess = list((set(guess) - set([i_best])) | set([j_best]))
            non_guess = list((set(non_guess) - set([j_best])) | set([i_best]))
            if prints:
                print( f'{sigma_guess.item()}:{i_best}->{j_best}' )
    
    return sigma_guess, Q_bar[ :, guess ], Q_bar_init_guess[0]

    #if sigma_guess >= 0.99:

def getQ( neptune, params, grads, LR, LR_inv, Q_opt, device='cuda', N=int(1e10), par_SVD=int(5e+5), treshold=0.4, sigma_tol=1e-7, cond='early', sigma_treshold=0.99, sparsity_tol=1e-6, count_hack=False ):
    W, b = params[0].to(device), params[1].to(device)
    grad_W, grad_b = grads[0].to(device), grads[1].to(device)
    L, R = LR[0].to(device), LR[1].to(device)
    L_inv, R_inv = LR_inv[0].to(device), LR_inv[1].to(device)
    B = L.shape[1]
    Q_opt = Q_opt.to(device)

    #Q_bar = torch.tensor( np.load("Q_bar_2.npy") ).cuda()

    Q_bar, its, counts = getQBarUniqueCol( L, R, L_inv, W, b, grad_b, Q_opt, N=N, par_attempts=par_SVD, treshold=treshold, cond=cond, sigma_tol=sigma_tol, sigma_treshold=sigma_treshold, sparsity_tol=sparsity_tol, count_hack=count_hack )
    dirs = Q_bar.shape[1]
    print( f'Bars: {dirs}' )
    if neptune:
        neptune['result/iterations'].log( its )
        neptune[f'result/bars'].log(dirs)
        neptune[f'result/dir_freq'].append( str(counts.cpu().numpy()) )
    
    if dirs == 0:
        print( f'Max rank: {0}' )
        if neptune:
            neptune[f'result/pre_init/sigma_init'].log(-1)
            neptune[f'result/pre_init/B'].log(-1)
            neptune[f'result/pre_init/l2'].log(-1)
            neptune[f'result/pre_init/cos'].log(-1)
            neptune[f'result/pre_init/r'].log(-1)
            neptune[f'result/init/sigma_init'].log(-1)
            neptune[f'result/init/B'].log(-1)
            neptune[f'result/init/l2'].log(-1)
            neptune[f'result/init/cos'].log(-1)
            neptune[f'result/init/r'].log(-1)
            neptune[f'result/greedy/C'].log(-1)
            neptune[f'result/greedy/B'].log(-1)
            neptune[f'result/greedy/l2'].log(-1)
            neptune[f'result/greedy/cos'].log(-1)
            neptune[f'result/greedy/r'].log(-1)
        return None, None
 
    _, _, _, _ = summarizeQ( neptune, 'pre_init', L, R, L_inv, W, b, grad_b, Q_bar, Q_opt, sigma_tol=sigma_tol )
    if count_hack:
        return None, None

    sigma_best, Q_best, Q_init = filterQ( L, R, L_inv, W, b, grad_b, Q_bar, sigma_tol=sigma_tol )
    if sigma_best is None:
        # If not full rank, complete the matrix
        for j in range(5):
            rank = torch.linalg.matrix_rank( Q_bar )
            print(f'Insufficient rank: {rank}/{B}')
            Q_bar = torch.concat( (Q_bar.T, torch.zeros( max(B-Q_bar.shape[1],0), B ).cuda()) ).T
            U,S,V = torch.linalg.svd(Q_bar, full_matrices=True )

            thresh = 1e-5
            guess = []
            for i in range(B-1,-1, -1):
                guess.append( i )
                q,r = np.linalg.qr( Q_bar[:,guess].cpu().numpy() )
                el_s = np.abs(np.diag(r))[-1]
                if el_s < thresh:
                    guess = guess[:-1]
                if len(guess) == B:
                    break
            rank = len(guess)
            non_guess = list( set(range(Q_bar.shape[1])) - set(guess) )
            Q_bar[:,non_guess[:B-rank]] =  V[:B, rank:B]
            if torch.linalg.matrix_rank( Q_bar ) == B:
                break
        sigma_best, Q_best, Q_init = filterQ( L, R, L_inv, W, b, grad_b, Q_bar, sigma_tol=sigma_tol )

    if sigma_best is None:
        if neptune:
            neptune[f'result/init/sigma_init'].log(-1)
            neptune[f'result/init/B'].log(-1)
            neptune[f'result/init/l2'].log(-1)
            neptune[f'result/init/cos'].log(-1)
            neptune[f'result/init/r'].log(-1)
            neptune[f'result/greedy/C'].log(-1)
            neptune[f'result/greedy/B'].log(-1)
            neptune[f'result/greedy/l2'].log(-1)
            neptune[f'result/greedy/cos'].log(-1)
            neptune[f'result/greedy/r'].log(-1)
        return None, None

    _, _, _, _ = summarizeQ( neptune, 'init', L, R, L_inv, W, b, grad_b, Q_init, Q_opt, sigma_tol=sigma_tol )
    _, _, Q_best, Q_best_inv = summarizeQ( neptune, 'greedy', L, R, L_inv, W, b, grad_b, Q_best, Q_opt, sigma_tol )
 
    return Q_best, Q_best_inv
