import random

import numpy as np
import time

from .evolution import Evolution
import warnings
from joblib import Parallel, delayed
import re
import concurrent.futures
from copy import deepcopy
from math import ceil, isnan


class InterfaceEC():
    def __init__(self, pop_size, m, api_endpoint, api_key, llm_model, llm_use_local, llm_local_url, debug_mode, interface_prob, select, n_p, timeout, use_numba, multi_obj,
                 **kwargs):

        # LLM settings
        self.pop_size = pop_size
        self.interface_eval = interface_prob
        prompts = interface_prob.prompts
        self.seed_algorithms = interface_prob.get_seed_algorithms()
        self.evol = Evolution(api_endpoint, api_key, llm_model, llm_use_local, llm_local_url, debug_mode, prompts, **kwargs)
        self.m = m
        self.debug = debug_mode

        if not self.debug:
            warnings.filterwarnings("ignore")

        self.select = select
        self.n_p = n_p
        self.multi_obj = multi_obj

        self.timeout = timeout
        self.use_numba = use_numba

    def code2file(self, code):
        with open("./ael_alg.py", "w") as file:
            # Write the code to the file
            file.write(code)
        return

    def add2pop(self, population, offspring):
        for ind in population:
            if ind['objective'] == offspring['objective']:
                if self.debug:
                    print("duplicated result, retrying ... ")
                return False
        population.append(offspring)
        return True

    def check_duplicate_obj(self, population, obj):
        for ind in population:
            if obj == ind['objective']:
                return True
        return False

    def check_duplicate(self, population, code):
        for ind in population:
            if code == ind['code']:
                return True
        return False

    def population_generation_seed(self, seeds):

        population = []

        fitness = self.interface_eval.batch_evaluate([seed['code'] for seed in seeds])

        for i in range(len(seeds)):
            try:
                seed_alg = {
                    'algorithm': seeds[i]['algorithm'],
                    'code': seeds[i]['code'],
                    'objective': None,
                    'other_inf': None
                }

                obj = np.array(fitness[i])
                seed_alg['objective'] = np.round(obj, 5)
                population.append(seed_alg)

            except Exception as e:
                print("Error in seed algorithm")
                exit()

        print("Initiliazation finished! Get " + str(len(seeds)) + " seed algorithms")

        return population

    def _get_alg(self, pop, operator, father=None, ref_alg=None, reduction=None):
        offspring = {
            'problem': reduction['problem'],
            'algorithm': None,
            'thought': None,
            'code': None,
            'objective': None,
            'runtime': None,
            'other_inf': None,
            'op': operator,
        }
        if operator == "i1":
            parents = None
            [offspring['code'], offspring['thought']] = self.evol.i1(ref_alg, reduction)
        elif operator == "e1":
            real_m = random.randint(2, self.m)
            real_m = min(real_m, len(pop))
            parents = self.select.parent_selection_e1(pop, real_m)
            [offspring['code'], offspring['thought']] = self.evol.e1(parents, reduction)
        elif operator == "e2":
            other = deepcopy(pop)
            if father in pop:
                other.remove(father)
            real_m = 1
            # real_m = random.randint(2, self.m) - 1
            # real_m = min(real_m, len(other))
            parents = self.select.parent_selection(other, real_m)
            parents.append(father)
            [offspring['code'], offspring['thought']] = self.evol.e2(parents, reduction)
        elif operator == "m1":
            parents = [father]
            [offspring['code'], offspring['thought']] = self.evol.m1(parents[0], reduction)
        elif operator == "m2":
            parents = [father]
            [offspring['code'], offspring['thought']] = self.evol.m2(parents[0], reduction)
        elif operator == "s1":
            parents = pop
            [offspring['code'], offspring['thought']] = self.evol.s1(pop, reduction)
        else:
            print(f"Evolution operator [{operator}] has not been implemented ! \n")

        offspring['algorithm'] = self.evol.post_thought(offspring['code'], offspring['thought'])
        return parents, offspring

    def get_offspring(self, pop, operator, father=None, ref_alg=None, reduction=None):

        if reduction is not None:
            reduction_code = reduction['code']
        else:
            reduction_code = None

    # while True:
        try:
            p, offspring = self._get_alg(pop, operator, father=father, ref_alg=ref_alg, reduction=reduction)
            code = offspring['code']
            n_retry = 1
            while self.check_duplicate(pop, offspring['code']):
                n_retry += 1
                if self.debug:
                    print("duplicated code, wait 1 second and retrying ... ")
                p, offspring = self._get_alg(pop, operator, father=father, ref_alg=ref_alg, reduction=reduction)
                code = offspring['code']
                if n_retry > 1:
                    break
            # break

            with concurrent.futures.ThreadPoolExecutor() as executor:
                future = executor.submit(self.interface_eval.evaluate, code, reduction_code)  # compute fitness of the new offspring
                fitness, runtime = future.result(timeout=self.timeout)
                offspring['objective'] = np.round(fitness, 5)  # TODO: also incorporate MAC into obj
                # offspring['mac_improvement'] = np.round(mac_improvement, 5)
                # offspring['tac_improvement'] = np.round(tac_improvement, 5)
                offspring['runtime'] = np.round(runtime, 5)
                future.cancel()

        except Exception as e:
            print(e)
        return p, offspring

    def get_algorithm(self, eval_times, pop, operator, use_seed_algs=False, reduction=None):
        while True:
            eval_times += 1
            if use_seed_algs:
                ...
            else:
                parents, offspring = self.get_offspring(pop, operator, reduction=reduction)
            # objs = self.interface_eval.batch_evaluate([offspring['code']], 0)
            objs = [offspring['objective']]
            if objs[0] is None or isnan(objs[0]) or self.check_duplicate_obj(pop, np.round(objs[0], 5)):
                continue
            # offspring['objective'] = np.round(objs[0], 5)

            return eval_times, pop, offspring
        return eval_times, None, None

    def evolve_algorithm(self, eval_times, pop, node, brother_node, operator, reduction=None):
        for i in range(3):
            eval_times += 1
            _, offspring = self.get_offspring(pop, operator, father=node, reduction=reduction)
            # objs = self.interface_eval.batch_evaluate([offspring['code']], 0)
            objs = [offspring['objective']]
            # if objs == 'timeout':
            #     return eval_times, None
            if objs[0] is None or isnan(objs[0]) or self.check_duplicate(pop, offspring['code']):
                continue
            offspring['objective'] = np.round(objs[0], 5)

            return eval_times, offspring
        return eval_times, None


    def modify_offspring(self, indiv, reduction):

        code = indiv['code']
        reduction_code = reduction['code']
        offspring = deepcopy(indiv)

        try:
            #self.code2file(offspring['code'])
            with concurrent.futures.ThreadPoolExecutor() as executor:
                future = executor.submit(self.interface_eval.evaluate, code, reduction_code)  # compute fitness of the new offspring
                fitness, runtime = future.result(timeout=self.timeout)
                offspring['objective'] = np.round(fitness, 5)  # TODO: also incorporate MAC into obj
                # offspring['mac_improvement'] = np.round(mac_improvement, 5)
                # offspring['tac_improvement'] = np.round(tac_improvement, 5)
                offspring['runtime'] = np.round(runtime, 5)
                future.cancel()
            # fitness, runtime = self.interface_eval.evaluate(code)
            # # print(np.round(mac_improvement, 5), np.round(tac_improvement, 5))
            # offspring['objective'] = np.round(fitness, 5)  # TODO: also incorporate MAC into obj
            # # offspring['mac_improvement'] = np.round(mac_improvement, 5)
            # # offspring['tac_improvement'] = np.round(tac_improvement, 5)
            # offspring['runtime'] = np.round(runtime, 5)
                

        except Exception as e:
            if self.debug:
                print(f"Code Error: {e}")

            offspring['objective'] = None
            # offspring['mac_improvement'] = None
            # offspring['tac_improvement'] = None
            offspring['runtime'] = None

        return offspring

    def reduc_population_generation_new(self, reduc_size, init_reduc_size, reduc_top_size, pop_size, n_max_retry=2):  # TODO

        reduc_population, population = [], []
        problems = self.evol.ri2(init_reduc_size, reduc_size)

        for problem in problems:
            reduction = self.get_reduc_offspring([], 'ri0', [problem])
            population_r = self.population_generation(False, reduction, count=ceil(pop_size/reduc_size))
            reduction['objective'], reduction['runtime'] = get_reduc_fitness(population_r, reduc_top_size)
            n_retry = 1
            while (reduction['objective'] is None or isnan(reduction['objective'])):
                if n_retry > n_max_retry:
                    break
                # [reduction['code'], reduction['problem'], reduction['template']] = self.evol.rm0(reduction)
                reduction = self.get_reduc_offspring([], 'ri0', [problem])
                population_r = self.population_generation(False, reduction, count=ceil(pop_size/reduc_size))
                reduction['objective'], reduction['runtime'] = get_reduc_fitness(population_r, reduc_top_size)
                n_retry += 1
            reduc_population.append(reduction)
            population.extend(population_r)

        reduc_population = [reduc for reduc in reduc_population if (reduc['objective'] is not None and not isnan(reduc['objective'])) ]
        if len(reduc_population) < reduc_size:
            reduc_size = len(reduc_population)
        reduc_population = sorted(reduc_population, key=lambda x: x['objective'], reverse=True)[:reduc_size]  # TODO: use stats (e.g., KDE) to automatically select top reductions without specifying reduc_size
        selected_problems = [reduc['problem'] for reduc in reduc_population]
        population = [ind for ind in population if ind['problem'] in selected_problems]
        return reduc_population, population


    def get_reduc_offspring(self, pop, operator, ref_probs=[]):
        offspring = {
            'problem': None,  # description of problem B
            'code': None,
            'template': None,
            'objective': None,
            'runtime': None,
            'other_inf': None,
            'op': operator,
        }
        if operator == "ri0":
            [offspring['code'],offspring['problem'],offspring['template']] = self.evol.ri0(ref_probs[0])
        elif operator == "ri1":
            [offspring['code'],offspring['problem'],offspring['template']] = self.evol.ri1(ref_probs)
        elif operator == "re1":  # assuming the population size for reduction is small (~5 ish)
            [offspring['code'],offspring['problem'],offspring['template']] = self.evol.re1(pop)
        elif operator == "re2":  # assuming the population size for reduction is small (~5 ish)
            [offspring['code'],offspring['problem'],offspring['template']] = self.evol.re2(pop)
        elif operator == "rm1":
            [offspring['code'],offspring['problem'],offspring['template']] = self.evol.rm1(pop[0])
        elif operator == "rm2":
            [offspring['code'],offspring['problem'],offspring['template']] = self.evol.rm2(pop[0])
        else:
            print(f"Evolution operator [{operator}] has not been implemented ! \n")

        return offspring

    def update_reductions(self, pop, reduction) -> list:
        ''' Apply modified reduction on existing heuristics (subject to same reduction). '''
        idx_with_reduc = [idx for idx, ind in enumerate(pop) if ind['problem'] == reduction['problem']]
        results = []
        try:
            results = Parallel(n_jobs=self.n_p,timeout=self.timeout+10)(delayed(self.modify_offspring)(pop[idx], reduction) for idx in idx_with_reduc)
        except Exception as e:
            if self.debug:
                print(f"Error: {e}")
            print("Parallel time out .")
            
        time.sleep(2)

        out_off = []
        for off in results:
            out_off.append(off)

        return out_off

    def refine_reduction(self, pop, reduction, pop_size, reduc_size):
        reduc_offspring = {
            'problem': None,  # description of problem B
            'code': None,
            'template': None,
            'objective': None,
            'runtime': None,
            'other_inf': None,
            'op': 'rm1',
        }
        [reduc_offspring['code'],reduc_offspring['problem'],reduc_offspring['template']] = self.evol.rm1(reduction)

        idx_with_reduc = [idx for idx, ind in enumerate(pop) if ind['problem'] == reduc_offspring['problem']]
        if idx_with_reduc:
            results = []
            try:
                results = Parallel(n_jobs=self.n_p,timeout=self.timeout+10)(delayed(self.modify_offspring)(pop[idx], reduc_offspring) for idx in idx_with_reduc)
            except Exception as e:
                if self.debug:
                    print(f"Error: {e}")
                print("Parallel time out .")
                
            time.sleep(2)

            out_off = []
            for off in results:
                out_off.append(off)
        else:
            out_off = self.population_generation(False, reduc_offspring, count=ceil(pop_size/reduc_size))

        return reduc_offspring, out_off

def get_reduc_fitness(pop, top_size):
    pop = [ind for ind in pop if (ind['objective'] is not None and not isnan(ind['objective'])) ]
    if len(pop) < top_size:
        if not pop:
            return None, None
        top_size = len(pop)
    pop_sorted = sorted(pop, key=lambda x: x['objective'], reverse=True)
    return np.round(np.mean([ind['objective'] for ind in pop_sorted[:top_size]]), 5), np.round(np.mean([ind['runtime'] for ind in pop_sorted[:top_size]]), 5)
