import numpy as np

from BACKEND import cp, sp, to_gpu, to_cpu

from datasets.dataloader import IterableDataLoader, DataLoader
from ..layer.srm_layer import SRMLayer
from tqdm.auto import tqdm

class FourierDataloader (IterableDataLoader):
    def __init__(self, dl: DataLoader, pre:SRMLayer, net: SRMLayer, ts:cp.ndarray, batch_size: int = 1000):
        self.y_cpu = dl.y_cpu[:dl.train_size + dl.val_size]
        self.x_cpu = None
        self.train_size = dl.train_size
        self.val_size = dl.val_size
        self.test_size =  0

        n_samples = dl.train_size + dl.val_size

        for k in tqdm(range((n_samples + batch_size - 1) // batch_size), desc="Applying previous layer"):
            start = k * batch_size
            end = min(start + batch_size, n_samples)
            self.eval_ts, _, x, _ = net._prepare_fourier_fit(
                ts, pre.compute_full_trains(to_gpu(dl.x_cpu[start:end]), ts, bin_spikes=True, progress=False),
                s_out=None, n_t=ts.shape[0], recompute_ffts=(self.x_cpu is None, False),
            )
            if self.x_cpu is None:
                self.x_cpu = np.zeros((n_samples, *x.shape[1:]), dtype=x.dtype)
            self.x_cpu[start:end] = to_cpu(x)