'''
Code for a generic algorithm that other methods can inherit from. 
'''
import numpy as np

class algorithm:
    '''
    class for generic algorithm. The algorithm should
    be able to choose which point to query and know when
    to stop. Other algorithms can inherit from this to 
    make specific query rules. 
    '''

    def __init__(self, data, oracle, verbose=False):
        self.data = data        # instance of the data class
        self.oracle = oracle    # instance of the oracle class
        self.num_queries = 0
        self.verbose = verbose

    def stop(self):
        '''
        Boolean for stopping condition.
        '''
        raise Exception('stop method not implemented')

    def next(self):
        '''
        Choose next point to be queried
        '''
        raise Exception('next method not implemented')

    def update_params(self):
        '''
        Udate any algorithm specific parameters
        '''
        raise Exception('update_params method not implemented')

    def get_model(self):
        ''' Return the guess at the model '''
        raise Exception('get_model method not implemented')

    def check_correct(self):
        return self.data.is_correct(self.get_model())

    def run(self):
        ''' Run the algorithm '''
        while not self.stop():
            self.next()
            self.update_params()
            if self.verbose:
                print('Queries: {}'.format(self.num_queries))
        self.correct = self.data.is_correct(self.get_model())

class GW_version_space(algorithm):
    '''
    Implememntation of Gaussian width algorithm
    '''
    def __init__(self, data, oracle, v_space, oracle_samples=100, verbose=False):
        super().__init__(data, oracle, verbose=verbose)
        self.Z = v_space        # should be +- 1
        self.oracle_samples = oracle_samples
        self.past_samples = []

    def stop(self):
        return self.Z.shape[0] == 1 # single model remaining

    def update_params(self):
        self.past_samples.append(self.new_query)
        self.num_queries += 1
        self.update_version_space()

    def update_version_space(self):
        # Find all z's not matching new label and remove
        to_delete = [self.Z[j, self.new_query] == self.new_label 
                                            for j in range(self.Z.shape[0])]
        self.Z = self.Z[np.array(to_delete)]

    def get_model(self):
        if self.Z.shape[0] == 1: return self.Z
        else: return self.Z[0, :]

    def next(self):
        '''
        Use exact oracle since we track version space anyway. 
        '''
        self.new_query = self.oracle.choose(self.Z, self.past_samples)
        self.new_label = self.data.query_label(self.new_query)




