
from neuron.units import ms, mV
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor, as_completed
from multiprocessing import Pool, Manager
from .model import DummyStim, Cell

import copy
import numpy as np
import matplotlib.pyplot as plt
import pickle as pkl

class Exp:
    def __init__(self, params_temp, num_loops, chunk_size=1, max_workers=60, if_parallel=True, if_plot=False, if_select_seed=True):
        """
        Initialize the experiment.

        Parameters:
        - params_temp (dict): Template parameters for the experiment.
        - num_loops (int): Number of simulation loops per parameter set.
        - chunk_size (int): Size of chunks for parallel processing.
        - max_workers (int): Maximum number of workers for parallel processing.
        - if_parallel (bool): Whether to run the simulation in parallel.
        - if_plot (bool): Whether to plot results during simulation.
        """
        self.params_temp = params_temp
        self.num_loops = num_loops
        self.chunk_size = chunk_size
        self.max_workers = max_workers
        self.if_parallel = if_parallel
        self.if_plot = if_plot

        if if_select_seed:
            if self.params_temp['misc']['path_for_seeds']:
                with open(self.params_temp['misc']['path_for_seeds'], 'rb') as f:
                    self.valid_seeds = pkl.load(f)
                    self.params_temp['run']['seeds'] = self.valid_seeds  # Store the seeds in the params
            else:
                self.valid_seeds = self.gen_seeds()  # Automatically generate seeds during initialization
                self.params_temp['run']['seeds'] = self.valid_seeds  # Store the seeds in the params
        else:
            self.params_temp['run']['seeds'] = np.random.choice(range(100*num_loops), size=num_loops, replace=False)

    def _simulate_seed(self, seed):
        """
        Simulate a single seed to check validity.

        Parameters:
            seed (int): The seed to simulate.

        Returns:
            int or None: Returns the seed if valid, otherwise None.
        """
        from neuron import h
        h.dt = self.params_temp['run']['dt']
        h.load_file('stdrun.hoc')
        stim = DummyStim(self.params_temp)
        pre_spike = h.Vector()
        stim.ncE_none.record(pre_spike)

        stim.seed_noise(seed)
        h.finitialize(self.params_temp['run']['E0'] * mV)
        h.continuerun(self.params_temp['run']['sim_length'] * ms)

        return seed if len(list(pre_spike)) > 0 else None

    def gen_seeds(self):
        """
        Generate valid seeds using parallel execution.

        Returns:
            list: A list of valid seeds.
        """
        if self.if_parallel:
            with Manager() as manager:
                valid_seeds = manager.list()  # Shared list for valid seeds
                seed_range = range(1000000)  # Arbitrary large seed range
                with Pool() as pool:
                    with tqdm(total=self.num_loops, desc="Processing Seeds", unit="seed") as pbar:
                        for seed in pool.imap_unordered(self._simulate_seed, seed_range):
                            if seed is not None:
                                valid_seeds.append(seed)
                                pbar.update(1)
                                if len(valid_seeds) >= self.num_loops:
                                    pool.terminate()  # Stop processing once we have enough seeds
                                    break

                # Update attributes and params
                return list(valid_seeds)[:self.num_loops]
        else:
            valid_seeds = []
            seed_range = range(1000000)
            for seed in seed_range:
                if len(valid_seeds) >= self.num_loops:
                    break
                valid_seed = self._simulate_seed(seed)
                if valid_seed is not None:
                    valid_seeds.append(valid_seed)
            return valid_seeds

    def create_params(self, synaptic_conductance, resting_potential, leak_conductance):
        # note that this function can only be used for stochastic simulation. Not for searching
        params = copy.deepcopy(self.params_temp)
        params['mode']['if_search_g_syn'] = False
        params['membrane']['ErL'] = resting_potential
        params['run']['E0'] = resting_potential
        params['synapse']['weight'] = synaptic_conductance
        params['membrane']['gl_scale'] = leak_conductance
        return params

    def generate_all_params(self, all_synaptic_conductance, all_resting_potential, all_leak_conductance, compensate_list=None):
        max_compensate = max(compensate_list.values())
        all_params = []
        for v_rest in all_resting_potential:
            for g_leak in all_leak_conductance:
                for i, g_syn in enumerate(all_synaptic_conductance):
                    num_g_syn = all_synaptic_conductance.shape[0]
                    if compensate_list is not None:
                        g_syn_comp = g_syn * compensate_list[(v_rest, g_leak)] / max_compensate
                        # in the new version, the g_syn is a scale factor, so we need to multiply it with the compensate factor
                        all_params.append(self.create_params(g_syn_comp, v_rest, g_leak))
                    else:
                        all_params.append(self.create_params(g_syn, v_rest, g_leak))
        return all_params
    def generate_all_params_dynamic(self, all_resting_potential, all_leak_conductance, g_syn_num, g_syn_index, compensate_list=None):

        all_params = []
        for v_rest in all_resting_potential:
            for g_leak in all_leak_conductance:
                # The searched g in the deterministic setting usually overestimates the firing probability in the stochastic setting.
                # So we need to scale the g_syn range, so that the simulation can be done with a reasonable range.
                # If we don't scale the g_syn, most of the firing probability will be close to zero.

                g_syn_active = compensate_list[0][(v_rest, g_leak)]
                g_syn_saturate = compensate_list[1][(v_rest, g_leak)] * 1.05
                g_syn_start = g_syn_active - g_syn_saturate/2
                g_syn_saturate = g_syn_saturate * 1.2
                g_syn_comp = g_syn_start + (g_syn_saturate - g_syn_start) * g_syn_index / g_syn_num

                all_params.append(self.create_params(g_syn_comp, v_rest, g_leak))
        return all_params

    def chunkify(self, data):
        for i in range(0, len(data), self.chunk_size):
            yield data[i:i + self.chunk_size]

    def process_chunk(self, param_chunk):
        results = []
        for parameter_set in param_chunk:
            results.append(self.simulate_at_g(parameter_set))
        return results

    def plot_simulation_results(self, record, t, config=None):
        if config is None:
            config = [
                {'data_key': 'i_synE', 'label': 'i_synE', 'ylabel': 'i_synE', 'color': 'blue'},
                {'data_key': 'i_na', 'label': 'i_na', 'ylabel': 'i_na', 'color': 'red'},
                {'data_key': 'i_kdr', 'label': 'i_kdr', 'ylabel': 'i_kdr', 'color': 'green'},
                {'data_key': 'v_soma', 'label': 'v_soma', 'ylabel': 'v_soma', 'color': 'orange'},
            ]

        num_plots = len(config)
        plt.figure(figsize=(10, 2.5 * num_plots))

        for idx, plot_config in enumerate(config, start=1):
            data_key = plot_config.get('data_key')
            label = plot_config.get('label', data_key)
            color = plot_config.get('color', 'blue')

            ax = plt.subplot(num_plots, 1, idx)

            # Plot the data
            ax.plot(t, list(record[data_key]), label=label, color=color, linewidth=0.2)

        #plt.savefig("./my_plot.svg", format="svg")
        plt.show()


    def simulate_at_g(self, params):
        from neuron import h
        vm_clip = params['membrane']['action_potential_thres'] - 5

        h.dt = params['run']['dt']
        h.load_file('stdrun.hoc')
        cell = Cell(params)

        record = {
            'i_synE': h.Vector().record(cell.synE._ref_i),
            'ap_counter': h.APCount(cell.soma(0.5)),
            'v_soma': h.Vector().record(cell.soma(0.5)._ref_v),
            #'i_subchan': h.Vector().record(cell.stch_subchan._ref_i),
            #'O_subchan': h.Vector().record(cell.stch_subchan._ref_O),
            'i_na': h.Vector().record(cell.stch_na._ref_i),
            'i_kdr': h.Vector().record(cell.stch_kdr._ref_i),
            #'post_spike': h.Vector(),
            #'pre_spike': h.Vector()
        }
        record['ap_counter'].thresh = params['membrane']['action_potential_thres']
        #record['ap_counter'].record(record['post_spike'])
        #cell.ncE_none.record(record['pre_spike'])

        sim_length = params['run']['sim_length']
        dt = params['run']['dt']
        num_points = sim_length // dt + 1
        t = np.linspace(0, sim_length, int(num_points)+1)

        all_spike_count = []
        plt.figure(figsize=(20, 10))
        for i in range(self.num_loops):
            g_syn = cell.set_NetCon_weight(params['synapse']['weight'], params['synapse']['noise_strenth'])
            cell.seed_noise(params['run']['seeds'][i])

            h.finitialize(params['run']['E0'] * mV)
            h.continuerun(params['run']['sim_length'] * ms)
            
            i_stch_na_bg = np.trapz(list(record['i_na'])[:], dx=dt*1e-3)
            i_na_leak_bg_trace = cell.g_na_leak * (np.array(list(record['v_soma'])[:]) - cell.soma.ena)
            i_na_leak_bg = np.trapz(i_na_leak_bg_trace, dx=dt*1e-3)

            i_synE_active = np.trapz(list(record['i_synE'])[:], dx=dt*1e-3)
            if_weird_syn = any([i > 0 for i in list(record['i_synE'])])

            vm_list = list(record['v_soma'])[:]
            vm_soma_clipped = [vm for vm in vm_list if vm < vm_clip]
            vh = np.quantile(vm_soma_clipped, 0.975)
            v_dist = vh - params['run']['E0']
            # if any synaptic current is positive, then it is a weird synapse
            all_spike_count.append([record['ap_counter'].n, i_stch_na_bg, i_synE_active, if_weird_syn,  v_dist, i_na_leak_bg])

            if self.if_plot:
                self.plot_simulation_results(record, t)
        g_syn = g_syn
        v_rest = cell.soma.e_pas
        g_leak = cell.soma.g_pas
        print("Done the simulation for g_syn: ", g_syn*10**6, "uS/cm2", "v_rest: ", v_rest, "mV", "g_leak: ", g_leak*10**6, "uS/cm2")
        return {(g_syn, v_rest, g_leak): all_spike_count}

    def parallel_execute(self, data, task_function):
        """
        Execute tasks in parallel using ProcessPoolExecutor.

        Parameters:
        - data (list): List of data to process.
        - task_function (function): Function to execute for each chunk of data.

        Returns:
        - results (list): Combined results from all tasks.
        """
        results = []
        chunks = list(self.chunkify(data))
        total_chunks = len(chunks)

        with tqdm(total=total_chunks, desc="Processing Chunks", unit="chunk") as pbar:
            with ProcessPoolExecutor(max_workers=self.max_workers) as executor:
                futures = {executor.submit(task_function, chunk): chunk for chunk in chunks}
                for future in as_completed(futures):
                    results.extend(future.result())
                    pbar.update(1)

        return results

    def run(self, all_params):
        """Run the experiment and return the results."""
        if self.if_parallel:
            return self.parallel_execute(all_params, self.process_chunk)
        else:
            results = []
            for params in tqdm(all_params, desc="Simulating Parameters"):
                results.append(self.simulate_at_g(params))
            return results

    def search_g_syn(self, all_resting_potential, all_leak_conductance):
        # this function is used to search the g_syn that can make the cell fire at the firing rate
        if self.params_temp['misc']['path_for_compensate']:
            with open(self.params_temp['misc']['path_for_compensate'], 'rb') as f:
                compensate_list = pkl.load(f)
                return compensate_list

        firing_rate = self.params_temp['misc']['g_syn_search_firing_rate']
        g_syn_range = self.params_temp['misc']['g_syn_search_range']
        precision = self.params_temp['misc']['g_syn_search_precision']
        sim_length = 1000 / firing_rate
        params = self.params_temp



        # Generate all combinations of v_rest and g_leak
        v_rest_grid, g_leak_grid = np.meshgrid(all_resting_potential, all_leak_conductance, indexing='ij')
        combinations = np.stack([v_rest_grid.ravel(), g_leak_grid.ravel()], axis=-1)


        from neuron import h
        search_results = {}
        for v_rest, g_leak in tqdm(combinations):
            params['membrane']['ErL'] = v_rest
            params['run']['E0'] = v_rest
            params['membrane']['gl_scale'] = g_leak
            h.dt = params['run']['dt']
            h.load_file('stdrun.hoc')
            cell = Cell(params)


            record = {
                'ap_counter': h.APCount(cell.soma(0.5)),
            }
            record['ap_counter'].thresh = params['membrane']['action_potential_thres']


            if_success = False
            if_fire = False
            now_and_pre_g_syn = [g_syn_range[0], g_syn_range[1]]

            while not if_success:
                # Calculate the midpoint of the current g_syn range
                mid_g_syn = (now_and_pre_g_syn[0] + now_and_pre_g_syn[1]) / 2.0

                # Set the synaptic conductance
                cell.set_NetCon_weight(mid_g_syn)

                # Initialize and run the simulation
                h.finitialize(params['run']['E0'] * mV)
                h.continuerun(sim_length * ms)

                # Check if the neuron fired
                if_fire = record['ap_counter'].n > 0

                # Update the g_syn range based on firing outcome
                if if_fire:
                    now_and_pre_g_syn[1] = mid_g_syn  # Decrease upper bound if neuron fires
                else:
                    now_and_pre_g_syn[0] = mid_g_syn  # Increase lower bound if neuron does not fire

                # Check if the range is within the desired precision
                if abs(now_and_pre_g_syn[1] - now_and_pre_g_syn[0]) <= precision:
                    if_success = True

            # The final g_syn value is the midpoint of the final range
            final_g_syn = (now_and_pre_g_syn[0] + now_and_pre_g_syn[1]) / 2.0
            search_results[(v_rest, g_leak)] = final_g_syn

        return search_results
