
from neuron import h
import numpy as np

class DummyStim:
    def __init__(self, params):
        self.params = params
        self._setup_stim()

    def _setup_stim(self):
        self.stimE = h.NetStim()
        self.stimE.start = self.params['stim']['stim_delay']
        self.stimE.interval = 1000 / self.params['stim']['rate']
        self.stimE.noise = 1 # poisson
        self.ncE_none = h.NetCon(self.stimE, None) # this dummy connection is used to monitor the presynaptic spikes
        self.pre_spike_vec = h.Vector()  # Presynaptic spike times
        self.ncE_none.record(self.pre_spike_vec)  # Record spikes

    def seed_noise(self, seed):
        self.stimE.noiseFromRandom123(seed, seed, seed)

class Cell:

    def __init__(self, params):
        # the params are structured as membrane, ion-channels
        # e.g., params['membrane'], params['ion_ch'], params['soma']
        self.params = params
        self._setup_soma(self.params['mode']['if_search_g_syn'])
        self._setup_membrane()
        self._setup_ion_ch(self.params['mode']['if_search_g_syn'])
        self._setup_synapse()
        self._setup_stim(self.params['mode']['if_search_g_syn'])
        self.g_na_leak = self._cal_g_na_leak()
        # to do list:
        #    calculate the effective g_na in uS (x)
        #    use g_na (v_mem - E_na) to calculate the current in nA
    def _cal_g_na_leak(self):

        g_total_leak = self.soma.g_pas * self.soma(0.5).area() * 1e-8 * 1e6 # S/cm2 -> uS
        alpha = (self.soma.ena - self.soma.e_pas) / (self.soma.e_pas - self.soma.ek)
        g_na_leak = g_total_leak / (1 + alpha)

        return g_na_leak # uS
        
    def _setup_soma(self, if_search_g_syn):
        # this function setups the soma, including geometry (but not membrane)
        self.soma = h.Section(name='soma')
        if if_search_g_syn:
            # if searching g_syn, the soma is setup as default
            # pass
            self.soma.diam = self.params['soma']['diam']
            self.soma.L = self.params['soma']['L']
        else:
            self.soma.diam = self.params['soma']['diam']
            self.soma.L = self.params['soma']['L']

    def _setup_membrane(self):
        self.soma.insert('pas')
        self.soma.cm = self.params['membrane']['cm'] # uF/cm2
        self.soma.g_pas = self.params['membrane']['gl_scale'] * \
                            self.params['membrane']['gl_ctr'] * 1e-3 # mS/cm2  -> NEURON model: S/cm2
        self.soma.e_pas = self.params['membrane']['ErL'] # mV

    def _setup_ion_ch(self, if_search_g_syn):
        # currently, this function only implemented in a stochastic way
        # to do list: load_file can be a self method, as well as the actually setup
        if if_search_g_syn:
            try:
                h.load_file(self.params['ion_ch']['det_ch_file'])
            except Exception as e:
                print(f"Error: Unable to load the channel file '{self.params['ch_files']}'. Details: {e}")
            self.soma.insert('ch_ctr_na')
            # Delayed rectifier Potassium (custom)
            self.soma.insert('ch_ctr_kdr')
            # Additional sub-threshold channel (NOISY+) (custom)
            self.soma.insert('ch_ctr_subchan')
            # Modify Channels max conductances
            for ch in ['na','kdr','subchan']:
                u = self.soma(0.5).__getattribute__('ch_ctr_'+ch)
                u.gmax = self.params['ion_ch'][ch]['gx'] * 1e-3 # mS/cm2  -> NEURON model: S/cm2
        else:
            try:
                h.load_file(self.params['ion_ch']['stch_ch_file'])
            except Exception as e:
                print(f"Error: Unable to load the channel file '{self.params['ch_files']}'. Details: {e}")
            # Sodium
            self.stch_na = h.stch_ctr_na(self.soma(0.5))
            [self.stch_na.gmax, self.stch_na.Nsingle] = \
                    self._compute_gmax_ch_num('na', self.soma(0.5).area(), self.params['ion_ch']['na']['if_stch'])

            # Delayed rectifier Potassium
            self.stch_kdr = h.stch_ctr_kdr(self.soma(0.5))
            [self.stch_kdr.gmax, self.stch_kdr.Nsingle] = \
                    self._compute_gmax_ch_num('kdr', self.soma(0.5).area(), self.params['ion_ch']['kdr']['if_stch'])

            # Sub-threshold channel (Noisy +)
            self.stch_subchan = h.stch_ctr_subchan(self.soma(0.5))
            [self.stch_subchan.gmax, self.stch_subchan.Nsingle] = \
                    self._compute_gmax_ch_num('subchan',self.soma(0.5).area(), self.params['ion_ch']['subchan']['if_stch'])
            self.soma.ena = self.params['ion_ch']['ena']
            self.soma.ek = self.params['ion_ch']['ek']



    def _compute_gmax_ch_num(self, ch, sa, if_stch):
        # ch: 'na', 'kdr', 'subchan'
        surface_area = self.soma(0.5).area()
        gmax0 = self.params['ion_ch'][ch]['gx'] * 1e3 * 1e-8 # mS/cm2 --> uS/um2
        if if_stch:
            gmax = self.params['ion_ch'][ch]['g_single'] * 1e-6 # pS --> uS
            N = int(round( (gmax0 / gmax) * surface_area ))
        else:
            gmax = gmax0 * surface_area # uS
            N = 0

        return gmax, N

    def _setup_synapse(self):
        # currently, this function only implements g input type.
        if self.params['synapse']['input_type'] == 'gsyn':
            self.synE = h.Exp2Syn( self.soma(0.5) )
            self.synE.tau1 = self.params['synapse']['tau_rise']
            self.synE.tau2 = self.params['synapse']['tau_decay']
            self.synE.e = self.params['synapse']['Er']
        else:
            raise NotImplementedError(f"Input type '{self.params['input_type']}' is not implemented.")

    def _setup_stim(self, if_search_g_syn):
        self.stimE = h.NetStim()
        self.stimE.start = self.params['stim']['stim_delay']
        if if_search_g_syn:
            self.stimE.number = 1
            self.stimE.noise = 0
        else:
            self.stimE.interval = 1000 / self.params['stim']['rate']
            self.stimE.noise = 1 # poisson
            if self.params['stim']['if_single_spike']:
                self.stimE.number = 1
                self.stimE.noise = self.params['stim']['noise']
        self.ncE = h.NetCon(self.stimE, self.synE, 0, self.params['stim']['NetCon_delay'], 0)

        # to do: this part should be handle better
        self.ncE_none = h.NetCon(self.stimE, None) # this dummy connection is used to monitor the presynaptic spikes
        self.pre_spike_vec = h.Vector()  # Presynaptic spike times
        self.ncE_none.record(self.pre_spike_vec)  # Record spikes

    def seed_noise(self, seed):
        self.stimE.noiseFromRandom123(seed, seed, seed)

    def set_NetCon_weight(self, weight, noise_strendth=0.15):
        scale = 1e3 * self.soma(0.5).area() * 1e-8
        # uA/cm2 -> nA/cm2 -> nA # SA in um2
        in_weight = weight * scale
        if noise_strendth == 0:
            in_weight_noise = in_weight
        else:
            in_weight_noise = np.random.normal(in_weight, in_weight * noise_strendth)
            
        #print(f"Set weight: {in_weight_noise}")
        self.ncE.weight[0] = in_weight_noise
        return in_weight
