import json

from src.common.structure import ModelAcquireData, ModelEvaData
from src.controller.controler import SampleController
from src.controller.sampler_all.seq_sampler import SequenceSampler

from src.eva_engine.phase1.evaluator import P1Evaluator
from src.logger import logger
from src.query_api.img_explore_ea import fetch_from_db
from torch.utils.data import DataLoader
from src.controller.sampler_ea.regularized_ea import RegularizedEASampler
from src.search_space.core.space import SpaceWrapper


# this is for image only
def p1_evaluate_query(space_name, dataset, run_id, N, K) -> (list, float):
    """
    :param space_name:
    :param dataset:
    :param run_id:
    :param N:
    :param K:
    :return: return list of models and time usage.
    """
    arch_id, candidates, current_time = fetch_from_db(space_name, dataset, run_id, N)
    return candidates[-K:], current_time


class RunPhase1:

    def __init__(self, args, K: int, N: int, search_space_ins: SpaceWrapper,
                 train_loader: DataLoader = None, is_simulate: bool = False):
        """
        Each model selection job will init one class here.
        :param args: space, population_size, sample_size
        :param K: K models return in 1st phase
        :param N: N models eval in total
        :param search_space_ins:
        """

        # return K models
        self.K = K
        # explore N models
        self.N = N

        self.args = args

        self.search_space_ins = search_space_ins

        # seq: init the search strategy and controller,
        if self.N >= min(len(self.search_space_ins), 100000):
            print("Explore all models")
            strategy = SequenceSampler(self.search_space_ins)
        else:
            strategy = RegularizedEASampler(self.search_space_ins,
                                            population_size=self.args.population_size,
                                            sample_size=self.args.sample_size)
        self.sampler = SampleController(strategy)

        # generate db config
        db_config = {
            "db_name": self.args.db_name,
            "db_user": self.args.db_user,
            "db_host": self.args.db_host,
            "db_port": self.args.db_port,
        }

        # seq: init the phase 1 evaluator,
        self._evaluator = P1Evaluator(device=self.args.device,
                                      num_label=self.args.num_labels,
                                      dataset_name=self.args.dataset,
                                      search_space_ins=self.search_space_ins,
                                      train_loader=train_loader,
                                      is_simulate=is_simulate,
                                      metrics=self.args.tfmem,
                                      db_config=db_config)

    def run_phase1(self) -> (list, list, list, list):
        """
        Controller explore n models, and return the top K models.
        :return:
        """

        # those two are used to track performance trace
        # current best model id
        trace_highest_scored_models_id = []
        # current highest score
        trace_highest_score = []
        explored_n = 1
        model_eva = ModelEvaData()

        while explored_n <= self.N:
            # generate new model
            arch_id, arch_micro = self.sampler.sample_next_arch()
            # this is for sequence sampler.
            if arch_id is None:
                break
            model_encoding = self.search_space_ins.serialize_model_encoding(arch_micro)

            explored_n += 1

            # run the model selection
            model_acquire_data = ModelAcquireData(model_id=str(arch_id),
                                                  model_encoding=model_encoding,
                                                  is_last=False)
            data_str = model_acquire_data.serialize_model()

            # update the shared model eval res
            try:
                model_eva.model_id = str(arch_id)
                model_eva.model_score = self._evaluator.p1_evaluate(data_str)
            except KeyError:
                # when it is simulate, it could be keyerror
                continue

            if explored_n % 100 == 0:
                logger.info("3. [trails] Phase 1: filter phase explored " + str(explored_n) +
                            " model, model_id = " + model_eva.model_id +
                            " model_scores = " + json.dumps(model_eva.model_score))

            ranked_score = self.sampler.fit_sampler(model_eva.model_id,
                                                    model_eva.model_score,
                                                    simple_score_sum=self.args.simple_score_sum)

            # this is to measure the value of metrix, sum of two value.
            if len(trace_highest_score) == 0:
                trace_highest_score.append(ranked_score)
                trace_highest_scored_models_id.append(str(arch_id))
            else:
                if ranked_score > trace_highest_score[-1]:
                    trace_highest_score.append(ranked_score)
                    trace_highest_scored_models_id.append(str(arch_id))
                else:
                    trace_highest_score.append(trace_highest_score[-1])
                    trace_highest_scored_models_id.append(trace_highest_scored_models_id[-1])

        logger.info("3. [trails] Phase 1: filter phase explored " + str(explored_n) +
                    " model, model_id = " + model_eva.model_id +
                    " model_scores = " + json.dumps(model_eva.model_score))
        # return the top K models
        return self.sampler.get_current_top_k_models(self.K), self.sampler.get_current_top_k_models(-1), \
               trace_highest_score, trace_highest_scored_models_id
