import numpy as np
import pandas as pd
import os
import json
import random
import time

from .eoh_interface_EC import InterfaceEC, get_reduc_fitness
# from .eoh_interface_EC_flatten import InterfaceEC, get_reduc_fitness  # for prototyping/debugging (no parallel processes)
# main class for eoh
class EOH:

    # initilization
    def __init__(self, paras, problem, select, manage, **kwargs):

        self.prob = problem
        self.select = select  # select which parents to reproduce
        self.manage = manage  # manage the population (e.g., if larger than specified size, remove individuals with lowest fitness)
        
        # LLM settings
        self.use_local_llm = paras.llm_use_local
        self.llm_local_url = paras.llm_local_url
        self.api_endpoint = paras.llm_api_endpoint  # currently only API2D + GPT
        self.api_key = paras.llm_api_key
        self.llm_model = paras.llm_model

        # ------------------ RZ: use local LLM ------------------
        # self.use_local_llm = kwargs.get('use_local_llm', False)
        # assert isinstance(self.use_local_llm, bool)
        # if self.use_local_llm:
        #     assert 'url' in kwargs, 'The keyword "url" should be provided when use_local_llm is True.'
        #     assert isinstance(kwargs.get('url'), str)
        #     self.url = kwargs.get('url')
        # -------------------------------------------------------

        # Experimental settings       
        self.pop_size = paras.ec_pop_size  # popopulation size, i.e., the number of algorithms in population
        self.n_pop = paras.ec_n_pop  # number of populations

        self.operators = paras.ec_operators
        self.operator_weights = paras.ec_operator_weights
        if paras.ec_m > self.pop_size or paras.ec_m == 1:
            print("m should not be larger than pop size or smaller than 2, adjust it to m=2")
            paras.ec_m = 2
        self.m = paras.ec_m

        self.mobj = paras.ec_mobj
        self.reduc = paras.ec_reduc
        self.reduc_seed_prob = paras.reduc_seed_prob
        self.init_reduc_size = paras.ec_init_reduc_size
        self.reduc_size = paras.ec_reduc_size
        self.reduc_top_size = paras.ec_reduc_top_size
        self.reduc_evol = paras.ec_reduc_evol
        self.patience = paras.patience
        self.evored_trial_id = paras.ec_evored_trial_id

        self.debug_mode = paras.exp_debug_mode  # if debug
        self.ndelay = 1  # default

        self.use_seed = paras.exp_use_seed
        self.seed_path = paras.exp_seed_path
        self.load_pop = paras.exp_use_continue
        self.load_pop_path = paras.exp_continue_path
        self.load_pop_id = paras.exp_continue_id

        self.use_seed_algs = paras.exp_use_seed_algs

        self.output_path = paras.exp_output_path

        self.exp_n_proc = paras.exp_n_proc
        
        self.timeout = paras.eva_timeout

        self.use_numba = paras.eva_numba_decorator

        print("- EoH parameters loaded -")

        # Set a random seed
        random.seed(2024)

    def add2pop(self, population: list, offspring: list):
        '''
        add new offsprings to population;
        do not add if offspring has identical fitness with existing algorithms
        and code length is not shorter
        '''
        add_new = True
        for off in offspring:
            
            for ind in population:
                if ind['objective'] == off['objective'] and off['code'] is not None and ind['code'] is not None: 
                    
                    if len(off['code'])<len(ind['code']):
                        population.remove(ind) 
                        population.append(off)
                        print(off['code'])
                        print(ind['code'])
                    else:
                        print("not add")
                        print(f"old fitness: {ind['objective']}; new fitness: {off['objective']}")
                        add_new = False   
                    continue  
                               
                    # if (self.debug_mode):
                    #     print("duplicated result, retrying ... ")
            if add_new:
                population.append(off)
            add_new = True
    

    # run eoh 
    def run(self):

        print("- Evolution Start -")

        time_start = time.time()

        # interface for large language model (llm)
        # interface_llm = PromptLLMs(self.api_endpoint,self.api_key,self.llm_model,self.debug_mode)

        # interface for evaluation
        interface_prob = self.prob

        # interface for ec operators
        interface_ec = InterfaceEC(self.pop_size, self.m, self.api_endpoint, self.api_key, self.llm_model, self.use_local_llm, self.llm_local_url,
                                   self.debug_mode, interface_prob, select=self.select,n_p=self.exp_n_proc,
                                   timeout = self.timeout, use_numba=self.use_numba, multi_obj=self.mobj
                                   )

        # separate directories when training on external data
        data_tag = '_mcts' if interface_prob.dirname == 'dataset_mcts' else ''

        # separate directories when evolving reductions
        evored = f'_evored_{self.evored_trial_id}' if self.reduc_evol else ''

        # initialization
        population = []
        if self.use_seed:  # DEPRECATED, never used it before and not sure what it does
            ...
        else:
            if self.load_pop:  # load population from files
                print("load initial population from " + self.load_pop_path)
                with open(self.load_pop_path) as file:
                    data = json.load(file)
                for individual in data:
                    population.append(individual)
                print("initial population has been loaded!")
                if self.reduc:
                    pop_dir, pop_file = os.path.split(self.load_pop_path)
                    with open(f'{pop_dir}/reduc_{pop_file}') as file:
                        reduc_population = json.load(file)
                else:
                    reduc_population = [None]
                    reduc_shares = [None]
                n_start = self.load_pop_id
            else:  # create new population
                print("creating initial population:")
                if self.reduc:
                    reduc_population, population = interface_ec.reduc_population_generation_new(self.reduc_size, self.init_reduc_size, self.reduc_top_size, self.pop_size)
                    filename = self.output_path + f"/results/pops{data_tag}{evored}/reduc_population_generation_0.json"
                    with open(filename, 'w') as f:
                        json.dump(reduc_population, f, indent=5)
                else:
                    population = interface_ec.population_generation(self.use_seed_algs, None)
                    reduc_population = [None]
                    reduc_shares = [None]

                population = self.manage.population_management(population, self.pop_size, multi_obj=self.mobj)

                print(f"Pop initial: ")
                for off in population:
                    print(" Obj: ", off['objective'], end="|")
                print()
                print("initial population has been created!")
                # Save population to a file
                filename = self.output_path + f"/results/pops{data_tag}{evored}/population_generation_0.json"
                with open(filename, 'w') as f:
                    json.dump(population, f, indent=5)
                n_start = 0

        # main loop for evolutionary algorithm
        n_op = len(self.operators)         
        n_reduc = len(reduc_population)
        if self.reduc_evol:
            num_gen_without_imp = [0] * n_reduc
            reduc_fitness_records = [reduc['objective'] for reduc in reduc_population]
            reduc_runtime_records = [reduc['runtime'] for reduc in reduc_population]

        for pop in range(n_start, self.n_pop):  
            #print(f" [{na + 1} / {self.pop_size}] ", end="|")
            if self.reduc:  # sample 'pop_size' reductions from reduc_population based on fitness
                reduc_selections = self.select.reduc_selection(reduc_population, self.pop_size)
                reduc_shares = [reduc_selections.count(reduc) for reduc in reduc_population]
            for i in range(n_op):  # iterate over all allowed operators
                op = self.operators[i]
                print(f" OP: {op}, [{i + 1} / {n_op}] ", end="|") 
                op_w = self.operator_weights[i]
                for r in range(n_reduc):
                    reduc = reduc_population[r]
                    if (np.random.rand() < op_w):  # TODO: reduce op_w of e1 early on (i.e., relax exploration) when reduction is still primitive
                        _, offsprings = interface_ec.get_algorithm(population, op, reduction=reduc, count=reduc_shares[r])
                        # print(offsprings)
                    self.add2pop(population, offsprings)  # Check duplication, and add the new offspring
                    for off in offsprings:  # NOTE: indivs are not sorted (yet)
                        print(" Obj: ", off['objective'], end="|")

                # population management
                size_act = min(len(population), self.pop_size)
                if self.reduc:
                    population, reduc_population = self.manage.population_management(population, size_act, [reduc_population, self.reduc_top_size], multi_obj=self.mobj)
                else:
                    population = self.manage.population_management(population, size_act, multi_obj=self.mobj)
                print()

            # refine reductions if exceeding patience
            if self.reduc_evol:
                for r in range(n_reduc):
                    reduc = reduc_population[r]
                    reduc_curr_fitness, reduc_curr_runtime = reduc['objective'], reduc['runtime']
                    if reduc_curr_fitness is None or np.isnan(reduc_curr_fitness):
                        num_gen_without_imp[r] = self.patience
                    else:
                        if (self.mobj and reduc_curr_fitness <= reduc_fitness_records[r] and reduc_curr_runtime >= reduc_runtime_records[r]) or \
                            (not self.mobj and reduc_curr_fitness <= reduc_fitness_records[r]):
                            num_gen_without_imp[r] += 1
                        else:
                            reduc_fitness_records[r] = reduc_curr_fitness
                            reduc_runtime_records[r] = reduc_curr_runtime
                            num_gen_without_imp[r] = 0

                    if num_gen_without_imp[r] >= self.patience:
                        # modify reduction
                        reduc_offspring, offsprings = interface_ec.refine_reduction(population, reduc, self.pop_size, self.reduc_size)
                        self.add2pop(population, offsprings)  # Check duplication, and add the new offspring
                        for off in offsprings:  # NOTE: indivs are not sorted (yet)
                            print(" Obj: ", off['objective'], end="|")
                        reduc_fitness, reduc_runtime = get_reduc_fitness(offsprings, self.reduc_top_size)
                        if reduc_fitness is not None and not np.isnan(reduc_fitness):
                            reduc_offspring['objective'] = reduc_fitness
                            reduc_offspring['runtime'] = reduc_runtime
                            if (self.mobj and (reduc_offspring['objective'] > reduc['objective'] or reduc_offspring['runtime'] < reduc['runtime'])) or \
                                (not self.mobj and reduc_offspring['objective'] > reduc['objective']):
                                reduc_population[r] = reduc_offspring
                                num_gen_without_imp[r] = 0  # NOTE: only reset when there is improvement

                # NOTE: keep the new offsprings as long as they have finite fitness, so json file may have more than 'pop_size' heuristics whenever evored occurs
                population = [individual for individual in population if (individual['objective'] is not None and not np.isnan(individual['objective'])) ]
                if self.mobj:
                    population = self.manage.sortby_dominance_dissimilarity(population)
                else:
                    population = sorted(population, key=lambda x: x['objective'], reverse=True)

            # Save population to a file
            filename = self.output_path + f"/results/pops{data_tag}{evored}/population_generation_" + str(pop + 1) + ".json"
            with open(filename, 'w') as f:
                json.dump(population, f, indent=5)

            # Save the best one to a file
            filename = self.output_path + f"/results/pops{data_tag}{evored}_best/population_generation_" + str(pop + 1) + ".json"
            with open(filename, 'w') as f:
                json.dump(population[0], f, indent=5)

            # Save reduction population to a file
            filename = self.output_path + f"/results/pops{data_tag}{evored}/reduc_population_generation_" + str(pop + 1) + ".json"
            with open(filename, 'w') as f:
                json.dump(reduc_population, f, indent=5)

            # Save the best reduction to a file
            filename = self.output_path + f"/results/pops{data_tag}{evored}_best/reduc_population_generation_" + str(pop + 1) + ".json"
            with open(filename, 'w') as f:
                json.dump(reduc_population[0], f, indent=5)


            print(f"--- {pop + 1} of {self.n_pop} populations finished. Time Cost:  {((time.time()-time_start)/60):.1f} m")
            print("Pop Objs: ", end=" ")
            for i in range(len(population)):
                print(str(population[i]['objective']) + " ", end="")
            print()


def same_top_individuals(top_individuals: list, curr_top_individuals: list) -> bool:
    is_same = np.full(len(curr_top_individuals), False)
    for i, ind in enumerate(curr_top_individuals):
        for top_ind in top_individuals:
            if ind['code'] == top_ind['code']:
                is_same[i] = True
                break
    return np.all(is_same)
