from abc import ABCMeta, abstractmethod

from tqdm.auto import tqdm
from BACKEND import cp, sp, to_cpu, to_gpu
from sklearn.linear_model import PoissonRegressor

from ..kernels import Kernel
from ..utils import bin_spike_times
from ..utils.enums import ReturnVs

# Together with srm_layer_fixed.py the main implementation of the proposed methods.



class SRMLayer(metaclass=ABCMeta):
    def __init__(self, n_neurons: int, n_inputs: int,
                 phi_k: Kernel = None,
                 phi_q: Kernel = None,
                 bias: cp.ndarray = None,
                 seed: int = 42,
                 debug: bool = False) -> None:
        """
        :param n_neurons: Number of neurons in the layer
        :param n_inputs: Number of input trains to the layer
        :param seed: Random seed for sampling spikes
        """
        self.DTYPE = cp.float32
        self.out_contrib_tolerance = self.DTYPE(1e-6)

        self.phi_k = phi_k
        self.phi_q = phi_q
        self.rng = cp.random.default_rng(seed)
        self.n_neurons = n_neurons
        self.n_in = n_inputs
        self.debug = debug

        self.shared_shape = None
        self.rec_field_k_a = None
        self.rec_field_k_b = None
        self.rec_field_q_a = None
        self.rec_field_q_b = None

        self.k_neg_delays_inv_widths = None
        self.k_inv_widths = None
        self.k_rfft = None

        self.q_neg_delays_inv_widths = None
        self.q_inv_widths = None
        self.q_rfft = None

        self.pseudo_spike_cost = None

        if bias is None:
            self.bias = cp.zeros((n_neurons,), dtype=self.DTYPE)
        else:
            self.bias = bias.astype(self.DTYPE)

    ################################
    # Shared with Subclasses #
    ################################

    # Fitting
    def compute_predictor(self, s_in: cp.ndarray, s_out: cp.ndarray, eval_ts: cp.ndarray, pred, pred_params,
                          fit_idcs: cp.ndarray = None, fourier_args=None):
        """
        :param s_in: (N, self.n_in, N_t) binned input spike trains
        :param s_out: (N, self.n_neurons, N_t) binned output spike trains
        :param eval_ts: (N_t) "Support" points of the lstsq problem
        :return: (N, self.n_neurons, N_t)
        """
        # TODO: Needs to be rewritten
        n_samples = s_in.shape[0]
        n_t = len(eval_ts)
        if fit_idcs is None:
            fit_idcs = cp.arange(n_t, dtype=cp.int32)

        if pred == 'true':
            return cp.asarray(pred_params)
        else:
            return None

    def fit_svd(self, s_in, s_out: cp.ndarray or None, eval_ts, pred, pred_params, fit_idcs: cp.ndarray = None,
                rcond=1e-6, progress=True, gls=False, **kwargs):
        # Not used during the numerical experiments, because it is prohibitively expensive
        """
        O(N * N_t * n_params) memory.
        :param s_in: (N, self.n_in, N_t) binned input spike trains
        :param s_out: (N, self.n_neurons, N_t) binned output spike trains
        :param eval_ts: (N_t) "Support" points of the sum in the lstsq problem, assumed uniformly spaced for fourier method
        :param rcond: float, cond parameter passed to lstsq solver
        :param pred_params: tuple or float, parameters for the voltage predictor
        :param pred: str, predictor to use for estimating voltage
        """
        n_samples = s_in.shape[0]
        n_t = eval_ts.shape[0]
        if fit_idcs is None:
            fit_idcs = cp.arange(n_t, dtype=cp.int32)
        n_t_fit = fit_idcs.shape[0]
        eval_ts, fourier_args, s_in, s_out = self._prepare_fourier_fit(eval_ts, s_in, s_out, n_t)

        v_full = self.compute_predictor(s_in, s_out, eval_ts, pred, pred_params, fit_idcs=fit_idcs,
                                        fourier_args=fourier_args).transpose(-1, 0, 1).reshape(n_t_fit * n_samples,
                                                                           self.n_neurons)
        shared_full = self._assemble_system_shared(s_in, s_out, eval_ts, fit_idcs=fit_idcs,
                                                   fourier_args=fourier_args)
        # Fit Weights
        x_full_neuron = cp.concatenate(
            (cp.ones((n_t_fit * n_samples, 1), dtype=self.DTYPE),
             shared_full,
             self._assemble_system_neuron(s_in, s_out, eval_ts, neuron_idx=0, fit_idcs=fit_idcs,
                                          fourier_args=fourier_args))
            , axis=1)
        for neuron_idx in tqdm(range(self.n_neurons), desc="Fitting neurons", disable=not progress, ):
            w_neuron, res, rank, sings = cp.linalg.lstsq(x_full_neuron, v_full[:, neuron_idx], rcond=rcond)

            self.bias[neuron_idx] = w_neuron[0]
            self._set_weights(w_neuron[1:], neuron_idx, matrix=x_full_neuron[:, 1:], vs=v_full, r_cond=rcond, **kwargs)
            if neuron_idx < self.n_neurons - 1:  # Assemble system for next neuron
                x_full_neuron[:, 1 + self.shared_shape:] = self._assemble_system_neuron(s_in, s_out, eval_ts,
                                                                                        neuron_idx=neuron_idx + 1,
                                                                                        fit_idcs=fit_idcs,
                                                                                        fourier_args=fourier_args)

    def fit_l2(self, s_in, s_out: cp.ndarray or None, eval_ts, pred, pred_params, fit_idcs: cp.ndarray = None,
               reg=1e-6, progress=True, **kwargs):
        """
        O(N * N_t * n_params) memory.
        :param s_in: (N, self.n_in, N_t) binned input spike trains
        :param s_out: (N, self.n_neurons, N_t) binned output spike trains
        :param eval_ts: (N_t) "Support" points of the sum in the lstsq problem, assumed uniformly spaced for fourier method
        :param rcond: float, cond parameter passed to lstsq solver
        :param pred_params: tuple or float, parameters for the voltage predictor
        :param pred: str, predictor to use for estimating voltage
        """
        n_samples = s_in.shape[0]
        n_t = eval_ts.shape[0]
        if fit_idcs is None:
            fit_idcs = cp.arange(n_t, dtype=cp.int32)
        n_t_fit = fit_idcs.shape[0]
        eval_ts, fourier_args, s_in, s_out = self._prepare_fourier_fit(eval_ts, s_in, s_out, n_t)

        v_full = self.compute_predictor(s_in, s_out, eval_ts, pred, pred_params, fit_idcs=fit_idcs,
                                        fourier_args=fourier_args).transpose(-1, 0, 1).reshape(n_t_fit, n_samples, self.n_neurons)
        # Fit Weights
        x_full_neuron = cp.concatenate((cp.ones((n_t_fit * n_samples, 1), dtype=self.DTYPE),
                                        self._assemble_system_neuron(s_in, s_out, eval_ts, neuron_idx=0,
                                                                     fit_idcs=fit_idcs,
                                                                     fourier_args=fourier_args)), axis=1)

        A, b = self.setup_l2_system(x_full_neuron, v_full[:, :, 0], n_t_fit, n_samples, only_in=(s_out is None), reg=reg)
        for neuron_idx in tqdm(range(self.n_neurons), desc="Fitting neurons", disable=not progress, ):
            w_neuron = cp.linalg.solve(A, b)
            self.bias[neuron_idx] = w_neuron[0]
            self._set_weights(w_neuron[1:], neuron_idx, matrix=x_full_neuron[:, 1:], vs=v_full, r_cond=reg, **kwargs)
            if neuron_idx < self.n_neurons - 1:  # Assemble system for next neuron
                x_full_neuron[:, 1 + self.shared_shape:] = self._assemble_system_neuron(s_in, s_out, eval_ts,
                                                                                        neuron_idx=neuron_idx + 1,
                                                                                        fit_idcs=fit_idcs,
                                                                                        fourier_args=fourier_args)
                A, b = self.setup_l2_system(x_full_neuron, v_full[:, :, neuron_idx + 1], n_t_fit, n_samples, only_in=(s_out is None), reg=reg)

    def fit_glm(self, s_in, s_out: cp.ndarray or None, eval_ts, alpha_reg=1.0, warm_start=True, max_iter=100,
                solver='newton-cholesky',
                scale_intercept=1, fit_idcs: cp.ndarray = None, **kwargs):
        """
        O(N * N_t * n_params) memory.
        :param s_in: (N, self.n_in, N_t) binned input spike trains
        :param s_out: (N, self.n_neurons, N_t) binned output spike trains
        :param eval_ts: (N_t) "Support" points of the sum in the lstsq problem, assumed uniformly spaced for fourier method
        :param alpha_reg: Regularization parameter, passed to PoissonRegressor
        :param warm_start: Whether to use previous weights as starting point
        :param max_iter: Maximum number of iterations for the solver
        :param solver: 'newton-cholesky' or 'lbfgs'
        :param scale_intercept: Scale the intercept by this factor when rescaling the weights

        """
        fit_model = PoissonRegressor(alpha=alpha_reg, warm_start=warm_start, fit_intercept=True, max_iter=max_iter,
                                     solver=solver)
        n_t = eval_ts.shape[0]
        if fit_idcs is None:
            fit_idcs = cp.arange(n_t, dtype=cp.int32)

        binned_spikes = s_out.transpose((-1, 0, 1)).reshape(-1, self.n_neurons)
        eval_ts, fourier_args, s_in, s_out = self._prepare_fourier_fit(eval_ts, s_in, s_out, n_t)

        shared_full = self._assemble_system_shared(s_in, s_out, eval_ts, fourier_args=fourier_args,
                                                   fit_idcs=fit_idcs)

        x_full_neuron = cp.concatenate((shared_full,
                                        self._assemble_system_neuron(s_in, s_out, eval_ts, 0,
                                                                     fourier_args=fourier_args, fit_idcs=fit_idcs)),
                                       axis=1)
        for neuron_idx in tqdm(range(self.n_neurons), desc="Fitting neurons"):
            fit_model.fit(
                to_cpu(x_full_neuron),
                to_cpu(binned_spikes[:, neuron_idx].reshape(-1))
            )

            weights = fit_model.coef_ / (-fit_model.intercept_ * scale_intercept)
            self.bias[neuron_idx] = 0 * fit_model.intercept_
            weights = to_gpu(weights)

            self._set_weights(weights, neuron_idx, matrix=x_full_neuron, **kwargs)
            if neuron_idx < self.n_neurons - 1:
                x_full_neuron[:, self.shared_shape:] = self._assemble_system_neuron(s_in, s_out, eval_ts,
                                                                                    neuron_idx + 1,
                                                                                    fourier_args=fourier_args,
                                                                                    fit_idcs=fit_idcs)

    def compute_residual(self, s_in: cp.ndarray, s_out: cp.ndarray, eval_ts: cp.ndarray, pred, pred_params,
                         print_idcs_thresh: cp.float16 = cp.inf, method='fourier'):
        """
        :param s_in: (N, self.n_in, *) input spike trains encoded by spike times
        :param s_out: (N, self.n_neurons, *) target spike trains encoded by spike times
        :param eval_ts: (N_t) "Support" points of the sum in the lstsq problem
        :param method: fourier or direct, method to assemble system matrices
        :return:
        """

        res = cp.zeros_like(eval_ts)
        v_preds = self.compute_predictor(s_in, s_out, eval_ts, pred, pred_params)

        for ti in range(len(eval_ts)):
            t = eval_ts[ti]

            v_comp = sum(self._compute_voltages_step(t, s_in, s_out))

            res[ti] = cp.linalg.norm(v_comp - v_preds[ti])
            if res[ti] > print_idcs_thresh:
                idcs = cp.nonzero((v_comp - v_preds[ti]))

        res /= s_in.shape[0]
        return res.sum(), res

    @abstractmethod
    def _compute_k_ffts(self, ts):
        """
        Compute FFT of kernel functions and store the result.
        :param ts: Time points array
        """
        pass

    @abstractmethod
    def _compute_q_ffts(self, ts):
        """
        Compute FFT of the after-potential basis functions and store the result.
        :param ts: Time points array
        """
        pass

    def _expand_t_fourier(self, t_max, dt, recompute_ffts=(True, True)):
        # Helper method for computing Convolutions by Convolution-Theorem
        # Extend ts for periodicity, s.t. v[0]=v[-1]=0
        n_ts = int(cp.round(t_max / dt))
        steps_pre = 0
        steps_post = int(self.rec_field_k_b.max() / dt) + 3
        # Cupy arange can't handle DTYPE directly
        ts = cp.asarray(
            cp.linspace(0, t_max + steps_post * dt, n_ts + steps_post,
                        endpoint=False), dtype=self.DTYPE)
        if recompute_ffts[0]:
            self._compute_k_ffts(ts)
        if recompute_ffts[1] and self.phi_q is not None:
            self._compute_q_ffts(ts)
        return ts, steps_pre, steps_post

    def _prepare_fourier_fit(self, eval_ts, s_in, s_out, n_t=None, recompute_ffts=(True, True)):
        # Helper method for computing Convolutions by Convolution-Theorem
        dt = eval_ts[1] - eval_ts[0]
        t_max = len(eval_ts) * dt
        eval_ts, steps_pre, steps_post = self._expand_t_fourier(t_max, dt, recompute_ffts=recompute_ffts)
        fourier_args = (n_t, steps_pre, steps_post)
        s_in = sp.fft.rfft(s_in, n=len(eval_ts), axis=-1, overwrite_x=True)
        if s_out is not None:
            s_out = sp.fft.rfft(s_out, n=len(eval_ts), axis=-1, overwrite_x=True)
        return eval_ts, fourier_args, s_in, s_out

    def _compute_in_voltages_fourier(self, s_in: cp.ndarray, t_max=1, dt=0.01, recompute_k_ffts=True):
        # Helper method for computing PSPK contributions by Convolution-Theorem
        """
        :param s_in: (N, self.n_in, *) input spike trains encoded by spike times
        :param t_max: Last time point
        :param dt: step size
        :param recompute_k_ffts: Whether to recompute the frequency domain representation of the kernels.
        :return: (N, n_out, n_t)
        """
        ts, steps_pre, steps_post = self._expand_t_fourier(t_max, dt,
                                                           recompute_ffts=(recompute_k_ffts, False))
        in_voltages = cp.zeros((s_in.shape[0], self.n_neurons, int(t_max / dt)), dtype=self.DTYPE)  # (N, n_out, n_t)

        s_fft = sp.fft.rfft(s_in, n=len(ts), axis=-1, overwrite_x=False)

        for neuron_idx in range(self.n_neurons):
            in_voltages[:, neuron_idx, :] = self._compute_in_voltages_fourier_single(s_fft, neuron_idx, ts,
                                                                                     trunc_pre=steps_pre,
                                                                                     trunc_post=steps_post)

        return in_voltages

    def _compute_in_voltages_cuda(self, s_in, ts, n_threads=32):
        raise NotImplementedError()

    def _compute_out_voltages_cuda(self, s_out, ts, n_threads=32):
        raise NotImplementedError()

    def compute_in_voltage(self, s_in: cp.ndarray, ts, return_cupy=True, method='f', fit_idcs=None, n_threads=32):
        """
                :param s_in: (N, self.n_in, N_t) Binned input spikes
                :param ts: (N_t) Evaluation points
                :return: vis (N, self.n_neurons, N_t)
                """
        ts = ts.astype(self.DTYPE)
        dt = ts[1] - ts[0]
        t_max = len(ts) * dt
        if fit_idcs is None:
            fit_idcs = cp.arange(ts.shape[0])

        if method=='f':
            vis = (self._compute_in_voltages_fourier(t_max=t_max, dt=dt, s_in=s_in)[:, :, fit_idcs] + self.bias[None, :, None])
        else:
            vis = (self._compute_in_voltages_cuda(s_in=s_in, ts=ts[fit_idcs], n_threads=n_threads) + self.bias[None, None, :]).transpose(0, -1, 1)

        if return_cupy:
            return vis
        else:
            return to_cpu(vis)

    def compute_full_trains(self, s_in: cp.ndarray, ts, return_vs:ReturnVs=ReturnVs.NO, num_spikes=25, pad=10,
                            progress=True, bin_spikes=False, method='f'):
        """
        :param progress: Whether to show a progress bar.
        :param bin_spikes: Whether to bin the resulting spike trains.
        :param method: Whether to compute PSPK contributions using the frequency domain representation or the CUDA kernel.
        :param s_in: (N, self.n_in, N_t) Binned input for f, time for c
        :param ts: (N_t) Evaluation points.
        :param return_vs: Whether to return the voltage traces.
        :param num_spikes: Initial number of spikes to record. Will be padded if necessary.
        :param pad: Number of spikes to pad to when num_spikes is reached.
        :return: Spikes (N, self.n_neurons, N_spikes), vis (N, self.n_neurons, N_t), vos (N, self.n_neurons, N_t)
        """
        n_samples = s_in.shape[0]
        n_t = ts.shape[0]
        ts = ts.astype(self.DTYPE)

        neuron_event_counts = cp.zeros((n_samples, self.n_neurons), dtype=cp.int32)
        spikes = cp.full((n_samples, self.n_neurons, num_spikes), dtype=cp.float32, fill_value=cp.inf)
        if return_vs==ReturnVs.SEPARATE:
            vos = cp.zeros((n_t, n_samples, self.n_neurons), dtype=self.DTYPE)

        dt = ts[1] - ts[0]
        t_max = len(ts) * dt
        if method=='f':
            vis = (self._compute_in_voltages_fourier(t_max=t_max, dt=dt, s_in=s_in) + self.bias[None, :, None]).transpose(-1, 0, 1) # (t, n, i)
        elif method=='c':
            vis = (self._compute_in_voltages_cuda(s_in, ts) + self.bias[None, None, :]).transpose(1, 0, -1) # (t, n, i)
        else:
            raise NotImplementedError()

        if self.pseudo_spike_cost is not None:
            vis[0, :, :] += self.pseudo_spike_cost

        for t_i in tqdm(range(len(ts)), desc="Computing spikes", disable=not progress, ):
            vo = self._compute_out_voltages_cuda(ts=ts[t_i:t_i + 1], s_out=spikes)[:, 0, :]
            if return_vs==ReturnVs.NO:
                spikes_t = (vis[t_i] + vo) > 1
            elif return_vs==ReturnVs.SUM:
                vis[t_i] += vo
                spikes_t = vis[t_i] > 1
            else:
                vos[t_i] = vo
                spikes_t = (vis[t_i] + vo) > 1
            if spikes_t.any():
                spikes[spikes_t, neuron_event_counts[spikes_t]] = ts[t_i]
                neuron_event_counts += spikes_t
                if neuron_event_counts.max() >= num_spikes:
                    num_spikes += pad
                    spikes = cp.pad(spikes, ((0, 0), (0, 0), (0, pad)), constant_values=cp.inf)

        # Prune to required length
        spikes = spikes[:, :, :neuron_event_counts.max()].copy()

        if bin_spikes:
            spikes = bin_spike_times(spikes, tmin=ts[0], tmax=ts[-1], n_bins=len(ts), reshape=False)
        if return_vs==ReturnVs.SEPARATE:
            vis = vis.transpose((1, 2, 0))
            vos = vos.transpose((1, 2, 0))
            return spikes, vis.get(), vos.get()
        elif return_vs==ReturnVs.SUM:
            return spikes, vis.transpose((1, 2, 0))
        else:
            return spikes

    ################################
    # Abstract Methods #
    ################################

    # Initialisation / Sampling

    # Fitting
    @abstractmethod
    def _set_weights(self, w_neuron, neuron_idx, matrix=None, vs=None, **kwargs):
        pass

    @abstractmethod
    def _assemble_system_neuron(self, s_in, s_out, eval_ts, neuron_idx, fit_idcs: cp.ndarray,
                                fourier_args=None):
        """
        :param s_in: (N, self.n_in, n_ts) Input spikes either in frequency representation
        :param s_in: (N, self.n_neuron, n_ts) Output spikes either in frequency representation
        :param neuron_idx: int
        :param eval_ts: (n_ts) time points
        :param fourier_args: Additional arguments for handling a-periodic voltage traces
        :return: (N_t * N, *) Neuron-specific contributions for the lstsq-problem. Shape depends on neuron type
        """
        pass

    @abstractmethod
    def _assemble_system_shared(self, s_in, s_out, eval_ts, fit_idcs: cp.ndarray, fourier_args=None):
        """
        :param s_in: (N, self.n_in, n_ts) Input spikes either in frequency representation
        :param s_in: (N, self.n_neuron, n_ts) Output spikes either in frequency representation
        :param eval_ts: (n_ts) time points
        :param fourier_args: Additional arguments for handling a-periodic voltage traces
        :return: (N_t * N, *) Neuron-specific contributions for the lstsq-problem. Shape depends on neuron type
        """
        pass

    # Evaluation
    @abstractmethod
    def _compute_in_voltages_fourier_single(self, s_fft: cp.ndarray, neuron_idx: int, ts: cp.ndarray,
                                            trunc_pre: int = 0, trunc_post: int = 1):
        pass

    # Utility
    def set_k_params(self, k_delays, k_widths):
        k_delays = k_delays.astype(self.DTYPE)
        k_widths = k_widths.astype(self.DTYPE)
        self.k_inv_widths = 1 / k_widths
        self.k_neg_delays_inv_widths = - k_delays * self.k_inv_widths

        self.compute_rec_fields(k=True, q=False)

    def set_q_params(self, q_delays, q_widths):
        q_delays = q_delays.astype(self.DTYPE)
        q_widths = q_widths.astype(self.DTYPE)
        self.q_inv_widths = 1 / q_widths
        self.q_neg_delays_inv_widths = - q_delays * self.q_inv_widths

        self.compute_rec_fields(k=False, q=True)

    def get_k_delays(self):
        return -self.k_neg_delays_inv_widths / self.k_inv_widths

    def get_k_widths(self):
        return 1 / self.k_inv_widths

    def get_q_delays(self):
        return -self.q_neg_delays_inv_widths / self.q_inv_widths
    def get_q_widths(self):
        return 1 / self.q_inv_widths

    def get_n_params(self):
        params = {
            "spatial":self.get_params_spatial(),
            "temporal":self.get_params_temporal()
        }
        return params

    @abstractmethod
    def compute_rec_fields(self, k=True, q=True):
        pass

    @abstractmethod
    def plot_k(self):
        pass

    @abstractmethod
    def plot_q(self):
        pass

    @abstractmethod
    def setup_l2_system(self, x_full_neuron, param, n_t_fit, n_samples, only_in=False, reg=1e-6):
        pass

    @abstractmethod
    def get_params_spatial(self):
        pass

    @abstractmethod
    def get_params_temporal(self):
        pass


