# This file is slightly modified from a code implementation by Prateek Munjal et al., authors of the paper https://arxiv.org/abs/2002.09564
# GitHub: https://github.com/PrateekMunjal
# ----------------------------------------------------------

from .Sampling import Sampling, CoreSetMIPSampling, AdversarySampler
import pycls.utils.logging as lu

logger = lu.get_logger(__name__)

class ActiveLearning:

    def __init__(self, dataObj, cfg):
        self.dataObj = dataObj
        self.sampler = Sampling(dataObj=dataObj,cfg=cfg)
        self.cfg = cfg
        
    def sample_from_uSet(self, clf_model, lSet, uSet, trainDataset, dele_set, best_gamma = None, supportingModels=None):

        assert self.cfg.ACTIVE_LEARNING.BUDGET_SIZE > 0, "Expected a positive budgetSize"
        assert self.cfg.ACTIVE_LEARNING.BUDGET_SIZE < len(uSet), "BudgetSet cannot exceed length of unlabelled set. Length of unlabelled set: {} and budgetSize: {}"\
        .format(len(uSet), self.cfg.ACTIVE_LEARNING.BUDGET_SIZE)
        
        print(self.cfg.ACTIVE_LEARNING.SAMPLING_FN)


        if self.cfg.ACTIVE_LEARNING.SAMPLING_FN.lower() in ["activesilhouette"]:
            from .herding import Active_Silhouette
            activesilhouette = Active_Silhouette(self.cfg, lSet, uSet, budgetSize=self.cfg.ACTIVE_LEARNING.BUDGET_SIZE,
                            delta=self.cfg.ACTIVE_LEARNING.INITIAL_DELTA)
            activeSet, uSet, best_gamma = activesilhouette.select_samples(fix_gamma = best_gamma)
            return activeSet, uSet, best_gamma


        else:
            print(f"{self.cfg.ACTIVE_LEARNING.SAMPLING_FN} is either not implemented or there is some spelling mistake.")
            raise NotImplementedError

        return activeSet, uSet
        

    def sample_byratio(self, lSet, uSet, trainDataset, dele_set, top_inds, supportingModels=None):
        
        assert self.cfg.ACTIVE_LEARNING.BUDGET_SIZE > 0, "Expected a positive budgetSize"
        assert self.cfg.ACTIVE_LEARNING.BUDGET_SIZE < len(uSet), "BudgetSet cannot exceed length of unlabelled set. Length of unlabelled set: {} and budgetSize: {}"\
        .format(len(uSet), self.cfg.ACTIVE_LEARNING.BUDGET_SIZE)

        print(self.cfg.ACTIVE_LEARNING.SAMPLING_FN)

        if self.cfg.ACTIVE_LEARNING.SAMPLING_FN.lower() in ["kernelherding"]:
            from .herding import ProbKernelherding
            # from .herding_budget1time import Kernelherding
            kernelherding = ProbKernelherding(self.cfg, lSet, uSet, budgetSize=self.cfg.ACTIVE_LEARNING.BUDGET_SIZE,
                            top_ind=top_inds)
            activeSet, uSet = kernelherding.select_samples()

        return activeSet, uSet
    


    def sample_initial_medoid(self, lSet, uSet):
        from .herding import kmeidoid_select_initial
        kmeidoid_initial = kmeidoid_select_initial(self.cfg, lSet, uSet, budgetSize=self.cfg.ACTIVE_LEARNING.BUDGET_SIZE)
        activeSet, uSet = kmeidoid_initial.select_samples()
        return activeSet, uSet
