import numpy as np
import jax
import jax.numpy as jnp
import torch
import utils
import tqdm

#use float 64 tensors
jax.config.update("jax_enable_x64", True)

class FourierDataset(torch.utils.data.Dataset):
    def __init__(self, x):
        self.x = x
        self.no_samples = x.shape[0]
        self.no_features = x.shape[1]

    def __len__(self):
        return self.no_samples

    # get a row at an index
    def __getitem__(self, idx):
        return self.x[idx]


class FourierWrapper:
    """A class that maintains the Fourier transform and allows us to efficiently evaluate it
    """

    def __init__(self,  dataset, b, freqs, amps):
        # need task name and b for getting optimal batch size for the data loader during evaluation (in the
        # __getitem__ function)
        self.dataset = dataset
        self.b = b
        self.freq_matrix = jnp.array(freqs, dtype=jnp.int32)
        self.amp_matrix = jnp.array(amps, dtype=jnp.float64)
        self.n_vars, self.k = freqs.shape

    def __getitem__(self, x):
        if len(x.shape) ==1:
            x = np.reshape(x, newshape=(1, x.shape[0]))

        dataset = FourierDataset(x)
        batch_size = utils.get_task_settings()["batch_size_for_fourier_evaluation"][f"{self.b}"]
        dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=3)
        ret_value = []
        for x in tqdm.tqdm(dataloader):
            x = jnp.array(x, dtype=jnp.float64)
            x = (-1) ** (((x @ self.freq_matrix) % 2).squeeze())
            x = np.array((x  @ self.amp_matrix).squeeze())
            ret_value.append(x)

        return np.hstack(ret_value)


