
import collections
import numpy as np

from deephyper.search import util
from deephyper.core.logs.logging import JsonMessage as jm

from xnas.nas_deephyper.xnas import XNAS

dhlogger = util.conf_logger('xnas.nas_deephyper.regevo_xnas')


class RegularizedEvolutionXNAS(XNAS):
    

    def __init__(
            self, problem, run, evaluator, population_size=10, sample_size=5,
            **kwargs
    ):

        super().__init__(problem=problem, run=run, evaluator=evaluator,
                         **kwargs)

        
        search_space = self.problem.build_search_space()
        self.space_list = [
            (0, vnode.num_ops - 1) for vnode in search_space.variable_nodes
        ]
        self.population_size = int(population_size)
        self.sample_size = int(sample_size)

    def main(self):

        num_evals_done = 0
        population = collections.deque(maxlen=self.population_size)

        
        self.evaluator.add_eval_batch(
            self.gen_random_batch(size=self.free_workers))

        
        while num_evals_done < self.max_evals:

            
            new_results = list(self.evaluator.get_finished_evals())

            if len(new_results) > 0:
                population.extend(new_results)
                stats = {
                    'num_cache_used': self.evaluator.stats['num_cache_used']}
                dhlogger.info(jm(type='env_stats', **stats))
                self.evaluator.dump_evals(saved_keys=self.saved_keys)

                num_received = len(new_results)
                num_evals_done += num_received

                if num_evals_done >= self.max_evals:
                    break

                
                if len(population) == self.population_size:
                    children_batch = []
                    
                    for _ in range(len(new_results)):
                        
                        indexes = np.random.choice(
                            self.population_size, self.sample_size,
                            replace=False
                        )
                        sample = [population[i] for i in indexes]
                        
                        parent = self.select_parent(sample)
                        
                        child = self.copy_mutate_arch(parent)
                        
                        children_batch.append(child)
                    
                    if len(new_results) > 0:
                        self.evaluator.add_eval_batch(children_batch)
                else:  
                    self.evaluator.add_eval_batch(
                        self.gen_random_batch(size=len(new_results))
                    )

    def select_parent(self, sample: list) -> dict:
        
        cfg, _ = max(sample, key=self.handle_score)
        return cfg

    def gen_random_batch(self, size: int) -> list:
        batch = []
        for _ in range(size):
            cfg = self.pb_dict.copy()
            cfg['arch_seq'] = self.random_search_space()
            cfg['parent_id'] = ''
            batch.append(cfg)
        return batch

    def random_search_space(self) -> list:
        return [np.random.choice(b + 1) for (_, b) in self.space_list]

    def copy_mutate_arch(self, parent: dict) -> dict:
        
        parent_arch: list = parent['arch_seq']
        i = np.random.choice(len(parent_arch))
        child_arch = parent_arch[:]

        range_upper_bound = self.space_list[i][1]
        elements = [j for j in range(range_upper_bound + 1)
                    if j != child_arch[i]]

        
        sample = np.random.choice(elements, 1)[0]

        child_arch[i] = sample
        cfg = self.pb_dict.copy()
        cfg['arch_seq'] = child_arch
        cfg['parent_id'] = parent['id']
        return cfg

    def saved_keys(self, val: dict):
        res = super().saved_keys(val)
        res['parent_id'] = val['parent_id']
        return res


if __name__ == '__main__':
    args = RegularizedEvolutionXNAS.parse_args()
    search = RegularizedEvolutionXNAS(**vars(args))
    search.main()
