import models
from swht_jax.swht_jax import get_time_samples, setup_fourier
import torch
from fourier_extractor.jax_fourier_extractor import SparseFourierDataset
import tqdm
def sample_and_wht(f, n, b, T, batch_size, num_workers):
    """
    f: callable to be sampled
    n: dimension of domain of f
    b: 2^b is no buckets
    T: no. of peeling rounds
    """
    # prepare time samples
    print("getting time samples")
    time_samples = get_time_samples(n, b, T)
    print("creating all time samples with shifts")
    dataset = SparseFourierDataset(time_samples)
    no_samples = len(dataset)
    print("creating dataloader")
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    # evaluate function on time samples
    y = torch.zeros((no_samples))
    print("sampling neural network")
    with torch.no_grad():
        for i, x in tqdm.tqdm(enumerate(dataloader)):
            x = x.cuda().double()
            evaluation = f(x).squeeze().cpu()
            y[i * batch_size:(i + 1) * batch_size] = evaluation

    y = y.reshape((T * (n + 1), -1))
    # Get WHT
    y_wht = torch.zeros(y.shape)
    print("getting wht")
    for i in tqdm.tqdm(range(y.shape[0])):
        y_wht[i] = hadamard_transform(y[i]) / np.power(2, b/2)
    ref_wht_index = slice(0, -1, n + 1)
    # convert from torch tensors to jax arrays
    print("converting WHTs to jax arrays")
    ref_wht = jnp.array(y_wht[ref_wht_index], dtype=jnp.float64)
    shifted_wht = jnp.array(np.delete(y_wht, ref_wht_index, axis=0), dtype=jnp.float64).reshape((T, n, -1))
    return ref_wht, shifted_wht





def compute_fourier(task_name, save_result=True):
    task_settings = utils.get_task_settings()
    # get run settings
    no_features = task_settings["no_features"][task_name]
    b_min, b_max = task_settings["b_range"][task_name]
    batch_size =  task_settings["batch_size"][task_name]
    num_workers = task_settings["no_workers"][task_name]
    # hard-coded for now
    T = 5
    # load model
    clf = models.random_forest.load_random_forest(task_name, best=True)

    # get time samples needed by Fourier transform
    for b in range(b_min, b_max+1):
        print(f"b={b} B={2 ** b}")
        # get compiled fourier transform function
        print("jitting sparse fourier")
        sparse_wht = setup_fourier(n=no_features, b=b, T=T)
        # evaluate neural net on those samples
        ref_whts, shifted_whts = sample_and_wht(clf=clf, n=no_features, b=b, T=T, batch_size=batch_size, num_workers=num_workers)
        print("computing sparse WHT")
        freqs, amps = sparse_wht(ref_whts, shifted_whts)
        if save_result:
            save_to_cache([freqs, amps], task_name, b)



if __name__ == "__main__":

