import numpy as np 
from tqdm import tqdm 
from typing import Tuple
import utils.plothelp as ph 


def generate_sks(d:int, dsk:float, rng:np.random.RandomState, noise_type = 'gauss',clip_range = 1e-4, sort:bool = True, base = 0.5) -> np.ndarray:
    """
    Generates a set of semantic strengths from a given source of disorder, parameterized by dsk.
    Parameters:
    -----------
    d : int
        The dimension of the semantic space.
    dsk : float
        The standard deviation (or other measure of dispersion) of the semantic strengths.
    rng : np.random.RandomState
        A random number generator.
    noise_type : str
        The type of noise to add to the semantic strengths. Can be 'gauss' or 'uniform' or 'pseudogap'.
    Returns:
    --------
    np.ndarray
        A set of semantic strengths.
    """
    if(noise_type=='gauss'):
        sks = base + rng.randn(d)*dsk
    elif(noise_type=='uniform'):
        sks = base + dsk * rng.uniform(-1,1,d)
    elif(noise_type=='pseudogap'):
        sks = 2*base - rng.weibull(dsk+1, d) * 1e-1
    else:
        raise ValueError('Invalid noise type: %s'%noise_type)
    if(clip_range is not None):
        # Ensure that the generated sks are within the range [0, 2*base]
        sks = np.clip(sks, clip_range, 2*base-clip_range)
    if(sort):
        return np.sort(sks)[::-1]
    else:
        return sks

def generate_P(sks,base=1.0,normalize=False):
    N  = 1<<len(sks)
    P  = np.ones((N,N))
    for k,sk in enumerate(sks):
        word_parities = [ get_bit(i,k) for i in range(N) ]
        ii,jj = np.meshgrid(word_parities,word_parities)
        parity_grid = (-1)**(ii + jj)
        P *= (base + parity_grid*sk)
    if(normalize):
        P/=P.sum()
    return P 


from numpy.typing import *
def measure_polarization_k(W: NDArray, k : int, d: int, return_vector: bool  = False, norm_W=True, norm_W_eigs=False) -> float | NDArray:
    """
    Measures the polarization of  with respect to a specified dimension k, for a semantic space of dimension d.
    Parameters:
    -----------
    W : NDArray
        A numerical array representing the low-rank projection of the semantic space. 
    k : int
        The dimension with respect to which polarization is measured.
    d : int
        The dimension of the semantic space.
    return_vector : bool
        If True, returns the polarization vector instead of the scalar value.
    Returns:
    --------
    float
        Returns the polarization of semantic dimension k. 
    """
    
    #we need to get a sign for each word. 
    signs = np.array([  (2*get_bit(w,k))-1 for w in range(2**d) ])
    norms = np.linalg.norm(W,2,axis=1) if norm_W else 1 
    eignorms = np.linalg.norm(W,2,axis=0)/(2.**(d/2)) if norm_W_eigs else 1
    polarization_vector =  np.mean(  (signs/norms * W.T).T , axis = 0  ) / eignorms
    if( return_vector ):
        return polarization_vector
    else:
        return np.linalg.norm(polarization_vector,2)
    
def measure_polarization_diff_k(W: NDArray, k : int, d: int, return_vector: bool  = False, eta:float = 1e-3, alpha:float = 1e1) -> float | NDArray:
    """
    Measures the polarization of  with respect to a specified dimension k, for a semantic space of dimension d, with a rescaling by the magnitude of the representations. 
    Parameters:
    -----------
    W : NDArray
        A numerical array representing the low-rank projection of the semantic space. 
    k : int
        The dimension with respect to which polarization is measured.
    d : int
        The dimension of the semantic space.
    eta : float
        The scale factor for the gate function -- the mean difference norms should be ~ eta for the gate to turn on.
    alpha : float
        The second scale factor for the gate function, scales how quickly the gate function turns on. 
    return_vector : bool
        If True, returns the polarization vector instead of the scalar value.
    Returns:
    --------
    float
        Returns the polarization of semantic dimension k. 
    NDArray
        Returns the polarization vector if return_vector is True.
    """
    # we need to get a sign for each word. 
    negative_words = np.unique( clear_bit(np.arange(0,2**d),k) )
    #we construct positive words from the negative words, so that we retain the pairing structure. 
    positive_words = set_bit(negative_words,k)
    #building the representations
    negative_reps = W[negative_words]
    positive_reps = W[positive_words] 
    #now, we get the norm of differences:
    diffs = positive_reps - negative_reps
    diff_norms =  np.linalg.norm(diffs,2,axis=1)
    polarization_vector = np.sum(diffs,axis = 0) / np.sum(diff_norms)  
    #I want to guard against measuring consistently correlated noise as a nonzero polarization
    #so we need to add something that measures whether the scale of the differences is comparable 
    #to the scale of the original entries in the vectors. 
    rep_entry_scale_factor = 0.5*(np.sqrt(np.mean(positive_reps**2 ))+np.sqrt(np.mean( negative_reps**2 )))
    #gate factor is a sigmoid, that turns on when the scale for differences is ~ the scale of the representations eta. 
    gate_factor = 1.0 / (1.0 + np.exp(-alpha/eta*(np.mean(diff_norms)/rep_entry_scale_factor - eta)))
    polarization_vector *= gate_factor 
    if( return_vector ):
        return polarization_vector
    else:
        return np.linalg.norm(polarization_vector,2)

def measure_polarization(W : NDArray, d:int,return_all:bool=False) -> float:
    """
    Measures the polarization of the semantic space, across all semantic dimensions.
    Parameters:
    -----------
    W : NDArray
        A numerical array representing the low-rank projection of the semantic space. 
    d : int
        The dimension of the semantic space.
    Returns:
    --------
    float
        Returns the polarization of the semantic space.
    """
    polarizations = [measure_polarization_diff_k(W,k,d) for k in range(d)]
    if(return_all):
        return np.array(polarizations)
    else:
        return np.mean(polarizations), np.std(polarizations,ddof=1)
   
def measure_expressivity(W,d,tol = 1e-10):
    rep_norms = np.sqrt((W**2).sum(axis=1))
    normed_dot_products = (np.einsum('ik,jk->ij',W,W) / np.outer(rep_norms,rep_norms)) # dot products of the representations with each other.
    N = 2**d 
    expressivity = (np.sum(np.abs(normed_dot_products) <  1- tol)) / (N*(N-1))
    return expressivity

def get_bit(number,index):
    return int(bool(number & (1 << index)))
def set_bit(number,index):
    return number | (1 << index)
def clear_bit(number,index):    
    return number & ~(1 << index)
def bit_difference(w1,w2):
    return w1^w2 
def bit_rep(w,d):
    return [get_bit(w,i) for i in range(d)]
def bit_norm(w,d):
    return np.sum([get_bit(w,i) for i in range(d)])


def test_analogy(W,d1:int,d2:int,d:int)->Tuple[float,float,float]:
    tested_analogies = set()
    N_tested = 0.0 
    N_same = 0.0 
    total_MSE = 0.0 
    N_nearest = 0
    for baseword in range(2**d): 
        w0 = clear_bit(baseword,d1)
        w0 = clear_bit(w0,d2) 
        w1 = set_bit(w0,d1)
        w2 = set_bit(w0,d2)
        w12 = set_bit(w1,d2)
        # print(baseword, bit_rep(baseword,4) , w0,w1,w2,w12)
        analogy_tuple = (w0,w1,w2,w12)
        if(analogy_tuple in tested_analogies):
            continue
        tested_analogies.add(analogy_tuple)
        w0rep = W[w0]
        w1rep = W[w1]
        w2rep = W[w2]
        w12rep = W[w12]
        #now, checking different accuracy metrics: 
        N_tested+=1
        N_same+=np.isclose(  w1rep - w0rep + w2rep, w12rep, atol=1e-6).all()
        total_MSE += np.sqrt(np.mean((w1rep - w0rep + w2rep - w12rep)**2))
        N_nearest+= (np.argmin( np.linalg.norm((W - (w1rep-w0rep+w2rep) ),axis = 1) ) == w12)
    assert(N_tested == 2**(d-2))
    return N_same / N_tested, total_MSE / N_tested, N_nearest/N_tested

def test_analogy_cuda(W,d1:int,d2:int,d:int)->Tuple[float,float,float]:
    import torch 
    N_same = 0.0 
    total_MSE = 0.0 
    #we will essentially do every test 2**2 (four) extra times. Not really a problem, actually.
    w0 = torch.unique(clear_bit(clear_bit(torch.arange(2**d,device='cuda'),d1),d2))
    w1=set_bit(w0,d1)
    w2=set_bit(w0,d2)
    w12=set_bit(w1,d2)
    #getting the corresponding representations:
    w0rep = W[w0]
    w1rep = W[w1]
    w2rep = W[w2]
    w12rep = W[w12]
    #now, checking different accuracy metrics: 
    N_same = torch.isclose(  w1rep - w0rep + w2rep, w12rep, rtol=1e-6).all(dim = 1).sum()
    total_MSE = torch.sum(torch.sqrt(torch.mean((w1rep - w0rep + w2rep - w12rep)**2,dim = 1)))
    # For W with shape (N, d), we need to compute the distance for each word in the vocabulary
    # to the vector (w1rep - w0rep + w2rep), so can check if w12 is the closest match to that vector.
    W_vocab_distance  = W[:,torch.newaxis,:] - (w1rep-w0rep+w2rep)[torch.newaxis,:,:]
    W_vocab_distance = torch.linalg.norm(W_vocab_distance,axis = 2)
    N_nearest =   torch.sum(torch.argmin( W_vocab_distance,axis = 0) == w12)
    return N_same.item() / 2**(d-2), total_MSE.item() / 2**(d-2), N_nearest.item()/2**(d-2)

def test_analogy_all(W,d:int,analogy_operation = test_analogy_cuda,return_all = False) -> Tuple[float,float,float]:
    """
    Test the analogy operation for all pairs of semantic dimensions.
    Parameters:
    -----------
    W : NDArray
        A numerical array representing the low-rank projection of the semantic space. 
    d : int
        The dimension of the semantic space.
    Returns:
    --------
    Tuple[float,float,float]
        Returns the fraction of exact matches, mean squared error, and fraction of nearest neighbors.
    """
    sames,mses,nearests = [],[],[]
    for d1 in range(d):
        for d2 in range(d1+1,d):
            same,mse,nearest = analogy_operation(W,d1,d2,d)
            sames.append(same)
            mses.append(mse)
            nearests.append(nearest)
    import torch 
    if(return_all):
        return np.array(sames),np.array(mses),np.array(nearests)
    else: 
        return np.mean(sames,axis = 0), np.array(mses,axis = 0), np.mean(nearests,axis = 0)

def test_analogy_k(W,k,d=None) -> Tuple[float,float,float]:
    """
    Test the analogy operation for one specific semantic dimenison.
    Parameters:
    -----------
    W : NDArray
        A numerical array representing the low-rank projection of the semantic space. 
    k : int
        The dimension of the semantic space to test
    Returns:
    --------
    Tuple[float,float,float]
        Returns the fraction of exact matches, mean squared error, and fraction of nearest neighbors.
    """
    N_tested = 0.0 
    N_same = 0.0 
    total_MSE = 0.0 
    N_nearest = 0
    d= d_from_matrix(W) if d is None else d 
    for d2 in range(0,d):
        if (k == d2):
            continue 
        same,mse,nearest = test_analogy_cuda(W,k,d2,d)
        N_tested+=1
        N_same+=same
        total_MSE += mse
        N_nearest+=nearest
    import torch 
    if isinstance(N_same, torch.Tensor):
        N_same = N_same.item()
        total_MSE = total_MSE.item()
        N_nearest = N_nearest.item()
    return N_same / N_tested, total_MSE / N_tested, N_nearest/N_tested

def build_W(eigs,eigv,K):
    return eigv[:,-K:]
def build_W_symm(eigs,eigv,K,abs_eigs=False):
    return (np.sqrt((np.abs(eigs) if abs_eigs else eigs)[-K:])*eigv[:,-K:])
    
def eigh_torch(P,to_numpy = True ):
    import torch.linalg 
    import torch 
    P_torch = torch.tensor(P, dtype=torch.float64, device = 'cuda', requires_grad=False )
    eigs,eigv = torch.linalg.eigh(P_torch)
    if(to_numpy):
        eigs = eigs.cpu().numpy()
        eigv = eigv.cpu().numpy()
    return eigs,eigv 

def apply_multiplicative_noise(rng, P,multiplicative_noise_scale:float, symm:bool = True):
    eta =   multiplicative_noise_scale * rng.randn(*P.shape)
    if symm:
        eta = 0.5*(eta + eta.T) * (2 ** .5) #2**0.5 to keep the noise scale fixed.
    return P * np.exp( eta )

def generate_eigs_eigv(rng, d, dsk, sk_noise_type = 'gauss', postpend = '', multiplicative_noise_scale:float = 0.0, to_numpy = True, symmetric_noise:bool = True ):
    sks = generate_sks(d,dsk,rng,noise_type=sk_noise_type,sort =True )
    P = generate_P(sks)
    P = apply_multiplicative_noise(rng,P,multiplicative_noise_scale,symm = symmetric_noise)
    P = generate_P_postpend(P,postpend)
    eigs,eigv = eigh_torch(P,to_numpy=to_numpy)
    return eigs,eigv

import torch 
def test_analogy_ranks_cuda(W:torch.Tensor,d1:int,d2:int,d:int)->torch.Tensor:
    N_same = 0.0 
    total_MSE = 0.0 
    #we will essentially do every test 2**2 (four) extra times. Not really a problem, actually.
    w0 = torch.unique(clear_bit(clear_bit(torch.arange(2**d,device='cuda'),d1),d2))
    w1=set_bit(w0,d1)
    w2=set_bit(w0,d2)
    w12=set_bit(w1,d2)
    #getting the corresponding representations:
    w0rep = W[w0]
    w1rep = W[w1]
    w2rep = W[w2]
    w12rep = W[w12]
    
    # For W with shape (N, d), we need to compute the distance for each word in the vocabulary
    # to the vector (w1rep - w0rep + w2rep)
    W_vocab_distance  = W[:,torch.newaxis,:] - (w1rep-w0rep+w2rep)[torch.newaxis,:,:]
    W_vocab_distance = torch.linalg.norm(W_vocab_distance,axis = 2)
    # Now, we want to see how far w12 is from being the closest match to that vector, i.e. what is its position if the W_vocab_distances were sorted
    sorted_distances, sorted_indices = torch.sort(  W_vocab_distance, dim = 0, descending = False)
    w12_positions = (sorted_indices == w12.unsqueeze(0)).nonzero(as_tuple=True)[0]
    return w12_positions

def evaluate_topk_cuda(ranks, topk:int = 10) -> float:
    """
    Evaluate the top-k accuracy of the ranks.
    Parameters:
    -----------
    ranks : NDArray
        The ranks of the words.
    topk : int
        The number of top ranks to consider.
    Returns:
    --------
    float
        The top-k accuracy.
    """
    return torch.mean((ranks < topk).float()).item()

from typing import List
def test_analogy_rank_all(W,d:int, ranks_eval:List[int], return_all :bool = False) -> float:
    """
    Test the analogy operation for all pairs of semantic dimensions.
    Parameters:
    -----------
    W : NDArray
        A numerical array representing the low-rank projection of the semantic space. 
    d : int
        The dimension of the semantic space.
    ranks_eval : List[int]
        A list of ranks to evaluate.
    Returns:
    --------
    Tuple[float,float,float]
        Returns the fraction of exact matches, mean squared error, and fraction of nearest neighbors.
    """
    ranks_eval = np.array(ranks_eval)
    N_tested = 0.0 
    ranks_all = [] 
    for d1 in range(d):
        for d2 in range(d1+1,d):
            ranks = test_analogy_ranks_cuda(W,d1,d2,d)
            ranks_all.append([evaluate_topk_cuda(ranks,rank) for rank in ranks_eval ] )
            N_tested+=1
    assert(N_tested == d*(d-1)/2)
    if(return_all):
        return ranks_all
    else: 
        return np.mean(ranks_all,axis=0) # ranks_all.mean(dim = 0)



POSTPENDS=['','_log','_M*','_normed','_lognormed']
def generate_P_postpend(P:np.ndarray,postpend:str) -> np.ndarray:
    if(postpend == '_log'):
        return np.log(P)
    elif(postpend == '_M*'):
        Pnorm = P/P.sum()
        Pi = Pnorm.sum(axis=0)
        PiPj = np.outer(Pi,Pi)
        return 2*(Pnorm - PiPj)/(Pnorm + PiPj)
    elif(postpend == '_normed'):
        Pnorm = P/P.sum()
        Pi = Pnorm.sum(axis=0)
        return Pnorm / np.outer(Pi,Pi)
    elif(postpend == '_lognormed'):
        Pnorm = P/P.sum()
        Pi = Pnorm.sum(axis=0)
        return np.log(Pnorm / np.outer(Pi,Pi))
    elif(postpend == ''):
        return P
    else:
        raise ValueError('Invalid postpend: %s'%postpend)
def d_from_matrix(matrix):
    return int(np.round(np.log2(matrix.shape[0])))
#apply the pruning to a P matrix.
def clean_dimension(P, dim_to_remove=0, new_P_function = (lambda w0,w1 : 0.0) ):
    d = int(round(np.log2(P.shape[0])))
    P = P.copy() 
    for baseword in range(2**d):
        w0 = clear_bit(baseword,dim_to_remove)
        w1 = set_bit(w0,dim_to_remove)
        # print(baseword, bit_rep(baseword,4) , w0,w1)
        P[w0,w1] = new_P_function(w0,w1)
        P[w1,w0] = new_P_function(w1,w0)
    return P 

#First, generating all the P's: 
def run_experiment(args,permutation_index:int=0):
    rng = np.random.RandomState((np.abs(args.seed)+1) * (1+args.replicate) * (1+permutation_index) % (2**32-1))
    noise_types = ['gauss','uniform', 'pseudogap']
    noise_params = {
        'gauss':[1e-3, 5e-2],
        'uniform':[0.25],
        'pseudogap':[0.5, 1, 2]
    }
    Nrep = args.Nrep 
    replicate_averages = [] 
    experiments = [] 
    for d in range(args.d_min,args.d_max+1):
        for noise_type in noise_types:
            for noise_param in noise_params[noise_type]:
                replicate_average = {
                    'd':d,
                    'noise_type': noise_type,
                    'noise_param': noise_param,
                    'args':args,
                }
                replicate_experiments = []
                for rep in range(Nrep):
                    experiment = { key:val for key,val in replicate_average.items() }
                    experiment['rep'] = rep 
                    experiments.append(experiment) 
                    experiment['sks'] = sks = generate_sks(d, noise_param, rng, noise_type = noise_type, base = args.fk_base / 2.0)
                    eta_noise =  rng.randn(2**d,2**d)
                    experiment['P'] = P = generate_P(sks,base = args.fk_base) * np.exp( args.P_noise_scale * 0.5*( eta_noise + eta_noise.T ) )
                    replicate_experiments.append(experiment)
                replicate_average['experiments'] = replicate_experiments
                replicate_averages.append(replicate_average)


    # Generating the eigenvalues and eigenvectors for each experiment, as well as the W and W_K 
    print('Solving eigensystem and generating and Ws!', flush=True)
    for exp in (experiments if args.no_TQDM else tqdm(experiments)):
        d = exp['d']
        if(args.K_geom):
            Ks = ph.buildGeomBins(1,2**d,50,True).astype(int)
        else:
            Ks = np.arange(1,2**d).astype(int)
        exp['Ks']=Ks
        for postpend in POSTPENDS:
            P_use = generate_P_postpend(exp['P'],postpend)
            # print('Postpend: ',postpend, np.min(P_use), np.max(P_use), np.min(exp['P']), np.max(exp['P']), exp['noise_type'],exp['noise_param'],flush=True) 
            eigs,eigv = exp['eigs'+postpend],exp['eigv'+postpend] = eigh_torch(P_use,to_numpy=True)
            # exp['W_K'+postpend]=[]
            exp['W_symm_K'+postpend] = [] 
            for K in Ks:
                # exp['W_K'+postpend].append(build_W(eigs,eigv,K))
                exp['W_symm_K'+postpend].append(build_W_symm(eigs,eigv,K,args.abs_eigs))

    print('Making measurements!',flush=True)
    for exp in (experiments if args.no_TQDM else tqdm(experiments)):
        if(args.no_TQDM):
            import utils.ticktock as tt
            tt.tick()
            print(f"Starting: {exp['d']:d}, noise {exp['noise_type']:s}+{exp['noise_param']:.1g}... ",end="",flush=True)
        for postpend in POSTPENDS:
            array_names = ['polarization_K','polarization_k_K','polarization_std_K','expressivity_K','analogy_exact_d1d2_K','analogy_mse_d1d2_K','analogy_nearest_d1d2_K','analogy_exact_K','analogy_mse_K','analogy_nearest_K','analogy_rank_d1d2_score']
            for name in array_names:
                exp[name+postpend] = []  
            for K_ind,K in enumerate(exp['Ks']):
                W = exp['W_symm_K'+postpend][K_ind]
                polarization_k = measure_polarization(W,exp['d'],return_all=True)
                PO,stdPO = polarization_k.mean(axis=0), polarization_k.std(axis = 0, ddof=1)
                exp['polarization_k_K'+postpend].append(polarization_k)
                exp['polarization_K'+postpend].append(PO)
                exp['polarization_std_K'+postpend].append(stdPO)
                exp['expressivity_K'+postpend].append(measure_expressivity(W,exp['d']))
                exacts,mses,nearests = test_analogy_all(torch.tensor(W,device='cuda'),exp['d'],return_all=True)
                exp['analogy_exact_d1d2_K'+postpend].append(exacts)
                exp['analogy_mse_d1d2_K'+postpend].append(mses)
                exp['analogy_nearest_d1d2_K'+postpend].append(nearests)
                # Now, we want to get the average over all pairs of semantic dimensions
                exp['analogy_exact_K'+postpend].append(np.mean(exacts,axis=0))
                exp['analogy_mse_K'+postpend].append(np.mean(mses,axis=0))
                exp['analogy_nearest_K'+postpend].append(np.mean(nearests,axis=0))
                exp['analogy_rank_d1d2_score'+postpend].append(test_analogy_rank_all(torch.tensor(W,device='cuda'),exp['d'],[1,2,3],True))

        if(args.no_W_K):
            for postpend in POSTPENDS:
                # del exp['W_K'+postpend]
                del exp['W_symm_K'+postpend]
        if(args.no_TQDM):
            time = tt.tock(printout=False)
            print("Done! (%.2g s)"%time,flush=True)
    simulation_results = {'replicate_averages':replicate_averages,'experiments':experiments}
    if(args.save):
        import pickle
        print('Saving results to %s'%args.save)
        with open(args.output_path_prefix+args.save, 'wb') as f:
            pickle.dump(simulation_results, f)
    return simulation_results 

def postprocess_args(args):
    return args 

def build_parser():
    import argparse
    parser = argparse.ArgumentParser(description='Run the toy noise experiment, e.g. python run_toy_noise.py --seed 123 --Nrep 1 --d_min 5 --d_max 6 --K_geom')
    #simulation parameters
    parser.add_argument('--seed', type=int, default=0, help='Random seed for the experiment.')
    parser.add_argument('--replicate', type=int, default=0, help='Replicate number, with which to increment the seed.')
    parser.add_argument('--Nrep', type=int, default=1, help='Number of replicates for each experiment.')
    parser.add_argument('--d_min', type=int, default=5, help='Minimum dimension of the semantic space.')
    parser.add_argument('--d_max', type=int, default=8, help='Maximum dimension (inclusive) of the semantic space.')
    parser.add_argument('--K_geom', action='store_true', help='Use geometric bins for K.')
    parser.add_argument('--P_noise_scale',type=float,default=0.0,help='Scale of the noise to multiply into the P matrix.')
    parser.add_argument('--abs_eigs', action='store_true', help='Use absolute eigenvalues for the symmeterized W matrix.')
    parser.add_argument('--fk_base',type = float, default= 1.0, help='Base for the f^{(k)} matrix.')
    #saving / output parameters
    parser.add_argument('--save', type=str, default=None, help='Path to save the results.')
    parser.add_argument("--no_W_K", action='store_true', help="Disable saving of the W_K's.")
    parser.add_argument("--output_path_prefix", type=str, default='', help="Prefix for the output path.")
    parser.add_argument("--output_filename_prefix", type=str, default='', help="Prefix for the filename.")
    parser.add_argument("--no_TQDM", action='store_true', help="Disable TQDM progress bar.")
    return parser 

if __name__ == '__main__':
    parser = build_parser()
    args = parser.parse_args()
    results = run_experiment(args)
    print('Done!')