import numpy as np 
import run_toy_noise
import utils.plothelp as ph 


#Okay, it would be nice to generate a family of Ps. 
def generate_P(rng, d:int, dsk:float, sk_noise_type = 'uniform', multiplicative_noise_scale = 0.0, postpend = ''):
    sks = run_toy_noise.generate_sks(d,dsk,rng,noise_type=sk_noise_type,sort =True )
    P = run_toy_noise.generate_P(sks)
    P*= np.exp( multiplicative_noise_scale * rng.randn(2**d,2**d) )
    P = run_toy_noise.generate_P_postpend(P,postpend)
    return P 

NREP = 10
D_DEFAULT = 8
import importlib
importlib.reload(run_toy_noise)
rng = np.random.RandomState(123875)
uniform_P = [ generate_P(rng,D_DEFAULT, 0.25, sk_noise_type = 'uniform', multiplicative_noise_scale = 0.0) for i in range(NREP) ]
gauss_low = [ generate_P(rng,D_DEFAULT, 1e-3, sk_noise_type = 'gauss', multiplicative_noise_scale = 0.0) for i in range(NREP)  ]
gauss_high= [ generate_P(rng,D_DEFAULT, 5e-2, sk_noise_type = 'gauss', multiplicative_noise_scale = 0.0) for i in range(NREP)  ]


def recursive_prune(P,d0,func = None):
    return recursive_prune(run_toy_noise.clean_dimension(P,d0),d0+1,func) if d0 < run_toy_noise.d_from_matrix(P)-1 else P
pruning_settings=['P_lognormed','P_pruned_high','P_pruned_low','P_pruned_all','P_pruned_half']
def build_pruned_setting(Ps):
    d = run_toy_noise.d_from_matrix(Ps[0])
    regularization = 4.0
    cleaning_func = lambda w0,w1 : regularization
    dataset = {} 
    dataset['P_normed'] = [run_toy_noise.generate_P_postpend(P,'_normed') + np.exp(-regularization) for P in Ps ]
    dataset['P_lognormed'] = [ np.log(P) for P in dataset['P_normed'] ]
    dataset['P_pruned_high'] = [run_toy_noise.clean_dimension(P,0,cleaning_func) for P in dataset['P_lognormed']]
    dataset['P_pruned_low'] = [run_toy_noise.clean_dimension(P,d-1,cleaning_func) for P in dataset['P_lognormed']]
    dataset['P_pruned_all'] = [recursive_prune(P,0,cleaning_func) for P in dataset['P_lognormed']]
    dataset['P_pruned_half'] = [recursive_prune(P,d//2,cleaning_func) for P in dataset['P_lognormed']]
    for key in pruning_settings:
        dataset[key+'_eigsystem'] = [run_toy_noise.eigh_torch(P) for P in dataset[key]]
    return dataset
def measure_polarization_dataset(dataset,keys,Ksamples):
    for key in keys:
        dataset[key+'_polarization'] = [measure_polarization(*eigensystem,Ksamples) for eigensystem in dataset[key+'_eigsystem'] ]
def measure_polarization(eigs,eigv,Ksamples):
    d = run_toy_noise.d_from_matrix(eigv)
    return [ run_toy_noise.measure_polarization(run_toy_noise.build_W_symm(eigs,eigv,K,True),d,True) for K in Ksamples ]

Ksamples = np.arange(1,2**D_DEFAULT,1)
def build_and_measure_pruning(Ps):
    dataset = build_pruned_setting(Ps)
    measure_polarization_dataset(dataset,pruning_settings,Ksamples)
    measure_analogy_dataset(dataset,pruning_settings,Ksamples)
    return dataset 


# === Here, we measure analogies!
import torch 
importlib.reload(run_toy_noise)
def measure_analogy_dataset(dataset,keys,Ksamples):
    d = run_toy_noise.d_from_matrix(dataset[keys[0]][0])
    for key in keys:
        dataset[key+'_rep_analogy_k'] = [[ [ run_toy_noise.test_analogy_k(torch.tensor(run_toy_noise.build_W_symm(eigs,eigv,K,True),device='cuda'),k)
                                         for K in Ksamples]
                                         for k in range(d)]
                                         for eigs,eigv  in dataset[key+'_eigsystem']]
        dataset[key+'_analogy_k'] = np.mean(dataset[key+'_rep_analogy_k'] ,axis = 0)
        dataset[key+'_analogy_k_std'] = np.std(dataset[key+'_rep_analogy_k'] ,axis = 0)


#Here, we measure polarization.
uniform_pruned = build_and_measure_pruning(uniform_P)
gauss_low_pruned = build_and_measure_pruning(gauss_low)
gauss_high_pruned = build_and_measure_pruning(gauss_high)

import pickle
with open('synthetic_pruning_dataset.pkl','wb') as f:
    pickle.dump((uniform_pruned,gauss_low_pruned,gauss_high_pruned),f)
