import numpy as np
from te_datasim.baseclass import Simulator

def downsample(time_series, factor):
    """
    Downsamples a time series by a given factor.
    """
    return time_series[::factor]

def accumulate_spikes(spike_counts, ratio):
    """
    Accumulate spike counts across every `ratio` time points.
    """
    n_time_points = spike_counts.shape[0]
    n_neurons = spike_counts.shape[1]
    n_time_points_new = int(n_time_points / ratio)
    spike_counts_new = np.zeros((n_time_points_new, n_neurons))

    for i in range(n_time_points_new):
        spike_counts_new[i] = np.sum(spike_counts[i*ratio:(i+1)*ratio], axis=0)

    return spike_counts_new



def generate_poisson_input(time_points, n_channels, mean_duration_up, mean_duration_down, seed):
    """
    Generates a binary input signal for each input channel based on a Poisson process,
    with the 'up' and 'down' states having specified mean durations.
    
    time_points: array, the time points at which the process should be sampled
    n_channels: int, number of channels to simulate
    mean_duration_up: float, mean duration for "up" state
    mean_duration_down: float, mean duration for "down" state
    sd: float, standard deviation for the duration of each state
    
    Returns: n_time_points x n_channels array of binary signals
    """
    
    # Create a local random generator
    local_rng = np.random.default_rng(seed)

    n_time_points = len(time_points)
    inputs = np.zeros((n_time_points, n_channels))
    
    # Calculate the actual time intervals between the time points
    dt = time_points[1] - time_points[0]  # Time step
    n_t = len(time_points)

    for channel in range(n_channels):
        state = 0  # Start with "down" state
        
        t = 0
        while t <= n_t:
            if state == 0:
                # Draw the duration of the "down" state from an exponential distribution
                duration = local_rng.exponential(mean_duration_up)
                # Calculate the number of corresponding time steps, and set the input to 0 until then
                t += int(np.ceil(duration/dt))
                inputs[t:, channel] = 0
                # Switch to "up" state
                state = 1
            elif state == 1:
                # Draw the duration of the "up" state from an exponential distribution
                duration = local_rng.exponential(mean_duration_down)
                # Calculate the number of corresponding time steps, and set the input to 1 until then
                t += int(np.ceil(duration/dt))
                inputs[t:, channel] = 1
                # Switch to "down" state
                state = 0
            
    return inputs

def simulate_neural_dynamics(
        time_points:        np.ndarray, 
        input_signals:      np.ndarray, 
        sigma:              float, 
        A:                  np.ndarray, 
        C:                  np.ndarray, 
        neur_rand_scale:    float
        ) ->                np.ndarray:    
    """
    Simulates neural dynamics based on the given differential equation model.

    Parameters:
    -----------

    time_points : np.ndarray
        the time points at which the process should be sampled
    input_signals: np.ndarray
        the input signals for each channel, typically generated by generate_poisson_input
    sigma : float
        the timescale of the system
    A : np.array
        the 2d neural connectivity matrix mapping neurons to each other
    C : np.array 
        the 2d matrix mapping external input channels to each neuron

    Returns:
    --------
    neural_dynamics: np.ndarray
        an array of neural time series of equal size to input_signals
    """
    
    dt = time_points[1] - time_points[0]  # Time step

    # Initialize the output array with the required dimensions
    neural_dynamics = np.zeros((input_signals.shape[0], A.shape[0]))

    z = np.zeros_like(neural_dynamics[0])  # Initial state of neural time series
    
    # Euler method to solve the differential equation
    for i in range(1, len(input_signals)):
        # Update the neural time series based on the rate of change
        z_dot = sigma * A@z + C@input_signals[i - 1]
        z_dot *= dt
        z = z + z_dot
        if neur_rand_scale > 0:
            np.random.seed(i)
            z += np.random.normal(0, neur_rand_scale, size=z.shape)

        neural_dynamics[i] = z
    
    return neural_dynamics

def bold_equation(
        time_points:        np.ndarray, 
        neural_activity:    np.ndarray
        ) ->                np.ndarray:
    """
    Given neural activity, returns the BOLD signal.
    Based on the work of Stephan et al. (2007) [10.1016/j.neuroimage.2007.07.040] and Maith et al. (2022) [10.3389/fninf.2022.790966]
    
    Parameters:
    -----------

    time_points: np.ndarray
        the time points at which the BOLD signal should be calculated
    neural_activity: np.ndarray
        the neural activity trajectories, typically generated by simulate_neural_dynamics

    Returns:
    --------
    bold: np.ndarray
        the BOLD signal time series
    """
    
    # Get time step
    dt = time_points[1] - time_points[0]
    if dt > 0.11:
        raise ValueError("The time step is too large. This will result in instability for the Forward Euler method.")

    # Parameters
    kappa     = 1/1.54  # signal decay
    gamma     = 1/2.46  # feedback regulation
    E_0       = 0.34    # oxygen extraction fraction at rest
    tau       = 0.98    # time constant
    alpha     = 0.33    # vessel stiffness
    V_0       = 0.02    # resting venous blood volume fraction
    v_0       = 40.3    # frequency offset at the outer surface of the magnetized vessel for fully deoxygenated blood at 1.5 T
    TE        = 40*dt   # echo time
    epsilon   = 1.43    # ratio of intra- and extravascular signal

    # Coefficients
    k_1 = (1-V_0) * 4.3 * v_0 * E_0 * TE
    k_2 = 2 * E_0
    k_3 = 1 - epsilon

    # Vasodilatory signal
    def ds_dt(x, s, f_in):
        return (x - kappa*s - gamma*(f_in - 1))

    # Flow in
    def flow_in(s):
        return s
    
    # Flow out
    def flow_out(v):
        return v**(1/alpha)
    
    # Changes in volume
    def dv_dt(f_in, v):
        return (f_in - flow_out(v)) / tau

    # Oxigen extraction fraction
    def extract(f_in):
        return (1-(1-E_0)**(1/f_in))

    # Change in deoxyhemoglobin
    def dq_dt(f_in, q, v):
        return (f_in*(extract(f_in)/E_0) - flow_out(v)*(q/v)) / tau                                       

    # BOLD signal change equation
    def BOLD(q, v):
        return V_0 * (k_1*(1-q) + k_2*(1-(q/v)) + k_3*(1-v))

    
    # Initialize arrays
    s = np.zeros(len(time_points));         s[0] = 0
    f_in = np.zeros(len(time_points));      f_in[0] = 1
    v = np.zeros(len(time_points));         v[0] = 1
    e = np.zeros(len(time_points));         e[0] = 0.34
    q = np.zeros(len(time_points));         q[0] = 1
    bold = np.zeros(len(time_points))

    for t in range(1, len(time_points)):
        # Get previous values
        s_t = s[t-1]
        f_t = f_in[t-1]
        v_t = v[t-1]
        q_t = q[t-1]

        # Get current values
        s_t1 = s_t + dt * ds_dt(neural_activity[t], s_t, f_t)
        f_t1 = f_t + dt * flow_in(s_t)
        v_t1 = v_t + dt * dv_dt(f_t, v_t)
        e_t1 = extract(f_t)
        q_t1 = q_t + dt * dq_dt(f_t, q_t, v_t)

        # Save current values
        s[t] = s_t1
        f_in[t] = f_t1
        v[t] = v_t1
        e[t] = e_t1
        q[t] = q_t1

        # Calculate BOLD signal
        bold[t] = BOLD(q_t1, v_t1)

    return bold

def simulate_bold_signal(
        time_points:        np.ndarray, 
        neural_activity:    np.ndarray
        ) ->                np.ndarray:
    """
    Simulates the BOLD signal given neural activity.
    
    wrapper for bold_equation(), executes the equation for each channel in the neural activity
    """
    bold = np.zeros_like(neural_activity)
    for i in range(neural_activity.shape[1]):
        bold[:, i] = bold_equation(time_points, neural_activity[:, i])

    return bold

def simulate_neural_spiking(
        time_points:        np.ndarray, 
        neural_activity:    np.ndarray,
        rates:              np.ndarray,
        baseline:           np.ndarray,
        ) ->                np.ndarray:
    """
    Simulates spiking activity given neural activity, firing rate multipliers and baseline rates.

    The rate of spiking is calculated as:
    rate = exp(neural_activity*3 + baseline) * rates * dt

    """
    dt = time_points[1] - time_points[0]


    spikes = np.zeros_like(neural_activity)
    for i in range(neural_activity.shape[0]):
        rate = np.exp((neural_activity[i]*3)+baseline)*rates*dt
        spikes[i] = np.random.poisson(rate)

    return spikes



class NeuralSimulator(Simulator):
    """
    Simulate various types of neural data.
    """
    def __init__(
            self, 
            A:                  np.ndarray, 
            C:                  np.ndarray,
            r:                  np.ndarray|None=None,
            b:                  np.ndarray|None=None,
            sigma:              float=1.0,
            mean_duration_up:   float=2.5,
            mean_duration_down: float=10.0,
            neur_rand_scale:    float=0.0,
            random_bold_error:  float=0.0,
            samplerate:         int=1,
            ):
        
        """
        Initializes a neural simulator instance with a set of parameters.
        
        Parameters:
        -----------
        A : np.ndarray
            the 2d neural connectivity matrix mapping neurons to each other
        C : np.ndarray
            the 2d matrix mapping external input channels to each neuron
        r : np.ndarray|None
            firing rate multiplier vector for each neuron, defaults to 1 for all neurons if None
        b : np.ndarray|None
            baseline firing rate vector (bias) for each neuron, defaults to 0 for all neurons if None
        sigma : float
            the timescale of the system
        mean_duration_up : float
            mean duration for "up" state
        mean_duration_down : float
            mean duration for "down" state
        neur_rand_scale : float
            the amount of random noise to add to the neural activity
        random_bold_error : float
            the amount of random noise to add to the BOLD signal
        samplerate : int
            the rate at which to sample the signals
        """
        
        assert isinstance(A, np.ndarray), "A must be a numpy array"
        assert A.shape[0] == A.shape[1], "A must be a square matrix"
        self.A = A
        self.n_neurons = A.shape[0]

        assert isinstance(C, np.ndarray), "C must be a numpy array"
        assert C.shape[0] == self.n_neurons, "C must have the same number of rows as A"
        self.C = C
        self.n_inputs = C.shape[1]
        
        if r is None:
            r = np.ones(self.n_neurons)
        else:
            assert isinstance(r, np.ndarray), "r must be a numpy array"
            assert r.shape[0] == self.n_neurons, "r must have the same size as the number of neurons"
        self.r = r

        if b is None:
            b = np.zeros(self.n_neurons)
        else:
            assert isinstance(b, np.ndarray), "b must be a numpy array"
            assert b.shape[0] == self.n_neurons, "b must have the same size as the number of neurons"
        self.b = b

        assert sigma > 0, "sigma must be greater than 0"
        self.sigma = sigma

        assert mean_duration_up > 0, "mean_duration_up must be greater than 0"
        self.mean_duration_up = mean_duration_up

        assert mean_duration_down > 0, "mean_duration_down must be greater than 0"
        self.mean_duration_down = mean_duration_down

        assert neur_rand_scale >= 0, "neur_rand_scale must be greater than or equal to 0"
        self.neur_rand_scale = neur_rand_scale

        assert random_bold_error >= 0, "random_bold_error must be greater than or equal to 0"
        self.random_bold_error = random_bold_error

        assert samplerate in [1, 2, 5, 10], "samplerate must be one of 1, 2, 5 or 10"
        self.samplerate = samplerate

        self.variables = ['input_signals', 'neural_activity', 'bold_signal', 'spike_counts']

    def simulate(self, time, seed):
        """
        Simulates input, resulting neural activity and resulting BOLD signal and spiking activity.

                                      /-> BOLD signal
        DAG: input -> neural activity 
                                      \-> spiking activity
        
        Parameters:
        -----------                              
        time : int
            duration of the simulation in seconds
        seed : int
            the seed for the random number generator

        Returns:
        --------
        input_signals : np.ndarray
            the input signals for each channel
        neural_activity : np.ndarray
            the neural activity trajectories
        bold_signal : np.ndarray
            the BOLD signal time series
        spike_counts : np.ndarray
            the spike counts for each neuron
        """

        SAMPLERATE = 10  # 10 Hz sampling rate is used internally, to avoid issues with the forward euler solver for the BOLD signal
        DOWNSAMPLE = int(SAMPLERATE/self.samplerate)

        time_points = np.linspace(0, time, int(time*SAMPLERATE)+1)

        # Generate input signals
        input_signals   = generate_poisson_input(time_points, self.n_inputs, self.mean_duration_up, self.mean_duration_down, seed)
        
        # Simulate neural dynamics
        neural_activity = simulate_neural_dynamics(time_points, input_signals, self.sigma, self.A, self.C, self.neur_rand_scale)
        
        # Simulate BOLD signal
        bold_signal     = simulate_bold_signal(time_points, neural_activity)

        # Simulate spiking activity
        spike_counts    = simulate_neural_spiking(time_points, neural_activity, self.r, self.b)

        # Downsample to 1 Hz
        input_signals   = downsample(input_signals,  DOWNSAMPLE)
        neural_activity = downsample(neural_activity, DOWNSAMPLE)
        bold_signal     = downsample(bold_signal,  DOWNSAMPLE)
        spike_counts    = accumulate_spikes(spike_counts, DOWNSAMPLE)
        
        # Add random noise to the BOLD signal
        if self.random_bold_error > 0:
            np.random.seed(1)
            bold_signal += np.random.normal(0, self.random_bold_error, size=bold_signal.shape)

        return input_signals, neural_activity, bold_signal, spike_counts
    
    def analytic_transfer_entropy(
            self, 
            var_from:   str, 
            var_to:     str):
        """
        Calculate the transfer entropy between the input signals and the neural activity.
        Only 0 transfer entropies are known with complete certainty, others will raise not implemented errors.
        """
        assert var_from in self.variables, f"var_from must be in {self.variables}"
        assert var_to in self.variables, f"var_to must be in {self.variables}"
        assert var_from != var_to, "var_from and var_to must be different variables"

        if var_from == 'spike_counts' and var_to == 'neural_activity':
            return 0.0
        elif var_from == 'spike_counts' and var_to == 'input_signals':
            return 0.0
        elif var_from == 'bold_signal' and var_to == 'neural_activity':
            return 0.0
        elif var_from == 'bold_signal' and var_to == 'input_signals':
            return 0.0
        elif var_from == 'neural_activity' and var_to == 'input_signals':
            return 0.0
        else:
            raise NotImplementedError("Analytic TE not known for this pair of variables")
         