
from abc import ABC

from deephyper.search.nas import NeuralArchitectureSearch  
from deephyper.core.parser import add_arguments_from_signature
from deephyper.core.parser import str2bool
from deephyper.evaluator.evaluate import Evaluator
from deephyper.problem.neuralarchitecture import NaProblem

from xnas.utils import get_logger

logger = get_logger(__name__)


class XNAS(NeuralArchitectureSearch, ABC):
    
    
    evaluator: Evaluator
    problem: NaProblem

    def __init__(
            self, problem, run='xnas.nas_deephyper.nas_run.run',
            evaluator='ray',
            multiobjective_explainability=True, record_mo_xai_only=False,
            explainability_type='activations',
            weight_perf=0.5, weight_xai=5.0, **kwargs
    ):
        multiobjective_explainability = str2bool(multiobjective_explainability)
        record_mo_xai_only = str2bool(record_mo_xai_only)
        if (run == 'deephyper.nas.run.alpha.run' and
                multiobjective_explainability):
            logger.warning(f'Running with both {run} and '
                           f'multiobjective_explainability! This is known '
                           f'to cause runtime issues as {run} does not '
                           f'handle the multi-objective case...')
        super().__init__(problem=problem, run=run, evaluator=evaluator,
                         **kwargs)

        
        self.free_workers = self.evaluator.num_workers
        self.pb_dict: dict = self.problem.space
        self.pb_dict['multiobjective_explainability'] = \
            multiobjective_explainability
        self.pb_dict['explainability_type'] = explainability_type
        self.record_mo_xai_only = record_mo_xai_only

        self.weight_perf = weight_perf
        self.weight_xai = weight_xai

        logger.info(f'Running with multiobjective_explainability='
                    f'{multiobjective_explainability}, record_mo_xai_only='
                    f'{record_mo_xai_only}, weight_perf={weight_perf}, '
                    f'weight_xai={weight_xai}')

    def handle_score(self, x):
        scores = x[1]
        if self.pb_dict['multiobjective_explainability']:
            assert isinstance(scores, dict), f'x: {x} | scores: {scores}'
            assert len(scores) >= 3, f'x: {x} | scores: {scores}'
            
            perf = scores['score']
            xai_fit = scores['xai_fitness']
            if xai_fit is None or self.record_mo_xai_only:
                return perf
            else:
                
                return self.weight_perf * perf + self.weight_xai * xai_fit
        else:
            return scores

    @classmethod
    def _extend_parser(cls, parser):
        NeuralArchitectureSearch._extend_parser(parser)
        add_arguments_from_signature(parser, cls)
        if cls is not XNAS:
            add_arguments_from_signature(parser, XNAS)
        return parser

    def saved_keys(self, val: dict):
        res = {
            'id': val['id'],
            'arch_seq': str(val['arch_seq'])
        }
        return res
