import numpy as np

from CAT.strategy.abstract_strategy import AbstractStrategy
from CAT.model import AbstractModel
from CAT.dataset import AdapTestDataset
import random


class CFATStrategy(AbstractStrategy):

    def __init__(self):
        super().__init__()

    @property
    def name(self):
        return 'CFAT Strategy'

    def adaptest_select(self, model: AbstractModel, adaptest_data: AdapTestDataset, S_set):
        assert hasattr(model, 'E_f_S_t'), \
            'the models must implement E_f_S_t method'
        assert hasattr(model, 'get_pred'), \
            'the models must implement get_pred method for accelerating'
        pred_all = model.get_pred(adaptest_data)
        selection = {}
        for sid in range(adaptest_data.num_students):
            untested_questions = np.array(list(adaptest_data.untested[sid]))
            untested_deltaq = [model.E_f_S_t(sid,qid, pred_all, S_set[sid]) for qid in
                               untested_questions]
            j = np.argmin(untested_deltaq)
            selection[sid] = untested_questions[j]

        return selection

