import warnings

from cupy.linalg import LinAlgError
from cupy.cuda.runtime import CUDARuntimeError
from cupy_backends.cuda.libs.cusolver import CUSOLVERError
import numpy as np
from tqdm.auto import tqdm

from BACKEND import cp, sp, to_gpu, to_cpu
from matplotlib import pyplot as plt
import primme

from datasets import DataLoader
from model.kernels import InVoltFixed, OutVoltFixed, FitContributionsFixed, TauCorrFixed
from model.kernels import Kernel
from model.layer.srm_layer import SRMLayer
from model.utils.eigen import power_it, gerschgorin
from model.utils.normalisers import Normaliser
from utils.metrics import r2, R2Accumulator, MSEAccumulator
from utils.spacing import alpha_space

###
# Main implementation of the proposed algorithm
###
# Methods which are part of the S-SWIM algorithm:
'''
init_swim: Weight construction
normalise_weights: Normalisation step
tau_correlation: Identification of output delays
sig_search: PSPK supports in output layer
fit_l2_minibatch: Batched implementation of the linear solve with search over regularisation parameter
'''


class SRMLayerFixed(SRMLayer):

    def get_params_spatial(self):
        n = self.n_neurons + self.n_neurons * self.n_in
        if self.phi_q is not None:
            n += self.n_neurons
        return n

    def get_params_temporal(self):
        n = 2 * self.n_neurons
        if self.phi_q is not None:
            n += self.n_neurons
        return n

    def compute_rec_fields(self, k=True, q=True):
        if k:
            k_widths = 1 / self.k_inv_widths
            k_delays = - self.k_neg_delays_inv_widths * k_widths
            self.rec_field_k_a = cp.maximum(k_delays - k_widths, 0)
            self.rec_field_k_b = k_delays + k_widths
        if q:
            q_widths = 1 / self.q_inv_widths
            q_delays = - self.q_neg_delays_inv_widths * q_widths
            self.rec_field_q_a = cp.maximum(q_delays - q_widths, 0)
            self.rec_field_q_b = q_delays + q_widths

    def __init__(self, n_neurons: int, n_inputs: int, phi_k: Kernel = None, phi_q: Kernel = None,
                 k_delays: cp.ndarray or None = None, k_widths: cp.ndarray or None = None,
                 q_delays: cp.ndarray or None = None, q_widths: cp.ndarray or None = None,
                 in_weights: cp.ndarray or None = None, out_weights: cp.ndarray or None = None,
                 bias: cp.ndarray or None = None, seed: int = 42, debug: bool = False):
        """
        :param n_neurons:
        :param n_inputs:
        :param phi_k:
        :param phi_q:
        :param k_delays:
        :param k_widths:
        :param q_delays:
        :param q_widths:
        :param in_weights: n_in x n_neurons
        :param out_weights: n_neurons x 1
        :param seed:
        """
        super().__init__(n_neurons=n_neurons, n_inputs=n_inputs, phi_k=phi_k, phi_q=phi_q, seed=seed, bias=bias,
                         debug=debug)
        if phi_q is not None:
            if q_delays is None:
                q_delays = cp.full(n_neurons, 0, dtype=self.DTYPE)
            if q_widths is None:
                q_widths = cp.full(n_neurons, 0.1, dtype=self.DTYPE)
            self.set_q_params(q_delays, q_widths)
            if out_weights is not None:
                self.out_weights = out_weights.astype(self.DTYPE)
            else:
                self.out_weights = cp.zeros((n_neurons,), dtype=self.DTYPE)
            self.out_volt_kernel = OutVoltFixed(phi=phi_q, n_neurons=n_neurons, debug=self.debug)

        if k_delays is None:
            k_delays = cp.linspace(0, 0.25, n_neurons, dtype=self.DTYPE)
        if k_widths is None:
            k_widths = cp.full(n_neurons, 0.15, dtype=self.DTYPE)
        self.set_k_params(k_delays, k_widths)
        self.shared_shape = 0

        if in_weights is not None:
            self.in_weights = in_weights.astype(self.DTYPE)
        else:
            self.in_weights = cp.zeros((n_inputs, n_neurons), dtype=self.DTYPE)



        self.in_volt_kernel = InVoltFixed(phi=phi_k, n_neurons=n_neurons, n_in=n_inputs, debug=self.debug)
        self.fit_contrib_kernel = FitContributionsFixed(phi=phi_k, n_neurons=n_neurons, n_in=n_inputs, debug=self.debug)
        self.tau_corr_abs_kernel = TauCorrFixed(n_neurons=n_neurons, n_in=n_inputs, debug=self.debug, accumulate_absolute=True)
        self.tau_corr_direct_kernel = TauCorrFixed(n_neurons=n_neurons, n_in=n_inputs, debug=self.debug, accumulate_absolute=False)

    def _compute_in_voltages_cuda(self, s_in, ts, n_threads=32):
        return self.in_volt_kernel.__call__(eval_ts=ts, s_in=s_in, in_weights=self.in_weights,
            neg_delays_inv_widths=self.k_neg_delays_inv_widths, inv_widths=self.k_inv_widths, n_threads=n_threads)

    def _compute_out_voltages_cuda(self, s_out, ts, n_threads=32):
        return self.out_volt_kernel.__call__(eval_ts=ts, s_out=s_out, out_weights=self.out_weights,
            neg_delays_inv_widths=self.q_neg_delays_inv_widths, inv_widths=self.q_inv_widths, n_threads=n_threads)

    def _compute_k_ffts(self, ts):
        k_eval = self.phi_k(ts[None, :] * self.k_inv_widths[:, None] + self.k_neg_delays_inv_widths[:, None])
        self.k_rfft = sp.fft.rfft(k_eval, n=len(ts), axis=-1, overwrite_x=True)

    def _compute_q_ffts(self, ts):
        q_eval = self.phi_q(ts[None, :] * self.q_inv_widths[:, None] + self.q_neg_delays_inv_widths[:, None])
        q_eval[:, 0] = 0
        self.q_rfft = sp.fft.rfft(q_eval, n=len(ts), axis=-1, overwrite_x=True)

    def pseudo_spikes(self, rate=1):
        # Not used during the numerical experiments
        rate_neuron = rate * self.q_inv_widths
        t_start = -self.get_q_widths()
        t_end = 0
        spikes_neuron = []
        if rate > 0:
            n_spikes = cp.random.poisson(rate_neuron * -t_start)
            n_spikes_max = n_spikes.max()
            for i in range(self.n_neurons):
                # Spike times ~ Uniform(0, T), then sort
                times = cp.sort(cp.random.uniform(t_start[i], t_end, size=int(n_spikes[i])))
                spikes_neuron.append(times)
            spikes = cp.full((1, self.n_neurons, int(n_spikes_max)), cp.inf, dtype=cp.float32)
            for i in range(self.n_neurons):
                spikes[0, i, :len(spikes_neuron[i])] = spikes_neuron[i]
            self.pseudo_spike_cost = self._compute_out_voltages_cuda(spikes, cp.array([0,], dtype=cp.float32),)[0, 0]
        else:
            self.pseudo_spike_cost = None

    def _compute_in_contributions_fourier_single(self, s_fft: cp.ndarray, neuron_idx: int, ts: cp.ndarray,
                                                 trunc_pre: int = 0, trunc_post: int = 0):
        """
        :param s_fft: (N, self.n_in, f(n_ts)) frequency representation of binned input spikes
        :param neuron_idx: int
        :param ts: (n_ts) time points
        :return: (N, n_in, n_ts)
        """
        y = sp.fft.irfft(s_fft * self.k_rfft[None, None, neuron_idx], n=len(ts), axis=-1, overwrite_x=True)

        return y[:, :, trunc_pre:len(ts) - trunc_post]

    def _compute_out_contributions_fourier_single(self, s_fft: cp.ndarray, neuron_idx: int, ts: cp.ndarray,
                                                  trunc_pre: int = 0, trunc_post: int = 0):
        """
        :param s_fft: (N, f(n_ts)) frequency representation of binned output spikes of neuron_idx
        :param neuron_idx: int
        :param ts: (n_ts) time points
        :return: (N, n_ts)
        """
        y = s_fft * self.q_rfft[None, neuron_idx]
        y = sp.fft.irfft(y, n=len(ts), axis=-1, overwrite_x=True)

        return y[:, trunc_pre:len(ts) - trunc_post]

    def _compute_in_voltages_fourier_single(self, s_fft: cp.ndarray, neuron_idx: int, ts: cp.ndarray,
                                            trunc_pre: int = 0, trunc_post: int = 0):
        """
        :param s_fft: (N, self.n_in, f(n_ts)) frequency representation of binned input spikes
        :param neuron_idx: int
        :param ts: (n_ts) time points
        :return: (N, n_neurons, n_ts)
        """
        return cp.einsum("ijk,j->ik",
                         self._compute_in_contributions_fourier_single(s_fft=s_fft, neuron_idx=neuron_idx, ts=ts,
                                                                       trunc_pre=trunc_pre, trunc_post=trunc_post),
                         self.in_weights[:, neuron_idx])

    def _assemble_system_shared(self, s_in, s_out, eval_ts, fit_idcs: cp.ndarray, fourier_args=None):
        # Called by the fit method in the base class.
        # Not used during the numerical experiments.
        n_samples = s_in.shape[0]
        n_t = fit_idcs.shape[0]
        return cp.zeros((n_t * n_samples, 0))

    def _assemble_system_neuron(self, s_in, s_out, eval_ts, neuron_idx, fit_idcs: cp.ndarray, fourier_args=None):
        # Called by the fit method in the base class.
        # Not used during the numerical experiments.
        n_samples = s_in.shape[0]
        n_t_trunc, trunc_pre, trunc_post = fourier_args
        n_t_trunc = fit_idcs.shape[0]
        neuron_specific = cp.zeros((n_samples, self.n_in + 1, n_t_trunc), dtype=self.DTYPE)
        neuron_specific[:, :-1] = self._compute_in_contributions_fourier_single(s_fft=s_in, neuron_idx=neuron_idx,
                                                                                ts=eval_ts, trunc_pre=trunc_pre,
                                                                                trunc_post=trunc_post)[:, :, fit_idcs]
        neuron_specific[:, -1] = 0 if s_out is None else self._compute_out_contributions_fourier_single(ts=eval_ts,
                                                                                                        s_fft=s_out[:,
                                                                                                              neuron_idx],
                                                                                                        neuron_idx=neuron_idx,
                                                                                                        trunc_pre=trunc_pre,
                                                                                                        trunc_post=trunc_post)[
                                                         :, fit_idcs]
        neuron_specific = neuron_specific.transpose(-1, 0, 1).reshape((n_t_trunc * n_samples, self.n_in + 1))
        return neuron_specific


    def _set_weights(self, w_neuron, neuron_idx, matrix=None, vs=None, **kwargs):
        self.in_weights[:, neuron_idx] = w_neuron[:self.n_in]
        if self.phi_q is not None:
            self.out_weights[neuron_idx] = (w_neuron[self.n_in:] if w_neuron.shape[0] > self.n_in else 0)


    def _compute_in_voltages_fourier(self, s_in: cp.ndarray, t_max=1, dt=1, recompute_k_ffts=True):
        """
        :param s_in: (N, self.n_in, *) input spike trains encoded by spike times
        :param t_max: Last time point
        :param dt: step size
        :param recompute_k_ffts: Whether to recompute the frequency domain representation of the kernels.
        :return: (N, n_out, n_t)
        """
        ts, steps_pre, steps_post = self._expand_t_fourier(t_max, dt, recompute_ffts=(recompute_k_ffts, False))
        return sp.fft.irfft(
            self.k_rfft[None, :, :] * cp.tensordot(sp.fft.rfft(s_in, n=len(ts), axis=-1, overwrite_x=False),
                self.in_weights, axes=([1], [0])).transpose(0, 2, 1), n=len(ts), axis=-1, overwrite_x=True)[:, :,
               steps_pre:len(ts) - steps_post]

    def init_swim(self, s_in, pair_idcs, t_max=1, dt=1, recompute_k_ffts=True,
                  max_it=20, eig_tol=1e-6, norm_sign=False,
                  objective: str = 'dist', solver='prime', plot_weights=False, compute_batch_size=None):
        """
        :param normaliser:
        :param s_in: (N, self.n_in, Nt) discretised input signals
        :param pair_idcs: (n_neurons, 2)
        :param t_max: Last time point
        :param dt: step size
        :param recompute_k_ffts: Whether to recompute the frequency domain representation of the kernels.
        :param objective: 'dist' |'dot' | 'random'
        :param max_it: Maximum number of iterations
        :param eig_tol: eigenvalue convergence tolerance
        :param solver: 'prime' | 'power_it | cupy'
        """
        ts, steps_pre, steps_post = self._expand_t_fourier(t_max, dt, recompute_ffts=(recompute_k_ffts, False))
        N, _, _ = s_in.shape
        T = len(ts)
        fT = T // 2 + 1

        if compute_batch_size is None:
            compute_batch_size = N

        n_batches = N // compute_batch_size + 1

        s_fft = cp.zeros((N, self.n_in, fT), dtype=cp.complex64)
        for i in range(0, n_batches):
            idx_start = i * compute_batch_size
            idx_end = min((i+1) * compute_batch_size, N)
            s_fft[idx_start: idx_end] = sp.fft.rfft(s_in[idx_start: idx_end], n=len(ts), axis=-1, overwrite_x=False)  # (N, n_in, f(t))


        if objective == 'random':
            self.in_weights = self.rng.standard_normal(size=(self.n_in, self.n_neurons), dtype=self.DTYPE)
        elif objective == 'dist':
            if solver == 'power_it' or solver == 'cupy':
                for neuron_idx in range(self.n_neurons):
                    sample1_idx, sample2_idx = pair_idcs[neuron_idx]
                    u = sp.fft.irfft(self.k_rfft[neuron_idx] * (s_fft[sample1_idx] - s_fft[sample2_idx]),
                        overwrite_x=True, axis=-1)[:, steps_pre:len(ts) - steps_post]  # (n_in, Nt)
                    u = u @ u.T  # (n_in, n_in)
                    if solver == 'power_it':
                        w_init = self.rng.standard_normal(size=(self.n_in,), dtype=self.DTYPE)
                        w = power_it(w_init=w_init, max_it=max_it, eps=eig_tol, A=u, debug=self.debug,
                                     debug_ctxt=f"idx={neuron_idx}")
                    else:
                        ew, ev = cp.linalg.eigh(u)
                        idx_max = cp.argmax(ew)
                        w = ev[:, idx_max]

                    self.in_weights[:, neuron_idx] = w
            else:
                sample1_idcs, sample2_idcs = pair_idcs.T
                u = sp.fft.irfft(self.k_rfft[:, None, :] * (s_fft[sample1_idcs] - s_fft[sample2_idcs]),
                    overwrite_x=True, axis=-1)[:, :, steps_pre:len(ts) - steps_post]  # (n_neurons, n_in, Nt)
                u = (u @ u.transpose(0, -1, 1)).get()
                ws = np.zeros((self.n_in, self.n_neurons,), dtype=self.DTYPE)
                for neuron_idx in range(self.n_neurons):
                    ew, ev = primme.eigsh(u[neuron_idx], k=1, which='LM', tol=eig_tol, maxiter=max_it,
                                          return_stats=False)
                    ws[:, neuron_idx] = ev[:, 0]
                self.in_weights = to_gpu(ws)
        elif objective == 'dot':
            if solver == 'prime':
                sample1_idcs, sample2_idcs = pair_idcs.T
                u1 = sp.fft.irfft(self.k_rfft[:, None, :] * s_fft[sample1_idcs], overwrite_x=True, axis=-1)[:, :,
                     steps_pre:len(ts) - steps_post]  # (n_neurons, n_in, Nt)
                u2 = sp.fft.irfft(self.k_rfft[:, None, :] * s_fft[sample2_idcs], overwrite_x=True, axis=-1)[:, :,
                     steps_pre:len(ts) - steps_post]  # (n_neurons, n_in, Nt)
                u = u1 @ u2.transpose(0, -1, 1)
                del u1
                del u2
                u = (1 / 2 * (u + u.transpose(0, -1, 1)))
                u = to_cpu(u)
                ws = np.zeros((self.n_in, self.n_neurons,), dtype=self.DTYPE)
                for neuron_idx in range(self.n_neurons):
                    ew, ev = primme.eigsh(u[neuron_idx], k=1, which='SA', tol=eig_tol, maxiter=max_it, return_stats=False)
                    ws[:, neuron_idx] = ev[:, 0]
                self.in_weights = to_gpu(ws)
            elif solver == 'cupy':
                for neuron_idx in range(self.n_neurons):
                    sample1_idx, sample2_idx = pair_idcs[neuron_idx]
                    u1 = sp.fft.irfft(self.k_rfft[neuron_idx] * (s_fft[sample1_idx]),
                                     overwrite_x=True, axis=-1)[:, steps_pre:len(ts) - steps_post]  # (n_in, Nt)
                    u2 = sp.fft.irfft(self.k_rfft[neuron_idx] * (s_fft[sample2_idx]),
                                      overwrite_x=True, axis=-1)[:, steps_pre:len(ts) - steps_post]  # (n_in, Nt)
                    u = u1 @ u2.T
                    del u1
                    del u2
                    u = (1 / 2 * (u + u.T))
                    u = cp.ascontiguousarray(u)
                    try:
                        ew, ev = cp.linalg.eigh(u, UPLO='L')
                    except Exception as e:
                        print(f"Exception in eigh for neuron {neuron_idx}: {e}")
                        try:
                            ew, ev = cp.linalg.eigh(u, UPLO='U')
                        except Exception as e: # If we are here, something is seriously wrong.
                            print(f"Exception in eigh for neuron {neuron_idx}: {e}")
                            print(f"u: {u}")
                            print(f"min(u): {u.min()}")
                            print(f"max(u): {u.max()}")
                            print(f"mean: {u.mean()}")
                            print(f"u.shape: {u.shape}")
                            print(f"u.dtype: {u.dtype}")
                            print(f"u.flags: {u.flags}")
                            print(f"u.strides: {u.strides}")
                            print(f"u.itemsize: {u.itemsize}")
                            print(f"u.nbytes: {u.nbytes}")
                            print(f"u.size: {u.size}")
                            print(f"u.ndim: {u.ndim}")
                            print(f"u.shape: {u.shape}")
                            print(f"u.strides: {u.strides}")
                            print(f"u.data: {u.data}")
                            print(f"u.base: {u.base}")
                            print(f"u.flags: {u.flags}")
                            print(cp.linalg.norm(u, ord='fro'))
                            print(cp.linalg.norm(u, ord=2))
                            print(cp.isfinite(u).all())
                            try:
                                ew, ev = cp.linalg.eigh(u.astype(cp.float64), UPLO='L')
                            except Exception as e:
                                try:
                                    import cupyx.scipy.linalg as cpx_linalg
                                    ew, ev = cpx_linalg.eigh(u)
                                except Exception as e:
                                    raise e
                    idx_min = cp.argmin(ew)
                    self.in_weights[:, neuron_idx] = ev[:, idx_min].astype(self.DTYPE)
            else:
                raise ValueError(f"Only cupy and prime solvers supported for dot, got {solver} instead.")
        else:
            raise ValueError(f"Unknown objective {objective}")

        nan_weights_idcs = cp.argwhere(cp.isnan(self.in_weights).any(axis=0))
        for neuron_idx in nan_weights_idcs[:, 0]:
            self.in_weights[:, neuron_idx] = self.rng.standard_normal(size=(self.n_in,), dtype=self.DTYPE)

        if norm_sign:
            self.in_weights *= cp.sign(self.in_weights.sum(axis=0, keepdims=True))


        if plot_weights:
            fig, ax = plt.subplots(figsize=(12, 12))
            norms = cp.linalg.norm(self.in_weights, axis=0)
            corrs = self.in_weights.T @ self.in_weights / (norms[None, :] * norms[:, None])
            m = ax.matshow(corrs.get(), cmap='seismic', vmin=-1, vmax=1)
            fig.colorbar(mappable=m, ax=ax)
            plt.show()


    def normalise_weights(self, s_in, t_max=1, dt=1, normaliser: Normaliser = Normaliser(), recompute_k_ffts=False,
                          fit_out_weights=True, slice_size=500, silence_correction=True):

        ts, steps_pre, steps_post = self._expand_t_fourier(t_max, dt, recompute_ffts=(recompute_k_ffts, False))

        # Evaluate in voltages to compute mean and std for normalisation
        N = s_in.shape[0]
        neuron_means = cp.zeros((self.n_neurons,), dtype=cp.float64)
        neuron_stds = cp.zeros((self.n_neurons,), dtype=cp.float64)
        neuron_maxs = cp.full((self.n_neurons,), fill_value=-cp.inf, dtype=cp.float32)
        n_prev = 0
        for i in range(0, N, slice_size):
            start = i
            stop = min(i + slice_size, N)
            slice_size_actual = stop - start
            s_fft = sp.fft.rfft(s_in[start:stop], n=len(ts), axis=-1, overwrite_x=False)  # (N, n_in, f(t))
            vs_eval = sp.fft.irfft(
                self.k_rfft[None, :, :] * cp.tensordot(s_fft, self.in_weights, axes=([1], [0])).transpose(0, 2, 1),
                n=len(ts), axis=-1, overwrite_x=True)[:, :, steps_pre:len(ts) - steps_post] + self.bias[None, :, None]  # (N, n_out, n_t)

            neuron_means_batch = vs_eval.mean(axis=0, dtype=cp.float64).mean(axis=-1, dtype=cp.float64)  # (n_out)
            neuron_stds_batch = vs_eval.std(axis=-1, dtype=cp.float64).mean(axis=0,
                                                                      dtype=cp.float64)  # (n_out): Exp_n[ Std_t [v] ]
            neuron_means += (neuron_means_batch - neuron_means) * slice_size_actual / (n_prev + slice_size_actual)
            neuron_stds += (neuron_stds_batch - neuron_stds) * slice_size_actual / (n_prev + slice_size_actual)
            neuron_maxs = cp.maximum(neuron_maxs, vs_eval.max(axis=-1).max(axis=0))
            n_prev += slice_size_actual

        alphas, b, out_weights = normaliser(self.in_weights, neuron_means, neuron_stds, dtype=self.DTYPE)
        neuron_maxs = (neuron_maxs - self.bias) * alphas + b

        self.in_weights *= alphas[None, :]
        self.bias = b

        if silence_correction:
            self.bias += (neuron_maxs < 1) * (1.01 - neuron_maxs)

        if fit_out_weights:
            self.out_weights = (out_weights / self.phi_q(0)).astype(self.DTYPE)

    def tau_correlation(self, s_in: cp.ndarray, f: cp.ndarray, t_max, dt, OL, demean_f=True, n_threads=32, plot=False, flip=True):
        """
        :param ts:
        :param s_in: (N, N_in, T)
        :param f: (N, N_out, T)
        :return:
        """
        # Fit delays:
        if demean_f:
            f_dm = f - f.mean(axis=2, keepdims=True)
        else:
            f_dm = f
        c = self.tau_corr_abs_kernel(tmax=t_max, dt=dt, s_in=s_in, f=f_dm, d_max=OL, n_threads=n_threads).sum(axis=0)
        if flip:
            ds = OL - (cp.argmax(c, axis=-1)).astype(self.DTYPE)
        else:
            ds = (cp.argmax(c, axis=-1)).astype(self.DTYPE)
        if plot:
            for i in range(self.n_neurons):
                plt.plot(to_cpu(c[i]))
            plt.show()
        ws = self.get_k_widths()
        self.set_k_params(ds * dt, ws)

    def sig_search(self, s_in: cp.ndarray, f: cp.ndarray, dt, eval_ts, fit_idcs: cp.ndarray, delay_agg=cp.median, n_sig=30, alpha=1.5, n_threads=32, progress=True, plot=False):
        H = fit_idcs.shape[0]
        N, J, T = s_in.shape
        sig_max = 2 * H * dt
        sig_min = 1 * dt
        sigs = alpha_space(sig_min, sig_max, n_sig, alpha=alpha, dtype=self.DTYPE)
        tau_bar = delay_agg(self.get_k_delays(), keepdims=True)
        f_re = f.transpose(0, 2, 1).reshape(N * H, self.n_neurons) * (1 / H) # (N * H, N_out)
        f_norms = cp.linalg.norm(f_re, axis=0)
        A = cp.ones((N * H, self.n_in + 1), dtype=self.DTYPE)
        res_qr = cp.zeros((n_sig, self.n_neurons,), dtype=self.DTYPE)
        for i in tqdm(range(n_sig), desc="Testing sigmas", disable=not progress):
            self.fit_contrib_kernel(A, eval_ts[fit_idcs], s_in, neg_delay_inv_width=-tau_bar * 1 / sigs[i:i + 1],
                                    inv_width=1 / sigs[i: i+1], n_threads=n_threads)
            q = cp.linalg.qr(A)[0]
            norms_proj = cp.linalg.norm(q.T @ f_re, axis=0)
            res_qr[i] = f_norms - norms_proj

        best_idcs = cp.argmin(res_qr, axis=0)
        sigs_best = sigs[best_idcs]
        if plot:
            for i in range(self.n_neurons):
                plt.scatter(to_cpu(sigs), to_cpu(res_qr[:, i]))
            plt.axhline(res_qr[best_idcs].mean().get(), color='r', linestyle='--')
            plt.yscale('log')
            plt.xlabel("sigma")
            plt.ylabel("Residual")
            plt.show()

        self.set_k_params(self.get_k_delays(), sigs_best)


    def fit_l2_minibatch(self, data: DataLoader, pre:SRMLayer, eval_ts, fit_idcs: cp.ndarray, batch_size: int=1000, reg=1e-6, progress=True, n_reg_search=25, n_threads=32, val_batch_size=None, reg_min=-5, reg_max=0.5, neuron_wise_best_reg=True, num_spikes=50, pad=25):
        # assemble system matrices jointly to limit data transfer
        system_matrices = cp.zeros((self.n_neurons, self.n_in + 1, self.n_in + 1), dtype=self.DTYPE)
        rhss = cp.zeros((self.n_neurons, self.n_in + 1), dtype=self.DTYPE)
        To = fit_idcs.shape[0]

        A = cp.ones((batch_size, To, self.n_in + 1), dtype=self.DTYPE)
        for x, y in tqdm(data.iterate(batch_size=batch_size, target='train'), desc="Assembling system matrices", disable=not progress, total=data.get_n_batches(batch_size=batch_size, target="train")):
            b = x.shape[0]
            if b != A.shape[0]:
                A = cp.ones((b, To, self.n_in + 1), dtype=self.DTYPE)
            trains = pre.compute_full_trains(x, eval_ts, num_spikes=num_spikes, pad=pad, progress=False, bin_spikes=False, method='f')
            for i in range(self.n_neurons):
                self.fit_contrib_kernel(A, eval_ts[fit_idcs], trains, neg_delay_inv_width=self.k_neg_delays_inv_widths[i:i+1], inv_width=self.k_inv_widths[i:i+1], n_threads=n_threads)
                system_matrices[i, :, :] += cp.tensordot(A / To, A, axes=([0, 1], [0, 1])) / b
                rhss[i] += cp.tensordot(A / To, y[:, i, :], axes=([0, 1], [0, 1])) / b
        del x
        del y
        del A

        if n_reg_search <= 0:
            system_matrices += (reg * cp.eye(self.n_in + 1, dtype=self.DTYPE))[None, :, :]
            for neuron_idx in tqdm(range(self.n_neurons), desc="Fitting neurons", disable=not progress, ):
                w_neuron = cp.linalg.solve(system_matrices[neuron_idx], rhss[neuron_idx])
                self.bias[neuron_idx] = w_neuron[0]
                self.in_weights[:, neuron_idx] = w_neuron[1:]
            return cp.full((self.n_neurons,), fill_value=reg, dtype=self.DTYPE)
        else:
            # Compute eigendecompositions of system matrices
            # TODO: Maybe bias is regularised rn? Check and fix
            eivals = cp.zeros((self.n_neurons, self.n_in + 1), dtype=self.DTYPE)
            for neuron_idx in tqdm(range(self.n_neurons), desc="Computing Eigendecompositions", disable=not progress, ):
                eivals[neuron_idx, :], system_matrices[neuron_idx, :, :] = cp.linalg.eigh(system_matrices[neuron_idx, :, :])


            regs = cp.logspace(reg_min, reg_max, num=n_reg_search, dtype=self.DTYPE)

            rhss = cp.einsum('bij,bi->bj', system_matrices, rhss)
            rhss = rhss[None, :, :] / (eivals[None, :, :] + regs[:, None, None])

            ws = cp.einsum(
                'bij,rbj->rbi',
                system_matrices,
                rhss
            )

            del system_matrices
            del rhss

            acc_list = [R2Accumulator() for i in range(n_reg_search)]

            if val_batch_size is None:
                val_batch_size = batch_size

            val_iterator = data.iterate(batch_size=val_batch_size, target="val")

            for x, y in tqdm(val_iterator, desc="Evaluating regs", disable=not progress, total=data.get_n_batches(batch_size=val_iterator.batch_size, target="val")):
                trains = pre.compute_full_trains(x, eval_ts, num_spikes=num_spikes, pad=pad, progress=False, bin_spikes=False, method='f')
                for i_reg in range(n_reg_search):
                    self.bias[:] = ws[i_reg, :, 0]
                    self.in_weights[:] = ws[i_reg, :, 1:].T

                    vs_reg = self.compute_in_voltage(trains, eval_ts, return_cupy=True, fit_idcs=fit_idcs, method='c')

                    acc_list[i_reg].accumulate(y, vs_reg)


            if neuron_wise_best_reg:
                best_idcs = cp.zeros((self.n_neurons,), dtype=cp.int32)
                best_score = cp.zeros((self.n_neurons,), dtype=self.DTYPE)

                for i_reg in tqdm(range(n_reg_search), disable=not progress, desc="Finding best weights"):
                    scores = acc_list[i_reg].reduce(neuron_wise=True)
                    mask = scores > best_score
                    best_score[mask] = scores[mask]
                    best_idcs[mask] = i_reg

                for neuron_idx in tqdm(range(self.n_neurons), disable=not progress, desc="Setting best weights"):
                    i_reg = best_idcs[neuron_idx]
                    self.bias[neuron_idx] = ws[i_reg, neuron_idx, 0]
                    self.in_weights[:, neuron_idx] = ws[i_reg, neuron_idx, 1:].T
                return regs[best_idcs]
            else:
                best_idx = 0
                best_score = -cp.inf
                for i_reg in tqdm(range(n_reg_search), disable=not progress, desc="Finding best weights"):
                    score = acc_list[i_reg].reduce(neuron_wise=False)
                    if score > best_score:
                        best_score = score
                        best_idx = i_reg
                self.bias[:] = ws[best_idx, :, 0]
                self.in_weights[:] = ws[best_idx, :, 1:].T
                return regs[best_idx]



    def setup_l2_system(self, x_full_neuron, v_full, n_t_fit, n_samples, only_in=False, reg=1e-6):
        """
        :param x_full_neuron: (n_t * N, n_in + 1)
        :param v_full: (n_t, N)
        :return:
        """
        x_full_neuron = x_full_neuron.reshape(n_t_fit, n_samples, -1)
        if only_in:
            x_full_neuron = x_full_neuron[:, :, :-1]
        b = cp.tensordot(v_full, x_full_neuron, axes=([0, 1], [0, 1])) / (n_t_fit * n_samples)
        A = np.tensordot(x_full_neuron, x_full_neuron, axes=([0, 1], [0, 1])) / (n_t_fit * n_samples)
        if reg > 0:
            A += reg * cp.eye(A.shape[0], dtype=self.DTYPE)
        return A, b

    def save_params(self, file):
        save_dict = {
            "bias": self.bias.get(),
            "in_weights": self.in_weights.get(),
            "k_widths": self.get_k_widths(),
            "k_delays": self.get_k_delays()
        }

        if self.phi_q is not None:
            save_dict["out_weights"] = self.out_weights.get()
            save_dict["q_widths"] = self.get_q_widths()

        np.savez(file, **save_dict)


