#!/usr/bin/env python3
import functools
from enum import Enum

import itertools
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import Parameter

from egru.models import EGRUThresholdInit
from torchdiffeq import odeint, odeint_adjoint
from torchdiffeq import odeint_event

torch.autograd.set_detect_anomaly(True)
torch.set_default_dtype(torch.float64)


class EventSource(Enum):
    input = 0
    output = 1
    network = 2
    none = 3


class EGRUC(nn.Module):
    def __init__(self, input_size: int, output_size: int, n_units: int, frac_out_units: float,
                 thr_init: EGRUThresholdInit, bias_std: float, batch_size: int, adjoint=True):
        super().__init__()
        self.model = EGRUCStep(input_size, output_size, n_units, thr_init, bias_std, batch_size)
        self.batch_size = batch_size
        self.n_units = n_units
        self.hidden2out = nn.Linear(int(frac_out_units * n_units), output_size)
        self.odeint = odeint_adjoint if adjoint else odeint
        # self.init_hidden()

    def to(self, device, *args, **kwargs):
        self.model = self.model.to(device)

        self.device = device
        return super().to(device, *args, **kwargs)

    def init_hidden(self):
        self.t0 = torch.zeros(self.batch_size).to(self.device)
        self.init_c = torch.zeros(self.batch_size, self.n_units).to(self.device)
        self.init_i_u = torch.zeros(self.batch_size, self.n_units).to(self.device)
        self.init_i_r = torch.zeros(self.batch_size, self.n_units).to(self.device)
        self.init_i_c = torch.zeros(self.batch_size, self.n_units).to(self.device)

    def get_initial_state(self):
        state = (self.init_c, self.init_i_u, self.init_i_r, self.init_i_c)
        return self.t0, state

    def _event_fn_ex(self, t, state, input_times_bi, target_times_bi):
        """
        Event occurs when these functions are 0
        :param t:
        :param state:
        :return:
        """

        c_t, _, _, _ = state
        # Positive if c_t < thr, negative is c_t > thr
        nspikes = c_t - self.model.thr
        # We do return target times as events, but don't take a snapshot of c at that time, just of h
        ret = torch.concat((nspikes, t - input_times_bi, t - target_times_bi), dim=0)
        # ret = torch.concat((nspikes, t - input_times[bi]), dim=0)
        return ret

    def _get_event_source(self, n, n_units, input_times_bi, target_times_bi):
        # Currently assuming simulation is one batch at a time
        if 0 <= n < n_units:
            # This is a neuron
            event_source = EventSource.network
            unit_num = n
        elif n_units <= n < n_units + len(input_times_bi):
            # This is input
            event_source = EventSource.input
            n -= n_units
            unit_num = n
        elif n_units + len(input_times_bi) <= n < n_units + len(input_times_bi) + len(target_times_bi):
            # This is output
            event_source = EventSource.output
            n -= (n_units + len(input_times_bi))
            unit_num = n
        else:
            raise RuntimeError(f"Impossible: n={n}")
        return event_source, unit_num

    # @profile
    def forward(self, initial_state, input_times, inputs, target_times):
        """
        :param initial_state:
        :param input_times:
        :param inputs:
        :param target_times:
        :return:
        """

        # input_times_flat = torch.tensor(list(chain.from_iterable(input_times)))
        # target_times_flat = torch.tensor(list(chain.from_iterable(target_times)))
        batch_size = self.batch_size
        n_units = self.n_units

        t0s, (c_tm_ae, iu_tm_ae, ir_tm_ae, ic_tm_ae) = initial_state

        cs_be, ius_t_be, irs_t_be, ics_t_be = [], [], [], []
        cs_ae, ius_t_ae, irs_t_ae, ics_t_ae = [], [], [], []
        # all_taus_t, taus_out_t = [], []
        all_taus_t = []
        net_taus = []
        out_taus = []
        couts = []
        hs_t = []

        event_fn_ex = self._event_fn_ex
        get_event_source = self._get_event_source

        for bi in range(batch_size):
            # print(f"Running batch {bi}")
            last_input_time = torch.max(input_times[bi])
            last_target_time = torch.max(target_times[bi])
            last_time = max(last_input_time, last_target_time)

            cs_be_tl, ius_be_tl, irs_be_tl, ics_be_tl = [], [], [], []
            cs_ae_tl, ius_ae_tl, irs_ae_tl, ics_ae_tl = [], [], [], []
            # all_taus_tl, taus_out_tl = [], []
            all_taus_tl, net_taus_tl, out_taus_tl = [], [], []
            couts_tl = []
            hs_tl = []
            input_times_bi, target_times_bi = input_times[bi], target_times[bi]

            next_spike_time = 0.
            t0 = t0s[bi]
            state_ae = (c_tm_ae[bi], iu_tm_ae[bi], ir_tm_ae[bi], ic_tm_ae[bi])
            event_fn = functools.partial(event_fn_ex, input_times_bi=input_times_bi, target_times_bi=target_times_bi)
            event_num = 0
            while next_spike_time <= last_time:
                event_t, solution = odeint_event(self.model, state_ae, t0, event_fn=event_fn,
                                                 reverse_time=False, atol=1e-8, rtol=1e-8,
                                                 odeint_interface=self.odeint,
                                                 # method='rk4', options=dict(step_size=0.1)
                                                 # adjoint_options=dict(norm="seminorm")
                                                 # options=dict(dtype=torch.float32)
                                                 )
                t0 = event_t + 1e-8  # Adding this due to numerical problems where same event repeated
                next_spike_time = event_t + 1e-8

                # print("solution", solution[0][-1][None].shape)
                all_taus_tl.append(event_t)
                state_be = tuple(sol[-1] for sol in solution)
                c_t_be, iu_t_be, ir_t_be, ic_t_be = state_be

                event_fn_out = event_fn(event_t, state_be)
                # print(event_fn_out)
                n_event = torch.nonzero(torch.isclose(event_fn_out, torch.tensor(0.)))[0][0]
                event_source, xx = get_event_source(n_event, n_units, input_times_bi, target_times_bi)
                if event_source == EventSource.input:
                    input_time_idx = xx
                    selected_inputs = inputs[bi][input_time_idx]
                    h_t = torch.zeros_like(c_t_be)
                    state_ae = self.model.state_update(event_t, state_be, selected_inputs, h_t)
                elif event_source == EventSource.network:
                    active_unit = xx
                    selected_inputs = torch.zeros_like(inputs[bi][0])
                    h_t = torch.zeros_like(c_t_be)
                    h_t[active_unit] = c_t_be[active_unit]
                    if len(self.model.thr) > 1:
                        assert torch.isclose(c_t_be[active_unit], self.model.thr[active_unit])
                    else:
                        assert torch.isclose(c_t_be[active_unit], self.model.thr)
                    net_taus_tl.append(event_t)
                    hs_tl.append(h_t)
                    state_ae = self.model.state_update(event_t, state_be, selected_inputs, h_t)
                    for s, sl in zip(state_be, [cs_be_tl, ius_be_tl, irs_be_tl, ics_be_tl]):
                        sl.append(s)

                    for s, sl in zip(state_ae, [cs_ae_tl, ius_ae_tl, irs_ae_tl, ics_ae_tl]):
                        sl.append(s)
                elif event_source == EventSource.output:
                    # output_time_idx = xx
                    # selected_inputs = torch.zeros_like(inputs[bi][0])
                    # h_t = torch.zeros_like(c_t_be)
                    out_taus_tl.append(event_t)
                    couts_tl.append(c_t_be)
                    # Skip this since all inputs are zero
                    state_ae = state_be
                else:
                    raise RuntimeError(f"Unknown event source {event_source}")

                # print(f"Event {event_num} at {event_t} from {event_source}")
                event_num += 1
            ## End `while next_spike_time <= last_time`

            for x_tl, xs in zip([cs_be_tl, ius_be_tl, irs_be_tl, ics_be_tl], [cs_be, ius_t_be, irs_t_be, ics_t_be]):
                if x_tl:
                    xs.append(torch.stack(x_tl))
                else:
                    xs.append(torch.tensor([]))
            for x_tl, xs in zip([cs_ae_tl, ius_ae_tl, irs_ae_tl, ics_ae_tl], [cs_ae, ius_t_ae, irs_t_ae, ics_t_ae]):
                if x_tl:
                    xs.append(torch.stack(x_tl))
                else:
                    xs.append(torch.tensor([]))
            all_taus_t.append(torch.stack(all_taus_tl))
            # Guaranteed non-empty because target/output times are now included
            out_taus.append(torch.stack(out_taus_tl))
            if net_taus_tl:
                net_taus.append(torch.stack(net_taus_tl))
            else:
                net_taus.append(torch.tensor([]))
            if hs_tl:
                hs_t.append(torch.stack(hs_tl))
            else:
                hs_t.append(torch.tensor([]))
            couts.append(torch.stack(couts_tl))
            del cs_be_tl, ius_be_tl, irs_be_tl, ics_be_tl, cs_ae_tl, ius_ae_tl, irs_ae_tl, ics_ae_tl, couts_tl
        # Done with batch
        n_spikes = len(list(itertools.chain.from_iterable(net_taus)))
        print(f"Average no. of spikes: {n_spikes / batch_size}")

        # longest_batch_len, longest_batch_idx = torch.max(torch.tensor([len(tau_bi) for tau_bi in net_taus]), dim=0)
        all_longest_batch_len, all_longest_batch_idx = torch.max(torch.tensor([len(tau_bi) for tau_bi in all_taus_t]),
                                                                 dim=0)

        inf_tensor = torch.tensor(float("Inf")).to(self.device)
        for bi in range(batch_size):
            all_taus_t[bi] = torch.cat(
                (
                    all_taus_t[bi],
                    torch.full((all_longest_batch_len - len(all_taus_t[bi]),), inf_tensor).to(self.device)))
            net_taus[bi] = torch.cat(
                (net_taus[bi], torch.full((all_longest_batch_len - len(net_taus[bi]),), inf_tensor).to(self.device)))
            hs_t[bi] = torch.cat(
                (hs_t[bi], torch.zeros((all_longest_batch_len - len(hs_t[bi]), n_units)).to(self.device)))
            for xs in [cs_ae, ius_t_ae, irs_t_ae, ics_t_ae]:
                xs[bi] = torch.cat(
                    (xs[bi], torch.zeros((all_longest_batch_len - len(xs[bi]), n_units)).to(self.device)))
            for xs in [cs_be, ius_t_be, irs_t_be, ics_t_be]:
                xs[bi] = torch.cat(
                    (xs[bi], torch.zeros((all_longest_batch_len - len(xs[bi]), n_units)).to(self.device)))

        ## TIME is always first dimension
        all_taus_tl_r = torch.transpose(torch.stack(all_taus_t), 0, 1)
        net_taus_tl_r = torch.transpose(torch.stack(net_taus), 0, 1)
        couts = torch.transpose(torch.stack(couts), 0, 1)
        out_taus = torch.transpose(torch.stack(out_taus), 0, 1)
        hs_t = torch.transpose(torch.stack(hs_t), 0, 1)

        cs_ae_tl_r = torch.transpose(torch.stack(cs_ae), 0, 1)
        ius_ae_tl_r = torch.transpose(torch.stack(ius_t_ae), 0, 1)
        irs_ae_tl_r = torch.transpose(torch.stack(irs_t_ae), 0, 1)
        ics_ae_tl_r = torch.transpose(torch.stack(ics_t_ae), 0, 1)

        # cs_be_tl = torch.transpose(torch.stack(cs_be), 0, 1)
        # ius_be_tl = torch.transpose(torch.stack(ius_t_be), 0, 1)
        # irs_be_tl = torch.transpose(torch.stack(irs_t_be), 0, 1)
        # ics_be_tl = torch.transpose(torch.stack(ics_t_be), 0, 1)

        ## It's ok to return ae, even though we use cs as otuput, because we look at it during target events
        return net_taus_tl_r, cs_ae_tl_r, ius_ae_tl_r, irs_ae_tl_r, ics_ae_tl_r, out_taus, couts, hs_t, all_taus_tl_r
        # return taus_tl, cs_be_tl, ius_be_tl, irs_be_tl, ics_be_tl

    def dynamics(self, initial_state, all_event_times, input_times, inputs, target_times):
        """
        :param initial_state:
        :param input_times:
        :param inputs:
        :param target_times:
        :return:
        """

        batch_size = self.batch_size
        n_units = self.n_units

        t0s, (c_tm_ae, iu_tm_ae, ir_tm_ae, ic_tm_ae) = initial_state

        event_fn_ex = self._event_fn_ex
        get_event_source = self._get_event_source

        solutions = []
        times = []

        for bi in range(batch_size):

            state_ae = (c_tm_ae[bi], iu_tm_ae[bi], ir_tm_ae[bi], ic_tm_ae[bi])
            input_times_bi, target_times_bi = input_times[bi], target_times[bi]

            event_num = 0

            solutions_tl = []
            times_tl = []

            event_fn = functools.partial(event_fn_ex, input_times_bi=input_times_bi, target_times_bi=target_times_bi)
            all_event_times_bi = all_event_times[:, bi]
            all_event_times_bi = torch.cat([torch.zeros(1), all_event_times_bi])
            all_event_times_bi = torch.masked_select(all_event_times_bi, torch.isfinite(all_event_times_bi))
            for tidx in range(1, len(all_event_times_bi)):
                prev_tt = all_event_times_bi[tidx - 1]
                tt = all_event_times_bi[tidx]
                if tidx > 1:
                    prev_tt += 1e-8
                tts = torch.linspace(float(prev_tt), float(tt), int((float(tt) - float(prev_tt)) * 50))[1:-1]
                tts = torch.cat([prev_tt.reshape(-1), tts, tt.reshape(-1)])

                solution = odeint(self.model, state_ae, tts, atol=1e-8, rtol=1e-8)

                assert len(solution[0]) == len(tts), f"{len(solution[0])}; {len(tts)}"
                times_tl.append(tts)
                solutions_tl.append(solution)

                state_be = tuple(sol[-1] for sol in solution)
                c_t_be, iu_t_be, ir_t_be, ic_t_be = state_be

                event_fn_out = event_fn(tt, state_be)
                n_event = torch.nonzero(torch.isclose(event_fn_out, torch.tensor(0.), atol=1e-7))[0][0]
                event_source, xx = get_event_source(n_event, n_units, input_times_bi, target_times_bi)
                if event_source == EventSource.input:
                    input_time_idx = xx
                    selected_inputs = inputs[bi][input_time_idx]
                    h_t = torch.zeros_like(c_t_be)
                    state_ae = self.model.state_update(tt, state_be, selected_inputs, h_t)
                elif event_source == EventSource.network:
                    active_unit = xx
                    selected_inputs = torch.zeros_like(inputs[bi][0])
                    h_t = torch.zeros_like(c_t_be)
                    h_t[active_unit] = c_t_be[active_unit]
                    if len(self.model.thr) > 1:
                        assert torch.isclose(c_t_be[active_unit], self.model.thr[active_unit])
                    else:
                        assert torch.isclose(c_t_be[active_unit], self.model.thr)
                    state_ae = self.model.state_update(tt, state_be, selected_inputs, h_t)
                elif event_source == EventSource.output:
                    state_ae = state_be
                else:
                    raise RuntimeError(f"Unknown event source {event_source}")
                # FIXME: Add asserts to compare states

                # print(f"Event {event_num} at {event_t} from {event_source}")
                event_num += 1
            ## End event loop
            solutions.append(solutions_tl)
            times.append(times_tl)

        # Done with batch

        ## It's ok to return ae, even though we use cs as otuput, because we look at it during target events
        return solutions, times


class EGRUCStep(nn.Module):

    def __init__(self, input_size: int, output_size: int, n_units: int,
                 thr_init: EGRUThresholdInit, bias_std: float, batch_size: int):
        """
        Assume time dimension is first for `inputs`
        """
        super().__init__()
        print("Using EGRU (continuous time)")

        self.input_size = input_size
        self.output_size = output_size
        self.n_units = n_units
        self.n_layers = 1
        ## \/ Used outside
        self.hidden_size = n_units
        self.input_dt_ms = 1.
        self.batch_size = batch_size
        self.bias_std = bias_std
        self._inputs = None

        self.tau_syn = Tensor([0.15])
        # self.tau_syn = Tensor([1.])

        # NOTE: Only works if these parameters are trainable.
        ## Scalar thr 0 init and f_0 vec init seems to work best on copy task. But random thr init has very similar performance.
        ## Random thr init leads to higher sparsity!
        # NOTE: For store recall, having threshold seems detrimental.
        if thr_init == EGRUThresholdInit.zero_scalar:
            self.thr = Parameter(Tensor([0.]))
        elif thr_init == EGRUThresholdInit.zero_vector:
            self.thr = Parameter(torch.zeros(self.n_units))
        elif thr_init == EGRUThresholdInit.rand_vector:
            self.thr = Parameter(torch.rand(self.n_units))
        elif thr_init == EGRUThresholdInit.const_scalar:
            self.thr = Parameter(Tensor([0.1]))

        # update gate
        self.U_u = Parameter(Tensor(input_size, n_units))
        self.W_u = Parameter(Tensor(n_units, n_units))
        self.b_u = Parameter(Tensor(n_units))

        # reset gate
        self.U_r = Parameter(Tensor(input_size, n_units))
        self.W_r = Parameter(Tensor(n_units, n_units))
        self.b_r = Parameter(Tensor(n_units))

        # cell
        self.U_c = Parameter(Tensor(input_size, n_units))
        self.W_c = Parameter(Tensor(n_units, n_units))
        self.b_c = Parameter(Tensor(n_units))

        self._init_weights()

    def to(self, device, *args, **kwargs):
        self.tau_syn = self.tau_syn.to(device)
        return super().to(device, *args, **kwargs)

    def _init_weights(self):
        for (n, p) in self.named_parameters():
            if n not in ['thr']:
                if p.data.ndimension() >= 2:
                    nn.init.xavier_uniform_(p)
                else:
                    # print(f"INit weights {n}")
                    # nn.init.normal_(p, mean=0., std=self.bias_std)
                    # Need this initial bias so that there is enough spiking in the network
                    nn.init.constant_(p, -self.bias_std)

    def forward(self, t, state):
        c_t, i_u, i_r, i_c = state
        # Note: Assuming grid-based solver. Otherwise this is wrong.
        i_u_in = self.b_u
        i_r_in = self.b_r
        i_c_in = self.b_c

        u_t = torch.sigmoid(i_u)
        z_t = torch.tanh(i_c)

        tau_syn = self.tau_syn

        return u_t * (z_t - c_t), - (i_u + i_u_in) / tau_syn, - (i_r + i_r_in) / tau_syn, - (i_c + i_c_in) / tau_syn

    def state_update(self, t, state, selected_inputs, hs_t):
        """ Updates state based on an event."""

        c_tm, i_u, i_r, i_c = state
        # o_t = torch.where(torch.logical_or(
        #     torch.isclose(c_tm - self.thr, torch.tensor(0.)), c_tm > self.thr
        # ), torch.ones_like(c_tm), torch.zeros_like(c_tm))

        # o_t = torch.heaviside(c_t - self.thr, torch.zeros_like(c_t))

        # RESET for c_t (this function called only after event has been emitted)
        c_t = c_tm - hs_t

        r_t = torch.sigmoid(i_r)
        xs_t = selected_inputs

        i_u_in = torch.matmul(hs_t, self.W_u) + torch.matmul(xs_t, self.U_u)
        i_r_in = torch.matmul(hs_t, self.W_r) + torch.matmul(xs_t, self.U_r)
        i_c_in = torch.matmul((r_t * hs_t), self.W_c) + torch.matmul(xs_t, self.U_c)

        # Current updates
        i_u = i_u + i_u_in
        i_r = i_r + i_r_in
        i_c = i_c + i_c_in
        return c_t, i_u, i_r, i_c

    def backward(self, adjoint_state, network_state):
        lambda_c, lambda_iu, lambda_ir, lambda_iz = adjoint_state
        c_tm, i_u, i_r, i_c = network_state

        u_t = torch.sigmoid(i_u)
        z_t = torch.tanh(i_c)

        return u_t * lambda_c / self.tau_mem, (lambda_iu + lambda_c * u_t * (c_tm - z_t) * (1 - u_t)) / self.tau_syn, \
               lambda_ir / self.tau_syn, (lambda_iz + lambda_c * (1 - z_t ** 2) * u_t) / self.tau_syn

    def backward_state_update(self):
        pass


def convolve_outputs(target_times, hs_t, all_taus, net_taus):
    # Code for convolving outputs
    batch_size = len(target_times)
    tau_out = 1.
    selected_values_list = []
    for bi in range(batch_size):
        selected_values = []
        target_times_bi = target_times[bi]
        # assert torch.sum(torch.isfinite(hs_t[:, bi])) == torch.sum(torch.isfinite(net_taus[:, bi]))
        valid_hs = torch.masked_select(hs_t[:, bi], torch.isfinite(net_taus[:, bi])[..., None]) \
            .reshape(-1, hs_t.shape[-1])
        valid_net_taus = torch.masked_select(net_taus[:, bi], torch.isfinite(net_taus[:, bi]))
        valid_all_taus = torch.masked_select(all_taus[:, bi], torch.isfinite(all_taus[:, bi]))
        assert len(valid_hs) == len(valid_net_taus)
        nidx = 0
        target_idx = 0
        tidx = 0
        if len(valid_net_taus) > 0 and torch.isclose(valid_all_taus[tidx], valid_net_taus[nidx]):
            conv_val = valid_hs[nidx]
            nidx += 1
        else:
            conv_val = torch.zeros_like(hs_t[0, bi])
        for tidx in range(1, len(valid_all_taus)):
            if nidx < len(valid_net_taus) and torch.isclose(valid_all_taus[tidx], valid_net_taus[nidx]):
                conv_val = valid_hs[nidx] + \
                           conv_val * torch.exp(-(valid_all_taus[tidx] - valid_all_taus[tidx - 1]) / tau_out)
                nidx += 1
            else:
                conv_val = conv_val * torch.exp(-(valid_all_taus[tidx] - valid_all_taus[tidx - 1]) / tau_out)
            if target_idx < len(target_times_bi) and torch.isclose(valid_all_taus[tidx], target_times_bi[target_idx]):
                selected_values.append(conv_val)
                target_idx += 1
            elif target_idx < len(target_times_bi) and valid_all_taus[tidx] > target_times_bi[target_idx]:
                sv = conv_val * torch.exp(-(target_times_bi[target_idx] - valid_all_taus[tidx - 1]) / tau_out)
                selected_values.append(sv)
                target_idx += 1
        # while target_idx != len(target_times_bi):
        #     conv_val = conv_val * torch.exp(-(target_times_bi[target_idx] - valid_all_taus[-1]) / tau_out)
        #     selected_values.append(conv_val)
        #     target_idx += 1
        assert target_idx == len(target_times_bi)
        selected_values_list.append(torch.stack(selected_values))
    selected_conv_values_arr = torch.stack(selected_values_list)  # Can stack because number of target times same
    # all_hs_arr = torch.stack(all_values_list)
    return selected_conv_values_arr


def convolve_dynamics(hs_t, all_taus, net_taus):
    # Code for convolving outputs
    batch_size = net_taus.shape[1]
    tau_out = 1.
    conv_values_list = []
    times_list = []
    for bi in range(batch_size):
        conv_values = []
        times = []
        # assert torch.sum(torch.isfinite(hs_t[:, bi])) == torch.sum(torch.isfinite(net_taus[:, bi]))
        valid_hs = torch.masked_select(hs_t[:, bi], torch.isfinite(net_taus[:, bi])[..., None]) \
            .reshape(-1, hs_t.shape[-1])
        valid_net_taus = torch.masked_select(net_taus[:, bi], torch.isfinite(net_taus[:, bi]))
        valid_all_taus = torch.masked_select(all_taus[:, bi], torch.isfinite(all_taus[:, bi]))
        if len(valid_all_taus) == 0:
            conv_values_list.append(torch.tensor([]))
            times_list.append(torch.tensor([]))
            continue
        assert len(valid_hs) == len(valid_net_taus)
        nidx = 0
        tidx = 0
        if len(valid_net_taus) > 0 and torch.isclose(valid_all_taus[tidx], valid_net_taus[nidx]):
            conv_val = valid_hs[nidx]
            nidx += 1
        else:
            conv_val = torch.zeros_like(hs_t[0, bi])
        for tidx in range(1, len(valid_all_taus)):
            prev_tt = valid_all_taus[tidx - 1]
            tt = valid_all_taus[tidx]
            tts = torch.linspace(float(prev_tt), float(tt), int((float(tt) - float(prev_tt)) * 100))[:-1]
            for tidx_ in range(1, len(tts)):
                conv_val = conv_val * torch.exp(-(tts[tidx_] - tts[tidx_ - 1]) / tau_out)
                conv_values.append(conv_val)
                times.append(tts[tidx_])
            if nidx < len(valid_net_taus) and torch.isclose(valid_net_taus[nidx], valid_all_taus[tidx]):
                if len(tts) > 0:
                    conv_val = valid_hs[nidx] + conv_val * torch.exp(-(tt - tts[-1]) / tau_out)
                else:
                    conv_val = valid_hs[nidx] + conv_val * torch.exp(-(tt - prev_tt) / tau_out)
                nidx += 1
            else:
                if len(tts) > 0:
                    conv_val = conv_val * torch.exp(-(tt - tts[-1]) / tau_out)
                else:
                    conv_val = conv_val * torch.exp(-(tt - prev_tt) / tau_out)
            conv_values.append(conv_val)
            times.append(tt)
        conv_values_list.append(torch.stack(conv_values))
        times_list.append(torch.stack(times))
    return conv_values_list, times_list
