import numpy as np
import time
from .eoh_evolution import Evolution
import warnings
from joblib import Parallel, delayed
from .evaluator_accelerate import add_numba_decorator
import re
import concurrent.futures
from copy import deepcopy
from math import ceil, isnan

class InterfaceEC():
    def __init__(self, pop_size, m, api_endpoint, api_key, llm_model,llm_use_local,llm_local_url, debug_mode, interface_prob, select,n_p,timeout,use_numba,multi_obj,**kwargs):

        # LLM settings
        self.pop_size = pop_size
        self.interface_eval = interface_prob  # e.g., ...exp.prob.FLocation
        prompts = interface_prob.prompts
        self.seed_algorithms = interface_prob.get_seed_algorithms()
        self.evol = Evolution(api_endpoint, api_key, llm_model,llm_use_local,llm_local_url, debug_mode,prompts, **kwargs)
        self.m = m
        self.debug = debug_mode

        if not self.debug:
            warnings.filterwarnings("ignore")

        self.select = select
        self.n_p = n_p
        self.multi_obj = multi_obj
        
        self.timeout = timeout
        self.use_numba = use_numba
        
    def code2file(self,code):
        with open("./ael_alg.py", "w") as file:
        # Write the code to the file
            file.write(code)
        return 
    
    def add2pop(self,population,offspring):
        for ind in population:
            if ind['objective'] == offspring['objective']:
                if self.debug:
                    print("duplicated result, retrying ... ")
                return False
        population.append(offspring)
        return True
    
    def check_duplicate(self,population,code):
        for ind in population:
            if code == ind['code']:
                return True
        return False

    # def population_management(self,pop):
    #     # Delete the worst individual
    #     pop_new = heapq.nsmallest(self.pop_size, pop, key=lambda x: x['objective'])
    #     return pop_new
    
    # def parent_selection(self,pop,m):
    #     ranks = [i for i in range(len(pop))]
    #     probs = [1 / (rank + 1 + len(pop)) for rank in ranks]
    #     parents = random.choices(pop, weights=probs, k=m)
    #     return parents

    def population_generation(self, use_seed_algs, reduction, count=None):
        
        n_create = 1
        
        population = []

        for i in range(n_create):
            _,pop = self.get_algorithm([],'i1', use_seed_algs, reduction, count)
            for p in pop:
                population.append(p)

        return population
    
    def population_generation_seed(self,seeds,n_p):  # DEPRECATED, never used it before and not sure what it does


        population = []

        mac_improvement, tac_improvement = zip(*Parallel(n_jobs=n_p)(delayed(self.interface_eval.evaluate)(seed['code']) for seed in seeds))

        for i in range(len(seeds)):
            try:
                seed_alg = {
                    'algorithm': seeds[i]['algorithm'],
                    'code': seeds[i]['code'],
                    'objective': None,
                    'mac_improvement': None,
                    'tac_improvement': None,
                    'other_inf': None,
                    'op': None,
                }

                obj = np.array(tac_improvement[i])  # TODO: also incorporate MAC into obj
                seed_alg['objective'] = np.round(obj, 5)
                mac_imp = np.array(mac_improvement[i])
                seed_alg['mac_improvement'] = np.round(mac_imp, 5)
                tac_imp = np.array(tac_improvement[i])
                seed_alg['tac_improvement'] = np.round(tac_imp, 5)
                population.append(seed_alg)

            except Exception as e:
                print("Error in seed algorithm")
                exit()

        print("Initiliazation finished! Get "+str(len(seeds))+" seed algorithms")

        return population
    

    def _get_alg(self,pop,operator, ref_alg, reduction):
        offspring = {
            'problem': reduction['problem'],
            'algorithm': None,
            'code': None,
            'objective': None,
            'runtime': None,
            'other_inf': None,
            'op': operator,
        }
        if operator == "i1":
            parents = None
            [offspring['code'],offspring['algorithm']] =  self.evol.i1(ref_alg, reduction)
        elif operator == "e1":
            parents = self.select.parent_selection(pop,self.m, self.multi_obj)
            [offspring['code'],offspring['algorithm']] = self.evol.e1(parents, reduction)
        elif operator == "e2":
            parents = self.select.parent_selection(pop,self.m, self.multi_obj)
            [offspring['code'],offspring['algorithm']] = self.evol.e2(parents, reduction)
        elif operator == "m1":
            parents = self.select.parent_selection(pop,1, self.multi_obj)
            [offspring['code'],offspring['algorithm']] = self.evol.m1(parents[0], reduction)
        elif operator == "m2":
            parents = self.select.parent_selection(pop,1, self.multi_obj)
            [offspring['code'],offspring['algorithm']] = self.evol.m2(parents[0], reduction)
        else:  # DEPRECATED
            if operator in self.evol.OP.keys():
                if operator.startswith('en'):  # exploration
                    parents = self.select.parent_selection(pop,self.m)
                    [offspring['code'],offspring['algorithm']] = self.evol.auto_explore(parents, operator)
                elif operator.startswith('mn'):  # modification
                    parents = self.select.parent_selection(pop,1)
                    [offspring['code'],offspring['algorithm']] = self.evol.auto_modify(parents[0], operator)
                else:
                    print(f"Unknown LLM-generated evolution operator [{operator}] ! \n")
            else:
                print(f"Evolution operator [{operator}] has not been implemented ! \n") 

        return parents, offspring

    def get_offspring(self, pop, operator, ref_alg=None, reduction=None):

        if reduction is not None:
            reduction_code = reduction['code']
        else:
            reduction_code = None

        try:
            p, offspring = self._get_alg(pop, operator, ref_alg, reduction)
            
            if self.use_numba:
                
                # Regular expression pattern to match function definitions
                pattern = r"def\s+(\w+)\s*\(.*\):"

                # Search for function definitions in the code
                match = re.search(pattern, offspring['code'])

                function_name = match.group(1)

                code = add_numba_decorator(program=offspring['code'], function_name=function_name)
            else:
                code = offspring['code']

            n_retry= 1
            while self.check_duplicate(pop, offspring['code']):
                
                n_retry += 1
                if self.debug:
                    print("duplicated code, wait 1 second and retrying ... ")
                    
                p, offspring = self._get_alg(pop, operator, ref_alg, reduction)

                if self.use_numba:
                    # Regular expression pattern to match function definitions
                    pattern = r"def\s+(\w+)\s*\(.*\):"

                    # Search for function definitions in the code
                    match = re.search(pattern, offspring['code'])

                    function_name = match.group(1)

                    code = add_numba_decorator(program=offspring['code'], function_name=function_name)
                else:
                    code = offspring['code']
                    
                if n_retry > 1:
                    break
                
                
            #self.code2file(offspring['code'])
            with concurrent.futures.ThreadPoolExecutor() as executor:
                future = executor.submit(self.interface_eval.evaluate, code, reduction_code)  # compute fitness of the new offspring
                fitness, runtime = future.result(timeout=self.timeout)
                offspring['objective'] = np.round(fitness, 5)  # TODO: also incorporate MAC into obj
                # offspring['mac_improvement'] = np.round(mac_improvement, 5)
                # offspring['tac_improvement'] = np.round(tac_improvement, 5)
                offspring['runtime'] = np.round(runtime, 5)
                future.cancel()        
            # fitness, runtime = self.interface_eval.evaluate(code, reduction_code)
            # # print(np.round(mac_improvement, 5), np.round(tac_improvement, 5))
            # offspring['objective'] = np.round(fitness, 5)  # TODO: also incorporate MAC into obj
            # # offspring['mac_improvement'] = np.round(mac_improvement, 5)
            # # offspring['tac_improvement'] = np.round(tac_improvement, 5)
            # offspring['runtime'] = np.round(runtime, 5)
                

        except Exception as e:
            if self.debug:
                print(f"Code Error: {e}")

            offspring = {
                'problem': None,
                'algorithm': None,
                'code': None,
                'objective': None,
                'runtime': None,
                'other_inf': None,
                'op': operator,
            }
            p = None

        # Round the objective values
        return p, offspring
    # def process_task(self,pop, operator):
    #     result =  None, {
    #             'algorithm': None,
    #             'code': None,
    #             'objective': None,
    #             'other_inf': None
    #         }
    #     with concurrent.futures.ThreadPoolExecutor() as executor:
    #         future = executor.submit(self.get_offspring, pop, operator)
    #         try:
    #             result = future.result(timeout=self.timeout)
    #             future.cancel()
    #             #print(result)
    #         except:
    #             future.cancel()
                
    #     return result


    def get_algorithm(self, pop, operator, use_seed_algs=False, reduction=None, count=None):
        '''
        generate population for new generation
        Args:
            reduction (dict): a dictionary containing 'code', 'problem', 'template' fields
            count (int): number of offsprings to generate under 'reduction'; must be <= pop_size
        '''
        results = []
        size = self.pop_size if count is None else count
        # if self.n_p < self.pop_size:
        #     n_used = 0
            
        #     while n_used < self.pop_size:
        #         results_0 = []
        #         n_rest = self.pop_size - n_used
        #         if n_rest > self.n_p:
        #             n_p_0 = self.n_p
        #         else:
        #             n_p_0 = n_rest
        #         try:
        #             results_0 = Parallel(n_jobs=self.n_p,timeout=self.timeout+10)(delayed(self.get_offspring)(pop, operator) for _ in range(n_p_0))
        #         except Exception as e:
        #             if self.debug:
        #                 print(f"Error: {e}")
        #             print("Parallel time out .")
        #         if len(results)>0:
        #             results.append(results_0)
        #         n_used += n_p_0
        # else:

        # if use_seed_algs:
        #     for alg in self.seed_algorithms.values():
        #         results.append(self.get_offspring(pop, operator, alg, reduction))
        # else:
        #     for _ in range(size):
        #         results.append(self.get_offspring(pop, operator, reduction=reduction))
        try:
            if use_seed_algs:
                results = Parallel(n_jobs=self.n_p,timeout=self.timeout+10)(delayed(self.get_offspring)(pop, operator, alg, reduction) for alg in self.seed_algorithms.values())
            else:
                results = Parallel(n_jobs=self.n_p,timeout=self.timeout+10)(delayed(self.get_offspring)(pop, operator, reduction=reduction) for _ in range(size))
        except Exception as e:
            if self.debug:
                print(f"Error: {e}")
            print("Parallel time out .")
            
        time.sleep(2)


        out_p = []
        out_off = []

        for p, off in results:
            out_p.append(p)
            out_off.append(off)
            if self.debug:
                print(f">>> check offsprings: \n {off}")
        return out_p, out_off
    # def get_algorithm(self,pop,operator, pop_size, n_p):
        
    #     # perform it pop_size times with n_p processes in parallel
    #     p,offspring = self._get_alg(pop,operator)
    #     while self.check_duplicate(pop,offspring['code']):
    #         if self.debug:
    #             print("duplicated code, wait 1 second and retrying ... ")
    #         time.sleep(1)
    #         p,offspring = self._get_alg(pop,operator)
    #     self.code2file(offspring['code'])
    #     try:
    #         fitness= self.interface_eval.evaluate()
    #     except:
    #         fitness = None
    #     offspring['objective'] =  fitness
    #     #offspring['other_inf'] =  first_gap
    #     while (fitness == None):
    #         if self.debug:
    #             print("warning! error code, retrying ... ")
    #         p,offspring = self._get_alg(pop,operator)
    #         while self.check_duplicate(pop,offspring['code']):
    #             if self.debug:
    #                 print("duplicated code, wait 1 second and retrying ... ")
    #             time.sleep(1)
    #             p,offspring = self._get_alg(pop,operator)
    #         self.code2file(offspring['code'])
    #         try:
    #             fitness= self.interface_eval.evaluate()
    #         except:
    #             fitness = None
    #         offspring['objective'] =  fitness
    #         #offspring['other_inf'] =  first_gap
    #     offspring['objective'] = np.round(offspring['objective'],5) 
    #     #offspring['other_inf'] = np.round(offspring['other_inf'],3)
    #     return p,offspring

    def evol_op(self, op_fitness, evol_type):
        '''
        evolve prompt strategies
        '''
        new_op = self.evol.evol_op(op_fitness, evol_type)
        new_prompt = self.evol.OP[new_op]
        print(f"\nNew operator created!\nName: {new_op}\nContent: {new_prompt}\n")
        return new_op, new_prompt


    def modify_offspring(self, indiv, reduction):

        code = indiv['code']
        reduction_code = reduction['code']
        offspring = deepcopy(indiv)

        try:
            #self.code2file(offspring['code'])
            with concurrent.futures.ThreadPoolExecutor() as executor:
                future = executor.submit(self.interface_eval.evaluate, code, reduction_code)  # compute fitness of the new offspring
                fitness, runtime = future.result(timeout=self.timeout)
                offspring['objective'] = np.round(fitness, 5)  # TODO: also incorporate MAC into obj
                # offspring['mac_improvement'] = np.round(mac_improvement, 5)
                # offspring['tac_improvement'] = np.round(tac_improvement, 5)
                offspring['runtime'] = np.round(runtime, 5)
                future.cancel()
            # fitness, runtime = self.interface_eval.evaluate(code)
            # # print(np.round(mac_improvement, 5), np.round(tac_improvement, 5))
            # offspring['objective'] = np.round(fitness, 5)  # TODO: also incorporate MAC into obj
            # # offspring['mac_improvement'] = np.round(mac_improvement, 5)
            # # offspring['tac_improvement'] = np.round(tac_improvement, 5)
            # offspring['runtime'] = np.round(runtime, 5)
                

        except Exception as e:
            if self.debug:
                print(f"Code Error: {e}")

            offspring['objective'] = None
            # offspring['mac_improvement'] = None
            # offspring['tac_improvement'] = None
            offspring['runtime'] = None

        return offspring


    # def reduc_population_generation(self, n=1, ref_prob=None):  # DEPRECATED
        
    #     # n_create = 3
        
    #     population = []

    #     # for i in range(n_create):
    #     for _ in range(n):  # TODO: if n > 1, the latter iterations should have info on existing reductions, so there are no overlaps
    #         population.append(self.get_reduc_offspring([], 'ri1', ref_prob))  # TODO: create boilerplate in case reduction returns error (C1)

    #     return population

    def reduc_population_generation(self, reduc_size, pop_size, ref_prob=None):

        reduc_population, population = [], []

        reduc_shares = ratios_to_counts([pop_size/reduc_size]*reduc_size, pop_size)
        reduc_probs = [ref_prob] if ref_prob is not None else []
        for r in range(reduc_size):
            reduction = self.get_reduc_offspring([], 'ri1', reduc_probs)
            reduc_population.append(reduction)
            reduc_probs.append(reduction['problem'])
            population.extend(self.population_generation(False, reduction, count=reduc_shares[r]))
        return reduc_population, population

    def reduc_population_generation_new(self, reduc_size, init_reduc_size, reduc_top_size, pop_size, n_max_retry=2):

        reduc_population, population = [], []
        problems = self.evol.ri2(init_reduc_size, reduc_size)

        for problem in problems:
            reduction = self.get_reduc_offspring([], 'ri0', [problem])
            population_r = self.population_generation(False, reduction, count=ceil(pop_size/reduc_size))
            reduction['objective'], reduction['runtime'] = get_reduc_fitness(population_r, reduc_top_size)
            n_retry = 1
            while (reduction['objective'] is None or isnan(reduction['objective'])):
                if n_retry > n_max_retry:
                    break
                # [reduction['code'], reduction['problem'], reduction['template']] = self.evol.rm0(reduction)
                reduction = self.get_reduc_offspring([], 'ri0', [problem])
                population_r = self.population_generation(False, reduction, count=ceil(pop_size/reduc_size))
                reduction['objective'], reduction['runtime'] = get_reduc_fitness(population_r, reduc_top_size)
                n_retry += 1
            reduc_population.append(reduction)
            population.extend(population_r)

        reduc_population = [reduc for reduc in reduc_population if (reduc['objective'] is not None and not isnan(reduc['objective'])) ]
        if len(reduc_population) < reduc_size:
            reduc_size = len(reduc_population)
        reduc_population = sorted(reduc_population, key=lambda x: x['objective'], reverse=True)[:reduc_size]  # TODO: use stats (e.g., KDE) to automatically select top reductions without specifying reduc_size
        selected_problems = [reduc['problem'] for reduc in reduc_population]
        population = [ind for ind in population if ind['problem'] in selected_problems]
        return reduc_population, population


    def get_reduc_offspring(self, pop, operator, ref_probs=[]):
        offspring = {
            'problem': None,  # description of problem B
            'code': None,
            'template': None,
            'objective': None,
            'runtime': None,
            'other_inf': None,
            'op': operator,
        }
        if operator == "ri0":
            [offspring['code'],offspring['problem'],offspring['template']] = self.evol.ri0(ref_probs[0])
        elif operator == "ri1":
            [offspring['code'],offspring['problem'],offspring['template']] = self.evol.ri1(ref_probs)
        elif operator == "re1":  # assuming the population size for reduction is small (~5 ish)
            [offspring['code'],offspring['problem'],offspring['template']] = self.evol.re1(pop)
        elif operator == "re2":  # assuming the population size for reduction is small (~5 ish)
            [offspring['code'],offspring['problem'],offspring['template']] = self.evol.re2(pop)
        elif operator == "rm1":
            [offspring['code'],offspring['problem'],offspring['template']] = self.evol.rm1(pop[0])
        elif operator == "rm2":
            [offspring['code'],offspring['problem'],offspring['template']] = self.evol.rm2(pop[0])
        else:
            print(f"Evolution operator [{operator}] has not been implemented ! \n")

        return offspring

    def update_reductions(self, pop, reduction) -> list:
        ''' Apply modified reduction on existing heuristics (subject to same reduction). '''
        idx_with_reduc = [idx for idx, ind in enumerate(pop) if ind['problem'] == reduction['problem']]
        results = []
        try:
            results = Parallel(n_jobs=self.n_p,timeout=self.timeout+10)(delayed(self.modify_offspring)(pop[idx], reduction) for idx in idx_with_reduc)
        except Exception as e:
            if self.debug:
                print(f"Error: {e}")
            print("Parallel time out .")
            
        time.sleep(2)

        out_off = []
        for off in results:
            out_off.append(off)

        return out_off

    def refine_reduction(self, pop, reduction, pop_size, reduc_size):
        reduc_offspring = {
            'problem': None,  # description of problem B
            'code': None,
            'template': None,
            'objective': None,
            'runtime': None,
            'other_inf': None,
            'op': 'rm1',
        }
        [reduc_offspring['code'],reduc_offspring['problem'],reduc_offspring['template']] = self.evol.rm1(reduction)

        idx_with_reduc = [idx for idx, ind in enumerate(pop) if ind['problem'] == reduc_offspring['problem']]
        if idx_with_reduc:
            results = []
            try:
                results = Parallel(n_jobs=self.n_p,timeout=self.timeout+10)(delayed(self.modify_offspring)(pop[idx], reduc_offspring) for idx in idx_with_reduc)
            except Exception as e:
                if self.debug:
                    print(f"Error: {e}")
                print("Parallel time out .")
                
            time.sleep(2)

            out_off = []
            for off in results:
                out_off.append(off)
        else:
            out_off = self.population_generation(False, reduc_offspring, count=ceil(pop_size/reduc_size))

        return reduc_offspring, out_off

def get_reduc_fitness(pop, top_size):
    pop = [ind for ind in pop if (ind['objective'] is not None and not isnan(ind['objective'])) ]
    if len(pop) < top_size:
        if not pop:
            return None, None
        top_size = len(pop)
    pop_sorted = sorted(pop, key=lambda x: x['objective'], reverse=True)
    return np.round(np.mean([ind['objective'] for ind in pop_sorted[:top_size]]), 5), np.round(np.mean([ind['runtime'] for ind in pop_sorted[:top_size]]), 5)

def get_shares(population, reduc_population, pop_size):
    reduc_shares = []
    for reduc in reduc_population:
        reduc_shares.append(len([ind for ind in population if ind['problem'] == reduc['problem']]))
    reduc_shares = ratios_to_counts(reduc_shares, pop_size)
    return reduc_shares

def ratios_to_counts(ratios, total_sum):  # NOTE: naming functions sure is hard
    """
    Converts a list of ratios to a list of counts that sum up to a total.

    Args:
        ratios: A list of numerical ratios (not necessarily adding up to 1).
        total_sum: The desired sum of the counts.

    Returns:
        A list of integer counts, or an empty list if input is invalid.

    Example usage:
    pop_size = 11
    num_reducs = 3
    ratios = [pop_size/num_reducs]*num_reducs
    print(ratios_to_counts(ratios, pop_size))  # Output: [4, 4, 3]
    """
    if not ratios or total_sum <= 0:
      return []
    
    sum_ratios = sum(ratios)
    if sum_ratios == 0:
        return [total_sum // len(ratios)] * len(ratios)

    counts = []
    for ratio in ratios:
        counts.append(ratio * total_sum / sum_ratios)

    adjusted_counts = [round(count) for count in counts]

    # Handle potential rounding errors to ensure the sum is exactly total_sum
    diff = total_sum - sum(adjusted_counts)
    
    if diff != 0:
      max_error_idx = adjusted_counts.index(max(adjusted_counts)) if diff > 0 else adjusted_counts.index(min(adjusted_counts))
      adjusted_counts[max_error_idx] += diff

    return sorted(adjusted_counts, reverse=True)
